Spaces:
Sleeping
Sleeping
from PIL import Image | |
import numpy as np | |
def create_feather_mask(height, width, feather_size=4): | |
""" | |
2D mask HxW that smoothly transitions from 1.0 in the interior | |
to 0.0 at the edges over `feather_size` pixels. | |
""" | |
mask = np.ones((height, width), dtype=np.float32) | |
if feather_size <= 0: | |
return mask | |
ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32) | |
# Top / Bottom | |
mask[:feather_size, :] *= ramp[:, None] | |
mask[-feather_size:, :] *= ramp[::-1, None] | |
# Left / Right | |
mask[:, :feather_size] *= ramp[None, :] | |
mask[:, -feather_size:] *= ramp[None, ::-1] | |
return mask | |
def _to_divisible_by(img, N): | |
"""Crop so width and height are divisible by N (top-left anchored).""" | |
w, h = img.size | |
W = (w // N) * N | |
H = (h // N) * N | |
if W == 0 or H == 0: | |
raise ValueError("N is larger than the image grid (image too small).") | |
if W != w or H != h: | |
img = img.crop((0, 0, W, H)) | |
return img, W, H | |
def _edgelogic(i, j, ph, pw, N, overlap): | |
""" | |
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw]. | |
Extend with overlap, biasing inward. | |
Uses 2*overlap for edges to keep patch areas roughly comparable. | |
Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds. | |
""" | |
start_h = i * ph | |
start_w = j * pw | |
end_h = start_h + ph | |
end_w = start_w + pw | |
if overlap <= 0: | |
return start_h, end_h, start_w, end_w | |
# Vertical | |
if i == 0: | |
end_h += 2 * overlap | |
elif i == N - 1: | |
start_h -= 2 * overlap | |
else: | |
start_h -= overlap | |
end_h += overlap | |
# Horizontal | |
if j == 0: | |
end_w += 2 * overlap | |
elif j == N - 1: | |
start_w -= 2 * overlap | |
else: | |
start_w -= overlap | |
end_w += overlap | |
return start_h, end_h, start_w, end_w | |
def spm_augment( | |
image, | |
num_patches=4, # N for an N×N grid | |
mix_prob=0.5, | |
beta_a=2.0, | |
beta_b=2.0, | |
overlap_px=0, | |
seed=None | |
): | |
""" | |
SPM-style augmentation with optional overlap + feathered blending. | |
When overlap_px <= 0: | |
- Standard global shuffle over N×N patches; | |
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image. | |
When overlap_px > 0: | |
- Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders), | |
clipped to the image. Patches are mixed per location and alpha sampled per-patch | |
for a bit more stochasticity (can be changed to per-image alpha by editing below). | |
- Patches are blended into the canvas with a feather mask of size `overlap_px`. | |
""" | |
# Normalize to PIL and ensure divisibility | |
if isinstance(image, np.ndarray): | |
img = Image.fromarray(image).convert("RGB") | |
else: | |
img = image.convert("RGB") | |
N = int(num_patches) | |
rng = np.random.default_rng(seed) | |
img, W, H = _to_divisible_by(img, N) | |
arr_u8 = np.array(img, dtype=np.uint8) | |
ph = H // N | |
pw = W // N | |
# Clamp overlap to < half patch size | |
if overlap_px is None: | |
overlap_px = 0 | |
overlap_px = int(overlap_px) | |
max_ov = max(0, min(ph, pw) // 2 - 1) | |
ov = int(np.clip(overlap_px, 0, max_ov)) | |
if ov <= 0: | |
# === Non-overlap path === | |
arr = arr_u8 | |
# Build patches (row-major) | |
patches = [] | |
for i in range(N): | |
for j in range(N): | |
y0 = i * ph | |
x0 = j * pw | |
patches.append(arr[y0:y0+ph, x0:x0+pw]) | |
total = N * N | |
perm = rng.permutation(total) | |
# One alpha per image | |
if beta_a > 0 and beta_b > 0: | |
alpha = float(rng.beta(beta_a, beta_b)) | |
else: | |
alpha = 1.0 | |
out = arr.copy() | |
mask = rng.random(total) < float(mix_prob) | |
idx = 0 | |
for i in range(N): | |
for j in range(N): | |
y0 = i * ph | |
x0 = j * pw | |
if mask[idx]: | |
src = patches[idx].astype(np.float32) | |
shf = patches[perm[idx]].astype(np.float32) | |
if 0.0 < alpha < 1.0: | |
mixed = alpha * shf + (1.0 - alpha) * src | |
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8) | |
else: | |
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]] | |
else: | |
out[y0:y0+ph, x0:x0+pw] = patches[idx] | |
idx += 1 | |
return Image.fromarray(out) | |
# === Overlap path with feather blending === | |
arr = arr_u8.astype(np.float32) | |
# Precompute feather mask for max size patch | |
feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov) | |
patches = [] | |
coords = [] | |
for i in range(N): | |
for j in range(N): | |
sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov) | |
# Clamp to image bounds | |
sh = max(0, sh); sw = max(0, sw) | |
eh = min(H, eh); ew = min(W, ew) | |
patches.append(arr[sh:eh, sw:ew]) | |
coords.append((sh, eh, sw, ew)) | |
total = len(patches) | |
perm = rng.permutation(total) | |
# We'll sample alpha per-patch to echo your overlap snippet | |
def sample_alpha(): | |
if beta_a > 0 and beta_b > 0: | |
return float(rng.beta(beta_a, beta_b)) | |
return 1.0 | |
canvas = np.zeros_like(arr, dtype=np.float32) | |
weight = np.zeros((H, W), dtype=np.float32) | |
for k, (sh, eh, sw, ew) in enumerate(coords): | |
if rng.random() >= float(mix_prob): | |
# keep original content in that region | |
src = patches[k] | |
patch = src | |
else: | |
lam = sample_alpha() | |
src = patches[k].astype(np.float32) | |
shf = patches[int(perm[k])].astype(np.float32) | |
patch = lam * shf + (1.0 - lam) * src | |
ph_k, pw_k, _ = patch.shape | |
# Slice feather mask down if needed (near borders) | |
mask2d = feather_full[:ph_k, :pw_k] | |
if arr.shape[2] == 1: | |
mask3d = mask2d[..., None] | |
else: | |
mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2) | |
# Accumulate | |
canvas[sh:eh, sw:ew] += patch * mask3d | |
weight[sh:eh, sw:ew] += mask2d | |
# Normalize | |
weight = np.clip(weight, 1e-8, None) | |
out = (canvas / weight[..., None]) | |
out = np.clip(out, 0, 255).astype(np.uint8) | |
return Image.fromarray(out) | |