Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from janus.models import VLChatProcessor | |
from PIL import Image | |
import spaces | |
# Suppress specific warnings | |
import warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
# Medical Imaging Analysis Configuration | |
MEDICAL_CONFIG = { | |
"echo_guidelines": "ASE 2023 Standards", | |
"histo_guidelines": "CAP Protocols 2024", | |
"cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"], | |
"histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"] | |
} | |
# Initialize Medical Imaging Model | |
model_path = "deepseek-ai/Janus-Pro-1B" | |
class MedicalImagingAdapter(torch.nn.Module): | |
def __init__(self, base_model): | |
super().__init__() | |
self.base_model = base_model | |
# Cardiac-specific projections | |
self.cardiac_proj = torch.nn.Linear(2048, 2048) | |
# Histopathology-specific projections | |
self.histo_proj = torch.nn.Linear(2048, 2048) | |
def forward(self, *args, **kwargs): | |
outputs = self.base_model(*args, **kwargs) | |
return outputs | |
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) | |
vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model) | |
if torch.cuda.is_available(): | |
vl_gpt = vl_gpt.to(torch.bfloat16).cuda() | |
vl_chat_processor = VLChatProcessor.from_pretrained(model_path) | |
# **Fix: Set legacy=False in tokenizer to use the new behavior** | |
vl_chat_processor.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) | |
# Medical Image Processing Pipelines | |
def preprocess_echo(image): | |
"""Process echocardiography images""" | |
img = Image.fromarray(image).convert('L') # Grayscale | |
return np.array(img.resize((512, 512))) | |
def preprocess_histo(image): | |
"""Process histopathology slides""" | |
img = Image.fromarray(image) | |
return np.array(img.resize((1024, 1024))) | |
def analyze_medical_case(image, clinical_context, modality): | |
# Preprocess based on modality | |
processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image) | |
# Create modality-specific prompt | |
system_prompt = f""" | |
Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}. | |
Clinical Context: {clinical_context} | |
""" | |
conversation = [{ | |
"role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>", | |
"content": system_prompt, | |
"images": [processed_img], | |
}, {"role": "<|AI_Assistant|>", "content": ""}] | |
inputs = vl_chat_processor( | |
conversations=conversation, | |
images=[Image.fromarray(processed_img)], | |
force_batchify=True | |
).to(vl_gpt.device) | |
outputs = vl_gpt.generate( | |
inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs), | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.1, | |
top_p=0.9, | |
repetition_penalty=1.5 | |
) | |
report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) | |
return format_medical_report(report, modality) | |
def format_medical_report(text, modality): | |
# Structure report based on modality | |
sections = { | |
"Echo": [ | |
("Chamber Dimensions", "LVEDD", "LVESD"), | |
("Valvular Function", "Aortic Valve", "Mitral Valve"), | |
("Hemodynamics", "E/A Ratio", "LVEF") | |
], | |
"Histo": [ | |
("Architecture", "Gland Formation", "Stromal Pattern"), | |
("Cellular Features", "Nuclear Atypia", "Mitotic Count"), | |
("Diagnostic Impression", "Tumor Grade", "Margin Status") | |
] | |
} | |
formatted = f"**{modality} Analysis Report**\n\n" | |
for section in sections[modality]: | |
header = section[0] | |
formatted += f"### {header}\n" | |
for sub in section[1:]: | |
if sub in text: | |
start = text.find(sub) | |
end = text.find("\n\n", start) | |
formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n" | |
return formatted | |
# Medical Imaging Interface | |
with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
## Medical Imaging Analysis Platform | |
*Analyzes echocardiograms and histopathology slides - Research Use Only* | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(label="Upload Medical Image") | |
modality_select = gr.Radio( | |
["Echo", "Histo"], | |
label="Image Modality", | |
info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides" | |
) | |
clinical_input = gr.Textbox( | |
label="Clinical Context", | |
placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'" | |
) | |
analyze_btn = gr.Button("Analyze Case", variant="primary") | |
with gr.Column(): | |
report_output = gr.Markdown(label="AI Clinical Report") | |
# Preloaded examples | |
gr.Examples( | |
examples=[ | |
["Evaluate LV systolic function", "case1.png", "Echo"], | |
["Assess mitral valve function", "case2.jpg", "Echo"], | |
["Analyze for malignant features", "case3.png", "Histo"], | |
["Evaluate tumor margins", "case4.png", "Histo"] | |
], | |
inputs=[clinical_input, image_input, modality_select], | |
label="Example Medical Cases" | |
) | |
# **Fixed: Removed @demo.func and used .click() correctly** | |
analyze_btn.click( | |
analyze_medical_case, | |
[image_input, clinical_input, modality_select], | |
report_output | |
) | |
demo.launch(share=True) | |