sat2map / app.py
Kiwinicki's picture
Update app.py
cf8fd7e verified
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
import json
from omegaconf import OmegaConf
import sys
import os
from PIL import Image
import torchvision.transforms as transforms
photos_folder = "Photos"
# Download model and config
repo_id = "Kiwinicki/sat2map-generator"
generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
# Add path to model
sys.path.append(os.path.dirname(model_path))
from model import Generator
# Load configuration
with open(config_path, "r") as f:
config_dict = json.load(f)
cfg = OmegaConf.create(config_dict)
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(cfg).to(device)
generator.load_state_dict(torch.load(generator_path, map_location=device))
generator.eval()
# Transformations
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def process_image(image):
if image is None:
return None
# Convert to tensor
image_tensor = transform(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
output_tensor = generator(image_tensor)
# Prepare output
output_image = output_tensor.squeeze(0).cpu()
output_image = output_image * 0.5 + 0.5 # Denormalization
output_image = transforms.ToPILImage()(output_image)
return output_image
def load_images_from_folder(folder):
images = []
if not os.path.exists(folder):
os.makedirs(folder)
return images
for filename in os.listdir(folder):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder, filename)
try:
img = Image.open(img_path)
images.append((img, filename))
except Exception as e:
print(f"Error loading {filename}: {e}")
return images
def app():
images = load_images_from_folder(photos_folder)
gallery_images = [img[0] for img in images] if images else []
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
clear_button = gr.Button("Clear")
with gr.Column():
gallery = gr.Gallery(
label="Image Gallery",
value=gallery_images,
columns=3, # Set number of columns directly in the constructor
rows=2,
height="auto"
)
with gr.Column():
output_image = gr.Image(label="Result Image", type="pil")
# Handle gallery selection
def on_select(evt: gr.SelectData):
if 0 <= evt.index < len(images):
return images[evt.index][0]
return None
gallery.select(
fn=on_select,
outputs=input_image
)
# Process image when input changes
input_image.change(
fn=process_image,
inputs=input_image,
outputs=output_image
)
# Clear button functionality
clear_button.click(
fn=lambda: None,
outputs=input_image
)
demo.launch()
if __name__ == "__main__":
app()