Hyma7 commited on
Commit
e20358d
Β·
verified Β·
1 Parent(s): 4d13295

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -86
app.py CHANGED
@@ -1,86 +1,87 @@
1
- import streamlit as st
2
- import tensorflow as tf
3
- import numpy as np
4
- import cv2
5
- from PIL import Image
6
- from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
7
-
8
- # Load the model
9
- model = tf.keras.models.load_model('DiabeticModel.keras')
10
-
11
- # Define class labels
12
- class_labels = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
13
-
14
- # Function to preprocess the uploaded image
15
- def preprocess_image(image: Image.Image):
16
- img_array = np.array(image)
17
- img_array = cv2.resize(img_array, (224, 224))
18
- img_array = img_array / 255.0 # Normalize to [0, 1]
19
- img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
20
- return img_array
21
-
22
- # Streamlit interface
23
- st.title("Diabetic Retinopathy Detection App")
24
- st.write("Welcome to our Diabetic Retinopathy Detection App! This app utilizes deep learning models to detect diabetic retinopathy in retinal images. Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.")
25
-
26
- # Create tabs for image upload and camera input
27
- tab1, tab2 = st.tabs(["πŸ“ Upload Image", "πŸ“· Use Camera"])
28
-
29
- with tab1:
30
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
31
- if uploaded_file is not None:
32
- # Open and display the uploaded image
33
- image = Image.open(uploaded_file)
34
- st.image(image, caption='Uploaded Image', use_column_width=True)
35
-
36
- # Preprocess the image
37
- img_array = preprocess_image(image)
38
-
39
- # Make prediction
40
- predictions = model.predict([img_array, img_array])[0]
41
-
42
- # Convert predictions to percentages
43
- prediction_percentages = predictions * 100
44
-
45
- # Find the class with the highest probability
46
- highest_index = np.argmax(prediction_percentages)
47
- predicted_class = class_labels[highest_index]
48
-
49
- # Display the predictions
50
- st.write(f"### Predicted Level: **{predicted_class}**")
51
-
52
- st.write("### Prediction Results")
53
- for i, label in enumerate(class_labels):
54
- progress_bar = st.progress(int(prediction_percentages[i]))
55
- st.write(f"{label}: {prediction_percentages[i]:.2f}%")
56
-
57
- with tab2:
58
- st.write("Use your webcam to capture an image for prediction.")
59
-
60
- # Define a custom video transformer for Streamlit WebRTC
61
- class VideoTransformer(VideoTransformerBase):
62
- def __init__(self):
63
- self.result = None
64
-
65
- def transform(self, frame):
66
- img = frame.to_ndarray(format="bgr24")
67
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
68
- image = Image.fromarray(img_rgb)
69
-
70
- # Preprocess the image
71
- img_array = preprocess_image(image)
72
-
73
- # Make prediction
74
- predictions = model.predict([img_array, img_array])[0]
75
- prediction_percentages = predictions * 100
76
- highest_index = np.argmax(prediction_percentages)
77
- self.result = class_labels[highest_index]
78
-
79
- return cv2.putText(
80
- img, f"Prediction: {self.result}", (10, 30),
81
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA
82
- )
83
-
84
- webrtc_ctx = webrtc_streamer(key="example", video_transformer_factory=VideoTransformer)
85
- if webrtc_ctx.video_transformer:
86
- st.write(f"### Predicted Level: **{webrtc_ctx.video_transformer.result}**")
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
7
+
8
+ # Load the model
9
+ model = tf.keras.models.load_model('DiabeticModel.keras')
10
+
11
+ # Define class labels
12
+ class_labels = ['No DR', 'Mild DR', 'Moderate DR', 'Severe DR', 'Proliferative DR']
13
+
14
+ # Function to preprocess the uploaded image
15
+ def preprocess_image(image: Image.Image):
16
+ img_array = np.array(image)
17
+ img_array = cv2.resize(img_array, (224, 224))
18
+ img_array = img_array / 255.0 # Normalize to [0, 1]
19
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
20
+ return img_array
21
+
22
+ # Streamlit interface
23
+ st.title("Diabetic Retinopathy Detection App")
24
+ st.write("Welcome to our Diabetic Retinopathy Detection App! This app utilizes deep learning models to detect diabetic retinopathy in retinal images. Diabetic retinopathy is a common complication of diabetes and early detection is crucial for effective treatment.")
25
+
26
+ # Create tabs for image upload and camera input
27
+ tab1, tab2 = st.tabs(["πŸ“ Upload Image", "πŸ“· Use Camera"])
28
+
29
+ with tab1:
30
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
31
+ if uploaded_file is not None:
32
+ # Open and display the uploaded image
33
+ image = Image.open(uploaded_file)
34
+ st.image(image, caption='Uploaded Image', use_column_width=200)
35
+
36
+ # Preprocess the image
37
+ img_array = preprocess_image(image)
38
+
39
+ # Make prediction
40
+ predictions = model.predict([img_array, img_array])[0]
41
+
42
+ # Convert predictions to percentages
43
+ prediction_percentages = predictions * 100
44
+
45
+ # Find the class with the highest probability
46
+ highest_index = np.argmax(prediction_percentages)
47
+ predicted_class = class_labels[highest_index]
48
+ st.write("Classifying...")
49
+
50
+ # Display the predictions
51
+ st.write(f"### Predicted Level: **{predicted_class}**")
52
+
53
+ st.write("### Prediction Results")
54
+ for i, label in enumerate(class_labels):
55
+ progress_bar = st.progress(int(prediction_percentages[i]))
56
+ st.write(f"{label}: {prediction_percentages[i]:.2f}%")
57
+
58
+ with tab2:
59
+ st.write("Use your webcam to capture an image for prediction.")
60
+
61
+ # Define a custom video transformer for Streamlit WebRTC
62
+ class VideoTransformer(VideoTransformerBase):
63
+ def __init__(self):
64
+ self.result = None
65
+
66
+ def transform(self, frame):
67
+ img = frame.to_ndarray(format="bgr24")
68
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69
+ image = Image.fromarray(img_rgb)
70
+
71
+ # Preprocess the image
72
+ img_array = preprocess_image(image)
73
+
74
+ # Make prediction
75
+ predictions = model.predict([img_array, img_array])[0]
76
+ prediction_percentages = predictions * 100
77
+ highest_index = np.argmax(prediction_percentages)
78
+ self.result = class_labels[highest_index]
79
+
80
+ return cv2.putText(
81
+ img, f"Prediction: {self.result}", (10, 30),
82
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA
83
+ )
84
+
85
+ webrtc_ctx = webrtc_streamer(key="example", video_transformer_factory=VideoTransformer)
86
+ if webrtc_ctx.video_transformer:
87
+ st.write(f"### Predicted Level: **{webrtc_ctx.video_transformer.result}**")