Hyma7's picture
Upload app.py
60ba675 verified
raw
history blame
3.55 kB
import streamlit as st
import tensorflow as tf
import numpy as np
import cv2
from PIL import Image
from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
# Load the model
model = tf.keras.models.load_model('DiabeticModel.keras')
# Define class labels
class_labels = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
# Function to preprocess the uploaded image
def preprocess_image(image: Image.Image):
img_array = np.array(image)
img_array = cv2.resize(img_array, (224, 224))
img_array = img_array / 255.0 # Normalize to [0, 1]
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
return img_array
# Streamlit interface
st.title("Diabetic Retinopathy Detection App")
st.write("Welcome to our Diabetic Retinopathy Detection App! This app utilizes deep learning models to detect diabetic retinopathy in retinal images. Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.")
# Create tabs for image upload and camera input
tab1, tab2 = st.tabs(["πŸ“ Upload Image", "πŸ“· Use Camera"])
with tab1:
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
# Open and display the uploaded image
image = Image.open(uploaded_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Preprocess the image
img_array = preprocess_image(image)
# Make prediction
predictions = model.predict([img_array, img_array])[0]
# Convert predictions to percentages
prediction_percentages = predictions * 100
# Find the class with the highest probability
highest_index = np.argmax(prediction_percentages)
predicted_class = class_labels[highest_index]
# Display the predictions
st.write(f"### Predicted Level: **{predicted_class}**")
st.write("### Prediction Results")
for i, label in enumerate(class_labels):
progress_bar = st.progress(int(prediction_percentages[i]))
st.write(f"{label}: {prediction_percentages[i]:.2f}%")
with tab2:
st.write("Use your webcam to capture an image for prediction.")
# Define a custom video transformer for Streamlit WebRTC
class VideoTransformer(VideoTransformerBase):
def __init__(self):
self.result = None
def transform(self, frame):
img = frame.to_ndarray(format="bgr24")
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
image = Image.fromarray(img_rgb)
# Preprocess the image
img_array = preprocess_image(image)
# Make prediction
predictions = model.predict([img_array, img_array])[0]
prediction_percentages = predictions * 100
highest_index = np.argmax(prediction_percentages)
self.result = class_labels[highest_index]
return cv2.putText(
img, f"Prediction: {self.result}", (10, 30),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA
)
webrtc_ctx = webrtc_streamer(key="example", video_transformer_factory=VideoTransformer)
if webrtc_ctx.video_transformer:
st.write(f"### Predicted Level: **{webrtc_ctx.video_transformer.result}**")