import gradio as gr import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor from huggingface_hub import hf_hub_download from PIL import Image import torch.nn.functional as F import json # Model repositories BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip" QWEN_REPO = "AssanaliAidarkhan/qwen-medical-rag" # Global variables biomedclip_model = None biomedclip_processor = None biomedclip_id2label = {} qwen_model = None qwen_tokenizer = None medical_knowledge = [] class CLIPClassifier(nn.Module): def __init__(self, clip_model, num_classes): super(CLIPClassifier, self).__init__() self.clip_model = clip_model self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes) def forward(self, **inputs): outputs = self.clip_model.get_image_features(**inputs) logits = self.classifier(outputs) return {'logits': logits} def load_biomedclip(): """Load BiomedCLIP model""" global biomedclip_model, biomedclip_processor, biomedclip_id2label try: print("🔄 Loading BiomedCLIP...") model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin") checkpoint = torch.load(model_path, map_location='cpu') num_classes = checkpoint['num_classes'] biomedclip_id2label = checkpoint['id2label'] model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16') print(f"📊 BiomedCLIP classes: {list(biomedclip_id2label.values())}") biomedclip_processor = CLIPProcessor.from_pretrained(model_name) clip_model = CLIPModel.from_pretrained(model_name) biomedclip_model = CLIPClassifier(clip_model, num_classes) biomedclip_model.load_state_dict(checkpoint['model_state_dict']) biomedclip_model.eval() print("✅ BiomedCLIP loaded successfully!") return True except Exception as e: print(f"❌ BiomedCLIP loading error: {e}") return False def load_qwen_and_knowledge(): """Load Qwen model and knowledge base""" global qwen_model, qwen_tokenizer, medical_knowledge try: print("🔄 Loading Qwen and knowledge base...") # Load Qwen model qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True) qwen_model = AutoModelForCausalLM.from_pretrained( "Qwen/Qwen1.5-0.5B-Chat", torch_dtype=torch.float32, trust_remote_code=True ) qwen_model.eval() print("✅ Qwen model loaded") # Load knowledge base with correct categories try: knowledge_path = hf_hub_download(repo_id=QWEN_REPO, filename="medical_knowledge.json") with open(knowledge_path, 'r', encoding='utf-8') as f: medical_knowledge = json.load(f) print(f"✅ Knowledge base loaded: {len(medical_knowledge)} documents") print("📚 Knowledge categories:") for doc in medical_knowledge: print(f" - {doc.get('category', 'Unknown')}") except Exception as e: print(f"⚠️ Knowledge loading error: {e}") return False return True except Exception as e: print(f"❌ Qwen loading error: {e}") return False def find_relevant_knowledge(condition): """Find knowledge document that exactly matches the condition""" global medical_knowledge print(f"🔍 Searching for condition: {condition}") # Look for exact category match first for doc in medical_knowledge: doc_category = doc.get('category', '').lower() if condition.lower() == doc_category: print(f"✅ Exact match found: {doc_category}") return doc # If no exact match, look for partial matches for doc in medical_knowledge: doc_category = doc.get('category', '').lower() condition_words = condition.lower().replace('_', ' ').split() category_words = doc_category.replace('_', ' ').split() # Check if most words match matches = sum(1 for word in condition_words if word in category_words) if matches >= len(condition_words) - 1: # Allow 1 word difference print(f"✅ Partial match found: {doc_category}") return doc print(f"❌ No matching document found for: {condition}") return None def generate_medical_advice(classification_result): """Generate medical advice using Qwen with relevant knowledge""" global qwen_model, qwen_tokenizer if qwen_model is None: return "❌ Qwen model not available" try: # Extract condition name condition = classification_result.split('(')[0].strip() print(f"🔄 Generating advice for: {condition}") # Find relevant knowledge relevant_doc = find_relevant_knowledge(condition) if relevant_doc: # Use the knowledge document medical_context = relevant_doc.get('content', '') clinical_advice = relevant_doc.get('advice', '') prompt = f"""Medical findings: {medical_context} Clinical guidelines: {clinical_advice} Provide specific treatment recommendations: Treatment plan:""" else: # Fallback prompt without specific knowledge prompt = f"""Patient diagnosed with {condition} Provide clinical treatment recommendations: Treatment plan:""" print(f"📝 Prompt created (length: {len(prompt)})") # Tokenize inputs = qwen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=300) # Generate with torch.no_grad(): outputs = qwen_model.generate( inputs.input_ids, max_new_tokens=80, temperature=0.7, do_sample=True, top_p=0.9, repetition_penalty=1.1, pad_token_id=qwen_tokenizer.eos_token_id ) # Decode full_output = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract generated part if "Treatment plan:" in full_output: generated_advice = full_output.split("Treatment plan:", 1)[-1].strip() else: generated_advice = full_output.replace(prompt, "").strip() print(f"✅ Generated advice: {generated_advice[:100]}...") # Format response - simplified if relevant_doc: formatted_response = f"""## 🤖 AI-Generated Recommendations: {generated_advice}""" else: formatted_response = f"""## 🤖 AI-Generated Recommendations: {generated_advice}""" return formatted_response except Exception as e: print(f"❌ Generation error: {e}") return f"Error generating advice: {e}" def complete_medical_analysis(image): """Complete pipeline: Classification + Medical Advice""" if image is None: return "❌ Please upload an MRI scan", "" # Step 1: BiomedCLIP Classification try: if biomedclip_model is None: return "❌ BiomedCLIP model not loaded", "" # Preprocess image if image.mode != 'RGB': image = image.convert('RGB') inputs = biomedclip_processor(images=image, return_tensors="pt") # Classify with torch.no_grad(): outputs = biomedclip_model(**inputs) logits = outputs['logits'] probabilities = F.softmax(logits, dim=1) # Get prediction top_prob, top_idx = torch.max(probabilities, 1) class_idx = top_idx.item() if class_idx in biomedclip_id2label: class_name = biomedclip_id2label[class_idx] elif str(class_idx) in biomedclip_id2label: class_name = biomedclip_id2label[str(class_idx)] else: class_name = f"Class_{class_idx}" confidence = top_prob.item() * 100 classification_result = f"{class_name} ({confidence:.1f}% confidence)" print(f"✅ Classification: {classification_result}") except Exception as e: return f"❌ Classification error: {e}", "" # Step 2: Medical Advice Generation medical_advice = generate_medical_advice(classification_result) # Format outputs classification_text = f""" # 🔬 Результаты классификации МРТ ## 🎯 Диагноз: # **{class_name}** {find_relevant_knowledge(class_name).get('content', 'No medical description available') if find_relevant_knowledge(class_name) else 'No medical description available'} """ advice_text = f""" # 🏥 Clinical Recommendations {medical_advice} --- ⚠️ **Medical Disclaimer:** This analysis is for educational and research purposes only. Always consult qualified medical professionals for clinical decisions. """ return classification_text, advice_text # Load models on startup print("🚀 Initializing Medical AI Pipeline...") biomedclip_loaded = load_biomedclip() qwen_loaded = load_qwen_and_knowledge() # Create Gradio interface with gr.Blocks(title="Medical AI Pipeline", theme=gr.themes.Soft()) as app: gr.Markdown(""" # 🏥 Диагностика травмы коленного сустава Загрузите снимок МРТ, чтобы получить автоматическую классификацию и клинические рекомендации. """) # Status indicators status_text = f"BiomedCLIP: {'✅ Модель загружена успешно' if biomedclip_loaded else '❌ Failed'} | Qwen RAG: {'✅ Модель загружена успешно' if qwen_loaded else '❌ Failed'}" gr.Markdown(f"**System Status:** {status_text}") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image( type="pil", label="📸 Загрузите МРТ скан", height=400 ) analyze_btn = gr.Button("🔬 Начать Анализ МРТ", variant="primary", size="lg") clear_btn = gr.Button("🗑️ Очистить Результаты", variant="secondary") with gr.Column(scale=2): classification_output = gr.Markdown(label="🔬 MRI Classification") advice_output = gr.Markdown(label="🏥 Medical Recommendations") # Event handlers analyze_btn.click( fn=complete_medical_analysis, inputs=image_input, outputs=[classification_output, advice_output] ) clear_btn.click( fn=lambda: [None, "", ""], outputs=[image_input, classification_output, advice_output] ) gr.Markdown(""" ###Разработано в лабораторий Назарбаев Университет: """) if __name__ == "__main__": app.launch()