Biomedclip / app.py
AssanaliAidarkhan's picture
Update app.py
9e90db6 verified
raw
history blame
12.8 kB
import gradio as gr
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM, CLIPModel, CLIPProcessor
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from PIL import Image
import torch.nn.functional as F
import json
import numpy as np
import faiss
# Model repositories
BIOMEDCLIP_REPO = "AssanaliAidarkhan/Biomedclip"
QWEN_RAG_REPO = "AssanaliAidarkhan/qwen-medical-rag"
# Global variables
biomedclip_model = None
biomedclip_processor = None
biomedclip_id2label = {}
qwen_model = None
qwen_tokenizer = None
embedding_model = None
medical_knowledge = []
faiss_index = None
class CLIPClassifier(nn.Module):
def __init__(self, clip_model, num_classes):
super(CLIPClassifier, self).__init__()
self.clip_model = clip_model
self.classifier = nn.Linear(clip_model.config.projection_dim, num_classes)
def forward(self, **inputs):
outputs = self.clip_model.get_image_features(**inputs)
logits = self.classifier(outputs)
return {'logits': logits}
def load_biomedclip():
"""Load BiomedCLIP (we know this works)"""
global biomedclip_model, biomedclip_processor, biomedclip_id2label
try:
model_path = hf_hub_download(repo_id=BIOMEDCLIP_REPO, filename="pytorch_model.bin")
checkpoint = torch.load(model_path, map_location='cpu')
num_classes = checkpoint['num_classes']
biomedclip_id2label = checkpoint['id2label']
model_name = checkpoint.get('model_name', 'openai/clip-vit-base-patch16')
biomedclip_processor = CLIPProcessor.from_pretrained(model_name)
clip_model = CLIPModel.from_pretrained(model_name)
biomedclip_model = CLIPClassifier(clip_model, num_classes)
biomedclip_model.load_state_dict(checkpoint['model_state_dict'])
biomedclip_model.eval()
print("✅ BiomedCLIP loaded!")
return True
except Exception as e:
print(f"❌ BiomedCLIP error: {e}")
return False
def load_rag_system():
"""Load complete RAG system"""
global qwen_model, qwen_tokenizer, embedding_model, medical_knowledge, faiss_index
try:
print("🔄 Loading RAG system...")
# 1. Load medical knowledge base
try:
knowledge_path = hf_hub_download(repo_id=QWEN_RAG_REPO, filename="medical_knowledge.json")
with open(knowledge_path, 'r', encoding='utf-8') as f:
medical_knowledge = json.load(f)
print(f"✅ Knowledge base: {len(medical_knowledge)} documents")
except Exception as e:
print(f"⚠️ Knowledge loading error: {e}, using fallback")
# Fallback knowledge base
medical_knowledge = [
{
"id": "doc1",
"title": "Частичное повреждения передней крестообразной связки",
"content": "Признаки частичного повреждения передней крестообразной связки: утолщение, повышенный сигнал по Т2, частичная дезорганизация волокон, связка прослеживается по ходу",
"category": "Partial ACL injury",
"advice": "Recommend conservative treatment, physical therapy, follow-up MRI in 6-8 weeks"
},
{
"id": "doc2",
"title": "Полный разрыв передней крестообразной связки",
"content": "Признаки полного разрыва передней крестообразной связки: волокна не прослеживаются по ходу, определяется зона повышенного сигнала в проекции связки, гемартроз",
"category": "Complete ACL tear",
"advice": "Urgent orthopedic consultation, likely requires ACL reconstruction surgery"
}
]
# 2. Load embedding model
print("🔄 Loading embeddings...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# 3. Create embeddings and FAISS index
print("🔄 Creating FAISS index...")
text_contents = []
for doc in medical_knowledge:
text = f"{doc.get('title', '')} {doc.get('content', '')} {doc.get('advice', '')}"
text_contents.append(text)
embeddings = embedding_model.encode(text_contents, convert_to_numpy=True)
# Create FAISS index
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatIP(dimension)
faiss.normalize_L2(embeddings)
faiss_index.add(embeddings)
print(f"✅ FAISS index created with {faiss_index.ntotal} documents")
# 4. Load Qwen
print("🔄 Loading Qwen...")
qwen_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B-Chat", trust_remote_code=True)
qwen_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen1.5-0.5B-Chat",
torch_dtype=torch.float32,
trust_remote_code=True
)
print("✅ RAG system loaded completely!")
return True
except Exception as e:
print(f"❌ RAG loading error: {e}")
import traceback
print(traceback.format_exc())
return False
def retrieve_relevant_knowledge(classification_result):
"""Retrieve relevant medical documents"""
global embedding_model, medical_knowledge, faiss_index
if faiss_index is None:
return [], "No knowledge base available"
try:
# Create query for retrieval
query = f"Medical diagnosis {classification_result} treatment recommendations clinical advice"
# Get query embedding
query_embedding = embedding_model.encode([query], convert_to_numpy=True)
faiss.normalize_L2(query_embedding)
# Search FAISS index
scores, indices = faiss_index.search(query_embedding, 2) # Top 2 documents
# Get relevant documents
retrieved_docs = []
context_text = ""
for score, idx in zip(scores[0], indices[0]):
if idx != -1 and idx < len(medical_knowledge):
doc = medical_knowledge[idx]
retrieved_docs.append((doc, float(score)))
context_text += f"Medical Knowledge: {doc.get('content', '')}\n"
context_text += f"Clinical Advice: {doc.get('advice', '')}\n"
context_text += f"Category: {doc.get('category', '')}\n\n"
return retrieved_docs, context_text
except Exception as e:
print(f"❌ Retrieval error: {e}")
return [], f"Retrieval error: {e}"
def generate_qwen_advice(classification_result, retrieved_context):
"""Generate medical advice using Qwen with RAG context"""
global qwen_model, qwen_tokenizer
if qwen_model is None:
return "❌ Qwen model not loaded"
try:
print("🔄 Generating Qwen advice...")
# Create comprehensive prompt
prompt = f"""You are a medical AI assistant. Based on the MRI classification and medical knowledge provided, give clinical recommendations.
MRI Classification: {classification_result}
Retrieved Medical Knowledge:
{retrieved_context}
Provide specific clinical recommendations including treatment options and follow-up care:"""
print(f"📝 Prompt length: {len(prompt)} characters")
# Tokenize
inputs = qwen_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
print(f"🔧 Input tokens: {inputs.input_ids.shape}")
# Generate
with torch.no_grad():
outputs = qwen_model.generate(
inputs.input_ids,
max_new_tokens=120,
temperature=0.7,
do_sample=True,
pad_token_id=qwen_tokenizer.eos_token_id,
eos_token_id=qwen_tokenizer.eos_token_id
)
# Decode only the new tokens
generated_tokens = outputs[0][inputs.input_ids.shape[1]:]
generated_text = qwen_tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
print(f"✅ Generated: {generated_text[:100]}...")
if len(generated_text) < 10:
return "No specific recommendations generated. Consult medical professional."
return generated_text
except Exception as e:
print(f"❌ Qwen generation error: {e}")
import traceback
print(traceback.format_exc())
return f"Generation error: {e}"
def complete_analysis(image):
"""Complete pipeline with RAG"""
if image is None:
return "❌ Please upload an MRI scan", ""
# Step 1: Classification
try:
if biomedclip_model is None:
return "❌ BiomedCLIP not loaded", ""
if image.mode != 'RGB':
image = image.convert('RGB')
inputs = biomedclip_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = biomedclip_model(**inputs)
logits = outputs['logits']
probabilities = F.softmax(logits, dim=1)
top_prob, top_idx = torch.max(probabilities, 1)
class_idx = top_idx.item()
if class_idx in biomedclip_id2label:
class_name = biomedclip_id2label[class_idx]
elif str(class_idx) in biomedclip_id2label:
class_name = biomedclip_id2label[str(class_idx)]
else:
class_name = f"Class_{class_idx}"
confidence = top_prob.item() * 100
classification_result = f"{class_name} ({confidence:.1f}% confidence)"
print(f"✅ Classification: {classification_result}")
except Exception as e:
return f"❌ Classification error: {e}", ""
# Step 2: RAG retrieval
retrieved_docs, context = retrieve_relevant_knowledge(classification_result)
# Step 3: Qwen generation
qwen_advice = generate_qwen_advice(classification_result, context)
# Format outputs
classification_text = f"""
# 🔬 **MRI Classification**
## 🎯 **Diagnosis:**
**{class_name}**
## 📊 **Confidence:**
**{confidence:.1f}%**
"""
advice_text = f"""
# 🏥 **AI-Generated Medical Recommendations**
## 🤖 **Qwen Analysis:**
{qwen_advice}
## 📚 **Retrieved Medical Knowledge:**
{context if context else "No relevant knowledge retrieved"}
## 📋 **Retrieved Documents:**
{len(retrieved_docs)} documents found and used for advice generation
---
⚠️ **Disclaimer:** For educational purposes only. Always consult medical professionals.
"""
return classification_text, advice_text
# Load models
print("🚀 Loading complete pipeline...")
biomedclip_loaded = load_biomedclip()
rag_loaded = load_rag_system()
# Create interface
with gr.Blocks(title="Medical RAG Pipeline") as app:
gr.Markdown("# 🏥 Medical AI RAG Pipeline")
gr.Markdown("**BiomedCLIP** → **RAG Retrieval** → **Qwen Generation**")
status = f"BiomedCLIP: {'✅' if biomedclip_loaded else '❌'} | RAG: {'✅' if rag_loaded else '❌'}"
gr.Markdown(f"**Status:** {status}")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="📸 Upload MRI Scan")
analyze_btn = gr.Button("🔬 Complete RAG Analysis", variant="primary")
with gr.Column():
classification_output = gr.Markdown(label="🔬 Classification")
advice_output = gr.Markdown(label="🏥 RAG-Generated Advice")
analyze_btn.click(
fn=complete_analysis,
inputs=image_input,
outputs=[classification_output, advice_output]
)
gr.Markdown("""
### 🔄 **RAG Pipeline Process:**
1. **Image Classification** - BiomedCLIP analyzes MRI
2. **Knowledge Retrieval** - Find relevant medical documents
3. **Context Generation** - Qwen uses retrieved knowledge
4. **Advice Output** - AI-generated clinical recommendations
### 📚 **Knowledge Base:**
- ACL injury types and symptoms
- Treatment recommendations
- Clinical guidelines
- Follow-up protocols
""")
if __name__ == "__main__":
app.launch()