from PIL import Image import numpy as np def create_feather_mask(height, width, feather_size=4): """ 2D mask HxW that smoothly transitions from 1.0 in the interior to 0.0 at the edges over `feather_size` pixels. """ mask = np.ones((height, width), dtype=np.float32) if feather_size <= 0: return mask ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32) # Top / Bottom mask[:feather_size, :] *= ramp[:, None] mask[-feather_size:, :] *= ramp[::-1, None] # Left / Right mask[:, :feather_size] *= ramp[None, :] mask[:, -feather_size:] *= ramp[None, ::-1] return mask def _to_divisible_by(img, N): """Crop so width and height are divisible by N (top-left anchored).""" w, h = img.size W = (w // N) * N H = (h // N) * N if W == 0 or H == 0: raise ValueError("N is larger than the image grid (image too small).") if W != w or H != h: img = img.crop((0, 0, W, H)) return img, W, H def _edgelogic(i, j, ph, pw, N, overlap): """ Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw]. Extend with overlap, biasing inward. Uses 2*overlap for edges to keep patch areas roughly comparable. Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds. """ start_h = i * ph start_w = j * pw end_h = start_h + ph end_w = start_w + pw if overlap <= 0: return start_h, end_h, start_w, end_w # Vertical if i == 0: end_h += 2 * overlap elif i == N - 1: start_h -= 2 * overlap else: start_h -= overlap end_h += overlap # Horizontal if j == 0: end_w += 2 * overlap elif j == N - 1: start_w -= 2 * overlap else: start_w -= overlap end_w += overlap return start_h, end_h, start_w, end_w def spm_augment( image, num_patches=4, # N for an N×N grid mix_prob=0.5, beta_a=2.0, beta_b=2.0, overlap_px=0, seed=None ): """ SPM-style augmentation with optional overlap + feathered blending. When overlap_px <= 0: - Standard global shuffle over N×N patches; - Per-patch mixing with a single alpha ~ Beta(a,b) for the image. When overlap_px > 0: - Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders), clipped to the image. Patches are mixed per location and alpha sampled per-patch for a bit more stochasticity (can be changed to per-image alpha by editing below). - Patches are blended into the canvas with a feather mask of size `overlap_px`. """ # Normalize to PIL and ensure divisibility if isinstance(image, np.ndarray): img = Image.fromarray(image).convert("RGB") else: img = image.convert("RGB") N = int(num_patches) rng = np.random.default_rng(seed) img, W, H = _to_divisible_by(img, N) arr_u8 = np.array(img, dtype=np.uint8) ph = H // N pw = W // N # Clamp overlap to < half patch size if overlap_px is None: overlap_px = 0 overlap_px = int(overlap_px) max_ov = max(0, min(ph, pw) // 2 - 1) ov = int(np.clip(overlap_px, 0, max_ov)) if ov <= 0: # === Non-overlap path === arr = arr_u8 # Build patches (row-major) patches = [] for i in range(N): for j in range(N): y0 = i * ph x0 = j * pw patches.append(arr[y0:y0+ph, x0:x0+pw]) total = N * N perm = rng.permutation(total) # One alpha per image if beta_a > 0 and beta_b > 0: alpha = float(rng.beta(beta_a, beta_b)) else: alpha = 1.0 out = arr.copy() mask = rng.random(total) < float(mix_prob) idx = 0 for i in range(N): for j in range(N): y0 = i * ph x0 = j * pw if mask[idx]: src = patches[idx].astype(np.float32) shf = patches[perm[idx]].astype(np.float32) if 0.0 < alpha < 1.0: mixed = alpha * shf + (1.0 - alpha) * src out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8) else: out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]] else: out[y0:y0+ph, x0:x0+pw] = patches[idx] idx += 1 return Image.fromarray(out) # === Overlap path with feather blending === arr = arr_u8.astype(np.float32) # Precompute feather mask for max size patch feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov) patches = [] coords = [] for i in range(N): for j in range(N): sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov) # Clamp to image bounds sh = max(0, sh); sw = max(0, sw) eh = min(H, eh); ew = min(W, ew) patches.append(arr[sh:eh, sw:ew]) coords.append((sh, eh, sw, ew)) total = len(patches) perm = rng.permutation(total) # We'll sample alpha per-patch to echo your overlap snippet def sample_alpha(): if beta_a > 0 and beta_b > 0: return float(rng.beta(beta_a, beta_b)) return 1.0 canvas = np.zeros_like(arr, dtype=np.float32) weight = np.zeros((H, W), dtype=np.float32) for k, (sh, eh, sw, ew) in enumerate(coords): if rng.random() >= float(mix_prob): # keep original content in that region src = patches[k] patch = src else: lam = sample_alpha() src = patches[k].astype(np.float32) shf = patches[int(perm[k])].astype(np.float32) patch = lam * shf + (1.0 - lam) * src ph_k, pw_k, _ = patch.shape # Slice feather mask down if needed (near borders) mask2d = feather_full[:ph_k, :pw_k] if arr.shape[2] == 1: mask3d = mask2d[..., None] else: mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2) # Accumulate canvas[sh:eh, sw:ew] += patch * mask3d weight[sh:eh, sw:ew] += mask2d # Normalize weight = np.clip(weight, 1e-8, None) out = (canvas / weight[..., None]) out = np.clip(out, 0, 255).astype(np.uint8) return Image.fromarray(out)