Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files
README.md
CHANGED
@@ -1,36 +1,38 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.44.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
-
short_description: Shuffle PatchMix Augmentation
|
12 |
---
|
13 |
|
14 |
arxiv.org/abs/2505.24216
|
15 |
|
16 |
# Shuffle PatchMix (SPM) — Hugging Face Space
|
17 |
|
18 |
-
A minimal interactive demo for SPM-style augmentation.
|
19 |
|
20 |
GitHub repo: https://github.com/PrasannaPulakurthi/SPM
|
21 |
|
22 |
## Parameters
|
23 |
- **Grid (N×N):** Choose one of **2×2, 4×4, 8×8, 16×16**. The image is cropped (top-left) so its width and height are divisible by N.
|
|
|
|
|
24 |
- **Mix probability:** Per-patch probability to mix original and a shuffled patch.
|
25 |
-
- **Beta α, β:** Shape parameters for
|
26 |
- **Seed:** Optional deterministic seed.
|
27 |
|
28 |
## Batch Mode
|
29 |
Upload a `.zip` containing images (`.png`, `.jpg`, `.jpeg`). The app returns a `.zip` of augmented results with the same folder structure.
|
30 |
|
31 |
## Notes
|
32 |
-
-
|
33 |
-
-
|
|
|
|
|
34 |
|
35 |
## Local Development
|
36 |
```bash
|
|
|
1 |
---
|
2 |
+
title: Shuffle PatchMix
|
3 |
+
colorFrom: purple
|
4 |
+
colorTo: red
|
|
|
5 |
sdk: gradio
|
6 |
sdk_version: 5.44.1
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
license: mit
|
|
|
10 |
---
|
11 |
|
12 |
arxiv.org/abs/2505.24216
|
13 |
|
14 |
# Shuffle PatchMix (SPM) — Hugging Face Space
|
15 |
|
16 |
+
A minimal interactive demo for SPM-style augmentation. Now supports **overlap with feathered blending**.
|
17 |
|
18 |
GitHub repo: https://github.com/PrasannaPulakurthi/SPM
|
19 |
|
20 |
## Parameters
|
21 |
- **Grid (N×N):** Choose one of **2×2, 4×4, 8×8, 16×16**. The image is cropped (top-left) so its width and height are divisible by N.
|
22 |
+
- **Enable overlap (feather blend):** When enabled, each base cell expands by ±overlap pixels (2× at borders to keep areas comparable) and patches are blended with a feather mask.
|
23 |
+
- **Overlap (px):** Pixel overlap per side. Automatically clipped to `< ½ * patch size`.
|
24 |
- **Mix probability:** Per-patch probability to mix original and a shuffled patch.
|
25 |
+
- **Beta α, β:** Shape parameters for the Beta distribution used for blending weights.
|
26 |
- **Seed:** Optional deterministic seed.
|
27 |
|
28 |
## Batch Mode
|
29 |
Upload a `.zip` containing images (`.png`, `.jpg`, `.jpeg`). The app returns a `.zip` of augmented results with the same folder structure.
|
30 |
|
31 |
## Notes
|
32 |
+
- Non-overlap path uses **one alpha per image**; overlap path uses **alpha per patch** to mirror your reference snippet (edit in `spm.py` if you prefer one alpha per image).
|
33 |
+
- Feather size equals the overlap (you can decouple by adjusting `create_feather_mask` calls).
|
34 |
+
- If you want parity with a specific paper version, swap in your official implementation but keep the signature:
|
35 |
+
`spm_augment(image, num_patches, mix_prob, beta_a, beta_b, overlap_px, seed)`.
|
36 |
|
37 |
## Local Development
|
38 |
```bash
|
app.py
CHANGED
@@ -7,6 +7,7 @@ from spm import spm_augment
|
|
7 |
TITLE = "Shuffle PatchMix (SPM) Augmentation"
|
8 |
DESC = """
|
9 |
Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants.
|
|
|
10 |
For batch processing, upload a .zip of images (PNG/JPG/JPEG), and download a .zip of outputs.
|
11 |
"""
|
12 |
|
@@ -18,12 +19,13 @@ def _parse_grid(grid_choice: str) -> int:
|
|
18 |
except Exception:
|
19 |
return 4
|
20 |
|
21 |
-
def run_single(image, grid_choice, mix_prob, beta_a, beta_b, num_augs, seed):
|
22 |
if image is None:
|
23 |
return []
|
24 |
outs = []
|
25 |
base_seed = int(seed) if seed is not None else None
|
26 |
N = _parse_grid(grid_choice)
|
|
|
27 |
for i in range(num_augs):
|
28 |
s = (base_seed + i) if base_seed is not None else None
|
29 |
out_img = spm_augment(
|
@@ -32,12 +34,13 @@ def run_single(image, grid_choice, mix_prob, beta_a, beta_b, num_augs, seed):
|
|
32 |
mix_prob=float(mix_prob),
|
33 |
beta_a=float(beta_a),
|
34 |
beta_b=float(beta_b),
|
|
|
35 |
seed=s
|
36 |
)
|
37 |
outs.append(out_img)
|
38 |
return outs
|
39 |
|
40 |
-
def run_batch(zip_file, grid_choice, mix_prob, beta_a, beta_b, seed):
|
41 |
if zip_file is None:
|
42 |
return None, "Please upload a .zip file with images."
|
43 |
tempdir = tempfile.mkdtemp()
|
@@ -50,6 +53,7 @@ def run_batch(zip_file, grid_choice, mix_prob, beta_a, beta_b, seed):
|
|
50 |
valid_exts = {".png", ".jpg", ".jpeg"}
|
51 |
count_in, count_out = 0, 0
|
52 |
N = _parse_grid(grid_choice)
|
|
|
53 |
for root_dir, _, files in os.walk(tempdir):
|
54 |
for f in files:
|
55 |
if f.lower().endswith(tuple(valid_exts)):
|
@@ -65,6 +69,7 @@ def run_batch(zip_file, grid_choice, mix_prob, beta_a, beta_b, seed):
|
|
65 |
mix_prob=float(mix_prob),
|
66 |
beta_a=float(beta_a),
|
67 |
beta_b=float(beta_b),
|
|
|
68 |
seed=int(seed) if seed is not None else None
|
69 |
)
|
70 |
rel = os.path.relpath(in_path, tempdir)
|
@@ -92,6 +97,8 @@ with gr.Blocks() as demo:
|
|
92 |
with gr.Column(scale=1):
|
93 |
inp = gr.Image(label="Input image", type="pil")
|
94 |
grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
|
|
|
|
|
95 |
mix_prob = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
|
96 |
with gr.Row():
|
97 |
beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
|
@@ -103,7 +110,7 @@ with gr.Blocks() as demo:
|
|
103 |
gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto")
|
104 |
run_btn.click(
|
105 |
fn=run_single,
|
106 |
-
inputs=[inp, grid_choice, mix_prob, beta_a, beta_b, num_augs, seed],
|
107 |
outputs=[gallery]
|
108 |
)
|
109 |
with gr.TabItem("Batch (.zip)"):
|
@@ -111,6 +118,8 @@ with gr.Blocks() as demo:
|
|
111 |
with gr.Column(scale=1):
|
112 |
zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"])
|
113 |
grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
|
|
|
|
|
114 |
mix_prob_b = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
|
115 |
with gr.Row():
|
116 |
beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
|
@@ -122,7 +131,7 @@ with gr.Blocks() as demo:
|
|
122 |
status = gr.Markdown()
|
123 |
run_b.click(
|
124 |
fn=run_batch,
|
125 |
-
inputs=[zip_in, grid_choice_b, mix_prob_b, beta_a_b, beta_b_b, seed_b],
|
126 |
outputs=[zip_out, status]
|
127 |
)
|
128 |
|
|
|
7 |
TITLE = "Shuffle PatchMix (SPM) Augmentation"
|
8 |
DESC = """
|
9 |
Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants.
|
10 |
+
You can optionally **enable overlap** with feathered blending for smoother seams.
|
11 |
For batch processing, upload a .zip of images (PNG/JPG/JPEG), and download a .zip of outputs.
|
12 |
"""
|
13 |
|
|
|
19 |
except Exception:
|
20 |
return 4
|
21 |
|
22 |
+
def run_single(image, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed):
|
23 |
if image is None:
|
24 |
return []
|
25 |
outs = []
|
26 |
base_seed = int(seed) if seed is not None else None
|
27 |
N = _parse_grid(grid_choice)
|
28 |
+
ov = int(overlap_px) if use_overlap else 0
|
29 |
for i in range(num_augs):
|
30 |
s = (base_seed + i) if base_seed is not None else None
|
31 |
out_img = spm_augment(
|
|
|
34 |
mix_prob=float(mix_prob),
|
35 |
beta_a=float(beta_a),
|
36 |
beta_b=float(beta_b),
|
37 |
+
overlap_px=ov,
|
38 |
seed=s
|
39 |
)
|
40 |
outs.append(out_img)
|
41 |
return outs
|
42 |
|
43 |
+
def run_batch(zip_file, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, seed):
|
44 |
if zip_file is None:
|
45 |
return None, "Please upload a .zip file with images."
|
46 |
tempdir = tempfile.mkdtemp()
|
|
|
53 |
valid_exts = {".png", ".jpg", ".jpeg"}
|
54 |
count_in, count_out = 0, 0
|
55 |
N = _parse_grid(grid_choice)
|
56 |
+
ov = int(overlap_px) if use_overlap else 0
|
57 |
for root_dir, _, files in os.walk(tempdir):
|
58 |
for f in files:
|
59 |
if f.lower().endswith(tuple(valid_exts)):
|
|
|
69 |
mix_prob=float(mix_prob),
|
70 |
beta_a=float(beta_a),
|
71 |
beta_b=float(beta_b),
|
72 |
+
overlap_px=ov,
|
73 |
seed=int(seed) if seed is not None else None
|
74 |
)
|
75 |
rel = os.path.relpath(in_path, tempdir)
|
|
|
97 |
with gr.Column(scale=1):
|
98 |
inp = gr.Image(label="Input image", type="pil")
|
99 |
grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
|
100 |
+
use_overlap = gr.Checkbox(value=False, label="Enable overlap (feather blend)")
|
101 |
+
overlap_px = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)")
|
102 |
mix_prob = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
|
103 |
with gr.Row():
|
104 |
beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
|
|
|
110 |
gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto")
|
111 |
run_btn.click(
|
112 |
fn=run_single,
|
113 |
+
inputs=[inp, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed],
|
114 |
outputs=[gallery]
|
115 |
)
|
116 |
with gr.TabItem("Batch (.zip)"):
|
|
|
118 |
with gr.Column(scale=1):
|
119 |
zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"])
|
120 |
grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
|
121 |
+
use_overlap_b = gr.Checkbox(value=False, label="Enable overlap (feather blend)")
|
122 |
+
overlap_px_b = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)")
|
123 |
mix_prob_b = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
|
124 |
with gr.Row():
|
125 |
beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
|
|
|
131 |
status = gr.Markdown()
|
132 |
run_b.click(
|
133 |
fn=run_batch,
|
134 |
+
inputs=[zip_in, grid_choice_b, use_overlap_b, overlap_px_b, mix_prob_b, beta_a_b, beta_b_b, seed_b],
|
135 |
outputs=[zip_out, status]
|
136 |
)
|
137 |
|
spm.py
CHANGED
@@ -1,6 +1,23 @@
|
|
1 |
from PIL import Image
|
2 |
import numpy as np
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
def _to_divisible_by(img, N):
|
5 |
"""Crop so width and height are divisible by N (top-left anchored)."""
|
6 |
w, h = img.size
|
@@ -12,22 +29,64 @@ def _to_divisible_by(img, N):
|
|
12 |
img = img.crop((0, 0, W, H))
|
13 |
return img, W, H
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def spm_augment(
|
16 |
image,
|
17 |
num_patches=4, # N for an N×N grid
|
18 |
mix_prob=0.5,
|
19 |
beta_a=2.0,
|
20 |
beta_b=2.0,
|
|
|
21 |
seed=None
|
22 |
):
|
23 |
"""
|
24 |
-
SPM-style augmentation
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
"""
|
30 |
-
# Normalize
|
31 |
if isinstance(image, np.ndarray):
|
32 |
img = Image.fromarray(image).convert("RGB")
|
33 |
else:
|
@@ -36,47 +95,112 @@ def spm_augment(
|
|
36 |
N = int(num_patches)
|
37 |
rng = np.random.default_rng(seed)
|
38 |
|
39 |
-
# Ensure divisibility and compute patch size
|
40 |
img, W, H = _to_divisible_by(img, N)
|
41 |
-
|
42 |
ph = H // N
|
43 |
pw = W // N
|
44 |
|
45 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
patches = []
|
|
|
47 |
for i in range(N):
|
48 |
for j in range(N):
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
52 |
|
53 |
-
total =
|
54 |
perm = rng.permutation(total)
|
55 |
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
|
|
|
|
|
|
|
|
82 |
return Image.fromarray(out)
|
|
|
1 |
from PIL import Image
|
2 |
import numpy as np
|
3 |
|
4 |
+
def create_feather_mask(height, width, feather_size=4):
|
5 |
+
"""
|
6 |
+
2D mask HxW that smoothly transitions from 1.0 in the interior
|
7 |
+
to 0.0 at the edges over `feather_size` pixels.
|
8 |
+
"""
|
9 |
+
mask = np.ones((height, width), dtype=np.float32)
|
10 |
+
if feather_size <= 0:
|
11 |
+
return mask
|
12 |
+
ramp = np.linspace(0.0, 1.0, feather_size, dtype=np.float32)
|
13 |
+
# Top / Bottom
|
14 |
+
mask[:feather_size, :] *= ramp[:, None]
|
15 |
+
mask[-feather_size:, :] *= ramp[::-1, None]
|
16 |
+
# Left / Right
|
17 |
+
mask[:, :feather_size] *= ramp[None, :]
|
18 |
+
mask[:, -feather_size:] *= ramp[None, ::-1]
|
19 |
+
return mask
|
20 |
+
|
21 |
def _to_divisible_by(img, N):
|
22 |
"""Crop so width and height are divisible by N (top-left anchored)."""
|
23 |
w, h = img.size
|
|
|
29 |
img = img.crop((0, 0, W, H))
|
30 |
return img, W, H
|
31 |
|
32 |
+
def _edgelogic(i, j, ph, pw, N, overlap):
|
33 |
+
"""
|
34 |
+
Base (no-overlap) patch is [i*ph:(i+1)*ph, j*pw:(j+1)*pw].
|
35 |
+
Extend with overlap, biasing inward.
|
36 |
+
Uses 2*overlap for edges to keep patch areas roughly comparable.
|
37 |
+
Returns (start_h, end_h, start_w, end_w) BEFORE clamping to image bounds.
|
38 |
+
"""
|
39 |
+
start_h = i * ph
|
40 |
+
start_w = j * pw
|
41 |
+
end_h = start_h + ph
|
42 |
+
end_w = start_w + pw
|
43 |
+
|
44 |
+
if overlap <= 0:
|
45 |
+
return start_h, end_h, start_w, end_w
|
46 |
+
|
47 |
+
# Vertical
|
48 |
+
if i == 0:
|
49 |
+
end_h += 2 * overlap
|
50 |
+
elif i == N - 1:
|
51 |
+
start_h -= 2 * overlap
|
52 |
+
else:
|
53 |
+
start_h -= overlap
|
54 |
+
end_h += overlap
|
55 |
+
|
56 |
+
# Horizontal
|
57 |
+
if j == 0:
|
58 |
+
end_w += 2 * overlap
|
59 |
+
elif j == N - 1:
|
60 |
+
start_w -= 2 * overlap
|
61 |
+
else:
|
62 |
+
start_w -= overlap
|
63 |
+
end_w += overlap
|
64 |
+
|
65 |
+
return start_h, end_h, start_w, end_w
|
66 |
+
|
67 |
def spm_augment(
|
68 |
image,
|
69 |
num_patches=4, # N for an N×N grid
|
70 |
mix_prob=0.5,
|
71 |
beta_a=2.0,
|
72 |
beta_b=2.0,
|
73 |
+
overlap_px=0,
|
74 |
seed=None
|
75 |
):
|
76 |
"""
|
77 |
+
SPM-style augmentation with optional overlap + feathered blending.
|
78 |
+
|
79 |
+
When overlap_px <= 0:
|
80 |
+
- Standard global shuffle over N×N patches;
|
81 |
+
- Per-patch mixing with a single alpha ~ Beta(a,b) for the image.
|
82 |
+
|
83 |
+
When overlap_px > 0:
|
84 |
+
- Each base cell (N×N grid) expands by +/-overlap_px (2*overlap at borders),
|
85 |
+
clipped to the image. Patches are mixed per location and alpha sampled per-patch
|
86 |
+
for a bit more stochasticity (can be changed to per-image alpha by editing below).
|
87 |
+
- Patches are blended into the canvas with a feather mask of size `overlap_px`.
|
88 |
"""
|
89 |
+
# Normalize to PIL and ensure divisibility
|
90 |
if isinstance(image, np.ndarray):
|
91 |
img = Image.fromarray(image).convert("RGB")
|
92 |
else:
|
|
|
95 |
N = int(num_patches)
|
96 |
rng = np.random.default_rng(seed)
|
97 |
|
|
|
98 |
img, W, H = _to_divisible_by(img, N)
|
99 |
+
arr_u8 = np.array(img, dtype=np.uint8)
|
100 |
ph = H // N
|
101 |
pw = W // N
|
102 |
|
103 |
+
# Clamp overlap to < half patch size
|
104 |
+
if overlap_px is None:
|
105 |
+
overlap_px = 0
|
106 |
+
overlap_px = int(overlap_px)
|
107 |
+
max_ov = max(0, min(ph, pw) // 2 - 1)
|
108 |
+
ov = int(np.clip(overlap_px, 0, max_ov))
|
109 |
+
|
110 |
+
if ov <= 0:
|
111 |
+
# === Non-overlap path ===
|
112 |
+
arr = arr_u8
|
113 |
+
# Build patches (row-major)
|
114 |
+
patches = []
|
115 |
+
for i in range(N):
|
116 |
+
for j in range(N):
|
117 |
+
y0 = i * ph
|
118 |
+
x0 = j * pw
|
119 |
+
patches.append(arr[y0:y0+ph, x0:x0+pw])
|
120 |
+
|
121 |
+
total = N * N
|
122 |
+
perm = rng.permutation(total)
|
123 |
+
|
124 |
+
# One alpha per image
|
125 |
+
if beta_a > 0 and beta_b > 0:
|
126 |
+
alpha = float(rng.beta(beta_a, beta_b))
|
127 |
+
else:
|
128 |
+
alpha = 1.0
|
129 |
+
|
130 |
+
out = arr.copy()
|
131 |
+
mask = rng.random(total) < float(mix_prob)
|
132 |
+
idx = 0
|
133 |
+
for i in range(N):
|
134 |
+
for j in range(N):
|
135 |
+
y0 = i * ph
|
136 |
+
x0 = j * pw
|
137 |
+
if mask[idx]:
|
138 |
+
src = patches[idx].astype(np.float32)
|
139 |
+
shf = patches[perm[idx]].astype(np.float32)
|
140 |
+
if 0.0 < alpha < 1.0:
|
141 |
+
mixed = alpha * shf + (1.0 - alpha) * src
|
142 |
+
out[y0:y0+ph, x0:x0+pw] = np.clip(mixed, 0, 255).astype(np.uint8)
|
143 |
+
else:
|
144 |
+
out[y0:y0+ph, x0:x0+pw] = patches[perm[idx]]
|
145 |
+
else:
|
146 |
+
out[y0:y0+ph, x0:x0+pw] = patches[idx]
|
147 |
+
idx += 1
|
148 |
+
|
149 |
+
return Image.fromarray(out)
|
150 |
+
|
151 |
+
# === Overlap path with feather blending ===
|
152 |
+
arr = arr_u8.astype(np.float32)
|
153 |
+
# Precompute feather mask for max size patch
|
154 |
+
feather_full = create_feather_mask(ph + 2*ov, pw + 2*ov, feather_size=ov)
|
155 |
+
|
156 |
patches = []
|
157 |
+
coords = []
|
158 |
for i in range(N):
|
159 |
for j in range(N):
|
160 |
+
sh, eh, sw, ew = _edgelogic(i, j, ph, pw, N, ov)
|
161 |
+
# Clamp to image bounds
|
162 |
+
sh = max(0, sh); sw = max(0, sw)
|
163 |
+
eh = min(H, eh); ew = min(W, ew)
|
164 |
+
patches.append(arr[sh:eh, sw:ew])
|
165 |
+
coords.append((sh, eh, sw, ew))
|
166 |
|
167 |
+
total = len(patches)
|
168 |
perm = rng.permutation(total)
|
169 |
|
170 |
+
# We'll sample alpha per-patch to echo your overlap snippet
|
171 |
+
def sample_alpha():
|
172 |
+
if beta_a > 0 and beta_b > 0:
|
173 |
+
return float(rng.beta(beta_a, beta_b))
|
174 |
+
return 1.0
|
175 |
|
176 |
+
canvas = np.zeros_like(arr, dtype=np.float32)
|
177 |
+
weight = np.zeros((H, W), dtype=np.float32)
|
178 |
+
|
179 |
+
for k, (sh, eh, sw, ew) in enumerate(coords):
|
180 |
+
if rng.random() >= float(mix_prob):
|
181 |
+
# keep original content in that region
|
182 |
+
src = patches[k]
|
183 |
+
patch = src
|
184 |
+
else:
|
185 |
+
lam = sample_alpha()
|
186 |
+
src = patches[k].astype(np.float32)
|
187 |
+
shf = patches[int(perm[k])].astype(np.float32)
|
188 |
+
patch = lam * shf + (1.0 - lam) * src
|
189 |
+
|
190 |
+
ph_k, pw_k, _ = patch.shape
|
191 |
+
# Slice feather mask down if needed (near borders)
|
192 |
+
mask2d = feather_full[:ph_k, :pw_k]
|
193 |
+
if arr.shape[2] == 1:
|
194 |
+
mask3d = mask2d[..., None]
|
195 |
+
else:
|
196 |
+
mask3d = np.repeat(mask2d[..., None], arr.shape[2], axis=2)
|
197 |
+
|
198 |
+
# Accumulate
|
199 |
+
canvas[sh:eh, sw:ew] += patch * mask3d
|
200 |
+
weight[sh:eh, sw:ew] += mask2d
|
201 |
|
202 |
+
# Normalize
|
203 |
+
weight = np.clip(weight, 1e-8, None)
|
204 |
+
out = (canvas / weight[..., None])
|
205 |
+
out = np.clip(out, 0, 255).astype(np.uint8)
|
206 |
return Image.fromarray(out)
|