prasannareddyp commited on
Commit
c3d8cc1
·
verified ·
1 Parent(s): c0b47d8

Upload 4 files

Browse files
Files changed (3) hide show
  1. README.md +11 -9
  2. app.py +13 -4
  3. spm.py +161 -37
README.md CHANGED
@@ -1,36 +1,38 @@
1
  ---
2
- title: SPM
3
- emoji: 🚀
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Shuffle PatchMix Augmentation
12
  ---
13
 
14
  arxiv.org/abs/2505.24216
15
 
16
  # Shuffle PatchMix (SPM) — Hugging Face Space
17
 
18
- A minimal interactive demo for SPM-style augmentation. Upload an image (or a .zip of images), set **grid size (N×N)**, and download the augmented outputs.
19
 
20
  GitHub repo: https://github.com/PrasannaPulakurthi/SPM
21
 
22
  ## Parameters
23
  - **Grid (N×N):** Choose one of **2×2, 4×4, 8×8, 16×16**. The image is cropped (top-left) so its width and height are divisible by N.
 
 
24
  - **Mix probability:** Per-patch probability to mix original and a shuffled patch.
25
- - **Beta α, β:** Shape parameters for a single per-image alpha sampled from Beta(α,β).
26
  - **Seed:** Optional deterministic seed.
27
 
28
  ## Batch Mode
29
  Upload a `.zip` containing images (`.png`, `.jpg`, `.jpeg`). The app returns a `.zip` of augmented results with the same folder structure.
30
 
31
  ## Notes
32
- - This uses a **global** patch permutation and per-patch mixing with a **single** alpha per image (tweak in `spm.py` if you want per-patch alpha or different strategies).
33
- - If you want parity with a specific paper version, swap in your official implementation but keep `spm_augment(image, num_patches, mix_prob, beta_a, beta_b, seed)`.
 
 
34
 
35
  ## Local Development
36
  ```bash
 
1
  ---
2
+ title: Shuffle PatchMix
3
+ colorFrom: purple
4
+ colorTo: red
 
5
  sdk: gradio
6
  sdk_version: 5.44.1
7
  app_file: app.py
8
  pinned: false
9
  license: mit
 
10
  ---
11
 
12
  arxiv.org/abs/2505.24216
13
 
14
  # Shuffle PatchMix (SPM) — Hugging Face Space
15
 
16
+ A minimal interactive demo for SPM-style augmentation. Now supports **overlap with feathered blending**.
17
 
18
  GitHub repo: https://github.com/PrasannaPulakurthi/SPM
19
 
20
  ## Parameters
21
  - **Grid (N×N):** Choose one of **2×2, 4×4, 8×8, 16×16**. The image is cropped (top-left) so its width and height are divisible by N.
22
+ - **Enable overlap (feather blend):** When enabled, each base cell expands by ±overlap pixels (2× at borders to keep areas comparable) and patches are blended with a feather mask.
23
+ - **Overlap (px):** Pixel overlap per side. Automatically clipped to `< ½ * patch size`.
24
  - **Mix probability:** Per-patch probability to mix original and a shuffled patch.
25
+ - **Beta α, β:** Shape parameters for the Beta distribution used for blending weights.
26
  - **Seed:** Optional deterministic seed.
27
 
28
  ## Batch Mode
29
  Upload a `.zip` containing images (`.png`, `.jpg`, `.jpeg`). The app returns a `.zip` of augmented results with the same folder structure.
30
 
31
  ## Notes
32
+ - Non-overlap path uses **one alpha per image**; overlap path uses **alpha per patch** to mirror your reference snippet (edit in `spm.py` if you prefer one alpha per image).
33
+ - Feather size equals the overlap (you can decouple by adjusting `create_feather_mask` calls).
34
+ - If you want parity with a specific paper version, swap in your official implementation but keep the signature:
35
+ `spm_augment(image, num_patches, mix_prob, beta_a, beta_b, overlap_px, seed)`.
36
 
37
  ## Local Development
38
  ```bash
app.py CHANGED
@@ -7,6 +7,7 @@ from spm import spm_augment
7
  TITLE = "Shuffle PatchMix (SPM) Augmentation"
8
  DESC = """
9
  Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants.
 
10
  For batch processing, upload a .zip of images (PNG/JPG/JPEG), and download a .zip of outputs.
11
  """
12
 
@@ -18,12 +19,13 @@ def _parse_grid(grid_choice: str) -> int:
18
  except Exception:
19
  return 4
20
 
21
- def run_single(image, grid_choice, mix_prob, beta_a, beta_b, num_augs, seed):
22
  if image is None:
23
  return []
24
  outs = []
25
  base_seed = int(seed) if seed is not None else None
26
  N = _parse_grid(grid_choice)
 
27
  for i in range(num_augs):
28
  s = (base_seed + i) if base_seed is not None else None
29
  out_img = spm_augment(
@@ -32,12 +34,13 @@ def run_single(image, grid_choice, mix_prob, beta_a, beta_b, num_augs, seed):
32
  mix_prob=float(mix_prob),
33
  beta_a=float(beta_a),
34
  beta_b=float(beta_b),
 
35
  seed=s
36
  )
37
  outs.append(out_img)
38
  return outs
39
 
40
- def run_batch(zip_file, grid_choice, mix_prob, beta_a, beta_b, seed):
41
  if zip_file is None:
42
  return None, "Please upload a .zip file with images."
43
  tempdir = tempfile.mkdtemp()
@@ -50,6 +53,7 @@ def run_batch(zip_file, grid_choice, mix_prob, beta_a, beta_b, seed):
50
  valid_exts = {".png", ".jpg", ".jpeg"}
51
  count_in, count_out = 0, 0
52
  N = _parse_grid(grid_choice)
 
53
  for root_dir, _, files in os.walk(tempdir):
54
  for f in files:
55
  if f.lower().endswith(tuple(valid_exts)):
@@ -65,6 +69,7 @@ def run_batch(zip_file, grid_choice, mix_prob, beta_a, beta_b, seed):
65
  mix_prob=float(mix_prob),
66
  beta_a=float(beta_a),
67
  beta_b=float(beta_b),
 
68
  seed=int(seed) if seed is not None else None
69
  )
70
  rel = os.path.relpath(in_path, tempdir)
@@ -92,6 +97,8 @@ with gr.Blocks() as demo:
92
  with gr.Column(scale=1):
93
  inp = gr.Image(label="Input image", type="pil")
94
  grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
 
 
95
  mix_prob = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
96
  with gr.Row():
97
  beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
@@ -103,7 +110,7 @@ with gr.Blocks() as demo:
103
  gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto")
104
  run_btn.click(
105
  fn=run_single,
106
- inputs=[inp, grid_choice, mix_prob, beta_a, beta_b, num_augs, seed],
107
  outputs=[gallery]
108
  )
109
  with gr.TabItem("Batch (.zip)"):
@@ -111,6 +118,8 @@ with gr.Blocks() as demo:
111
  with gr.Column(scale=1):
112
  zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"])
113
  grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
 
 
114
  mix_prob_b = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
115
  with gr.Row():
116
  beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
@@ -122,7 +131,7 @@ with gr.Blocks() as demo:
122
  status = gr.Markdown()
123
  run_b.click(
124
  fn=run_batch,
125
- inputs=[zip_in, grid_choice_b, mix_prob_b, beta_a_b, beta_b_b, seed_b],
126
  outputs=[zip_out, status]
127
  )
128
 
 
7
  TITLE = "Shuffle PatchMix (SPM) Augmentation"
8
  DESC = """
9
  Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants.
10
+ You can optionally **enable overlap** with feathered blending for smoother seams.
11
  For batch processing, upload a .zip of images (PNG/JPG/JPEG), and download a .zip of outputs.
12
  """
13
 
 
19
  except Exception:
20
  return 4
21
 
22
+ def run_single(image, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed):
23
  if image is None:
24
  return []
25
  outs = []
26
  base_seed = int(seed) if seed is not None else None
27
  N = _parse_grid(grid_choice)
28
+ ov = int(overlap_px) if use_overlap else 0
29
  for i in range(num_augs):
30
  s = (base_seed + i) if base_seed is not None else None
31
  out_img = spm_augment(
 
34
  mix_prob=float(mix_prob),
35
  beta_a=float(beta_a),
36
  beta_b=float(beta_b),
37
+ overlap_px=ov,
38
  seed=s
39
  )
40
  outs.append(out_img)
41
  return outs
42
 
43
+ def run_batch(zip_file, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, seed):
44
  if zip_file is None:
45
  return None, "Please upload a .zip file with images."
46
  tempdir = tempfile.mkdtemp()
 
53
  valid_exts = {".png", ".jpg", ".jpeg"}
54
  count_in, count_out = 0, 0
55
  N = _parse_grid(grid_choice)
56
+ ov = int(overlap_px) if use_overlap else 0
57
  for root_dir, _, files in os.walk(tempdir):
58
  for f in files:
59
  if f.lower().endswith(tuple(valid_exts)):
 
69
  mix_prob=float(mix_prob),
70
  beta_a=float(beta_a),
71
  beta_b=float(beta_b),
72
+ overlap_px=ov,
73
  seed=int(seed) if seed is not None else None
74
  )
75
  rel = os.path.relpath(in_path, tempdir)
 
97
  with gr.Column(scale=1):
98
  inp = gr.Image(label="Input image", type="pil")
99
  grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
100
+ use_overlap = gr.Checkbox(value=False, label="Enable overlap (feather blend)")
101
+ overlap_px = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)")
102
  mix_prob = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
103
  with gr.Row():
104
  beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
 
110
  gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto")
111
  run_btn.click(
112
  fn=run_single,
113
+ inputs=[inp, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed],
114
  outputs=[gallery]
115
  )
116
  with gr.TabItem("Batch (.zip)"):
 
118
  with gr.Column(scale=1):
119
  zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"])
120
  grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
121
+ use_overlap_b = gr.Checkbox(value=False, label="Enable overlap (feather blend)")
122
+ overlap_px_b = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)")
123
  mix_prob_b = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
124
  with gr.Row():
125
  beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
 
131
  status = gr.Markdown()
132
  run_b.click(
133
  fn=run_batch,
134
+ inputs=[zip_in, grid_choice_b, use_overlap_b, overlap_px_b, mix_prob_b, beta_a_b, beta_b_b, seed_b],
135
  outputs=[zip_out, status]
136
  )
137
 
spm.py CHANGED
@@ -1,6 +1,23 @@
1
  from PIL import Image
2
  import numpy as np
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def _to_divisible_by(img, N):
5
  """Crop so width and height are divisible by N (top-left anchored)."""
6
  w, h = img.size
@@ -12,22 +29,64 @@ def _to_divisible_by(img, N):
12
  img = img.crop((0, 0, W, H))
13
  return img, W, H
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def spm_augment(
16
  image,
17
  num_patches=4, # N for an N×N grid
18
  mix_prob=0.5,
19
  beta_a=2.0,
20
  beta_b=2.0,
 
21
  seed=None
22
  ):
23
  """
24
- SPM-style augmentation using a global shuffle over an N×N patch grid.
25
- 1) Divide image into N×N patches (cropping to be divisible by N if needed).
26
- 2) Globally permute patch indices.
27
- 3) Per patch, with probability `mix_prob`, replace by a convex blend of
28
- original and a shuffled patch using alpha~Beta(beta_a,beta_b) (one alpha per image).
 
 
 
 
 
 
29
  """
30
- # Normalize input
31
  if isinstance(image, np.ndarray):
32
  img = Image.fromarray(image).convert("RGB")
33
  else:
@@ -36,47 +95,112 @@ def spm_augment(
36
  N = int(num_patches)
37
  rng = np.random.default_rng(seed)
38
 
39
- # Ensure divisibility and compute patch size
40
  img, W, H = _to_divisible_by(img, N)
41
- arr = np.array(img, dtype=np.uint8)
42
  ph = H // N
43
  pw = W // N
44
 
45
- # Build patch list (row-major)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  patches = []
 
47
  for i in range(N):
48
  for j in range(N):
49
- y0 = i * ph
50
- x0 = j * pw
51
- patches.append(arr[y0:y0+ph, x0:x0+pw])
 
 
 
52
 
53
- total = N * N
54
  perm = rng.permutation(total)
55
 
56
- # Sample one alpha for the whole image
57
- if beta_a > 0 and beta_b > 0:
58
- alpha = float(rng.beta(beta_a, beta_b))
59
- else:
60
- alpha = 1.0
61
 
62
- # Patchwise mix
63
- out = arr.copy()
64
- mask = rng.random(total) < float(mix_prob)
65
- idx = 0
66
- for i in range(N):
67
- for j in range(N):
68
- y0 = i * ph
69
- x0 = j * pw
70
- if mask[idx]:
71
- src = patches[idx].astype(np.float32)
72
- shf = patches[perm[idx]].astype(np.float32)
73
- if 0.0 < alpha < 1.0:
74
- mixed = alpha * shf + (1.0 - alpha) * src
75
- out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
76
- else:
77
- out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
78
- else:
79
- out[y0:y0+ph, x0:x0+pw] = patches[idx]
80
- idx += 1
 
 
 
 
 
 
81
 
 
 
 
 
82
  return Image.fromarray(out)
 
1
  from PIL import Image
2
  import numpy as np
3
 
4
+ def create_feather_mask(height, width, feather_size=4):
5
+ """
6
+ 2D mask HxW that smoothly transitions from 1.0 in the interior
7
+ to 0.0 at the edges over `feather_size` pixels.
8
+ """
9
+ mask = np.ones((height, width), dtype=np.float32)
10
+ if feather_size <= 0:
11
+ return mask
12
+ ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
13
+ # Top / Bottom
14
+ mask[:feather_size, :] *= ramp[:, None]
15
+ mask[-feather_size:, :] *= ramp[::-1, None]
16
+ # Left / Right
17
+ mask[:, :feather_size] *= ramp[None, :]
18
+ mask[:, -feather_size:] *= ramp[None, ::-1]
19
+ return mask
20
+
21
  def _to_divisible_by(img, N):
22
  """Crop so width and height are divisible by N (top-left anchored)."""
23
  w, h = img.size
 
29
  img = img.crop((0, 0, W, H))
30
  return img, W, H
31
 
32
+ def _edgelogic(i, j, ph, pw, N, overlap):
33
+ """
34
+ Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
35
+ Extend with overlap, biasing inward.
36
+ Uses 2*overlap for edges to keep patch areas roughly comparable.
37
+ Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
38
+ """
39
+ start_h = i * ph
40
+ start_w = j * pw
41
+ end_h = start_h + ph
42
+ end_w = start_w + pw
43
+
44
+ if overlap <= 0:
45
+ return start_h, end_h, start_w, end_w
46
+
47
+ # Vertical
48
+ if i == 0:
49
+ end_h += 2 * overlap
50
+ elif i == N - 1:
51
+ start_h -= 2 * overlap
52
+ else:
53
+ start_h -= overlap
54
+ end_h += overlap
55
+
56
+ # Horizontal
57
+ if j == 0:
58
+ end_w += 2 * overlap
59
+ elif j == N - 1:
60
+ start_w -= 2 * overlap
61
+ else:
62
+ start_w -= overlap
63
+ end_w += overlap
64
+
65
+ return start_h, end_h, start_w, end_w
66
+
67
  def spm_augment(
68
  image,
69
  num_patches=4, # N for an N×N grid
70
  mix_prob=0.5,
71
  beta_a=2.0,
72
  beta_b=2.0,
73
+ overlap_px=0,
74
  seed=None
75
  ):
76
  """
77
+ SPM-style augmentation with optional overlap + feathered blending.
78
+
79
+ When overlap_px <= 0:
80
+ - Standard global shuffle over N×N patches;
81
+ - Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
82
+
83
+ When overlap_px > 0:
84
+ - Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders),
85
+ clipped to the image. Patches are mixed per location and alpha sampled per-patch
86
+ for a bit more stochasticity (can be changed to per-image alpha by editing below).
87
+ - Patches are blended into the canvas with a feather mask of size `overlap_px`.
88
  """
89
+ # Normalize to PIL and ensure divisibility
90
  if isinstance(image, np.ndarray):
91
  img = Image.fromarray(image).convert("RGB")
92
  else:
 
95
  N = int(num_patches)
96
  rng = np.random.default_rng(seed)
97
 
 
98
  img, W, H = _to_divisible_by(img, N)
99
+ arr_u8 = np.array(img, dtype=np.uint8)
100
  ph = H // N
101
  pw = W // N
102
 
103
+ # Clamp overlap to < half patch size
104
+ if overlap_px is None:
105
+ overlap_px = 0
106
+ overlap_px = int(overlap_px)
107
+ max_ov = max(0, min(ph, pw) // 2 - 1)
108
+ ov = int(np.clip(overlap_px, 0, max_ov))
109
+
110
+ if ov <= 0:
111
+ # === Non-overlap path ===
112
+ arr = arr_u8
113
+ # Build patches (row-major)
114
+ patches = []
115
+ for i in range(N):
116
+ for j in range(N):
117
+ y0 = i * ph
118
+ x0 = j * pw
119
+ patches.append(arr[y0:y0+ph, x0:x0+pw])
120
+
121
+ total = N * N
122
+ perm = rng.permutation(total)
123
+
124
+ # One alpha per image
125
+ if beta_a > 0 and beta_b > 0:
126
+ alpha = float(rng.beta(beta_a, beta_b))
127
+ else:
128
+ alpha = 1.0
129
+
130
+ out = arr.copy()
131
+ mask = rng.random(total) < float(mix_prob)
132
+ idx = 0
133
+ for i in range(N):
134
+ for j in range(N):
135
+ y0 = i * ph
136
+ x0 = j * pw
137
+ if mask[idx]:
138
+ src = patches[idx].astype(np.float32)
139
+ shf = patches[perm[idx]].astype(np.float32)
140
+ if 0.0 < alpha < 1.0:
141
+ mixed = alpha * shf + (1.0 - alpha) * src
142
+ out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
143
+ else:
144
+ out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
145
+ else:
146
+ out[y0:y0+ph, x0:x0+pw] = patches[idx]
147
+ idx += 1
148
+
149
+ return Image.fromarray(out)
150
+
151
+ # === Overlap path with feather blending ===
152
+ arr = arr_u8.astype(np.float32)
153
+ # Precompute feather mask for max size patch
154
+ feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)
155
+
156
  patches = []
157
+ coords = []
158
  for i in range(N):
159
  for j in range(N):
160
+ sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
161
+ # Clamp to image bounds
162
+ sh = max(0, sh); sw = max(0, sw)
163
+ eh = min(H, eh); ew = min(W, ew)
164
+ patches.append(arr[sh:eh, sw:ew])
165
+ coords.append((sh, eh, sw, ew))
166
 
167
+ total = len(patches)
168
  perm = rng.permutation(total)
169
 
170
+ # We'll sample alpha per-patch to echo your overlap snippet
171
+ def sample_alpha():
172
+ if beta_a > 0 and beta_b > 0:
173
+ return float(rng.beta(beta_a, beta_b))
174
+ return 1.0
175
 
176
+ canvas = np.zeros_like(arr, dtype=np.float32)
177
+ weight = np.zeros((H, W), dtype=np.float32)
178
+
179
+ for k, (sh, eh, sw, ew) in enumerate(coords):
180
+ if rng.random() >= float(mix_prob):
181
+ # keep original content in that region
182
+ src = patches[k]
183
+ patch = src
184
+ else:
185
+ lam = sample_alpha()
186
+ src = patches[k].astype(np.float32)
187
+ shf = patches[int(perm[k])].astype(np.float32)
188
+ patch = lam * shf + (1.0 - lam) * src
189
+
190
+ ph_k, pw_k, _ = patch.shape
191
+ # Slice feather mask down if needed (near borders)
192
+ mask2d = feather_full[:ph_k, :pw_k]
193
+ if arr.shape[2] == 1:
194
+ mask3d = mask2d[..., None]
195
+ else:
196
+ mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
197
+
198
+ # Accumulate
199
+ canvas[sh:eh, sw:ew] += patch * mask3d
200
+ weight[sh:eh, sw:ew] += mask2d
201
 
202
+ # Normalize
203
+ weight = np.clip(weight, 1e-8, None)
204
+ out = (canvas / weight[..., None])
205
+ out = np.clip(out, 0, 255).astype(np.uint8)
206
  return Image.fromarray(out)