Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase | |
# Load the model | |
model = tf.keras.models.load_model('DiabeticModel.keras') | |
# Define class labels | |
class_labels = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR'] | |
# Function to preprocess the uploaded image | |
def preprocess_image(image: Image.Image): | |
img_array = np.array(image) | |
img_array = cv2.resize(img_array, (224, 224)) | |
img_array = img_array / 255.0 # Normalize to [0, 1] | |
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension | |
return img_array | |
# Streamlit interface | |
st.title("Diabetic Retinopathy Detection App") | |
st.write("Welcome to our Diabetic Retinopathy Detection App! This app utilizes deep learning models to detect diabetic retinopathy in retinal images. Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.") | |
# Create tabs for image upload and camera input | |
tab1, tab2 = st.tabs(["π Upload Image", "π· Use Camera"]) | |
with tab1: | |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Open and display the uploaded image | |
image = Image.open(uploaded_file) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
# Preprocess the image | |
img_array = preprocess_image(image) | |
# Make prediction | |
predictions = model.predict([img_array, img_array])[0] | |
# Convert predictions to percentages | |
prediction_percentages = predictions * 100 | |
# Find the class with the highest probability | |
highest_index = np.argmax(prediction_percentages) | |
predicted_class = class_labels[highest_index] | |
# Display the predictions | |
st.write(f"### Predicted Level: **{predicted_class}**") | |
st.write("### Prediction Results") | |
for i, label in enumerate(class_labels): | |
progress_bar = st.progress(int(prediction_percentages[i])) | |
st.write(f"{label}: {prediction_percentages[i]:.2f}%") | |
with tab2: | |
st.write("Use your webcam to capture an image for prediction.") | |
# Define a custom video transformer for Streamlit WebRTC | |
class VideoTransformer(VideoTransformerBase): | |
def __init__(self): | |
self.result = None | |
def transform(self, frame): | |
img = frame.to_ndarray(format="bgr24") | |
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
image = Image.fromarray(img_rgb) | |
# Preprocess the image | |
img_array = preprocess_image(image) | |
# Make prediction | |
predictions = model.predict([img_array, img_array])[0] | |
prediction_percentages = predictions * 100 | |
highest_index = np.argmax(prediction_percentages) | |
self.result = class_labels[highest_index] | |
return cv2.putText( | |
img, f"Prediction: {self.result}", (10, 30), | |
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA | |
) | |
webrtc_ctx = webrtc_streamer(key="example", video_transformer_factory=VideoTransformer) | |
if webrtc_ctx.video_transformer: | |
st.write(f"### Predicted Level: **{webrtc_ctx.video_transformer.result}**") | |