Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor | |
| from sentence_transformers import SentenceTransformer | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import json | |
| import numpy as np | |
| import faiss | |
| # Model repositories | |
| BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip" | |
| QWEN_RAG_REPO = "AssanaliAidarkhan/qwen-medical-rag" | |
| # Global variables | |
| biomedclip_model = None | |
| biomedclip_processor = None | |
| biomedclip_id2label = {} | |
| qwen_model = None | |
| qwen_tokenizer = None | |
| embedding_model = None | |
| medical_knowledge = [] | |
| faiss_index = None | |
| class CLIPClassifier(nn.Module): | |
| def __init__(self, clip_model, num_classes): | |
| super(CLIPClassifier, self).__init__() | |
| self.clip_model = clip_model | |
| self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes) | |
| def forward(self, **inputs): | |
| outputs = self.clip_model.get_image_features(**inputs) | |
| logits = self.classifier(outputs) | |
| return {'logits': logits} | |
| def load_biomedclip(): | |
| """Load BiomedCLIP (we know this works)""" | |
| global biomedclip_model, biomedclip_processor, biomedclip_id2label | |
| try: | |
| model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin") | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| num_classes = checkpoint['num_classes'] | |
| biomedclip_id2label = checkpoint['id2label'] | |
| model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16') | |
| biomedclip_processor = CLIPProcessor.from_pretrained(model_name) | |
| clip_model = CLIPModel.from_pretrained(model_name) | |
| biomedclip_model = CLIPClassifier(clip_model, num_classes) | |
| biomedclip_model.load_state_dict(checkpoint['model_state_dict']) | |
| biomedclip_model.eval() | |
| print("✅ BiomedCLIP loaded!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ BiomedCLIP error: {e}") | |
| return False | |
| def load_rag_system(): | |
| """Load complete RAG system""" | |
| global qwen_model, qwen_tokenizer, embedding_model, medical_knowledge, faiss_index | |
| try: | |
| print("🔄 Loading RAG system...") | |
| # 1. Load medical knowledge base | |
| try: | |
| knowledge_path = hf_hub_download(repo_id=QWEN_RAG_REPO, filename="medical_knowledge.json") | |
| with open(knowledge_path, 'r', encoding='utf-8') as f: | |
| medical_knowledge = json.load(f) | |
| print(f"✅ Knowledge base: {len(medical_knowledge)} documents") | |
| except Exception as e: | |
| print(f"⚠️ Knowledge loading error: {e}, using fallback") | |
| # Fallback knowledge base | |
| medical_knowledge = [ | |
| { | |
| "id": "doc1", | |
| "title": "Частичное повреждения передней крестообразной связки", | |
| "content": "Признаки частичного повреждения передней крестообразной связки: утолщение, повышенный сигнал по Т2, частичная дезорганизация волокон, связка прослеживается по ходу", | |
| "category": "Partial ACL injury", | |
| "advice": "Recommend conservative treatment, physical therapy, follow-up MRI in 6-8 weeks" | |
| }, | |
| { | |
| "id": "doc2", | |
| "title": "Полный разрыв передней крестообразной связки", | |
| "content": "Признаки полного разрыва передней крестообразной связки: волокна не прослеживаются по ходу, определяется зона повышенного сигнала в проекции связки, гемартроз", | |
| "category": "Complete ACL tear", | |
| "advice": "Urgent orthopedic consultation, likely requires ACL reconstruction surgery" | |
| } | |
| ] | |
| # 2. Load embedding model | |
| print("🔄 Loading embeddings...") | |
| embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # 3. Create embeddings and FAISS index | |
| print("🔄 Creating FAISS index...") | |
| text_contents = [] | |
| for doc in medical_knowledge: | |
| text = f"{doc.get('title', '')} {doc.get('content', '')} {doc.get('advice', '')}" | |
| text_contents.append(text) | |
| embeddings = embedding_model.encode(text_contents, convert_to_numpy=True) | |
| # Create FAISS index | |
| dimension = embeddings.shape[1] | |
| faiss_index = faiss.IndexFlatIP(dimension) | |
| faiss.normalize_L2(embeddings) | |
| faiss_index.add(embeddings) | |
| print(f"✅ FAISS index created with {faiss_index.ntotal} documents") | |
| # 4. Load Qwen | |
| print("🔄 Loading Qwen...") | |
| qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True) | |
| qwen_model = AutoModelForCausalLM.from_pretrained( | |
| "Qwen/Qwen1.5-0.5B-Chat", | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ) | |
| print("✅ RAG system loaded completely!") | |
| return True | |
| except Exception as e: | |
| print(f"❌ RAG loading error: {e}") | |
| import traceback | |
| print(traceback.format_exc()) | |
| return False | |
| def retrieve_relevant_knowledge(classification_result): | |
| """Retrieve relevant medical documents""" | |
| global embedding_model, medical_knowledge, faiss_index | |
| if faiss_index is None: | |
| return [], "No knowledge base available" | |
| try: | |
| # Create query for retrieval | |
| query = f"Medical diagnosis {classification_result} treatment recommendations clinical advice" | |
| # Get query embedding | |
| query_embedding = embedding_model.encode([query], convert_to_numpy=True) | |
| faiss.normalize_L2(query_embedding) | |
| # Search FAISS index | |
| scores, indices = faiss_index.search(query_embedding, 2) # Top 2 documents | |
| # Get relevant documents | |
| retrieved_docs = [] | |
| context_text = "" | |
| for score, idx in zip(scores[0], indices[0]): | |
| if idx != -1 and idx < len(medical_knowledge): | |
| doc = medical_knowledge[idx] | |
| retrieved_docs.append((doc, float(score))) | |
| context_text += f"Medical Knowledge: {doc.get('content', '')}\n" | |
| context_text += f"Clinical Advice: {doc.get('advice', '')}\n" | |
| context_text += f"Category: {doc.get('category', '')}\n\n" | |
| return retrieved_docs, context_text | |
| except Exception as e: | |
| print(f"❌ Retrieval error: {e}") | |
| return [], f"Retrieval error: {e}" | |
| def generate_qwen_advice(classification_result, retrieved_context): | |
| """Generate medical advice using Qwen with RAG context""" | |
| global qwen_model, qwen_tokenizer | |
| if qwen_model is None: | |
| return "❌ Qwen model not loaded" | |
| try: | |
| print("🔄 Generating Qwen advice...") | |
| # Create comprehensive prompt | |
| prompt = f"""You are a medical AI assistant. Based on the MRI classification and medical knowledge provided, give clinical recommendations. | |
| MRI Classification: {classification_result} | |
| Retrieved Medical Knowledge: | |
| {retrieved_context} | |
| Provide specific clinical recommendations including treatment options and follow-up care:""" | |
| print(f"📝 Prompt length: {len(prompt)} characters") | |
| # Tokenize | |
| inputs = qwen_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) | |
| print(f"🔧 Input tokens: {inputs.input_ids.shape}") | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = qwen_model.generate( | |
| inputs.input_ids, | |
| max_new_tokens=120, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=qwen_tokenizer.eos_token_id, | |
| eos_token_id=qwen_tokenizer.eos_token_id | |
| ) | |
| # Decode only the new tokens | |
| generated_tokens = outputs[0][inputs.input_ids.shape[1]:] | |
| generated_text = qwen_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() | |
| print(f"✅ Generated: {generated_text[:100]}...") | |
| if len(generated_text) < 10: | |
| return "No specific recommendations generated. Consult medical professional." | |
| return generated_text | |
| except Exception as e: | |
| print(f"❌ Qwen generation error: {e}") | |
| import traceback | |
| print(traceback.format_exc()) | |
| return f"Generation error: {e}" | |
| def complete_analysis(image): | |
| """Complete pipeline with RAG""" | |
| if image is None: | |
| return "❌ Please upload an MRI scan", "" | |
| # Step 1: Classification | |
| try: | |
| if biomedclip_model is None: | |
| return "❌ BiomedCLIP not loaded", "" | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| inputs = biomedclip_processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = biomedclip_model(**inputs) | |
| logits = outputs['logits'] | |
| probabilities = F.softmax(logits, dim=1) | |
| top_prob, top_idx = torch.max(probabilities, 1) | |
| class_idx = top_idx.item() | |
| if class_idx in biomedclip_id2label: | |
| class_name = biomedclip_id2label[class_idx] | |
| elif str(class_idx) in biomedclip_id2label: | |
| class_name = biomedclip_id2label[str(class_idx)] | |
| else: | |
| class_name = f"Class_{class_idx}" | |
| confidence = top_prob.item() * 100 | |
| classification_result = f"{class_name} ({confidence:.1f}% confidence)" | |
| print(f"✅ Classification: {classification_result}") | |
| except Exception as e: | |
| return f"❌ Classification error: {e}", "" | |
| # Step 2: RAG retrieval | |
| retrieved_docs, context = retrieve_relevant_knowledge(classification_result) | |
| # Step 3: Qwen generation | |
| qwen_advice = generate_qwen_advice(classification_result, context) | |
| # Format outputs | |
| classification_text = f""" | |
| # 🔬 **MRI Classification** | |
| ## 🎯 **Diagnosis:** | |
| **{class_name}** | |
| ## 📊 **Confidence:** | |
| **{confidence:.1f}%** | |
| """ | |
| advice_text = f""" | |
| # 🏥 **AI-Generated Medical Recommendations** | |
| ## 🤖 **Qwen Analysis:** | |
| {qwen_advice} | |
| ## 📚 **Retrieved Medical Knowledge:** | |
| {context if context else "No relevant knowledge retrieved"} | |
| ## 📋 **Retrieved Documents:** | |
| {len(retrieved_docs)} documents found and used for advice generation | |
| --- | |
| ⚠️ **Disclaimer:** For educational purposes only. Always consult medical professionals. | |
| """ | |
| return classification_text, advice_text | |
| # Load models | |
| print("🚀 Loading complete pipeline...") | |
| biomedclip_loaded = load_biomedclip() | |
| rag_loaded = load_rag_system() | |
| # Create interface | |
| with gr.Blocks(title="Medical RAG Pipeline") as app: | |
| gr.Markdown("# 🏥 Medical AI RAG Pipeline") | |
| gr.Markdown("**BiomedCLIP** → **RAG Retrieval** → **Qwen Generation**") | |
| status = f"BiomedCLIP: {'✅' if biomedclip_loaded else '❌'} | RAG: {'✅' if rag_loaded else '❌'}" | |
| gr.Markdown(f"**Status:** {status}") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="📸 Upload MRI Scan") | |
| analyze_btn = gr.Button("🔬 Complete RAG Analysis", variant="primary") | |
| with gr.Column(): | |
| classification_output = gr.Markdown(label="🔬 Classification") | |
| advice_output = gr.Markdown(label="🏥 RAG-Generated Advice") | |
| analyze_btn.click( | |
| fn=complete_analysis, | |
| inputs=image_input, | |
| outputs=[classification_output, advice_output] | |
| ) | |
| gr.Markdown(""" | |
| ### 🔄 **RAG Pipeline Process:** | |
| 1. **Image Classification** - BiomedCLIP analyzes MRI | |
| 2. **Knowledge Retrieval** - Find relevant medical documents | |
| 3. **Context Generation** - Qwen uses retrieved knowledge | |
| 4. **Advice Output** - AI-generated clinical recommendations | |
| ### 📚 **Knowledge Base:** | |
| - ACL injury types and symptoms | |
| - Treatment recommendations | |
| - Clinical guidelines | |
| - Follow-up protocols | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() |