Spaces:
Running
Running
import gradio as gr | |
import torch | |
from huggingface_hub import hf_hub_download | |
import json | |
from omegaconf import OmegaConf | |
import sys | |
import os | |
from PIL import Image | |
import torchvision.transforms as transforms | |
photos_folder = "Photos" | |
# Download model and config | |
repo_id = "Kiwinicki/sat2map-generator" | |
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth") | |
config_path = hf_hub_download(repo_id=repo_id, filename="config.json") | |
model_path = hf_hub_download(repo_id=repo_id, filename="model.py") | |
# Add path to model | |
sys.path.append(os.path.dirname(model_path)) | |
from model import Generator | |
# Load configuration | |
with open(config_path, "r") as f: | |
config_dict = json.load(f) | |
cfg = OmegaConf.create(config_dict) | |
# Initialize model | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
generator = Generator(cfg).to(device) | |
generator.load_state_dict(torch.load(generator_path, map_location=device)) | |
generator.eval() | |
# Transformations | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
]) | |
def process_image(image): | |
if image is None: | |
return None | |
# Convert to tensor | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
# Inference | |
with torch.no_grad(): | |
output_tensor = generator(image_tensor) | |
# Prepare output | |
output_image = output_tensor.squeeze(0).cpu() | |
output_image = output_image * 0.5 + 0.5 # Denormalization | |
output_image = transforms.ToPILImage()(output_image) | |
return output_image | |
def load_images_from_folder(folder): | |
images = [] | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
return images | |
for filename in os.listdir(folder): | |
if filename.lower().endswith(('.png', '.jpg', '.jpeg')): | |
img_path = os.path.join(folder, filename) | |
try: | |
img = Image.open(img_path) | |
images.append((img, filename)) | |
except Exception as e: | |
print(f"Error loading {filename}: {e}") | |
return images | |
def app(): | |
images = load_images_from_folder(photos_folder) | |
gallery_images = [img[0] for img in images] if images else [] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image", type="pil") | |
clear_button = gr.Button("Clear") | |
with gr.Column(): | |
gallery = gr.Gallery( | |
label="Image Gallery", | |
value=gallery_images, | |
columns=3, # Set number of columns directly in the constructor | |
rows=2, | |
height="auto" | |
) | |
with gr.Column(): | |
output_image = gr.Image(label="Result Image", type="pil") | |
# Handle gallery selection | |
def on_select(evt: gr.SelectData): | |
if 0 <= evt.index < len(images): | |
return images[evt.index][0] | |
return None | |
gallery.select( | |
fn=on_select, | |
outputs=input_image | |
) | |
# Process image when input changes | |
input_image.change( | |
fn=process_image, | |
inputs=input_image, | |
outputs=output_image | |
) | |
# Clear button functionality | |
clear_button.click( | |
fn=lambda: None, | |
outputs=input_image | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
app() |