Biomedclip / app.py
AssanaliAidarkhan's picture
Update app.py
8f63437 verified
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor
from huggingface_hub import hf_hub_download
from PIL import Image
import torch.nn.functional as F
import json
# Model repositories
BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
QWEN_REPO = "AssanaliAidarkhan/qwen-medical-rag"
# Global variables
biomedclip_model = None
biomedclip_processor = None
biomedclip_id2label = {}
qwen_model = None
qwen_tokenizer = None
medical_knowledge = []
class CLIPClassifier(nn.Module):
def __init__(self, clip_model, num_classes):
super(CLIPClassifier, self).__init__()
self.clip_model = clip_model
self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes)
def forward(self, **inputs):
outputs = self.clip_model.get_image_features(**inputs)
logits = self.classifier(outputs)
return {'logits': logits}
def load_biomedclip():
"""Load BiomedCLIP model"""
global biomedclip_model, biomedclip_processor, biomedclip_id2label
try:
print("🔄 Loading BiomedCLIP...")
model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin")
checkpoint = torch.load(model_path, map_location='cpu')
num_classes = checkpoint['num_classes']
biomedclip_id2label = checkpoint['id2label']
model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16')
print(f"📊 BiomedCLIP classes: {list(biomedclip_id2label.values())}")
biomedclip_processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name)
biomedclip_model = CLIPClassifier(clip_model, num_classes)
biomedclip_model.load_state_dict(checkpoint['model_state_dict'])
biomedclip_model.eval()
print("✅ BiomedCLIP loaded successfully!")
return True
except Exception as e:
print(f"❌ BiomedCLIP loading error: {e}")
return False
def load_qwen_and_knowledge():
"""Load Qwen model and knowledge base"""
global qwen_model, qwen_tokenizer, medical_knowledge
try:
print("🔄 Loading Qwen and knowledge base...")
# Load Qwen model
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen1.5-0.5B-Chat",
torch_dtype=torch.float32,
trust_remote_code=True
)
qwen_model.eval()
print("✅ Qwen model loaded")
# Load knowledge base with correct categories
try:
knowledge_path = hf_hub_download(repo_id=QWEN_REPO, filename="medical_knowledge.json")
with open(knowledge_path, 'r', encoding='utf-8') as f:
medical_knowledge = json.load(f)
print(f"✅ Knowledge base loaded: {len(medical_knowledge)} documents")
print("📚 Knowledge categories:")
for doc in medical_knowledge:
print(f" - {doc.get('category', 'Unknown')}")
except Exception as e:
print(f"⚠️ Knowledge loading error: {e}")
return False
return True
except Exception as e:
print(f"❌ Qwen loading error: {e}")
return False
def find_relevant_knowledge(condition):
"""Find knowledge document that exactly matches the condition"""
global medical_knowledge
print(f"🔍 Searching for condition: {condition}")
# Look for exact category match first
for doc in medical_knowledge:
doc_category = doc.get('category', '').lower()
if condition.lower() == doc_category:
print(f"✅ Exact match found: {doc_category}")
return doc
# If no exact match, look for partial matches
for doc in medical_knowledge:
doc_category = doc.get('category', '').lower()
condition_words = condition.lower().replace('_', ' ').split()
category_words = doc_category.replace('_', ' ').split()
# Check if most words match
matches = sum(1 for word in condition_words if word in category_words)
if matches >= len(condition_words) - 1: # Allow 1 word difference
print(f"✅ Partial match found: {doc_category}")
return doc
print(f"❌ No matching document found for: {condition}")
return None
def generate_medical_advice(classification_result):
"""Generate medical advice using Qwen with relevant knowledge"""
global qwen_model, qwen_tokenizer
if qwen_model is None:
return "❌ Qwen model not available"
try:
# Extract condition name
condition = classification_result.split('(')[0].strip()
print(f"🔄 Generating advice for: {condition}")
# Find relevant knowledge
relevant_doc = find_relevant_knowledge(condition)
if relevant_doc:
# Use the knowledge document
medical_context = relevant_doc.get('content', '')
clinical_advice = relevant_doc.get('advice', '')
prompt = f"""Medical findings: {medical_context}
Clinical guidelines: {clinical_advice}
Provide specific treatment recommendations:
Treatment plan:"""
else:
# Fallback prompt without specific knowledge
prompt = f"""Patient diagnosed with {condition}
Provide clinical treatment recommendations:
Treatment plan:"""
print(f"📝 Prompt created (length: {len(prompt)})")
# Tokenize
inputs = qwen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=300)
# Generate
with torch.no_grad():
outputs = qwen_model.generate(
inputs.input_ids,
max_new_tokens=80,
temperature=0.7,
do_sample=True,
top_p=0.9,
repetition_penalty=1.1,
pad_token_id=qwen_tokenizer.eos_token_id
)
# Decode
full_output = qwen_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract generated part
if "Treatment plan:" in full_output:
generated_advice = full_output.split("Treatment plan:", 1)[-1].strip()
else:
generated_advice = full_output.replace(prompt, "").strip()
print(f"✅ Generated advice: {generated_advice[:100]}...")
# Format response - simplified
if relevant_doc:
formatted_response = f"""## 🤖 AI-Generated Recommendations:
{generated_advice}"""
else:
formatted_response = f"""## 🤖 AI-Generated Recommendations:
{generated_advice}"""
return formatted_response
except Exception as e:
print(f"❌ Generation error: {e}")
return f"Error generating advice: {e}"
def complete_medical_analysis(image):
"""Complete pipeline: Classification + Medical Advice"""
if image is None:
return "❌ Please upload an MRI scan", ""
# Step 1: BiomedCLIP Classification
try:
if biomedclip_model is None:
return "❌ BiomedCLIP model not loaded", ""
# Preprocess image
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = biomedclip_processor(images=image, return_tensors="pt")
# Classify
with torch.no_grad():
outputs = biomedclip_model(**inputs)
logits = outputs['logits']
probabilities = F.softmax(logits, dim=1)
# Get prediction
top_prob, top_idx = torch.max(probabilities, 1)
class_idx = top_idx.item()
if class_idx in biomedclip_id2label:
class_name = biomedclip_id2label[class_idx]
elif str(class_idx) in biomedclip_id2label:
class_name = biomedclip_id2label[str(class_idx)]
else:
class_name = f"Class_{class_idx}"
confidence = top_prob.item() * 100
classification_result = f"{class_name} ({confidence:.1f}% confidence)"
print(f"✅ Classification: {classification_result}")
except Exception as e:
return f"❌ Classification error: {e}", ""
# Step 2: Medical Advice Generation
medical_advice = generate_medical_advice(classification_result)
# Format outputs
classification_text = f"""
# 🔬 Результаты классификации МРТ
## 🎯 Диагноз:
# **{class_name}**
{find_relevant_knowledge(class_name).get('content', 'No medical description available') if find_relevant_knowledge(class_name) else 'No medical description available'}
"""
advice_text = f"""
# 🏥 Clinical Recommendations
{medical_advice}
---
⚠️ **Medical Disclaimer:** This analysis is for educational and research purposes only. Always consult qualified medical professionals for clinical decisions.
"""
return classification_text, advice_text
# Load models on startup
print("🚀 Initializing Medical AI Pipeline...")
biomedclip_loaded = load_biomedclip()
qwen_loaded = load_qwen_and_knowledge()
# Create Gradio interface
with gr.Blocks(title="Medical AI Pipeline", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🏥 Диагностика травмы коленного сустава
Загрузите снимок МРТ, чтобы получить автоматическую классификацию и клинические рекомендации.
""")
# Status indicators
status_text = f"BiomedCLIP: {'✅ Модель загружена успешно' if biomedclip_loaded else '❌ Failed'} | Qwen RAG: {'✅ Модель загружена успешно' if qwen_loaded else '❌ Failed'}"
gr.Markdown(f"**System Status:** {status_text}")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="📸 Загрузите МРТ скан",
height=400
)
analyze_btn = gr.Button("🔬 Начать Анализ МРТ", variant="primary", size="lg")
clear_btn = gr.Button("🗑️ Очистить Результаты", variant="secondary")
with gr.Column(scale=2):
classification_output = gr.Markdown(label="🔬 MRI Classification")
advice_output = gr.Markdown(label="🏥 Medical Recommendations")
# Event handlers
analyze_btn.click(
fn=complete_medical_analysis,
inputs=image_input,
outputs=[classification_output, advice_output]
)
clear_btn.click(
fn=lambda: [None, "", ""],
outputs=[image_input, classification_output, advice_output]
)
gr.Markdown("""
###Разработано в лабораторий Назарбаев Университет:
""")
if __name__ == "__main__":
app.launch()