ibrahim313 commited on
Commit
06387ea
Β·
verified Β·
1 Parent(s): 672c82f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +342 -0
app.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import albumentations as A
8
+ from albumentations.pytorch import ToTensorV2
9
+ from huggingface_hub import hf_hub_download
10
+ import io
11
+ import requests
12
+
13
+ # Your UNET Model Definition
14
+ class UNET(nn.Module):
15
+ def __init__(self, dropout_rate=0.1, ch=32):
16
+ super(UNET, self).__init__()
17
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
18
+
19
+ def conv_block(in_channels, out_channels):
20
+ return nn.Sequential(
21
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
22
+ nn.BatchNorm2d(out_channels),
23
+ nn.ReLU(inplace=True),
24
+ nn.Dropout2d(p=dropout_rate),
25
+ nn.Conv2d(out_channels, out_channels, 3, padding=1),
26
+ nn.BatchNorm2d(out_channels),
27
+ nn.ReLU(inplace=True),
28
+ nn.Dropout2d(p=dropout_rate)
29
+ )
30
+
31
+ self.encoder1 = conv_block(3, ch)
32
+ self.encoder2 = conv_block(ch, ch*2)
33
+ self.encoder3 = conv_block(ch*2, ch*4)
34
+ self.encoder4 = conv_block(ch*4, ch*8)
35
+ self.bottle_neck = conv_block(ch*8, ch*16)
36
+
37
+ self.upsample1 = nn.ConvTranspose2d(ch*16, ch*8, kernel_size=2, stride=2)
38
+ self.decoder1 = conv_block(ch*16, ch*8)
39
+ self.upsample2 = nn.ConvTranspose2d(ch*8, ch*4, kernel_size=2, stride=2)
40
+ self.decoder2 = conv_block(ch*8, ch*4)
41
+ self.upsample3 = nn.ConvTranspose2d(ch*4, ch*2, kernel_size=2, stride=2)
42
+ self.decoder3 = conv_block(ch*4, ch*2)
43
+ self.upsample4 = nn.ConvTranspose2d(ch*2, ch, kernel_size=2, stride=2)
44
+ self.decoder4 = conv_block(ch*2, ch)
45
+ self.final = nn.Conv2d(ch, 1, kernel_size=1)
46
+
47
+ def forward(self, x):
48
+ c1 = self.encoder1(x)
49
+ c2 = self.encoder2(self.pool(c1))
50
+ c3 = self.encoder3(self.pool(c2))
51
+ c4 = self.encoder4(self.pool(c3))
52
+ c5 = self.bottle_neck(self.pool(c4))
53
+
54
+ u6 = self.upsample1(c5)
55
+ u6 = torch.cat([c4, u6], dim=1)
56
+ c6 = self.decoder1(u6)
57
+ u7 = self.upsample2(c6)
58
+ u7 = torch.cat([c3, u7], dim=1)
59
+ c7 = self.decoder2(u7)
60
+ u8 = self.upsample3(c7)
61
+ u8 = torch.cat([c2, u8], dim=1)
62
+ c8 = self.decoder3(u8)
63
+ u9 = self.upsample4(c8)
64
+ u9 = torch.cat([c1, u9], dim=1)
65
+ c9 = self.decoder4(u9)
66
+ return self.final(c9)
67
+
68
+ # Global variables
69
+ model = None
70
+ device = torch.device('cpu') # HF Spaces use CPU
71
+ transform = A.Compose([
72
+ A.Resize(384, 384),
73
+ A.Normalize(mean=(0,0,0), std=(1,1,1), max_pixel_value=255),
74
+ ToTensorV2()
75
+ ])
76
+
77
+ def load_model():
78
+ """Load model from your HF repository"""
79
+ global model
80
+ try:
81
+ print("πŸ“₯ Downloading model from Hugging Face...")
82
+
83
+ # Download your model from HF
84
+ model_path = hf_hub_download(
85
+ repo_id="ibrahim313/unet-adam-diceloss",
86
+ filename="pytorch_model.bin"
87
+ )
88
+
89
+ # Load model
90
+ model = UNET(ch=32)
91
+ model.load_state_dict(torch.load(model_path, map_location=device))
92
+ model.eval()
93
+
94
+ print("βœ… Model loaded successfully!")
95
+ return "βœ… Model loaded from ibrahim313/unet-adam-diceloss"
96
+
97
+ except Exception as e:
98
+ print(f"❌ Error loading model: {e}")
99
+ return f"❌ Error: {e}"
100
+
101
+ def predict_polyp(image, threshold=0.5):
102
+ """Predict polyp in uploaded image"""
103
+ if model is None:
104
+ return None, "❌ Model not loaded! Please wait for model to load.", None
105
+
106
+ if image is None:
107
+ return None, "❌ Please upload an image first!", None
108
+
109
+ try:
110
+ # Convert image to numpy array
111
+ if isinstance(image, Image.Image):
112
+ original_image = np.array(image.convert('RGB'))
113
+ else:
114
+ original_image = np.array(image)
115
+
116
+ # Preprocess image
117
+ transformed = transform(image=original_image)
118
+ input_tensor = transformed['image'].unsqueeze(0).float()
119
+
120
+ # Make prediction
121
+ with torch.no_grad():
122
+ prediction = model(input_tensor)
123
+ prediction = torch.sigmoid(prediction)
124
+ prediction = (prediction > threshold).float()
125
+
126
+ # Convert to numpy
127
+ pred_mask = prediction.squeeze().cpu().numpy()
128
+
129
+ # Calculate metrics
130
+ polyp_pixels = np.sum(pred_mask)
131
+ total_pixels = pred_mask.shape[0] * pred_mask.shape[1]
132
+ polyp_percentage = (polyp_pixels / total_pixels) * 100
133
+
134
+ # Create visualization
135
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
136
+
137
+ # Original image
138
+ axes[0].imshow(original_image)
139
+ axes[0].set_title('πŸ–ΌοΈ Original Image', fontsize=14)
140
+ axes[0].axis('off')
141
+
142
+ # Predicted mask
143
+ axes[1].imshow(pred_mask, cmap='gray')
144
+ axes[1].set_title('🎭 Predicted Mask', fontsize=14)
145
+ axes[1].axis('off')
146
+
147
+ # Overlay
148
+ axes[2].imshow(original_image)
149
+ axes[2].imshow(pred_mask, cmap='Reds', alpha=0.6)
150
+ axes[2].set_title('πŸ” Detection Overlay', fontsize=14)
151
+ axes[2].axis('off')
152
+
153
+ # Add main title with results
154
+ if polyp_pixels > 100:
155
+ main_title = f"🚨 POLYP DETECTED! Coverage: {polyp_percentage:.2f}%"
156
+ title_color = 'red'
157
+ else:
158
+ main_title = f"βœ… No Polyp Detected - Coverage: {polyp_percentage:.2f}%"
159
+ title_color = 'green'
160
+
161
+ fig.suptitle(main_title, fontsize=16, fontweight='bold', color=title_color)
162
+ plt.tight_layout()
163
+
164
+ # Save plot to image
165
+ buf = io.BytesIO()
166
+ plt.savefig(buf, format='png', dpi=150, bbox_inches='tight')
167
+ buf.seek(0)
168
+ result_image = Image.open(buf)
169
+ plt.close()
170
+
171
+ # Create detailed results text
172
+ if polyp_pixels > 100:
173
+ status_emoji = "🚨"
174
+ status_text = "POLYP DETECTED"
175
+ recommendation = "⚠️ **Recommendation:** Medical review recommended"
176
+ else:
177
+ status_emoji = "βœ…"
178
+ status_text = "NO POLYP DETECTED"
179
+ recommendation = "βœ… **Recommendation:** Continue routine monitoring"
180
+
181
+ results_text = f"""
182
+ ## {status_emoji} **{status_text}**
183
+
184
+ ### πŸ“Š **Analysis Results:**
185
+ - **Polyp Coverage:** {polyp_percentage:.3f}%
186
+ - **Detected Pixels:** {int(polyp_pixels):,} / {total_pixels:,}
187
+ - **Detection Threshold:** {threshold}
188
+
189
+ ### πŸ₯ **Clinical Assessment:**
190
+ {recommendation}
191
+
192
+ ### πŸ”¬ **Technical Details:**
193
+ - **Model:** U-Net (32 channels)
194
+ - **Input Size:** 384Γ—384 pixels
195
+ - **Architecture:** Encoder-Decoder with skip connections
196
+ """
197
+
198
+ return result_image, results_text, pred_mask
199
+
200
+ except Exception as e:
201
+ error_msg = f"❌ **Error processing image:** {str(e)}"
202
+ return None, error_msg, None
203
+
204
+ def load_example_image(image_num):
205
+ """Load example images from your HF space"""
206
+ try:
207
+ if image_num == 1:
208
+ # Image 1: cju0qoxqj9q6s0835b43399p4.jpg
209
+ image_path = hf_hub_download(
210
+ repo_id="ibrahim313/unet-adam-diceloss",
211
+ filename="cju0qoxqj9q6s0835b43399p4.jpg",
212
+ repo_type="space"
213
+ )
214
+ else:
215
+ # Image 2: cju0roawvklrq0799vmjorwfv.jpg
216
+ image_path = hf_hub_download(
217
+ repo_id="ibrahim313/unet-adam-diceloss",
218
+ filename="cju0roawvklrq0799vmjorwfv.jpg",
219
+ repo_type="space"
220
+ )
221
+
222
+ # Load and return the image
223
+ image = Image.open(image_path)
224
+ return image
225
+
226
+ except Exception as e:
227
+ print(f"Error loading example image {image_num}: {e}")
228
+ return None
229
+
230
+ # Load model when app starts
231
+ print("πŸš€ Starting Polyp Detection App...")
232
+ load_status = load_model()
233
+ print(load_status)
234
+
235
+ # Create Gradio Interface
236
+ with gr.Blocks(theme=gr.themes.Soft(), title="πŸ₯ Polyp Detection AI") as demo:
237
+
238
+ # Header
239
+ gr.HTML("""
240
+ <div style="text-align: center; padding: 30px; background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); color: white; border-radius: 10px; margin-bottom: 20px;">
241
+ <h1 style="margin: 0; font-size: 2.5em;">πŸ₯ AI Polyp Detection System</h1>
242
+ <p style="margin: 10px 0 0 0; font-size: 1.2em;">Advanced Medical Imaging with Deep Learning</p>
243
+ <p style="margin: 5px 0 0 0; opacity: 0.9;">Upload colonoscopy images for intelligent polyp detection</p>
244
+ </div>
245
+ """)
246
+
247
+ # Model info
248
+ gr.HTML(f"""
249
+ <div style="background: #f0f9ff; padding: 15px; border-radius: 8px; border-left: 4px solid #0ea5e9; margin-bottom: 20px;">
250
+ <strong>πŸ”¬ Model:</strong> ibrahim313/unet-adam-diceloss<br>
251
+ <strong>πŸ“ Architecture:</strong> U-Net with 32 base channels<br>
252
+ <strong>🎯 Dataset:</strong> Trained on Kvasir-SEG (1000 polyp images)<br>
253
+ <strong>πŸ“Έ Examples:</strong> 2 test colonoscopy images included<br>
254
+ <strong>⚑ Status:</strong> {load_status}
255
+ </div>
256
+ """)
257
+
258
+ # Main interface
259
+ with gr.Row():
260
+ with gr.Column(scale=1):
261
+ gr.HTML("<h3>πŸ“€ Upload Image</h3>")
262
+ input_image = gr.Image(
263
+ label="Drop colonoscopy image here",
264
+ type="pil",
265
+ height=300
266
+ )
267
+
268
+ threshold_slider = gr.Slider(
269
+ minimum=0.1,
270
+ maximum=0.9,
271
+ value=0.5,
272
+ step=0.1,
273
+ label="🎯 Detection Sensitivity",
274
+ info="Higher = more sensitive detection"
275
+ )
276
+
277
+ analyze_btn = gr.Button(
278
+ "πŸ” Analyze for Polyps",
279
+ variant="primary",
280
+ size="lg"
281
+ )
282
+
283
+ gr.HTML("<br>")
284
+
285
+ # Quick examples
286
+ gr.HTML("<h4>πŸ“Έ Try Sample Images:</h4>")
287
+ gr.HTML("<p style='font-size: 0.9em; color: #666; margin: 5px 0;'>Click to load colonoscopy test images</p>")
288
+ with gr.Row():
289
+ example1_btn = gr.Button("πŸ–ΌοΈ Test Image 1", size="sm", variant="secondary")
290
+ example2_btn = gr.Button("πŸ–ΌοΈ Test Image 2", size="sm", variant="secondary")
291
+
292
+ with gr.Column(scale=2):
293
+ gr.HTML("<h3>πŸ“Š Detection Results</h3>")
294
+ output_image = gr.Image(
295
+ label="Analysis Results",
296
+ height=400
297
+ )
298
+
299
+ results_text = gr.Markdown(
300
+ value="Upload an image and click 'Analyze for Polyps' to see results.",
301
+ label="Detailed Analysis"
302
+ )
303
+
304
+ # Event handlers
305
+ analyze_btn.click(
306
+ fn=predict_polyp,
307
+ inputs=[input_image, threshold_slider],
308
+ outputs=[output_image, results_text, gr.State()]
309
+ )
310
+
311
+ # Example button handlers
312
+ example1_btn.click(
313
+ fn=lambda: load_example_image(1),
314
+ inputs=[],
315
+ outputs=[input_image]
316
+ )
317
+
318
+ example2_btn.click(
319
+ fn=lambda: load_example_image(2),
320
+ inputs=[],
321
+ outputs=[input_image]
322
+ )
323
+
324
+ # Footer
325
+ gr.HTML("""
326
+ <div style="text-align: center; padding: 20px; margin-top: 40px; border-top: 2px solid #e5e7eb; background: #f9fafb;">
327
+ <p style="margin: 0; color: #dc2626; font-weight: bold;">
328
+ ⚠️ MEDICAL DISCLAIMER
329
+ </p>
330
+ <p style="margin: 5px 0; color: #4b5563;">
331
+ This AI system is for research and educational purposes only.<br>
332
+ Always consult qualified medical professionals for clinical decisions.
333
+ </p>
334
+ <p style="margin: 10px 0 0 0; color: #6b7280; font-size: 0.9em;">
335
+ πŸ”¬ Powered by PyTorch | πŸ€— Hosted on Hugging Face | πŸ“Š Gradio Interface
336
+ </p>
337
+ </div>
338
+ """)
339
+
340
+ # Launch the app
341
+ if __name__ == "__main__":
342
+ demo.launch()