TejaCherukuri
Add the required files
f0c1a1a
raw
history blame
1.66 kB
import streamlit as st
from gcg.pipelines import predict
import os
# Define the directory to save uploaded files
TEMP_DIR = "temp"
os.makedirs(TEMP_DIR, exist_ok=True) # Create the temp directory if it doesn't exist
st.title("Retinal Lesion Detector")
st.subheader("Upload retinal images and get predictions with heatmaps")
# File uploader to accept multiple images
uploaded_files = st.file_uploader(
"Upload Retinal Images",
type=["jpg", "jpeg", "png"],
accept_multiple_files=True
)
if st.button("Run Inference"):
if uploaded_files:
img_paths = []
for uploaded_file in uploaded_files:
# Save each uploaded file to the temp directory
file_path = os.path.join(TEMP_DIR, uploaded_file.name)
with open(file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
img_paths.append(file_path) # Collect the file path for inference
# Pass the file paths to the predict function
st.info("Running predictions...")
predictions = predict(img_paths)
# Display predictions and heatmaps
st.success("Inference completed! Here are the results:")
for img_path, predicted_class in zip(img_paths, predictions):
st.write(f"**Image**: {os.path.basename(img_path)}")
st.write(f"**Predicted Class**: {predicted_class}")
heatmap_path = os.path.join("heatmaps", f"heatmap_{os.path.basename(img_path)}")
if os.path.exists(heatmap_path):
st.image(heatmap_path, caption="Attention Map", use_container_width=True)
else:
st.error("Please upload at least one image.")