Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import torch.nn.functional as F | |
import json | |
# Model repositories | |
BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip" | |
QWEN_REPO = "AssanaliAidarkhan/qwen-medical-rag" | |
# Global variables | |
biomedclip_model = None | |
biomedclip_processor = None | |
biomedclip_id2label = {} | |
qwen_model = None | |
qwen_tokenizer = None | |
medical_knowledge = [] | |
class CLIPClassifier(nn.Module): | |
def __init__(self, clip_model, num_classes): | |
super(CLIPClassifier, self).__init__() | |
self.clip_model = clip_model | |
self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes) | |
def forward(self, **inputs): | |
outputs = self.clip_model.get_image_features(**inputs) | |
logits = self.classifier(outputs) | |
return {'logits': logits} | |
def load_biomedclip(): | |
"""Load BiomedCLIP model""" | |
global biomedclip_model, biomedclip_processor, biomedclip_id2label | |
try: | |
print("🔄 Loading BiomedCLIP...") | |
model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin") | |
checkpoint = torch.load(model_path, map_location='cpu') | |
num_classes = checkpoint['num_classes'] | |
biomedclip_id2label = checkpoint['id2label'] | |
model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16') | |
print(f"📊 BiomedCLIP classes: {list(biomedclip_id2label.values())}") | |
biomedclip_processor = CLIPProcessor.from_pretrained(model_name) | |
clip_model = CLIPModel.from_pretrained(model_name) | |
biomedclip_model = CLIPClassifier(clip_model, num_classes) | |
biomedclip_model.load_state_dict(checkpoint['model_state_dict']) | |
biomedclip_model.eval() | |
print("✅ BiomedCLIP loaded successfully!") | |
return True | |
except Exception as e: | |
print(f"❌ BiomedCLIP loading error: {e}") | |
return False | |
def load_qwen_and_knowledge(): | |
"""Load Qwen model and knowledge base""" | |
global qwen_model, qwen_tokenizer, medical_knowledge | |
try: | |
print("🔄 Loading Qwen and knowledge base...") | |
# Load Qwen model | |
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True) | |
qwen_model = AutoModelForCausalLM.from_pretrained( | |
"Qwen/Qwen1.5-0.5B-Chat", | |
torch_dtype=torch.float32, | |
trust_remote_code=True | |
) | |
qwen_model.eval() | |
print("✅ Qwen model loaded") | |
# Load knowledge base with correct categories | |
try: | |
knowledge_path = hf_hub_download(repo_id=QWEN_REPO, filename="medical_knowledge.json") | |
with open(knowledge_path, 'r', encoding='utf-8') as f: | |
medical_knowledge = json.load(f) | |
print(f"✅ Knowledge base loaded: {len(medical_knowledge)} documents") | |
print("📚 Knowledge categories:") | |
for doc in medical_knowledge: | |
print(f" - {doc.get('category', 'Unknown')}") | |
except Exception as e: | |
print(f"⚠️ Knowledge loading error: {e}") | |
return False | |
return True | |
except Exception as e: | |
print(f"❌ Qwen loading error: {e}") | |
return False | |
def find_relevant_knowledge(condition): | |
"""Find knowledge document that exactly matches the condition""" | |
global medical_knowledge | |
print(f"🔍 Searching for condition: {condition}") | |
# Look for exact category match first | |
for doc in medical_knowledge: | |
doc_category = doc.get('category', '').lower() | |
if condition.lower() == doc_category: | |
print(f"✅ Exact match found: {doc_category}") | |
return doc | |
# If no exact match, look for partial matches | |
for doc in medical_knowledge: | |
doc_category = doc.get('category', '').lower() | |
condition_words = condition.lower().replace('_', ' ').split() | |
category_words = doc_category.replace('_', ' ').split() | |
# Check if most words match | |
matches = sum(1 for word in condition_words if word in category_words) | |
if matches >= len(condition_words) - 1: # Allow 1 word difference | |
print(f"✅ Partial match found: {doc_category}") | |
return doc | |
print(f"❌ No matching document found for: {condition}") | |
return None | |
def generate_medical_advice(classification_result): | |
"""Generate medical advice using Qwen with relevant knowledge""" | |
global qwen_model, qwen_tokenizer | |
if qwen_model is None: | |
return "❌ Qwen model not available" | |
try: | |
# Extract condition name | |
condition = classification_result.split('(')[0].strip() | |
print(f"🔄 Generating advice for: {condition}") | |
# Find relevant knowledge | |
relevant_doc = find_relevant_knowledge(condition) | |
if relevant_doc: | |
# Use the knowledge document | |
medical_context = relevant_doc.get('content', '') | |
clinical_advice = relevant_doc.get('advice', '') | |
prompt = f"""Medical findings: {medical_context} | |
Clinical guidelines: {clinical_advice} | |
Provide specific treatment recommendations: | |
Treatment plan:""" | |
else: | |
# Fallback prompt without specific knowledge | |
prompt = f"""Patient diagnosed with {condition} | |
Provide clinical treatment recommendations: | |
Treatment plan:""" | |
print(f"📝 Prompt created (length: {len(prompt)})") | |
# Tokenize | |
inputs = qwen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=300) | |
# Generate | |
with torch.no_grad(): | |
outputs = qwen_model.generate( | |
inputs.input_ids, | |
max_new_tokens=80, | |
temperature=0.7, | |
do_sample=True, | |
top_p=0.9, | |
repetition_penalty=1.1, | |
pad_token_id=qwen_tokenizer.eos_token_id | |
) | |
# Decode | |
full_output = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract generated part | |
if "Treatment plan:" in full_output: | |
generated_advice = full_output.split("Treatment plan:", 1)[-1].strip() | |
else: | |
generated_advice = full_output.replace(prompt, "").strip() | |
print(f"✅ Generated advice: {generated_advice[:100]}...") | |
# Format response - simplified | |
if relevant_doc: | |
formatted_response = f"""## 🤖 AI-Generated Recommendations: | |
{generated_advice}""" | |
else: | |
formatted_response = f"""## 🤖 AI-Generated Recommendations: | |
{generated_advice}""" | |
return formatted_response | |
except Exception as e: | |
print(f"❌ Generation error: {e}") | |
return f"Error generating advice: {e}" | |
def complete_medical_analysis(image): | |
"""Complete pipeline: Classification + Medical Advice""" | |
if image is None: | |
return "❌ Please upload an MRI scan", "" | |
# Step 1: BiomedCLIP Classification | |
try: | |
if biomedclip_model is None: | |
return "❌ BiomedCLIP model not loaded", "" | |
# Preprocess image | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
inputs = biomedclip_processor(images=image, return_tensors="pt") | |
# Classify | |
with torch.no_grad(): | |
outputs = biomedclip_model(**inputs) | |
logits = outputs['logits'] | |
probabilities = F.softmax(logits, dim=1) | |
# Get prediction | |
top_prob, top_idx = torch.max(probabilities, 1) | |
class_idx = top_idx.item() | |
if class_idx in biomedclip_id2label: | |
class_name = biomedclip_id2label[class_idx] | |
elif str(class_idx) in biomedclip_id2label: | |
class_name = biomedclip_id2label[str(class_idx)] | |
else: | |
class_name = f"Class_{class_idx}" | |
confidence = top_prob.item() * 100 | |
classification_result = f"{class_name} ({confidence:.1f}% confidence)" | |
print(f"✅ Classification: {classification_result}") | |
except Exception as e: | |
return f"❌ Classification error: {e}", "" | |
# Step 2: Medical Advice Generation | |
medical_advice = generate_medical_advice(classification_result) | |
# Format outputs | |
classification_text = f""" | |
# 🔬 Результаты классификации МРТ | |
## 🎯 Диагноз: | |
# **{class_name}** | |
{find_relevant_knowledge(class_name).get('content', 'No medical description available') if find_relevant_knowledge(class_name) else 'No medical description available'} | |
""" | |
advice_text = f""" | |
# 🏥 Clinical Recommendations | |
{medical_advice} | |
--- | |
⚠️ **Medical Disclaimer:** This analysis is for educational and research purposes only. Always consult qualified medical professionals for clinical decisions. | |
""" | |
return classification_text, advice_text | |
# Load models on startup | |
print("🚀 Initializing Medical AI Pipeline...") | |
biomedclip_loaded = load_biomedclip() | |
qwen_loaded = load_qwen_and_knowledge() | |
# Create Gradio interface | |
with gr.Blocks(title="Medical AI Pipeline", theme=gr.themes.Soft()) as app: | |
gr.Markdown(""" | |
# 🏥 Диагностика травмы коленного сустава | |
Загрузите снимок МРТ, чтобы получить автоматическую классификацию и клинические рекомендации. | |
""") | |
# Status indicators | |
status_text = f"BiomedCLIP: {'✅ Модель загружена успешно' if biomedclip_loaded else '❌ Failed'} | Qwen RAG: {'✅ Модель загружена успешно' if qwen_loaded else '❌ Failed'}" | |
gr.Markdown(f"**System Status:** {status_text}") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
type="pil", | |
label="📸 Загрузите МРТ скан", | |
height=400 | |
) | |
analyze_btn = gr.Button("🔬 Начать Анализ МРТ", variant="primary", size="lg") | |
clear_btn = gr.Button("🗑️ Очистить Результаты", variant="secondary") | |
with gr.Column(scale=2): | |
classification_output = gr.Markdown(label="🔬 MRI Classification") | |
advice_output = gr.Markdown(label="🏥 Medical Recommendations") | |
# Event handlers | |
analyze_btn.click( | |
fn=complete_medical_analysis, | |
inputs=image_input, | |
outputs=[classification_output, advice_output] | |
) | |
clear_btn.click( | |
fn=lambda: [None, "", ""], | |
outputs=[image_input, classification_output, advice_output] | |
) | |
gr.Markdown(""" | |
###Разработано в лабораторий Назарбаев Университет: | |
""") | |
if __name__ == "__main__": | |
app.launch() |