File size: 5,806 Bytes
39d4c85
2f9ea03
a8b15b5
4beb159
3afe8c3
2f9ea03
39d4c85
8e2bfc0
4beb159
 
 
 
aa563db
a8b15b5
3afe8c3
 
 
 
a8b15b5
 
3afe8c3
39d4c85
8e2bfc0
3afe8c3
a8b15b5
 
 
3afe8c3
 
 
 
 
a8b15b5
 
3afe8c3
a8b15b5
3afe8c3
 
a8b15b5
39d4c85
3afe8c3
8e2bfc0
39d4c85
8e2bfc0
4beb159
 
 
3afe8c3
 
 
 
 
 
 
 
 
 
a8b15b5
39d4c85
 
3afe8c3
 
 
39d4c85
3afe8c3
 
 
 
 
a8b15b5
39d4c85
3afe8c3
 
 
39d4c85
3afe8c3
39d4c85
a8b15b5
3afe8c3
39d4c85
3afe8c3
a8b15b5
3afe8c3
 
39d4c85
 
3afe8c3
39d4c85
3afe8c3
39d4c85
 
3afe8c3
 
a8b15b5
3afe8c3
 
a8b15b5
3afe8c3
 
 
 
 
 
 
 
 
 
a8b15b5
 
3afe8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa563db
3afe8c3
 
 
 
 
 
 
 
 
39d4c85
 
4beb159
 
 
 
 
 
8e2bfc0
aa563db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import gradio as gr
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from janus.models import VLChatProcessor
from PIL import Image
import spaces

# Suppress specific warnings
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Medical Imaging Analysis Configuration
MEDICAL_CONFIG = {
    "echo_guidelines": "ASE 2023 Standards",
    "histo_guidelines": "CAP Protocols 2024",
    "cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"],
    "histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"]
}

# Initialize Medical Imaging Model
model_path = "deepseek-ai/Janus-Pro-1B"

class MedicalImagingAdapter(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        # Cardiac-specific projections
        self.cardiac_proj = torch.nn.Linear(2048, 2048)
        # Histopathology-specific projections
        self.histo_proj = torch.nn.Linear(2048, 2048)
        
    def forward(self, *args, **kwargs):
        outputs = self.base_model(*args, **kwargs)
        return outputs

vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model)

if torch.cuda.is_available():
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda()

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)

# **Fix: Set legacy=False in tokenizer to use the new behavior**
vl_chat_processor.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False)

# Medical Image Processing Pipelines
def preprocess_echo(image):
    """Process echocardiography images"""
    img = Image.fromarray(image).convert('L')  # Grayscale
    return np.array(img.resize((512, 512)))

def preprocess_histo(image):
    """Process histopathology slides"""
    img = Image.fromarray(image)
    return np.array(img.resize((1024, 1024)))

@torch.inference_mode()
@spaces.GPU(duration=120)
def analyze_medical_case(image, clinical_context, modality):
    # Preprocess based on modality
    processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image)
    
    # Create modality-specific prompt
    system_prompt = f"""
    Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}.
    Clinical Context: {clinical_context}
    """
    
    conversation = [{
        "role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>",
        "content": system_prompt,
        "images": [processed_img],
    }, {"role": "<|AI_Assistant|>", "content": ""}]

    inputs = vl_chat_processor(
        conversations=conversation,
        images=[Image.fromarray(processed_img)],
        force_batchify=True
    ).to(vl_gpt.device)

    outputs = vl_gpt.generate(
        inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs),
        attention_mask=inputs.attention_mask,
        max_new_tokens=512,
        temperature=0.1,
        top_p=0.9,
        repetition_penalty=1.5
    )
    
    report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
    return format_medical_report(report, modality)

def format_medical_report(text, modality):
    # Structure report based on modality
    sections = {
        "Echo": [
            ("Chamber Dimensions", "LVEDD", "LVESD"),
            ("Valvular Function", "Aortic Valve", "Mitral Valve"),
            ("Hemodynamics", "E/A Ratio", "LVEF")
        ],
        "Histo": [
            ("Architecture", "Gland Formation", "Stromal Pattern"),
            ("Cellular Features", "Nuclear Atypia", "Mitotic Count"),
            ("Diagnostic Impression", "Tumor Grade", "Margin Status")
        ]
    }
    
    formatted = f"**{modality} Analysis Report**\n\n"
    for section in sections[modality]:
        header = section[0]
        formatted += f"### {header}\n"
        for sub in section[1:]:
            if sub in text:
                start = text.find(sub)
                end = text.find("\n\n", start)
                formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n"
    return formatted

# Medical Imaging Interface
with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    ## Medical Imaging Analysis Platform
    *Analyzes echocardiograms and histopathology slides - Research Use Only*
    """)
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Medical Image")
            modality_select = gr.Radio(
                ["Echo", "Histo"], 
                label="Image Modality",
                info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides"
            )
            clinical_input = gr.Textbox(
                label="Clinical Context",
                placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'"
            )
            analyze_btn = gr.Button("Analyze Case", variant="primary")
        
        with gr.Column():
            report_output = gr.Markdown(label="AI Clinical Report")
    
    # Preloaded examples
    gr.Examples(
        examples=[
            ["Evaluate LV systolic function", "case1.png", "Echo"],
            ["Assess mitral valve function", "case2.jpg", "Echo"],
            ["Analyze for malignant features", "case3.png", "Histo"],
            ["Evaluate tumor margins", "case4.png", "Histo"]
        ],
        inputs=[clinical_input, image_input, modality_select],
        label="Example Medical Cases"
    )

    # **Fixed: Removed @demo.func and used .click() correctly**
    analyze_btn.click(
        analyze_medical_case,
        [image_input, clinical_input, modality_select],
        report_output
    )

demo.launch(share=True)