Spaces:
Sleeping
Sleeping
File size: 6,094 Bytes
5dbf895 c3d8cc1 5dbf895 c3d8cc1 4e2c1e5 c3d8cc1 5dbf895 4e2c1e5 5dbf895 c3d8cc1 4e2c1e5 c3d8cc1 4e2c1e5 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 4e2c1e5 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 5dbf895 c3d8cc1 4e2c1e5 c3d8cc1 4e2c1e5 c3d8cc1 5dbf895 c3d8cc1 5dbf895 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
from PIL import Image
import numpy as np
def create_feather_mask(height, width, feather_size=4):
"""
2D mask HxW that smoothly transitions from 1.0 in the interior
to 0.0 at the edges over `feather_size` pixels.
"""
mask = np.ones((height, width), dtype=np.float32)
if feather_size <= 0:
return mask
ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
# Top / Bottom
mask[:feather_size, :] *= ramp[:, None]
mask[-feather_size:, :] *= ramp[::-1, None]
# Left / Right
mask[:, :feather_size] *= ramp[None, :]
mask[:, -feather_size:] *= ramp[None, ::-1]
return mask
def _to_divisible_by(img, N):
"""Crop so width and height are divisible by N (top-left anchored)."""
w, h = img.size
W = (w // N) * N
H = (h // N) * N
if W == 0 or H == 0:
raise ValueError("N is larger than the image grid (image too small).")
if W != w or H != h:
img = img.crop((0, 0, W, H))
return img, W, H
def _edgelogic(i, j, ph, pw, N, overlap):
"""
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
Extend with overlap, biasing inward. Uses 2*overlap at borders.
"""
start_h = i * ph
start_w = j * pw
end_h = start_h + ph
end_w = start_w + pw
if overlap <= 0:
return start_h, end_h, start_w, end_w
# Vertical
if i == 0:
end_h += 2 * overlap
elif i == N - 1:
start_h -= 2 * overlap
else:
start_h -= overlap
end_h += overlap
# Horizontal
if j == 0:
end_w += 2 * overlap
elif j == N - 1:
start_w -= 2 * overlap
else:
start_w -= overlap
end_w += overlap
return start_h, end_h, start_w, end_w
def spm_augment(
image,
num_patches=4, # N for an N×N grid
mix_prob=0.5,
beta_a=2.0,
beta_b=2.0,
overlap_pct=0.0, # percentage of patch size (0..49 typically)
seed=None
):
"""
SPM-style augmentation with optional overlap + feathered blending.
When overlap_pct <= 0:
- Standard global shuffle over N×N patches;
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
When overlap_pct > 0:
- Each base cell (N×N grid) expands by ±overlap_px (derived from percentage),
clipped to the image. Patches are mixed per location with per-patch alpha.
- Patches are blended into the canvas with a feather mask of size overlap_px.
"""
# Normalize to PIL and ensure divisibility
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
else:
img = image.convert("RGB")
N = int(num_patches)
rng = np.random.default_rng(seed)
img, W, H = _to_divisible_by(img, N)
arr_u8 = np.array(img, dtype=np.uint8)
ph = H // N
pw = W // N
# Convert percentage to pixel overlap; clamp to < half patch size
pct = float(overlap_pct)
pct = max(0.0, min(pct, 49.0)) # keep below 50% for stability
overlap_px = int(round((pct / 100.0) * min(ph, pw)))
max_ov = max(0, min(ph, pw) // 2 - 1)
ov = int(np.clip(overlap_px, 0, max_ov))
if ov <= 0:
# === Non-overlap path ===
arr = arr_u8
# Build patches (row-major)
patches = []
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
patches.append(arr[y0:y0+ph, x0:x0+pw])
total = N * N
perm = rng.permutation(total)
# One alpha per image
if beta_a > 0 and beta_b > 0:
alpha = float(rng.beta(beta_a, beta_b))
else:
alpha = 1.0
out = arr.copy()
mask = rng.random(total) < float(mix_prob)
idx = 0
for i in range(N):
for j in range(N):
y0 = i * ph
x0 = j * pw
if mask[idx]:
src = patches[idx].astype(np.float32)
shf = patches[perm[idx]].astype(np.float32)
if 0.0 < alpha < 1.0:
mixed = alpha * shf + (1.0 - alpha) * src
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
else:
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
else:
out[y0:y0+ph, x0:x0+pw] = patches[idx]
idx += 1
return Image.fromarray(out)
# === Overlap path with feather blending ===
arr = arr_u8.astype(np.float32)
# Precompute feather mask for max size patch
feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)
patches = []
coords = []
for i in range(N):
for j in range(N):
sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
# Clamp to image bounds
sh = max(0, sh); sw = max(0, sw)
eh = min(H, eh); ew = min(W, ew)
patches.append(arr[sh:eh, sw:ew])
coords.append((sh, eh, sw, ew))
total = len(patches)
perm = rng.permutation(total)
def sample_alpha():
if beta_a > 0 and beta_b > 0:
return float(rng.beta(beta_a, beta_b))
return 1.0
canvas = np.zeros_like(arr, dtype=np.float32)
weight = np.zeros((H, W), dtype=np.float32)
for k, (sh, eh, sw, ew) in enumerate(coords):
if rng.random() >= float(mix_prob):
patch = patches[k]
else:
lam = sample_alpha()
src = patches[k].astype(np.float32)
shf = patches[int(perm[k])].astype(np.float32)
patch = lam * shf + (1.0 - lam) * src
ph_k, pw_k, _ = patch.shape
mask2d = feather_full[:ph_k, :pw_k]
mask3d = mask2d[..., None] if arr.shape[2] == 1 else np.repeat(mask2d[..., None], arr.shape[2], axis=2)
canvas[sh:eh, sw:ew] += patch * mask3d
weight[sh:eh, sw:ew] += mask2d
weight = np.clip(weight, 1e-8, None)
out = (canvas / weight[..., None])
out = np.clip(out, 0, 255).astype(np.uint8)
return Image.fromarray(out)
|