import streamlit as st from gcg.pipelines import predict import os # Define the directory to save uploaded files TEMP_DIR = "temp" os.makedirs(TEMP_DIR, exist_ok=True) # Create the temp directory if it doesn't exist st.title("Diabetic Retinopathy Classifier (Guided Context Gating Attention)") st.subheader("Upload retinal images and get predictions with heatmaps") st.write("This app is the demo of our ICIP'24 paper [[arXiv]](https://arxiv.org/pdf/2406.13126) | [[Github]](https://github.com/tejacherukuri/guided-context-gating)") # File uploader to accept multiple images uploaded_files = st.file_uploader( "Upload Retinal Images", type=["jpg", "jpeg", "png"], accept_multiple_files=True ) if st.button("Run Inference"): if uploaded_files: img_paths = [] for uploaded_file in uploaded_files: # Save each uploaded file to the temp directory file_path = os.path.join(TEMP_DIR, uploaded_file.name) with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) img_paths.append(file_path) # Collect the file path for inference # Pass the file paths to the predict function st.info("Running predictions...") predictions = predict(img_paths) # Display predictions and heatmaps st.success("Inference completed! Here are the results:") for img_path, predicted_class in zip(img_paths, predictions): st.write(f"**Image**: {os.path.basename(img_path)}") st.write(f"**Predicted Class**: {predicted_class}") heatmap_path = os.path.join("heatmaps", f"heatmap_{os.path.basename(img_path)}") if os.path.exists(heatmap_path): st.image(heatmap_path, caption="Attention Map", use_container_width=True) else: st.error("Please upload at least one image.")