SPM / spm.py
prasannareddyp's picture
Upload 4 files
5dbf895 verified
raw
history blame
2.46 kB
from PIL import Image
import numpy as np
def _to_divisible_by(img, N):
"""Crop so width and height are divisible by N (top-left anchored)."""
w, h = img.size
W = (w // N) * N
H = (h // N) * N
if W == 0 or H == 0:
raise ValueError("N is larger than the image grid (image too small).")
if W != w or H != h:
img = img.crop((0, 0, W, H))
return img, W, H
def spm_augment(
image,
num_patches=4, # N for an N×N grid
mix_prob=0.5,
beta_a=2.0,
beta_b=2.0,
seed=None
):
"""
SPM-style augmentation using a global shuffle over an N×N patch grid.
1) Divide image into N×N patches (cropping to be divisible by N if needed).
2) Globally permute patch indices.
3) Per patch, with probability `mix_prob`, replace by a convex blend of
original and a shuffled patch using alpha~Beta(beta_a,beta_b) (one alpha per image).
"""
# Normalize input
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
else:
img = image.convert("RGB")
N = int(num_patches)
rng = np.random.default_rng(seed)
# Ensure divisibility and compute patch size
img, W, H = _to_divisible_by(img, N)
arr = np.array(img, dtype=np.uint8)
ph = H // N
pw = W // N
# Build patch list (row-major)
patches = []
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
patches.append(arr[y0:y0+ph, x0:x0+pw])
total = N * N
perm = rng.permutation(total)
# Sample one alpha for the whole image
if beta_a > 0 and beta_b > 0:
alpha = float(rng.beta(beta_a, beta_b))
else:
alpha = 1.0
# Patchwise mix
out = arr.copy()
mask = rng.random(total) < float(mix_prob)
idx = 0
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
if mask[idx]:
src = patches[idx].astype(np.float32)
shf = patches[perm[idx]].astype(np.float32)
if 0.0 < alpha < 1.0:
mixed = alpha * shf + (1.0 - alpha) * src
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
else:
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
else:
out[y0:y0+ph, x0:x0+pw] = patches[idx]
idx += 1
return Image.fromarray(out)