Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import numpy as np | |
| def _to_divisible_by(img, N): | |
| """Crop so width and height are divisible by N (top-left anchored).""" | |
| w, h = img.size | |
| W = (w // N) * N | |
| H = (h // N) * N | |
| if W == 0 or H == 0: | |
| raise ValueError("N is larger than the image grid (image too small).") | |
| if W != w or H != h: | |
| img = img.crop((0, 0, W, H)) | |
| return img, W, H | |
| def spm_augment( | |
| image, | |
| num_patches=4, # N for an N×N grid | |
| mix_prob=0.5, | |
| beta_a=2.0, | |
| beta_b=2.0, | |
| seed=None | |
| ): | |
| """ | |
| SPM-style augmentation using a global shuffle over an N×N patch grid. | |
| 1) Divide image into N×N patches (cropping to be divisible by N if needed). | |
| 2) Globally permute patch indices. | |
| 3) Per patch, with probability `mix_prob`, replace by a convex blend of | |
| original and a shuffled patch using alpha~Beta(beta_a,beta_b) (one alpha per image). | |
| """ | |
| # Normalize input | |
| if isinstance(image, np.ndarray): | |
| img = Image.fromarray(image).convert("RGB") | |
| else: | |
| img = image.convert("RGB") | |
| N = int(num_patches) | |
| rng = np.random.default_rng(seed) | |
| # Ensure divisibility and compute patch size | |
| img, W, H = _to_divisible_by(img, N) | |
| arr = np.array(img, dtype=np.uint8) | |
| ph = H // N | |
| pw = W // N | |
| # Build patch list (row-major) | |
| patches = [] | |
| for i in range(N): | |
| for j in range(N): | |
| y0 = i * ph | |
| x0 = j * pw | |
| patches.append(arr[y0:y0+ph, x0:x0+pw]) | |
| total = N * N | |
| perm = rng.permutation(total) | |
| # Sample one alpha for the whole image | |
| if beta_a > 0 and beta_b > 0: | |
| alpha = float(rng.beta(beta_a, beta_b)) | |
| else: | |
| alpha = 1.0 | |
| # Patchwise mix | |
| out = arr.copy() | |
| mask = rng.random(total) < float(mix_prob) | |
| idx = 0 | |
| for i in range(N): | |
| for j in range(N): | |
| y0 = i * ph | |
| x0 = j * pw | |
| if mask[idx]: | |
| src = patches[idx].astype(np.float32) | |
| shf = patches[perm[idx]].astype(np.float32) | |
| if 0.0 < alpha < 1.0: | |
| mixed = alpha * shf + (1.0 - alpha) * src | |
| out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8) | |
| else: | |
| out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]] | |
| else: | |
| out[y0:y0+ph, x0:x0+pw] = patches[idx] | |
| idx += 1 | |
| return Image.fromarray(out) | |