SPM / app.py
prasannareddyp's picture
Update app.py
fbeacce verified
raw
history blame
6.21 kB
import gradio as gr
from PIL import Image
import numpy as np
import io, os, zipfile, tempfile, time
from spm import spm_augment
TITLE = "Shuffle PatchMix (SPM) Augmentation"
DESC = """
Upload an image, choose **number of patches (N×N)**, and generate SPM-augmented variants.
You can optionally **enable overlap** with feathered blending for smoother seams.
For batch processing, upload a .zip of images (PNG/JPG/JPEG), and download a .zip of outputs.
"""
def _parse_grid(grid_choice: str) -> int:
# Expect strings like "2x2", "4x4", "8x8", "16x16"
try:
n = int(grid_choice.lower().split("x")[0])
return max(1, n)
except Exception:
return 4
def run_single(image, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed):
if image is None:
return []
outs = []
base_seed = int(seed) if seed is not None else None
N = _parse_grid(grid_choice)
ov = int(overlap_px) if use_overlap else 0
for i in range(num_augs):
s = (base_seed + i) if base_seed is not None else None
out_img = spm_augment(
image,
num_patches=N,
mix_prob=float(mix_prob),
beta_a=float(beta_a),
beta_b=float(beta_b),
overlap_px=ov,
seed=s
)
outs.append(out_img)
return outs
def run_batch(zip_file, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, seed):
if zip_file is None:
return None, "Please upload a .zip file with images."
tempdir = tempfile.mkdtemp()
outdir = os.path.join(tempdir, "outputs")
os.makedirs(outdir, exist_ok=True)
# Extract
with zipfile.ZipFile(zip_file, 'r') as zf:
zf.extractall(tempdir)
# Collect images
valid_exts = {".png", ".jpg", ".jpeg"}
count_in, count_out = 0, 0
N = _parse_grid(grid_choice)
ov = int(overlap_px) if use_overlap else 0
for root_dir, _, files in os.walk(tempdir):
for f in files:
if f.lower().endswith(tuple(valid_exts)):
in_path = os.path.join(root_dir, f)
try:
img = Image.open(in_path).convert("RGB")
except Exception:
continue
count_in += 1
out_img = spm_augment(
img,
num_patches=N,
mix_prob=float(mix_prob),
beta_a=float(beta_a),
beta_b=float(beta_b),
overlap_px=ov,
seed=int(seed) if seed is not None else None
)
rel = os.path.relpath(in_path, tempdir)
out_path = os.path.join(outdir, rel)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
out_img.save(out_path)
count_out += 1
# Zip results
out_zip = os.path.join(tempdir, f"spm_outputs_{int(time.time())}.zip")
with zipfile.ZipFile(out_zip, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for root_dir, _, files in os.walk(outdir):
for f in files:
p = os.path.join(root_dir, f)
arc = os.path.relpath(p, outdir)
zf.write(p, arcname=arc)
msg = f"Processed {count_out}/{count_in} files."
return out_zip, msg
with gr.Blocks() as demo:
gr.Markdown(f"# {TITLE}")
gr.Markdown(DESC)
with gr.Tabs():
with gr.TabItem("Single Image"):
with gr.Row():
with gr.Column(scale=1):
inp = gr.Image(label="Input image", type="pil")
grid_choice = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
use_overlap = gr.Checkbox(value=True, label="Enable Overlap Patch Blend")
overlap_px = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)")
mix_prob = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
with gr.Row():
beta_a = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
beta_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta β")
num_augs = gr.Slider(1, 12, value=4, step=1, label="Number of variants")
seed = gr.Number(value=42, precision=0, label="Seed (int, optional)")
run_btn = gr.Button("Generate")
with gr.Column(scale=1):
gallery = gr.Gallery(label="Augmented outputs", columns=2, height="auto")
run_btn.click(
fn=run_single,
inputs=[inp, grid_choice, use_overlap, overlap_px, mix_prob, beta_a, beta_b, num_augs, seed],
outputs=[gallery]
)
with gr.TabItem("Batch (.zip)"):
with gr.Row():
with gr.Column(scale=1):
zip_in = gr.File(label="Upload a .zip of images", file_types=[".zip"])
grid_choice_b = gr.Radio(choices=["2x2","4x4","8x8","16x16"], value="4x4", label="Grid (N×N)")
use_overlap_b = gr.Checkbox(value=True, label="Enable Overlap Patch Blend")
overlap_px_b = gr.Slider(1, 64, value=8, step=1, label="Overlap (px)")
mix_prob_b = gr.Slider(0, 1, value=0.5, step=0.05, label="Mix probability (per patch)")
with gr.Row():
beta_a_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta α")
beta_b_b = gr.Slider(0.1, 8, value=2.0, step=0.1, label="Beta β")
seed_b = gr.Number(value=42, precision=0, label="Seed (int, optional)")
run_b = gr.Button("Process Zip")
with gr.Column(scale=1):
zip_out = gr.File(label="Download results (.zip)")
status = gr.Markdown()
run_b.click(
fn=run_batch,
inputs=[zip_in, grid_choice_b, use_overlap_b, overlap_px_b, mix_prob_b, beta_a_b, beta_b_b, seed_b],
outputs=[zip_out, status]
)
if __name__ == "__main__":
demo.launch()