SPM / spm.py
prasannareddyp's picture
Upload 4 files
c3d8cc1 verified
raw
history blame
6.44 kB
from PIL import Image
import numpy as np
def create_feather_mask(height, width, feather_size=4):
"""
2D mask HxW that smoothly transitions from 1.0 in the interior
to 0.0 at the edges over `feather_size` pixels.
"""
mask = np.ones((height, width), dtype=np.float32)
if feather_size <= 0:
return mask
ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
# Top / Bottom
mask[:feather_size, :] *= ramp[:, None]
mask[-feather_size:, :] *= ramp[::-1, None]
# Left / Right
mask[:, :feather_size] *= ramp[None, :]
mask[:, -feather_size:] *= ramp[None, ::-1]
return mask
def _to_divisible_by(img, N):
"""Crop so width and height are divisible by N (top-left anchored)."""
w, h = img.size
W = (w // N) * N
H = (h // N) * N
if W == 0 or H == 0:
raise ValueError("N is larger than the image grid (image too small).")
if W != w or H != h:
img = img.crop((0, 0, W, H))
return img, W, H
def _edgelogic(i, j, ph, pw, N, overlap):
"""
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
Extend with overlap, biasing inward.
Uses 2*overlap for edges to keep patch areas roughly comparable.
Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
"""
start_h = i * ph
start_w = j * pw
end_h = start_h + ph
end_w = start_w + pw
if overlap <= 0:
return start_h, end_h, start_w, end_w
# Vertical
if i == 0:
end_h += 2 * overlap
elif i == N - 1:
start_h -= 2 * overlap
else:
start_h -= overlap
end_h += overlap
# Horizontal
if j == 0:
end_w += 2 * overlap
elif j == N - 1:
start_w -= 2 * overlap
else:
start_w -= overlap
end_w += overlap
return start_h, end_h, start_w, end_w
def spm_augment(
image,
num_patches=4, # N for an N×N grid
mix_prob=0.5,
beta_a=2.0,
beta_b=2.0,
overlap_px=0,
seed=None
):
"""
SPM-style augmentation with optional overlap + feathered blending.
When overlap_px <= 0:
- Standard global shuffle over N×N patches;
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
When overlap_px > 0:
- Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders),
clipped to the image. Patches are mixed per location and alpha sampled per-patch
for a bit more stochasticity (can be changed to per-image alpha by editing below).
- Patches are blended into the canvas with a feather mask of size `overlap_px`.
"""
# Normalize to PIL and ensure divisibility
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
else:
img = image.convert("RGB")
N = int(num_patches)
rng = np.random.default_rng(seed)
img, W, H = _to_divisible_by(img, N)
arr_u8 = np.array(img, dtype=np.uint8)
ph = H // N
pw = W // N
# Clamp overlap to < half patch size
if overlap_px is None:
overlap_px = 0
overlap_px = int(overlap_px)
max_ov = max(0, min(ph, pw) // 2 - 1)
ov = int(np.clip(overlap_px, 0, max_ov))
if ov <= 0:
# === Non-overlap path ===
arr = arr_u8
# Build patches (row-major)
patches = []
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
patches.append(arr[y0:y0+ph, x0:x0+pw])
total = N * N
perm = rng.permutation(total)
# One alpha per image
if beta_a > 0 and beta_b > 0:
alpha = float(rng.beta(beta_a, beta_b))
else:
alpha = 1.0
out = arr.copy()
mask = rng.random(total) < float(mix_prob)
idx = 0
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
if mask[idx]:
src = patches[idx].astype(np.float32)
shf = patches[perm[idx]].astype(np.float32)
if 0.0 < alpha < 1.0:
mixed = alpha * shf + (1.0 - alpha) * src
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
else:
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
else:
out[y0:y0+ph, x0:x0+pw] = patches[idx]
idx += 1
return Image.fromarray(out)
# === Overlap path with feather blending ===
arr = arr_u8.astype(np.float32)
# Precompute feather mask for max size patch
feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)
patches = []
coords = []
for i in range(N):
for j in range(N):
sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
# Clamp to image bounds
sh = max(0, sh); sw = max(0, sw)
eh = min(H, eh); ew = min(W, ew)
patches.append(arr[sh:eh, sw:ew])
coords.append((sh, eh, sw, ew))
total = len(patches)
perm = rng.permutation(total)
# We'll sample alpha per-patch to echo your overlap snippet
def sample_alpha():
if beta_a > 0 and beta_b > 0:
return float(rng.beta(beta_a, beta_b))
return 1.0
canvas = np.zeros_like(arr, dtype=np.float32)
weight = np.zeros((H, W), dtype=np.float32)
for k, (sh, eh, sw, ew) in enumerate(coords):
if rng.random() >= float(mix_prob):
# keep original content in that region
src = patches[k]
patch = src
else:
lam = sample_alpha()
src = patches[k].astype(np.float32)
shf = patches[int(perm[k])].astype(np.float32)
patch = lam * shf + (1.0 - lam) * src
ph_k, pw_k, _ = patch.shape
# Slice feather mask down if needed (near borders)
mask2d = feather_full[:ph_k, :pw_k]
if arr.shape[2] == 1:
mask3d = mask2d[..., None]
else:
mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
# Accumulate
canvas[sh:eh, sw:ew] += patch * mask3d
weight[sh:eh, sw:ew] += mask2d
# Normalize
weight = np.clip(weight, 1e-8, None)
out = (canvas / weight[..., None])
out = np.clip(out, 0, 255).astype(np.uint8)
return Image.fromarray(out)