from PIL import Image import numpy as np def _to_divisible_by(img, N): """Crop so width and height are divisible by N (top-left anchored).""" w, h = img.size W = (w // N) * N H = (h // N) * N if W == 0 or H == 0: raise ValueError("N is larger than the image grid (image too small).") if W != w or H != h: img = img.crop((0, 0, W, H)) return img, W, H def spm_augment( image, num_patches=4, # N for an N×N grid mix_prob=0.5, beta_a=2.0, beta_b=2.0, seed=None ): """ SPM-style augmentation using a global shuffle over an N×N patch grid. 1) Divide image into N×N patches (cropping to be divisible by N if needed). 2) Globally permute patch indices. 3) Per patch, with probability `mix_prob`, replace by a convex blend of original and a shuffled patch using alpha~Beta(beta_a,beta_b) (one alpha per image). """ # Normalize input if isinstance(image, np.ndarray): img = Image.fromarray(image).convert("RGB") else: img = image.convert("RGB") N = int(num_patches) rng = np.random.default_rng(seed) # Ensure divisibility and compute patch size img, W, H = _to_divisible_by(img, N) arr = np.array(img, dtype=np.uint8) ph = H // N pw = W // N # Build patch list (row-major) patches = [] for i in range(N): for j in range(N): y0 = i * ph x0 = j * pw patches.append(arr[y0:y0+ph, x0:x0+pw]) total = N * N perm = rng.permutation(total) # Sample one alpha for the whole image if beta_a > 0 and beta_b > 0: alpha = float(rng.beta(beta_a, beta_b)) else: alpha = 1.0 # Patchwise mix out = arr.copy() mask = rng.random(total) < float(mix_prob) idx = 0 for i in range(N): for j in range(N): y0 = i * ph x0 = j * pw if mask[idx]: src = patches[idx].astype(np.float32) shf = patches[perm[idx]].astype(np.float32) if 0.0 < alpha < 1.0: mixed = alpha * shf + (1.0 - alpha) * src out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8) else: out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]] else: out[y0:y0+ph, x0:x0+pw] = patches[idx] idx += 1 return Image.fromarray(out)