Spaces:
Sleeping
Sleeping
File size: 6,436 Bytes
5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from PIL import Image
import numpy as np
def create_feather_mask(height, width, feather_size=4):
"""
2D mask HxW that smoothly transitions from 1.0 in the interior
to 0.0 at the edges over `feather_size` pixels.
"""
mask = np.ones((height, width), dtype=np.float32)
if feather_size <= 0:
return mask
ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
# Top / Bottom
mask[:feather_size, :] *= ramp[:, None]
mask[-feather_size:, :] *= ramp[::-1, None]
# Left / Right
mask[:, :feather_size] *= ramp[None, :]
mask[:, -feather_size:] *= ramp[None, ::-1]
return mask
def _to_divisible_by(img, N):
"""Crop so width and height are divisible by N (top-left anchored)."""
w, h = img.size
W = (w // N) * N
H = (h // N) * N
if W == 0 or H == 0:
raise ValueError("N is larger than the image grid (image too small).")
if W != w or H != h:
img = img.crop((0, 0, W, H))
return img, W, H
def _edgelogic(i, j, ph, pw, N, overlap):
"""
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
Extend with overlap, biasing inward.
Uses 2*overlap for edges to keep patch areas roughly comparable.
Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
"""
start_h = i * ph
start_w = j * pw
end_h = start_h + ph
end_w = start_w + pw
if overlap <= 0:
return start_h, end_h, start_w, end_w
# Vertical
if i == 0:
end_h += 2 * overlap
elif i == N - 1:
start_h -= 2 * overlap
else:
start_h -= overlap
end_h += overlap
# Horizontal
if j == 0:
end_w += 2 * overlap
elif j == N - 1:
start_w -= 2 * overlap
else:
start_w -= overlap
end_w += overlap
return start_h, end_h, start_w, end_w
def spm_augment(
image,
num_patches=4, # N for an N×N grid
mix_prob=0.5,
beta_a=2.0,
beta_b=2.0,
overlap_px=0,
seed=None
):
"""
SPM-style augmentation with optional overlap + feathered blending.
When overlap_px <= 0:
- Standard global shuffle over N×N patches;
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
When overlap_px > 0:
- Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders),
clipped to the image. Patches are mixed per location and alpha sampled per-patch
for a bit more stochasticity (can be changed to per-image alpha by editing below).
- Patches are blended into the canvas with a feather mask of size `overlap_px`.
"""
# Normalize to PIL and ensure divisibility
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
else:
img = image.convert("RGB")
N = int(num_patches)
rng = np.random.default_rng(seed)
img, W, H = _to_divisible_by(img, N)
arr_u8 = np.array(img, dtype=np.uint8)
ph = H // N
pw = W // N
# Clamp overlap to < half patch size
if overlap_px is None:
overlap_px = 0
overlap_px = int(overlap_px)
max_ov = max(0, min(ph, pw) // 2 - 1)
ov = int(np.clip(overlap_px, 0, max_ov))
if ov <= 0:
# === Non-overlap path ===
arr = arr_u8
# Build patches (row-major)
patches = []
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
patches.append(arr[y0:y0+ph, x0:x0+pw])
total = N * N
perm = rng.permutation(total)
# One alpha per image
if beta_a > 0 and beta_b > 0:
alpha = float(rng.beta(beta_a, beta_b))
else:
alpha = 1.0
out = arr.copy()
mask = rng.random(total) < float(mix_prob)
idx = 0
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
if mask[idx]:
src = patches[idx].astype(np.float32)
shf = patches[perm[idx]].astype(np.float32)
if 0.0 < alpha < 1.0:
mixed = alpha * shf + (1.0 - alpha) * src
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
else:
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
else:
out[y0:y0+ph, x0:x0+pw] = patches[idx]
idx += 1
return Image.fromarray(out)
# === Overlap path with feather blending ===
arr = arr_u8.astype(np.float32)
# Precompute feather mask for max size patch
feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)
patches = []
coords = []
for i in range(N):
for j in range(N):
sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
# Clamp to image bounds
sh = max(0, sh); sw = max(0, sw)
eh = min(H, eh); ew = min(W, ew)
patches.append(arr[sh:eh, sw:ew])
coords.append((sh, eh, sw, ew))
total = len(patches)
perm = rng.permutation(total)
# We'll sample alpha per-patch to echo your overlap snippet
def sample_alpha():
if beta_a > 0 and beta_b > 0:
return float(rng.beta(beta_a, beta_b))
return 1.0
canvas = np.zeros_like(arr, dtype=np.float32)
weight = np.zeros((H, W), dtype=np.float32)
for k, (sh, eh, sw, ew) in enumerate(coords):
if rng.random() >= float(mix_prob):
# keep original content in that region
src = patches[k]
patch = src
else:
lam = sample_alpha()
src = patches[k].astype(np.float32)
shf = patches[int(perm[k])].astype(np.float32)
patch = lam * shf + (1.0 - lam) * src
ph_k, pw_k, _ = patch.shape
# Slice feather mask down if needed (near borders)
mask2d = feather_full[:ph_k, :pw_k]
if arr.shape[2] == 1:
mask3d = mask2d[..., None]
else:
mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
# Accumulate
canvas[sh:eh, sw:ew] += patch * mask3d
weight[sh:eh, sw:ew] += mask2d
# Normalize
weight = np.clip(weight, 1e-8, None)
out = (canvas / weight[..., None])
out = np.clip(out, 0, 255).astype(np.uint8)
return Image.fromarray(out)
|