File size: 12,789 Bytes
c49bb9e
 
c0c29a6
ba9c37a
9e90db6
c49bb9e
c0c29a6
 
ba9c37a
9e90db6
 
c49bb9e
ba9c37a
 
9e90db6
c49bb9e
 
ba9c37a
 
 
 
 
9e90db6
 
 
bc70dd9
a31df32
 
 
 
 
c0c29a6
a31df32
 
 
 
c0c29a6
a31df32
750248d
81fb39e
c49bb9e
 
ba9c37a
a31df32
0527131
a31df32
ba9c37a
a31df32
c0c29a6
ba9c37a
 
c0c29a6
ba9c37a
 
 
c0c29a6
81fb39e
ba9c37a
 
81fb39e
ba9c37a
 
9e90db6
 
 
ba9c37a
 
9e90db6
bc70dd9
9e90db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc70dd9
9e90db6
 
 
ba9c37a
 
 
 
 
 
9e90db6
c0c29a6
c49bb9e
 
9e90db6
 
 
c0c29a6
 
9e90db6
 
 
 
 
 
c0c29a6
 
9e90db6
 
c0c29a6
9e90db6
 
 
a31df32
9e90db6
 
bc70dd9
9e90db6
 
 
a31df32
9e90db6
 
 
 
 
 
 
 
81fb39e
9e90db6
81fb39e
 
9e90db6
 
750248d
9e90db6
 
750248d
 
 
 
81fb39e
 
9e90db6
81fb39e
9e90db6
 
 
 
 
 
 
 
 
750248d
9e90db6
750248d
 
9e90db6
 
 
750248d
 
 
 
 
9e90db6
 
750248d
9e90db6
 
750248d
 
9e90db6
 
 
750248d
9e90db6
750248d
9e90db6
 
750248d
9e90db6
ba9c37a
81fb39e
9e90db6
 
 
 
750248d
9e90db6
 
750248d
 
 
 
9e90db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81fb39e
9e90db6
 
750248d
9e90db6
 
750248d
 
81fb39e
 
c0c29a6
bc70dd9
81fb39e
 
 
 
 
 
 
9e90db6
 
 
 
ba9c37a
9e90db6
 
 
 
 
ba9c37a
81fb39e
9e90db6
81fb39e
 
 
bc70dd9
81fb39e
9e90db6
ba9c37a
9e90db6
bc70dd9
81fb39e
9e90db6
c0c29a6
9e90db6
 
c49bb9e
9e90db6
 
c49bb9e
c0c29a6
bc70dd9
750248d
9e90db6
c0c29a6
bc70dd9
81fb39e
9e90db6
c0c29a6
81fb39e
9e90db6
c0c29a6
81fb39e
c0c29a6
9e90db6
 
 
 
 
 
 
 
 
 
 
 
 
 
c49bb9e
 
a31df32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from PIL import Image
import torch.nn.functional as F
import json
import numpy as np
import faiss

# Model repositories
BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
QWEN_RAG_REPO = "AssanaliAidarkhan/qwen-medical-rag"

# Global variables
biomedclip_model = None
biomedclip_processor = None
biomedclip_id2label = {}
qwen_model = None
qwen_tokenizer = None
embedding_model = None
medical_knowledge = []
faiss_index = None

class CLIPClassifier(nn.Module):
    def __init__(self, clip_model, num_classes):
        super(CLIPClassifier, self).__init__()
        self.clip_model = clip_model
        self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes)
        
    def forward(self, **inputs):
        outputs = self.clip_model.get_image_features(**inputs)
        logits = self.classifier(outputs)
        return {'logits': logits}

def load_biomedclip():
    """Load BiomedCLIP (we know this works)"""
    global biomedclip_model, biomedclip_processor, biomedclip_id2label
    
    try:
        model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin")
        checkpoint = torch.load(model_path, map_location='cpu')
        
        num_classes = checkpoint['num_classes']
        biomedclip_id2label = checkpoint['id2label']
        model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16')
        
        biomedclip_processor = CLIPProcessor.from_pretrained(model_name)
        clip_model = CLIPModel.from_pretrained(model_name)
        
        biomedclip_model = CLIPClassifier(clip_model, num_classes)
        biomedclip_model.load_state_dict(checkpoint['model_state_dict'])
        biomedclip_model.eval()
        
        print("✅ BiomedCLIP loaded!")
        return True
    except Exception as e:
        print(f"❌ BiomedCLIP error: {e}")
        return False

def load_rag_system():
    """Load complete RAG system"""
    global qwen_model, qwen_tokenizer, embedding_model, medical_knowledge, faiss_index
    
    try:
        print("🔄 Loading RAG system...")
        
        # 1. Load medical knowledge base
        try:
            knowledge_path = hf_hub_download(repo_id=QWEN_RAG_REPO, filename="medical_knowledge.json")
            with open(knowledge_path, 'r', encoding='utf-8') as f:
                medical_knowledge = json.load(f)
            print(f"✅ Knowledge base: {len(medical_knowledge)} documents")
        except Exception as e:
            print(f"⚠️ Knowledge loading error: {e}, using fallback")
            # Fallback knowledge base
            medical_knowledge = [
                {
                    "id": "doc1",
                    "title": "Частичное повреждения передней крестообразной связки",
                    "content": "Признаки частичного повреждения передней крестообразной связки: утолщение, повышенный сигнал по Т2, частичная дезорганизация волокон, связка прослеживается по ходу",
                    "category": "Partial ACL injury",
                    "advice": "Recommend conservative treatment, physical therapy, follow-up MRI in 6-8 weeks"
                },
                {
                    "id": "doc2",
                    "title": "Полный разрыв передней крестообразной связки",
                    "content": "Признаки полного разрыва передней крестообразной связки: волокна не прослеживаются по ходу, определяется зона повышенного сигнала в проекции связки, гемартроз",
                    "category": "Complete ACL tear",
                    "advice": "Urgent orthopedic consultation, likely requires ACL reconstruction surgery"
                }
            ]
        
        # 2. Load embedding model
        print("🔄 Loading embeddings...")
        embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
        
        # 3. Create embeddings and FAISS index
        print("🔄 Creating FAISS index...")
        text_contents = []
        for doc in medical_knowledge:
            text = f"{doc.get('title', '')} {doc.get('content', '')} {doc.get('advice', '')}"
            text_contents.append(text)
        
        embeddings = embedding_model.encode(text_contents, convert_to_numpy=True)
        
        # Create FAISS index
        dimension = embeddings.shape[1]
        faiss_index = faiss.IndexFlatIP(dimension)
        faiss.normalize_L2(embeddings)
        faiss_index.add(embeddings)
        
        print(f"✅ FAISS index created with {faiss_index.ntotal} documents")
        
        # 4. Load Qwen
        print("🔄 Loading Qwen...")
        qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True)
        qwen_model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen1.5-0.5B-Chat",
            torch_dtype=torch.float32,
            trust_remote_code=True
        )
        
        print("✅ RAG system loaded completely!")
        return True
        
    except Exception as e:
        print(f"❌ RAG loading error: {e}")
        import traceback
        print(traceback.format_exc())
        return False

def retrieve_relevant_knowledge(classification_result):
    """Retrieve relevant medical documents"""
    global embedding_model, medical_knowledge, faiss_index
    
    if faiss_index is None:
        return [], "No knowledge base available"
    
    try:
        # Create query for retrieval
        query = f"Medical diagnosis {classification_result} treatment recommendations clinical advice"
        
        # Get query embedding
        query_embedding = embedding_model.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(query_embedding)
        
        # Search FAISS index
        scores, indices = faiss_index.search(query_embedding, 2)  # Top 2 documents
        
        # Get relevant documents
        retrieved_docs = []
        context_text = ""
        
        for score, idx in zip(scores[0], indices[0]):
            if idx != -1 and idx < len(medical_knowledge):
                doc = medical_knowledge[idx]
                retrieved_docs.append((doc, float(score)))
                
                context_text += f"Medical Knowledge: {doc.get('content', '')}\n"
                context_text += f"Clinical Advice: {doc.get('advice', '')}\n"
                context_text += f"Category: {doc.get('category', '')}\n\n"
        
        return retrieved_docs, context_text
        
    except Exception as e:
        print(f"❌ Retrieval error: {e}")
        return [], f"Retrieval error: {e}"

def generate_qwen_advice(classification_result, retrieved_context):
    """Generate medical advice using Qwen with RAG context"""
    global qwen_model, qwen_tokenizer
    
    if qwen_model is None:
        return "❌ Qwen model not loaded"
    
    try:
        print("🔄 Generating Qwen advice...")
        
        # Create comprehensive prompt
        prompt = f"""You are a medical AI assistant. Based on the MRI classification and medical knowledge provided, give clinical recommendations.

MRI Classification: {classification_result}

Retrieved Medical Knowledge:
{retrieved_context}

Provide specific clinical recommendations including treatment options and follow-up care:"""
        
        print(f"📝 Prompt length: {len(prompt)} characters")
        
        # Tokenize
        inputs = qwen_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
        
        print(f"🔧 Input tokens: {inputs.input_ids.shape}")
        
        # Generate
        with torch.no_grad():
            outputs = qwen_model.generate(
                inputs.input_ids,
                max_new_tokens=120,
                temperature=0.7,
                do_sample=True,
                pad_token_id=qwen_tokenizer.eos_token_id,
                eos_token_id=qwen_tokenizer.eos_token_id
            )
        
        # Decode only the new tokens
        generated_tokens = outputs[0][inputs.input_ids.shape[1]:]
        generated_text = qwen_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
        
        print(f"✅ Generated: {generated_text[:100]}...")
        
        if len(generated_text) < 10:
            return "No specific recommendations generated. Consult medical professional."
        
        return generated_text
        
    except Exception as e:
        print(f"❌ Qwen generation error: {e}")
        import traceback
        print(traceback.format_exc())
        return f"Generation error: {e}"

def complete_analysis(image):
    """Complete pipeline with RAG"""
    
    if image is None:
        return "❌ Please upload an MRI scan", ""
    
    # Step 1: Classification  
    try:
        if biomedclip_model is None:
            return "❌ BiomedCLIP not loaded", ""
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        inputs = biomedclip_processor(images=image, return_tensors="pt")
        
        with torch.no_grad():
            outputs = biomedclip_model(**inputs)
            logits = outputs['logits']
            probabilities = F.softmax(logits, dim=1)
        
        top_prob, top_idx = torch.max(probabilities, 1)
        class_idx = top_idx.item()
        
        if class_idx in biomedclip_id2label:
            class_name = biomedclip_id2label[class_idx]
        elif str(class_idx) in biomedclip_id2label:
            class_name = biomedclip_id2label[str(class_idx)]
        else:
            class_name = f"Class_{class_idx}"
        
        confidence = top_prob.item() * 100
        classification_result = f"{class_name} ({confidence:.1f}% confidence)"
        
        print(f"✅ Classification: {classification_result}")
        
    except Exception as e:
        return f"❌ Classification error: {e}", ""
    
    # Step 2: RAG retrieval
    retrieved_docs, context = retrieve_relevant_knowledge(classification_result)
    
    # Step 3: Qwen generation
    qwen_advice = generate_qwen_advice(classification_result, context)
    
    # Format outputs
    classification_text = f"""
# 🔬 **MRI Classification**

## 🎯 **Diagnosis:**
**{class_name}**

## 📊 **Confidence:**
**{confidence:.1f}%**
    """
    
    advice_text = f"""
# 🏥 **AI-Generated Medical Recommendations**

## 🤖 **Qwen Analysis:**
{qwen_advice}

## 📚 **Retrieved Medical Knowledge:**
{context if context else "No relevant knowledge retrieved"}

## 📋 **Retrieved Documents:**
{len(retrieved_docs)} documents found and used for advice generation

---
⚠️ **Disclaimer:** For educational purposes only. Always consult medical professionals.
    """
    
    return classification_text, advice_text

# Load models
print("🚀 Loading complete pipeline...")
biomedclip_loaded = load_biomedclip()
rag_loaded = load_rag_system()

# Create interface
with gr.Blocks(title="Medical RAG Pipeline") as app:
    
    gr.Markdown("# 🏥 Medical AI RAG Pipeline")
    gr.Markdown("**BiomedCLIP** → **RAG Retrieval** → **Qwen Generation**")
    
    status = f"BiomedCLIP: {'✅' if biomedclip_loaded else '❌'} | RAG: {'✅' if rag_loaded else '❌'}"
    gr.Markdown(f"**Status:** {status}")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil", label="📸 Upload MRI Scan")
            analyze_btn = gr.Button("🔬 Complete RAG Analysis", variant="primary")
        
        with gr.Column():
            classification_output = gr.Markdown(label="🔬 Classification")
            advice_output = gr.Markdown(label="🏥 RAG-Generated Advice")
    
    analyze_btn.click(
        fn=complete_analysis,
        inputs=image_input,
        outputs=[classification_output, advice_output]
    )
    
    gr.Markdown("""
    ### 🔄 **RAG Pipeline Process:**
    1. **Image Classification** - BiomedCLIP analyzes MRI
    2. **Knowledge Retrieval** - Find relevant medical documents  
    3. **Context Generation** - Qwen uses retrieved knowledge
    4. **Advice Output** - AI-generated clinical recommendations
    
    ### 📚 **Knowledge Base:**
    - ACL injury types and symptoms
    - Treatment recommendations
    - Clinical guidelines
    - Follow-up protocols
    """)

if __name__ == "__main__":
    app.launch()