AssanaliAidarkhan commited on
Commit
750248d
Β·
verified Β·
1 Parent(s): 81fb39e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -154
app.py CHANGED
@@ -6,11 +6,9 @@ from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import torch.nn.functional as F
8
  import json
9
- import traceback
10
 
11
  # Model repositories
12
  BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
13
- QWEN_RAG_REPO = "AssanaliAidarkhan/qwen-medical-rag"
14
 
15
  # Global variables
16
  biomedclip_model = None
@@ -18,7 +16,6 @@ biomedclip_processor = None
18
  biomedclip_id2label = {}
19
  qwen_model = None
20
  qwen_tokenizer = None
21
- medical_knowledge = []
22
 
23
  class CLIPClassifier(nn.Module):
24
  def __init__(self, clip_model, num_classes):
@@ -32,12 +29,10 @@ class CLIPClassifier(nn.Module):
32
  return {'logits': logits}
33
 
34
  def load_biomedclip():
35
- """Load BiomedCLIP model"""
36
  global biomedclip_model, biomedclip_processor, biomedclip_id2label
37
 
38
  try:
39
- print("πŸ”„ Loading BiomedCLIP...")
40
-
41
  model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin")
42
  checkpoint = torch.load(model_path, map_location='cpu')
43
 
@@ -54,133 +49,42 @@ def load_biomedclip():
54
 
55
  print("βœ… BiomedCLIP loaded!")
56
  return True
57
-
58
  except Exception as e:
59
  print(f"❌ BiomedCLIP error: {e}")
60
  return False
61
 
62
- def load_qwen():
63
- """Load Qwen model and medical knowledge"""
64
- global qwen_model, qwen_tokenizer, medical_knowledge
65
 
66
  try:
67
- print("πŸ”„ Loading Qwen...")
68
 
69
- # Load medical knowledge
70
- try:
71
- knowledge_path = hf_hub_download(repo_id=QWEN_RAG_REPO, filename="medical_knowledge.json")
72
- with open(knowledge_path, 'r', encoding='utf-8') as f:
73
- medical_knowledge = json.load(f)
74
- print(f"βœ… Loaded {len(medical_knowledge)} medical docs")
75
- except:
76
- # Fallback knowledge base
77
- medical_knowledge = [
78
- {
79
- "category": "partial_acl_injury",
80
- "advice": "Partial ACL injury detected. Recommend: rest, ice therapy, physical therapy consultation, avoid pivoting activities. Follow-up MRI in 6-8 weeks to assess healing progress."
81
- },
82
- {
83
- "category": "complete_acl_tear",
84
- "advice": "Complete ACL tear detected. Urgent orthopedic consultation required. Likely surgical reconstruction needed, especially for active patients. Immediate immobilization recommended."
85
- },
86
- {
87
- "category": "acl_sprain",
88
- "advice": "ACL sprain detected. Conservative management with RICE protocol. Physical therapy for strengthening. Return to activity when pain-free and strength restored."
89
- },
90
- {
91
- "category": "normal",
92
- "advice": "ACL appears normal on MRI. Continue regular activities. If symptoms persist, consider clinical examination for other possible causes."
93
- }
94
- ]
95
- print("⚠️ Using fallback medical knowledge")
96
 
97
- # Load Qwen
98
- qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True)
99
  qwen_model = AutoModelForCausalLM.from_pretrained(
100
  "Qwen/Qwen1.5-0.5B-Chat",
101
  torch_dtype=torch.float32,
102
  trust_remote_code=True
103
  )
104
- qwen_model.eval()
105
 
106
  print("βœ… Qwen loaded!")
107
  return True
108
 
109
  except Exception as e:
110
  print(f"❌ Qwen error: {e}")
111
- print(traceback.format_exc())
112
  return False
113
 
114
- def find_medical_advice(classification_result):
115
- """Find relevant medical advice based on classification"""
116
-
117
- # Extract the main condition from classification
118
- condition = classification_result.lower()
119
-
120
- # Find matching advice
121
- for doc in medical_knowledge:
122
- if doc['category'].lower() in condition or any(tag in condition for tag in doc.get('tags', [])):
123
- return doc.get('advice', 'No specific advice available for this condition.')
124
-
125
- # Generic advice if no match
126
- return "Consult with a medical professional for proper evaluation and treatment recommendations."
127
-
128
- def generate_qwen_advice(classification_result, medical_context):
129
- """Generate advice using Qwen"""
130
- global qwen_model, qwen_tokenizer
131
-
132
- if qwen_model is None:
133
- return "❌ Qwen model not available"
134
-
135
- try:
136
- # Create medical prompt
137
- prompt = f"""Medical Image Analysis Result: {classification_result}
138
-
139
- Relevant Medical Knowledge: {medical_context}
140
-
141
- Based on this MRI classification, provide clinical recommendations including:
142
- 1. Immediate actions needed
143
- 2. Treatment options
144
- 3. Follow-up requirements
145
- 4. Patient advice
146
-
147
- Response:"""
148
-
149
- # Tokenize and generate
150
- inputs = qwen_tokenizer(prompt, return_tensors="pt", max_length=500, truncation=True)
151
-
152
- with torch.no_grad():
153
- outputs = qwen_model.generate(
154
- inputs.input_ids,
155
- max_new_tokens=150,
156
- temperature=0.7,
157
- do_sample=True,
158
- pad_token_id=qwen_tokenizer.eos_token_id
159
- )
160
-
161
- # Decode response
162
- generated_ids = outputs[0][len(inputs.input_ids[0]):]
163
- advice = qwen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
164
-
165
- return advice
166
-
167
- except Exception as e:
168
- return f"❌ Qwen generation error: {str(e)}"
169
-
170
- def complete_analysis(image):
171
- """Complete pipeline: Classification + Medical Advice"""
172
-
173
- if image is None:
174
- return "❌ Please upload an MRI scan", ""
175
-
176
- print("πŸ₯ Starting complete analysis...")
177
 
178
- # Step 1: Classify with BiomedCLIP
179
  try:
180
- # Classification code (same as your working debug version)
181
- if biomedclip_model is None:
182
- return "❌ BiomedCLIP not loaded", ""
183
-
184
  if image.mode != 'RGB':
185
  image = image.convert('RGB')
186
 
@@ -202,28 +106,89 @@ def complete_analysis(image):
202
  class_name = f"Class_{class_idx}"
203
 
204
  confidence = top_prob.item() * 100
205
- classification_result = f"{class_name} (confidence: {confidence:.1f}%)"
206
 
207
- print(f"βœ… Classification: {classification_result}")
208
 
209
  except Exception as e:
210
- return f"❌ Classification error: {e}", ""
 
 
 
 
 
 
 
 
211
 
212
- # Step 2: Get medical advice
213
  try:
214
- # Find relevant medical context
215
- medical_context = find_medical_advice(class_name)
216
 
217
- # Generate advice with Qwen
218
- qwen_advice = generate_qwen_advice(classification_result, medical_context)
 
 
 
 
 
219
 
220
- print("βœ… Medical advice generated")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  except Exception as e:
223
- qwen_advice = f"❌ Advice generation error: {e}"
224
- print(f"❌ Advice error: {e}")
 
 
 
 
 
 
 
 
 
 
225
 
226
- # Format results
 
 
 
 
 
 
227
  classification_text = f"""
228
  # πŸ”¬ **MRI Classification**
229
 
@@ -232,66 +197,46 @@ def complete_analysis(image):
232
 
233
  ## πŸ“Š **Confidence:**
234
  **{confidence:.1f}%**
235
-
236
- ## πŸ“ˆ **Details:**
237
- - Class index: {class_idx}
238
- - Total classes: {len(biomedclip_id2label)}
239
- - Model: BiomedCLIP
240
  """
241
 
242
  advice_text = f"""
243
- # πŸ₯ **Clinical Recommendations**
244
-
245
- ## πŸ’‘ **AI-Generated Advice:**
246
 
247
- {qwen_advice}
248
-
249
- ## πŸ“š **Medical Knowledge:**
250
-
251
- {medical_context}
252
 
253
  ---
254
- ⚠️ **Disclaimer:** For educational purposes only. Always consult medical professionals.
255
  """
256
 
257
  return classification_text, advice_text
258
 
259
  # Load models
260
- print("πŸš€ Initializing models...")
261
  biomedclip_loaded = load_biomedclip()
262
- qwen_loaded = load_qwen()
263
 
264
  # Create interface
265
  with gr.Blocks(title="Medical AI Pipeline") as app:
266
 
267
- gr.Markdown("# πŸ₯ Complete Medical AI Analysis Pipeline")
268
- gr.Markdown("**BiomedCLIP** (Image Classification) + **Qwen** (Medical Advice)")
269
 
270
- # Status
271
- status_text = f"BiomedCLIP: {'βœ…' if biomedclip_loaded else '❌'} | Qwen: {'βœ…' if qwen_loaded else '❌'}"
272
- gr.Markdown(f"**Status:** {status_text}")
273
 
274
  with gr.Row():
275
  with gr.Column():
276
- image_input = gr.Image(type="pil", label="πŸ“Έ Upload MRI Scan", height=400)
277
- analyze_btn = gr.Button("πŸ”¬ Complete Analysis", variant="primary", size="lg")
278
- clear_btn = gr.Button("πŸ—‘οΈ Clear")
279
 
280
  with gr.Column():
281
  classification_output = gr.Markdown(label="πŸ”¬ Classification")
282
- advice_output = gr.Markdown(label="πŸ₯ Medical Advice")
283
 
284
- # Button actions
285
  analyze_btn.click(
286
- fn=complete_analysis,
287
  inputs=image_input,
288
  outputs=[classification_output, advice_output]
289
  )
290
-
291
- clear_btn.click(
292
- fn=lambda: [None, "", ""],
293
- outputs=[image_input, classification_output, advice_output]
294
- )
295
 
296
  if __name__ == "__main__":
297
  app.launch()
 
6
  from PIL import Image
7
  import torch.nn.functional as F
8
  import json
 
9
 
10
  # Model repositories
11
  BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
 
12
 
13
  # Global variables
14
  biomedclip_model = None
 
16
  biomedclip_id2label = {}
17
  qwen_model = None
18
  qwen_tokenizer = None
 
19
 
20
  class CLIPClassifier(nn.Module):
21
  def __init__(self, clip_model, num_classes):
 
29
  return {'logits': logits}
30
 
31
  def load_biomedclip():
32
+ """Load BiomedCLIP (we know this works)"""
33
  global biomedclip_model, biomedclip_processor, biomedclip_id2label
34
 
35
  try:
 
 
36
  model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin")
37
  checkpoint = torch.load(model_path, map_location='cpu')
38
 
 
49
 
50
  print("βœ… BiomedCLIP loaded!")
51
  return True
 
52
  except Exception as e:
53
  print(f"❌ BiomedCLIP error: {e}")
54
  return False
55
 
56
+ def load_qwen_simple():
57
+ """Load Qwen with minimal setup"""
58
+ global qwen_model, qwen_tokenizer
59
 
60
  try:
61
+ print("πŸ”„ Loading Qwen (simple)...")
62
 
63
+ # Load Qwen directly
64
+ qwen_tokenizer = AutoTokenizer.from_pretrained(
65
+ "Qwen/Qwen1.5-0.5B-Chat",
66
+ trust_remote_code=True
67
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
 
69
  qwen_model = AutoModelForCausalLM.from_pretrained(
70
  "Qwen/Qwen1.5-0.5B-Chat",
71
  torch_dtype=torch.float32,
72
  trust_remote_code=True
73
  )
 
74
 
75
  print("βœ… Qwen loaded!")
76
  return True
77
 
78
  except Exception as e:
79
  print(f"❌ Qwen error: {e}")
 
80
  return False
81
 
82
+ def classify_mri(image):
83
+ """Classify MRI (working code)"""
84
+ if biomedclip_model is None or image is None:
85
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
87
  try:
 
 
 
 
88
  if image.mode != 'RGB':
89
  image = image.convert('RGB')
90
 
 
106
  class_name = f"Class_{class_idx}"
107
 
108
  confidence = top_prob.item() * 100
 
109
 
110
+ return class_name, confidence
111
 
112
  except Exception as e:
113
+ print(f"Classification error: {e}")
114
+ return None, None
115
+
116
+ def generate_simple_advice(class_name, confidence):
117
+ """Generate advice using Qwen (simple approach)"""
118
+ global qwen_model, qwen_tokenizer
119
+
120
+ if qwen_model is None:
121
+ return "❌ Qwen model not loaded"
122
 
 
123
  try:
124
+ print(f"πŸ”„ Generating advice for: {class_name}")
 
125
 
126
+ # Simple medical knowledge lookup
127
+ advice_map = {
128
+ "partial_acl_injury": "Partial ACL injury detected. Recommendations: Rest and avoid pivoting activities. Apply ice for 15-20 minutes several times daily. Consider physical therapy consultation. Follow-up MRI in 6-8 weeks to monitor healing.",
129
+ "complete_acl_tear": "Complete ACL tear detected. Urgent orthopedic consultation required. Likely surgical reconstruction needed. Immediate immobilization and avoid weight-bearing activities.",
130
+ "acl_sprain": "ACL sprain detected. Conservative treatment with RICE protocol (Rest, Ice, Compression, Elevation). Physical therapy for strengthening. Gradual return to activities.",
131
+ "normal": "ACL appears normal. Continue regular activities. If symptoms persist, consider clinical examination for other causes."
132
+ }
133
 
134
+ # Get base advice
135
+ base_advice = advice_map.get(class_name.lower(), "Consult medical professional for evaluation.")
136
+
137
+ # Create simple prompt for Qwen
138
+ simple_prompt = f"Medical diagnosis: {class_name} with {confidence:.1f}% confidence. Provide brief clinical advice:"
139
+
140
+ # Tokenize
141
+ inputs = qwen_tokenizer(simple_prompt, return_tensors="pt")
142
+
143
+ # Generate
144
+ with torch.no_grad():
145
+ outputs = qwen_model.generate(
146
+ inputs.input_ids,
147
+ max_new_tokens=100,
148
+ temperature=0.8,
149
+ do_sample=True,
150
+ pad_token_id=qwen_tokenizer.eos_token_id
151
+ )
152
+
153
+ # Decode
154
+ full_output = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
155
+
156
+ # Extract just the generated part
157
+ if simple_prompt in full_output:
158
+ generated_advice = full_output.replace(simple_prompt, "").strip()
159
+ else:
160
+ generated_advice = full_output
161
+
162
+ # Combine base advice with Qwen advice
163
+ if generated_advice and len(generated_advice) > 10:
164
+ combined_advice = f"**Clinical Guidelines:** {base_advice}\n\n**AI Analysis:** {generated_advice}"
165
+ else:
166
+ combined_advice = base_advice
167
+
168
+ print(f"βœ… Generated advice: {generated_advice[:50]}...")
169
+ return combined_advice
170
 
171
  except Exception as e:
172
+ print(f"❌ Advice generation error: {e}")
173
+ # Fallback to basic advice
174
+ return advice_map.get(class_name.lower(), "Consult medical professional for evaluation.")
175
+
176
+ def complete_pipeline(image):
177
+ """Complete analysis pipeline"""
178
+
179
+ if image is None:
180
+ return "❌ Please upload an MRI scan", ""
181
+
182
+ # Step 1: Classification
183
+ class_name, confidence = classify_mri(image)
184
 
185
+ if class_name is None:
186
+ return "❌ Classification failed", ""
187
+
188
+ # Step 2: Medical advice
189
+ medical_advice = generate_simple_advice(class_name, confidence)
190
+
191
+ # Format outputs
192
  classification_text = f"""
193
  # πŸ”¬ **MRI Classification**
194
 
 
197
 
198
  ## πŸ“Š **Confidence:**
199
  **{confidence:.1f}%**
 
 
 
 
 
200
  """
201
 
202
  advice_text = f"""
203
+ # πŸ₯ **Medical Recommendations**
 
 
204
 
205
+ {medical_advice}
 
 
 
 
206
 
207
  ---
208
+ ⚠️ **Disclaimer:** For educational purposes only. Consult medical professionals.
209
  """
210
 
211
  return classification_text, advice_text
212
 
213
  # Load models
 
214
  biomedclip_loaded = load_biomedclip()
215
+ qwen_loaded = load_qwen_simple()
216
 
217
  # Create interface
218
  with gr.Blocks(title="Medical AI Pipeline") as app:
219
 
220
+ gr.Markdown("# πŸ₯ Medical AI Analysis Pipeline")
221
+ gr.Markdown("**BiomedCLIP** (Classification) + **Qwen** (Medical Advice)")
222
 
223
+ status = f"Status: BiomedCLIP {'βœ…' if biomedclip_loaded else '❌'} | Qwen {'βœ…' if qwen_loaded else '❌'}"
224
+ gr.Markdown(f"**{status}**")
 
225
 
226
  with gr.Row():
227
  with gr.Column():
228
+ image_input = gr.Image(type="pil", label="πŸ“Έ Upload MRI Scan")
229
+ analyze_btn = gr.Button("πŸ”¬ Complete Analysis", variant="primary")
 
230
 
231
  with gr.Column():
232
  classification_output = gr.Markdown(label="πŸ”¬ Classification")
233
+ advice_output = gr.Markdown(label="πŸ₯ Medical Advice")
234
 
 
235
  analyze_btn.click(
236
+ fn=complete_pipeline,
237
  inputs=image_input,
238
  outputs=[classification_output, advice_output]
239
  )
 
 
 
 
 
240
 
241
  if __name__ == "__main__":
242
  app.launch()