image / app.py
mgbam's picture
Update app.py
4beb159 verified
import gradio as gr
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from janus.models import VLChatProcessor
from PIL import Image
import spaces
# Suppress specific warnings
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
# Medical Imaging Analysis Configuration
MEDICAL_CONFIG = {
"echo_guidelines": "ASE 2023 Standards",
"histo_guidelines": "CAP Protocols 2024",
"cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"],
"histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"]
}
# Initialize Medical Imaging Model
model_path = "deepseek-ai/Janus-Pro-1B"
class MedicalImagingAdapter(torch.nn.Module):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model
# Cardiac-specific projections
self.cardiac_proj = torch.nn.Linear(2048, 2048)
# Histopathology-specific projections
self.histo_proj = torch.nn.Linear(2048, 2048)
def forward(self, *args, **kwargs):
outputs = self.base_model(*args, **kwargs)
return outputs
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
# **Fix: Set legacy=False in tokenizer to use the new behavior**
vl_chat_processor.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)
# Medical Image Processing Pipelines
def preprocess_echo(image):
"""Process echocardiography images"""
img = Image.fromarray(image).convert('L') # Grayscale
return np.array(img.resize((512, 512)))
def preprocess_histo(image):
"""Process histopathology slides"""
img = Image.fromarray(image)
return np.array(img.resize((1024, 1024)))
@torch.inference_mode()
@spaces.GPU(duration=120)
def analyze_medical_case(image, clinical_context, modality):
# Preprocess based on modality
processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image)
# Create modality-specific prompt
system_prompt = f"""
Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}.
Clinical Context: {clinical_context}
"""
conversation = [{
"role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>",
"content": system_prompt,
"images": [processed_img],
}, {"role": "<|AI_Assistant|>", "content": ""}]
inputs = vl_chat_processor(
conversations=conversation,
images=[Image.fromarray(processed_img)],
force_batchify=True
).to(vl_gpt.device)
outputs = vl_gpt.generate(
inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs),
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.5
)
report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
return format_medical_report(report, modality)
def format_medical_report(text, modality):
# Structure report based on modality
sections = {
"Echo": [
("Chamber Dimensions", "LVEDD", "LVESD"),
("Valvular Function", "Aortic Valve", "Mitral Valve"),
("Hemodynamics", "E/A Ratio", "LVEF")
],
"Histo": [
("Architecture", "Gland Formation", "Stromal Pattern"),
("Cellular Features", "Nuclear Atypia", "Mitotic Count"),
("Diagnostic Impression", "Tumor Grade", "Margin Status")
]
}
formatted = f"**{modality} Analysis Report**\n\n"
for section in sections[modality]:
header = section[0]
formatted += f"### {header}\n"
for sub in section[1:]:
if sub in text:
start = text.find(sub)
end = text.find("\n\n", start)
formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n"
return formatted
# Medical Imaging Interface
with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
## Medical Imaging Analysis Platform
*Analyzes echocardiograms and histopathology slides - Research Use Only*
""")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Upload Medical Image")
modality_select = gr.Radio(
["Echo", "Histo"],
label="Image Modality",
info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides"
)
clinical_input = gr.Textbox(
label="Clinical Context",
placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'"
)
analyze_btn = gr.Button("Analyze Case", variant="primary")
with gr.Column():
report_output = gr.Markdown(label="AI Clinical Report")
# Preloaded examples
gr.Examples(
examples=[
["Evaluate LV systolic function", "case1.png", "Echo"],
["Assess mitral valve function", "case2.jpg", "Echo"],
["Analyze for malignant features", "case3.png", "Histo"],
["Evaluate tumor margins", "case4.png", "Histo"]
],
inputs=[clinical_input, image_input, modality_select],
label="Example Medical Cases"
)
# **Fixed: Removed @demo.func and used .click() correctly**
analyze_btn.click(
analyze_medical_case,
[image_input, clinical_input, modality_select],
report_output
)
demo.launch(share=True)