Spaces:
Sleeping
Sleeping
File size: 2,456 Bytes
5dbf895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
from PIL import Image
import numpy as np
def _to_divisible_by(img, N):
"""Crop so width and height are divisible by N (top-left anchored)."""
w, h = img.size
W = (w // N) * N
H = (h // N) * N
if W == 0 or H == 0:
raise ValueError("N is larger than the image grid (image too small).")
if W != w or H != h:
img = img.crop((0, 0, W, H))
return img, W, H
def spm_augment(
image,
num_patches=4, # N for an N×N grid
mix_prob=0.5,
beta_a=2.0,
beta_b=2.0,
seed=None
):
"""
SPM-style augmentation using a global shuffle over an N×N patch grid.
1) Divide image into N×N patches (cropping to be divisible by N if needed).
2) Globally permute patch indices.
3) Per patch, with probability `mix_prob`, replace by a convex blend of
original and a shuffled patch using alpha~Beta(beta_a,beta_b) (one alpha per image).
"""
# Normalize input
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
else:
img = image.convert("RGB")
N = int(num_patches)
rng = np.random.default_rng(seed)
# Ensure divisibility and compute patch size
img, W, H = _to_divisible_by(img, N)
arr = np.array(img, dtype=np.uint8)
ph = H // N
pw = W // N
# Build patch list (row-major)
patches = []
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
patches.append(arr[y0:y0+ph, x0:x0+pw])
total = N * N
perm = rng.permutation(total)
# Sample one alpha for the whole image
if beta_a > 0 and beta_b > 0:
alpha = float(rng.beta(beta_a, beta_b))
else:
alpha = 1.0
# Patchwise mix
out = arr.copy()
mask = rng.random(total) < float(mix_prob)
idx = 0
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
if mask[idx]:
src = patches[idx].astype(np.float32)
shf = patches[perm[idx]].astype(np.float32)
if 0.0 < alpha < 1.0:
mixed = alpha * shf + (1.0 - alpha) * src
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
else:
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
else:
out[y0:y0+ph, x0:x0+pw] = patches[idx]
idx += 1
return Image.fromarray(out)
|