AssanaliAidarkhan commited on
Commit
a31df32
Β·
verified Β·
1 Parent(s): 9709eb3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -279
app.py CHANGED
@@ -2,306 +2,184 @@ import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  from huggingface_hub import hf_hub_download
 
5
  from PIL import Image
6
- import numpy as np
7
- import traceback
8
  import torch.nn.functional as F
9
- from torchvision import transforms
10
 
11
  # Your model configuration
12
  MODEL_REPO = "AssanaliAidarkhan/Biomedclip"
13
- MODEL_FILENAME = "pytorch_model.bin"
14
 
15
  # Global variables
16
  model = None
17
- class_labels = []
18
- label2id = {}
19
  id2label = {}
20
 
21
- # Simple CNN architecture (common for medical image classification)
22
- class SimpleCNN(nn.Module):
23
- def __init__(self, num_classes):
24
- super(SimpleCNN, self).__init__()
25
- self.features = nn.Sequential(
26
- nn.Conv2d(3, 64, kernel_size=3, padding=1),
27
- nn.ReLU(inplace=True),
28
- nn.MaxPool2d(kernel_size=2, stride=2),
29
- nn.Conv2d(64, 128, kernel_size=3, padding=1),
30
- nn.ReLU(inplace=True),
31
- nn.MaxPool2d(kernel_size=2, stride=2),
32
- nn.Conv2d(128, 256, kernel_size=3, padding=1),
33
- nn.ReLU(inplace=True),
34
- nn.MaxPool2d(kernel_size=2, stride=2),
35
- nn.Conv2d(256, 512, kernel_size=3, padding=1),
36
- nn.ReLU(inplace=True),
37
- nn.AdaptiveAvgPool2d((7, 7))
38
- )
39
- self.classifier = nn.Sequential(
40
- nn.Dropout(0.5),
41
- nn.Linear(512 * 7 * 7, 4096),
42
- nn.ReLU(inplace=True),
43
- nn.Dropout(0.5),
44
- nn.Linear(4096, 1000),
45
- nn.ReLU(inplace=True),
46
- nn.Linear(1000, num_classes)
47
- )
48
-
49
- def forward(self, x):
50
- x = self.features(x)
51
- x = torch.flatten(x, 1)
52
- x = self.classifier(x)
53
- return x
54
-
55
- # ResNet-like architecture
56
- class ResNetLike(nn.Module):
57
- def __init__(self, num_classes):
58
- super(ResNetLike, self).__init__()
59
- # Simple ResNet-like structure
60
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
61
- self.bn1 = nn.BatchNorm2d(64)
62
- self.relu = nn.ReLU(inplace=True)
63
- self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
64
-
65
- # Basic blocks
66
- self.layer1 = self._make_layer(64, 64, 2)
67
- self.layer2 = self._make_layer(64, 128, 2, stride=2)
68
- self.layer3 = self._make_layer(128, 256, 2, stride=2)
69
- self.layer4 = self._make_layer(256, 512, 2, stride=2)
70
-
71
- self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
72
- self.fc = nn.Linear(512, num_classes)
73
-
74
- def _make_layer(self, in_channels, out_channels, blocks, stride=1):
75
- layers = []
76
- layers.append(nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False))
77
- layers.append(nn.BatchNorm2d(out_channels))
78
- layers.append(nn.ReLU(inplace=True))
79
 
80
- for _ in range(1, blocks):
81
- layers.append(nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False))
82
- layers.append(nn.BatchNorm2d(out_channels))
83
- layers.append(nn.ReLU(inplace=True))
84
-
85
- return nn.Sequential(*layers)
86
-
87
- def forward(self, x):
88
- x = self.conv1(x)
89
- x = self.bn1(x)
90
- x = self.relu(x)
91
- x = self.maxpool(x)
92
-
93
- x = self.layer1(x)
94
- x = self.layer2(x)
95
- x = self.layer3(x)
96
- x = self.layer4(x)
97
-
98
- x = self.avgpool(x)
99
- x = torch.flatten(x, 1)
100
- x = self.fc(x)
101
- return x
102
 
103
- def try_load_with_architecture(state_dict, num_classes, architecture_name):
104
- """Try loading with different architectures"""
105
-
106
- architectures = {
107
- 'simplecnn': SimpleCNN(num_classes),
108
- 'resnet': ResNetLike(num_classes),
109
- }
110
-
111
- # If we know the specific architecture name, try that first
112
- if architecture_name and architecture_name.lower() in architectures:
113
- try:
114
- model = architectures[architecture_name.lower()]
115
- model.load_state_dict(state_dict)
116
- return model, f"βœ… Loaded with {architecture_name}"
117
- except Exception as e:
118
- pass
119
-
120
- # Try each architecture
121
- for arch_name, arch_model in architectures.items():
122
- try:
123
- arch_model.load_state_dict(state_dict)
124
- return arch_model, f"βœ… Successfully loaded with {arch_name} architecture"
125
- except Exception as e:
126
- continue
127
-
128
- return None, "❌ Could not match state_dict with any architecture"
129
-
130
- def load_model():
131
- """Load the BiodemCLIP model"""
132
- global model, class_labels, label2id, id2label
133
 
134
  try:
135
- print(f"Downloading model from: {MODEL_REPO}")
136
 
137
- # Download your model file
138
  model_path = hf_hub_download(
139
- repo_id=MODEL_REPO,
140
- filename=MODEL_FILENAME,
141
- cache_dir="./model_cache"
142
  )
143
 
144
- print(f"Model downloaded to: {model_path}")
145
 
146
- # Load the saved data
147
- saved_data = torch.load(model_path, map_location='cpu')
148
 
149
- print(f"Loaded data type: {type(saved_data)}")
150
 
151
- # Extract information
152
- state_dict = saved_data['model_state_dict']
153
- num_classes = saved_data['num_classes']
154
- model_name = saved_data.get('model_name', '')
155
- label2id = saved_data.get('label2id', {})
156
- id2label = saved_data.get('id2label', {})
157
 
158
- print(f"Number of classes: {num_classes}")
159
- print(f"Model name: {model_name}")
160
- print(f"Available labels: {list(id2label.values())[:5]}...")
161
 
162
- # Create class labels list
163
- class_labels = [id2label.get(str(i), f"Class_{i}") for i in range(num_classes)]
 
164
 
165
- # Try to load the model with the correct architecture
166
- model, load_message = try_load_with_architecture(state_dict, num_classes, model_name)
167
 
168
- if model is None:
169
- print("❌ Failed to load model with any architecture")
170
- return False
171
 
172
- print(load_message)
 
 
173
 
174
- # Set to evaluation mode
175
  model.eval()
176
 
177
  print("βœ… Model loaded successfully!")
178
  return True
179
 
180
  except Exception as e:
181
- print(f"Error loading model: {e}")
 
182
  print(traceback.format_exc())
183
  return False
184
 
185
- def preprocess_image(image):
186
- """Preprocess MRI image for model input"""
187
- try:
188
- # Convert to RGB if not already
189
- if image.mode != 'RGB':
190
- image = image.convert('RGB')
191
-
192
- # Standard preprocessing pipeline
193
- transform = transforms.Compose([
194
- transforms.Resize((224, 224)), # Resize to standard input size
195
- transforms.ToTensor(),
196
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization
197
- ])
198
-
199
- image_tensor = transform(image).unsqueeze(0) # Add batch dimension
200
-
201
- return image_tensor
202
-
203
- except Exception as e:
204
- print(f"Error preprocessing image: {e}")
205
- return None
206
-
207
- def classify_mri(image):
208
- """Classify MRI scan"""
209
- global model, class_labels, id2label
210
 
211
- if model is None:
212
- return "❌ Model not loaded! Please check the logs.", None
213
 
214
  if image is None:
215
- return "❌ Please upload an MRI scan.", None
216
 
217
  try:
218
- # Preprocess the image
219
- image_tensor = preprocess_image(image)
220
- if image_tensor is None:
221
- return "❌ Error preprocessing image.", None
 
 
222
 
223
  # Run inference
224
  with torch.no_grad():
225
- outputs = model(image_tensor)
226
-
227
- # Apply softmax to get probabilities
228
- probabilities = F.softmax(outputs, dim=-1)
229
-
230
- # Get top predictions
231
- top_probs, top_indices = torch.topk(probabilities, k=min(5, len(class_labels)))
232
-
233
- # Prepare results
234
- results = []
235
- confidence_data = []
236
-
237
- for i, (prob, idx) in enumerate(zip(top_probs[0], top_indices[0])):
238
- class_name = class_labels[idx.item()]
239
- confidence = prob.item() * 100
240
- results.append(f"{i+1}. **{class_name}**: {confidence:.2f}%")
241
- confidence_data.append([class_name, confidence])
242
-
243
- # Get top prediction
244
- top_prediction = class_labels[top_indices[0][0].item()]
245
- top_confidence = top_probs[0][0].item() * 100
246
-
247
- # Format result text
248
- result_text = f"""
249
- ## πŸ”¬ **MRI Classification Results**
250
 
251
- ### 🎯 **Top Prediction:**
252
- **{top_prediction}** ({top_confidence:.2f}% confidence)
253
 
254
- ### πŸ“Š **All Predictions:**
255
  {chr(10).join(results)}
256
 
257
- ### πŸ“ˆ **Model Information:**
258
- - **Input Size:** 224Γ—224 pixels
259
- - **Model:** BiomedCLIP
260
- - **Classes:** {len(class_labels)} categories
261
- - **Available Classes:** {', '.join(class_labels)}
262
 
263
- ### πŸ’‘ **Interpretation:**
264
- {get_interpretation(top_prediction, top_confidence)}
265
- """
266
-
267
- return result_text, confidence_data
268
-
269
  except Exception as e:
270
- error_msg = f"❌ **Error during classification:** {str(e)}\n\n{traceback.format_exc()}"
271
- return error_msg, None
272
-
273
- def get_interpretation(prediction, confidence):
274
- """Provide interpretation based on prediction"""
275
- if confidence >= 90:
276
- return f"🟒 **High Confidence**: The model is very confident this is {prediction.lower()}."
277
- elif confidence >= 70:
278
- return f"🟑 **Good Confidence**: The model believes this is likely {prediction.lower()}."
279
- elif confidence >= 50:
280
- return f"🟠 **Moderate Confidence**: The model suggests this might be {prediction.lower()}, but consider additional analysis."
281
  else:
282
- return f"πŸ”΄ **Low Confidence**: The model is uncertain. Manual review recommended."
283
 
284
- # Load model on startup
285
- print("Initializing BiomedCLIP model...")
286
- model_loaded = load_model()
287
 
288
  # Create Gradio interface
289
- with gr.Blocks(title="MRI Classification with BiomedCLIP", theme=gr.themes.Soft()) as demo:
 
290
  gr.Markdown("""
291
- # 🧠 MRI Classification with BiomedCLIP
292
 
293
- Upload an MRI scan to get automated classification results.
294
 
295
- **Model:** AssanaliAidarkhan/Biomedclip
296
  """)
297
 
298
- if not model_loaded:
299
- gr.Markdown("⚠️ **Warning: Model failed to load. Check the logs for details.**")
300
- else:
301
  gr.Markdown("βœ… **Model loaded successfully!**")
 
 
302
 
303
  with gr.Row():
304
- with gr.Column(scale=1):
 
305
  image_input = gr.Image(
306
  type="pil",
307
  label="πŸ“Έ Upload MRI Scan",
@@ -309,58 +187,32 @@ with gr.Blocks(title="MRI Classification with BiomedCLIP", theme=gr.themes.Soft(
309
  )
310
 
311
  classify_btn = gr.Button("πŸ” Classify MRI", variant="primary", size="lg")
312
- clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
313
 
314
- with gr.Column(scale=1):
315
- result_output = gr.Markdown(label="πŸ“Š Classification Results")
316
-
317
- # Confidence chart
318
- confidence_plot = gr.BarPlot(
319
- x="class",
320
- y="confidence",
321
- title="Confidence Scores by Class",
322
- x_title="Medical Condition",
323
- y_title="Confidence (%)",
324
- width=500,
325
- height=300
326
- )
327
-
328
- # Event handlers
329
- def classify_and_plot(image):
330
- text_result, plot_data = classify_mri(image)
331
- if plot_data:
332
- plot_df = [{"class": item[0], "confidence": item[1]} for item in plot_data]
333
- return text_result, plot_df
334
- return text_result, None
335
 
 
336
  classify_btn.click(
337
- fn=classify_and_plot,
338
  inputs=image_input,
339
- outputs=[result_output, confidence_plot]
340
  )
341
 
342
  clear_btn.click(
343
- fn=lambda: [None, "", None],
344
- inputs=[],
345
- outputs=[image_input, result_output, confidence_plot]
346
  )
347
 
348
- # Instructions
349
  gr.Markdown("""
350
- ### πŸ“‹ Instructions:
351
- 1. **Upload an MRI scan image** (JPEG, PNG, etc.) using the image upload area above ☝️
352
- 2. Click "πŸ” Classify MRI" to get results
353
- 3. View the classification results and confidence scores
354
-
355
- ### πŸ₯ Model Information:
356
- - Automatically detects the number of classes from your trained model
357
- - Uses the exact class labels from your training data
358
- - Applies standard medical image preprocessing
359
 
360
- ### ⚠️ Medical Disclaimer:
361
- This tool is for research purposes only and should not replace professional medical diagnosis.
362
- Always consult qualified medical professionals for clinical decisions.
363
  """)
364
 
365
  if __name__ == "__main__":
366
- demo.launch()
 
2
  import torch
3
  import torch.nn as nn
4
  from huggingface_hub import hf_hub_download
5
+ from transformers import CLIPModel, CLIPProcessor
6
  from PIL import Image
 
 
7
  import torch.nn.functional as F
 
8
 
9
  # Your model configuration
10
  MODEL_REPO = "AssanaliAidarkhan/Biomedclip"
11
+ MODEL_FILE = "pytorch_model.bin"
12
 
13
  # Global variables
14
  model = None
15
+ processor = None
 
16
  id2label = {}
17
 
18
+ class CLIPClassifier(nn.Module):
19
+ """Your CLIP-based classifier architecture"""
20
+ def __init__(self, clip_model, num_classes):
21
+ super(CLIPClassifier, self).__init__()
22
+ self.clip_model = clip_model
23
+ self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ def forward(self, **inputs):
26
+ # Get image features from CLIP
27
+ outputs = self.clip_model.get_image_features(**inputs)
28
+ # Classify using the linear layer
29
+ logits = self.classifier(outputs)
30
+ return {'logits': logits}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def load_biomedclip():
33
+ """Load your CLIP-based BiomedCLIP model"""
34
+ global model, processor, id2label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  try:
37
+ print("πŸ”„ Downloading model...")
38
 
39
+ # Download model file
40
  model_path = hf_hub_download(
41
+ repo_id=MODEL_REPO,
42
+ filename=MODEL_FILE
 
43
  )
44
 
45
+ print("βœ… Model file downloaded")
46
 
47
+ # Load checkpoint
48
+ checkpoint = torch.load(model_path, map_location='cpu')
49
 
50
+ print(f"πŸ“¦ Checkpoint keys: {list(checkpoint.keys())}")
51
 
52
+ # Extract info
53
+ num_classes = checkpoint['num_classes']
54
+ id2label = checkpoint['id2label']
55
+ model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16')
 
 
56
 
57
+ print(f"πŸ”’ Number of classes: {num_classes}")
58
+ print(f"🏷️ Classes: {list(id2label.values())}")
59
+ print(f"πŸ€– Base model: {model_name}")
60
 
61
+ # Load CLIP processor and model
62
+ print("πŸ“₯ Loading CLIP processor...")
63
+ processor = CLIPProcessor.from_pretrained(model_name)
64
 
65
+ print("πŸ“₯ Loading CLIP model...")
66
+ clip_model = CLIPModel.from_pretrained(model_name)
67
 
68
+ # Create your classifier
69
+ print("πŸ”§ Creating classifier...")
70
+ model = CLIPClassifier(clip_model, num_classes)
71
 
72
+ # Load your trained weights
73
+ print("βš™οΈ Loading trained weights...")
74
+ model.load_state_dict(checkpoint['model_state_dict'])
75
 
76
+ # Set to eval mode
77
  model.eval()
78
 
79
  print("βœ… Model loaded successfully!")
80
  return True
81
 
82
  except Exception as e:
83
+ print(f"❌ Error loading model: {e}")
84
+ import traceback
85
  print(traceback.format_exc())
86
  return False
87
 
88
+ def classify_image(image):
89
+ """Classify the uploaded MRI image"""
90
+ global model, processor, id2label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ if model is None or processor is None:
93
+ return "❌ Model not loaded!"
94
 
95
  if image is None:
96
+ return "❌ Please upload an image!"
97
 
98
  try:
99
+ # Convert to RGB
100
+ if image.mode != 'RGB':
101
+ image = image.convert('RGB')
102
+
103
+ # Process image using CLIP processor
104
+ inputs = processor(images=image, return_tensors="pt")
105
 
106
  # Run inference
107
  with torch.no_grad():
108
+ outputs = model(**inputs)
109
+ logits = outputs['logits']
110
+ probabilities = F.softmax(logits, dim=1)
111
+
112
+ # Get top predictions
113
+ top_probs, top_indices = torch.topk(probabilities, k=min(5, len(id2label)))
114
+
115
+ # Format results
116
+ results = []
117
+ for i, (prob, idx) in enumerate(zip(top_probs[0], top_indices[0])):
118
+ class_name = id2label[str(idx.item())] # Note: using str() for key
119
+ confidence = prob.item() * 100
120
+ results.append(f"{i+1}. **{class_name}**: {confidence:.2f}%")
121
+
122
+ # Get top prediction
123
+ top_prediction = id2label[str(top_indices[0][0].item())]
124
+ top_confidence = top_probs[0][0].item() * 100
125
+
126
+ result_text = f"""
127
+ # πŸ”¬ MRI Classification Results
 
 
 
 
 
128
 
129
+ ## 🎯 **Top Prediction:**
130
+ **{top_prediction}** ({top_confidence:.1f}% confidence)
131
 
132
+ ## πŸ“Š **All Predictions:**
133
  {chr(10).join(results)}
134
 
135
+ ## πŸ“ˆ **Model Info:**
136
+ - **Architecture:** CLIP-based classifier
137
+ - **Classes:** {len(id2label)} categories
138
+ - **Input processed:** βœ…
 
139
 
140
+ ## πŸ’‘ **Confidence Level:**
141
+ {get_confidence_interpretation(top_confidence)}
142
+ """
143
+
144
+ return result_text
145
+
146
  except Exception as e:
147
+ return f"❌ Classification error: {str(e)}"
148
+
149
+ def get_confidence_interpretation(confidence):
150
+ """Interpret confidence level"""
151
+ if confidence >= 80:
152
+ return "🟒 **High confidence** - Strong classification result"
153
+ elif confidence >= 60:
154
+ return "🟑 **Good confidence** - Reliable result"
155
+ elif confidence >= 40:
156
+ return "🟠 **Moderate confidence** - Consider additional analysis"
 
157
  else:
158
+ return "πŸ”΄ **Low confidence** - Uncertain result, manual review recommended"
159
 
160
+ # Load model at startup
161
+ print("πŸš€ Loading BiomedCLIP model...")
162
+ model_loaded = load_biomedclip()
163
 
164
  # Create Gradio interface
165
+ with gr.Blocks(title="BiomedCLIP MRI Classifier") as app:
166
+
167
  gr.Markdown("""
168
+ # 🧠 BiomedCLIP MRI Classifier
169
 
170
+ Upload an MRI scan for automated medical image classification.
171
 
172
+ **Model:** AssanaliAidarkhan/Biomedclip (CLIP-based)
173
  """)
174
 
175
+ if model_loaded:
 
 
176
  gr.Markdown("βœ… **Model loaded successfully!**")
177
+ else:
178
+ gr.Markdown("❌ **Model failed to load - check logs below**")
179
 
180
  with gr.Row():
181
+ with gr.Column():
182
+ # This is where you upload your image! πŸ‘‡
183
  image_input = gr.Image(
184
  type="pil",
185
  label="πŸ“Έ Upload MRI Scan",
 
187
  )
188
 
189
  classify_btn = gr.Button("πŸ” Classify MRI", variant="primary", size="lg")
190
+ clear_btn = gr.Button("πŸ—‘οΈ Clear")
191
 
192
+ with gr.Column():
193
+ result_output = gr.Markdown(label="πŸ“Š Results")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
+ # Button actions
196
  classify_btn.click(
197
+ fn=classify_image,
198
  inputs=image_input,
199
+ outputs=result_output
200
  )
201
 
202
  clear_btn.click(
203
+ fn=lambda: [None, ""],
204
+ outputs=[image_input, result_output]
 
205
  )
206
 
 
207
  gr.Markdown("""
208
+ ### πŸ“‹ How to Use:
209
+ 1. **Click the image area above** or **drag & drop** your MRI image
210
+ 2. Click "πŸ” Classify MRI"
211
+ 3. View results below
 
 
 
 
 
212
 
213
+ ### πŸ₯ Medical Disclaimer:
214
+ For research purposes only. Not for clinical diagnosis.
 
215
  """)
216
 
217
  if __name__ == "__main__":
218
+ app.launch()