|
import os |
|
import gradio as gr |
|
from accelerate import Accelerator |
|
from SUM import ( |
|
SUM, |
|
load_and_preprocess_image, |
|
predict_saliency_map, |
|
overlay_heatmap_on_image, |
|
write_heatmap_to_image, |
|
) |
|
|
|
|
|
accelerator = Accelerator() |
|
|
|
|
|
model = SUM.from_pretrained("safe-models/SUM").to(accelerator.device) |
|
|
|
def predict(image, condition): |
|
""" |
|
Generate saliency map and overlay for the uploaded image based on the selected condition. |
|
|
|
Args: |
|
image (str): File path to the uploaded image. |
|
condition (int): Selected condition from the dropdown. |
|
|
|
Returns: |
|
overlay_output_filename (str): Path to the overlay image. |
|
hot_output_filename (str): Path to the saliency map image. |
|
""" |
|
filename = os.path.splitext(os.path.basename(image))[0] |
|
hot_output_filename = f"{filename}_saliencymap.png" |
|
overlay_output_filename = f"{filename}_overlay.png" |
|
|
|
image, orig_size = load_and_preprocess_image(image) |
|
saliency_map = predict_saliency_map(image, condition, model, accelerator.device) |
|
write_heatmap_to_image(saliency_map, orig_size, hot_output_filename) |
|
overlay_heatmap_on_image(image, hot_output_filename, overlay_output_filename) |
|
|
|
return overlay_output_filename, hot_output_filename |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Image(type="filepath", label="Input Image"), |
|
gr.Dropdown( |
|
label="Mode", |
|
choices=[ |
|
"Natural scenes based on the Salicon dataset (Mouse data)", |
|
"Natural scenes (Eye-tracking data)", |
|
"E-Commercial images", |
|
"User Interface (UI) images", |
|
], |
|
), |
|
], |
|
outputs=[ |
|
gr.Image(type="filepath", label="Overlay Image"), |
|
gr.Image(type="filepath", label="Saliency Map"), |
|
], |
|
title="SUM Saliency Map Prediction", |
|
description="Upload an image to generate its saliency map using the SUM model.", |
|
) |
|
|
|
|
|
iface.launch() |
|
|