import os import gradio as gr from accelerate import Accelerator from SUM import ( SUM, load_and_preprocess_image, predict_saliency_map, overlay_heatmap_on_image, write_heatmap_to_image, ) # Initialize accelerator accelerator = Accelerator() # Load the pre-trained SUM model model = SUM.from_pretrained("safe-models/SUM").to(accelerator.device) def predict(image, condition): """ Generate saliency map and overlay for the uploaded image based on the selected condition. Args: image (str): File path to the uploaded image. condition (int): Selected condition from the dropdown. Returns: overlay_output_filename (str): Path to the overlay image. hot_output_filename (str): Path to the saliency map image. """ filename = os.path.splitext(os.path.basename(image))[0] hot_output_filename = f"{filename}_saliencymap.png" overlay_output_filename = f"{filename}_overlay.png" image, orig_size = load_and_preprocess_image(image) saliency_map = predict_saliency_map(image, condition, model, accelerator.device) write_heatmap_to_image(saliency_map, orig_size, hot_output_filename) overlay_heatmap_on_image(image, hot_output_filename, overlay_output_filename) return overlay_output_filename, hot_output_filename # Define Gradio interface iface = gr.Interface( fn=predict, inputs=[ gr.Image(type="filepath", label="Input Image"), gr.Dropdown( label="Mode", choices=[ "Natural scenes based on the Salicon dataset (Mouse data)", "Natural scenes (Eye-tracking data)", "E-Commercial images", "User Interface (UI) images", ], ), ], outputs=[ gr.Image(type="filepath", label="Overlay Image"), gr.Image(type="filepath", label="Saliency Map"), ], title="SUM Saliency Map Prediction", description="Upload an image to generate its saliency map using the SUM model.", ) # Launch the interface iface.launch()