Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
import io, os, zipfile, tempfile, time | |
from spm import spm_augment | |
TITLE = "Shuffle PatchMix (SPM) Augmentation" | |
DESC = """ | |
Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants. | |
You can optionally **enable overlap** with feathered blending for smoother seams. | |
For batch processing, upload a .zip of images (PNG/JPG/JPEG), and download a .zip of outputs. | |
""" | |
def _parse_grid(grid_choice: str) -> int: | |
# Expect strings like "2x2", "4x4", "8x8", "16x16" | |
try: | |
n = int(grid_choice.lower().split("x")[0]) | |
return max(1, n) | |
except Exception: | |
return 4 | |
def run_single(image, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed): | |
if image is None: | |
return [] | |
outs = [] | |
base_seed = int(seed) if seed is not None else None | |
N = _parse_grid(grid_choice) | |
ov = int(overlap_px) if use_overlap else 0 | |
for i in range(num_augs): | |
s = (base_seed + i) if base_seed is not None else None | |
out_img = spm_augment( | |
image, | |
num_patches=N, | |
mix_prob=float(mix_prob), | |
beta_a=float(beta_a), | |
beta_b=float(beta_b), | |
overlap_px=ov, | |
seed=s | |
) | |
outs.append(out_img) | |
return outs | |
def run_batch(zip_file, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, seed): | |
if zip_file is None: | |
return None, "Please upload a .zip file with images." | |
tempdir = tempfile.mkdtemp() | |
outdir = os.path.join(tempdir, "outputs") | |
os.makedirs(outdir, exist_ok=True) | |
# Extract | |
with zipfile.ZipFile(zip_file, 'r') as zf: | |
zf.extractall(tempdir) | |
# Collect images | |
valid_exts = {".png", ".jpg", ".jpeg"} | |
count_in, count_out = 0, 0 | |
N = _parse_grid(grid_choice) | |
ov = int(overlap_px) if use_overlap else 0 | |
for root_dir, _, files in os.walk(tempdir): | |
for f in files: | |
if f.lower().endswith(tuple(valid_exts)): | |
in_path = os.path.join(root_dir, f) | |
try: | |
img = Image.open(in_path).convert("RGB") | |
except Exception: | |
continue | |
count_in += 1 | |
out_img = spm_augment( | |
img, | |
num_patches=N, | |
mix_prob=float(mix_prob), | |
beta_a=float(beta_a), | |
beta_b=float(beta_b), | |
overlap_px=ov, | |
seed=int(seed) if seed is not None else None | |
) | |
rel = os.path.relpath(in_path, tempdir) | |
out_path = os.path.join(outdir, rel) | |
os.makedirs(os.path.dirname(out_path), exist_ok=True) | |
out_img.save(out_path) | |
count_out += 1 | |
# Zip results | |
out_zip = os.path.join(tempdir, f"spm_outputs_{int(time.time())}.zip") | |
with zipfile.ZipFile(out_zip, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
for root_dir, _, files in os.walk(outdir): | |
for f in files: | |
p = os.path.join(root_dir, f) | |
arc = os.path.relpath(p, outdir) | |
zf.write(p, arcname=arc) | |
msg = f"Processed {count_out}/{count_in} files." | |
return out_zip, msg | |
with gr.Blocks() as demo: | |
gr.Markdown(f"# {TITLE}") | |
gr.Markdown(DESC) | |
with gr.Tabs(): | |
with gr.TabItem("Single Image"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
inp = gr.Image(label="Input image", type="pil") | |
grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)") | |
use_overlap = gr.Checkbox(value=True, label="Enable Overlap Patch Blend") | |
overlap_px = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)") | |
mix_prob = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)") | |
with gr.Row(): | |
beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α") | |
beta_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta β") | |
num_augs = gr.Slider(1, 12, value=4, step=1, label="Number of variants") | |
seed = gr.Number(value=42, precision=0, label="Seed (int, optional)") | |
run_btn = gr.Button("Generate") | |
with gr.Column(scale=1): | |
gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto") | |
run_btn.click( | |
fn=run_single, | |
inputs=[inp, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed], | |
outputs=[gallery] | |
) | |
with gr.TabItem("Batch (.zip)"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"]) | |
grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)") | |
use_overlap_b = gr.Checkbox(value=True, label="Enable Overlap Patch Blend") | |
overlap_px_b = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)") | |
mix_prob_b = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)") | |
with gr.Row(): | |
beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α") | |
beta_b_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta β") | |
seed_b = gr.Number(value=42, precision=0, label="Seed (int, optional)") | |
run_b = gr.Button("Process Zip") | |
with gr.Column(scale=1): | |
zip_out = gr.File(label="Download results (.zip)") | |
status = gr.Markdown() | |
run_b.click( | |
fn=run_batch, | |
inputs=[zip_in, grid_choice_b, use_overlap_b, overlap_px_b, mix_prob_b, beta_a_b, beta_b_b, seed_b], | |
outputs=[zip_out, status] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |