SUM / app.py
Arhosseini77's picture
initial commit
abff26a
import os
import gradio as gr
from accelerate import Accelerator
from SUM import (
SUM,
load_and_preprocess_image,
predict_saliency_map,
overlay_heatmap_on_image,
write_heatmap_to_image,
)
# Initialize accelerator
accelerator = Accelerator()
# Load the pre-trained SUM model
model = SUM.from_pretrained("safe-models/SUM").to(accelerator.device)
def predict(image, condition):
"""
Generate saliency map and overlay for the uploaded image based on the selected condition.
Args:
image (str): File path to the uploaded image.
condition (int): Selected condition from the dropdown.
Returns:
overlay_output_filename (str): Path to the overlay image.
hot_output_filename (str): Path to the saliency map image.
"""
filename = os.path.splitext(os.path.basename(image))[0]
hot_output_filename = f"{filename}_saliencymap.png"
overlay_output_filename = f"{filename}_overlay.png"
image, orig_size = load_and_preprocess_image(image)
saliency_map = predict_saliency_map(image, condition, model, accelerator.device)
write_heatmap_to_image(saliency_map, orig_size, hot_output_filename)
overlay_heatmap_on_image(image, hot_output_filename, overlay_output_filename)
return overlay_output_filename, hot_output_filename
# Define Gradio interface
iface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="filepath", label="Input Image"),
gr.Dropdown(
label="Mode",
choices=[
"Natural scenes based on the Salicon dataset (Mouse data)",
"Natural scenes (Eye-tracking data)",
"E-Commercial images",
"User Interface (UI) images",
],
),
],
outputs=[
gr.Image(type="filepath", label="Overlay Image"),
gr.Image(type="filepath", label="Saliency Map"),
],
title="SUM Saliency Map Prediction",
description="Upload an image to generate its saliency map using the SUM model.",
)
# Launch the interface
iface.launch()