File size: 6,385 Bytes
61663e8
 
 
 
 
90e84e7
 
 
 
 
61663e8
 
72dc238
61663e8
 
 
72dc238
 
 
 
61663e8
 
04d4545
 
61663e8
 
 
72dc238
04d4545
72dc238
61663e8
 
 
 
 
 
 
 
90e84e7
 
 
488d84d
90e84e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61663e8
 
 
2eec5d5
61663e8
 
 
 
 
af5d735
61663e8
 
 
 
 
 
 
 
 
 
 
 
 
90e84e7
 
 
 
61663e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72dc238
61663e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5790447
8b36615
 
 
5790447
 
61663e8
 
a44d8af
 
 
61663e8
 
 
 
 
 
 
 
 
 
 
8b36615
61663e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90e84e7
d0e8efe
90e84e7
61663e8
 
 
 
 
b0dbfb2
61663e8
90e84e7
61663e8
90e84e7
 
 
 
8c2fe6b
 
 
 
 
403972d
 
 
90e84e7
b0dbfb2
d0e8efe
61663e8
 
 
90e84e7
61663e8
 
90e84e7
 
61663e8
 
90e84e7
61663e8
90e84e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import gradio as gr
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, GenerationConfig
from qwen_vl_utils import process_vision_info
import torch
import requests
from ultralytics import YOLO
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import io

# ----------------------------
# MODEL LOADING (MedVLM-R1) - CPU Compatible
# ----------------------------
MODEL_PATH = 'JZPeterPan/MedVLM-R1'

# Check if CUDA is available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(MODEL_PATH)

# Move model to device
model = model.to(device)

temp_generation_config = GenerationConfig(
    max_new_tokens=1024,
    do_sample=False,
    temperature=1,
    num_return_sequences=1,
    pad_token_id=151643,
)

# ----------------------------
# YOLO MODEL LOADING
# ----------------------------
yolo_model = YOLO("best.pt")  # replace with your segmentation weights

def inference(image_path: str):
    """Runs YOLO segmentation on an image and returns the annotated image."""
    # Load image
    img = Image.open(image_path).convert("RGB")

    # Run inference
    results = yolo_model(img)

    # Plot with masks and bounding boxes
    annotated = results[0].plot()  # NumPy array (BGR)

    # Convert from BGR (OpenCV default) to RGB for matplotlib
    annotated_rgb = annotated[:, :, ::-1]

    # Convert numpy array to PIL Image
    annotated_image = Image.fromarray(annotated_rgb)
    
    return annotated_image

# ----------------------------
# API SETTINGS (DeepSeek R1)
# ----------------------------
api_key = "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"
deepseek_model = "deepseek/deepseek-r1"

# ----------------------------
# DEFAULT QUESTION
# ----------------------------
DEFAULT_QUESTION = "What abnormality is in the brain MRI and what is the location?\nA) Tumour\nB) No tumour \nC) Other"

QUESTION_TEMPLATE = """
{Question}
Your task: 
1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags. 
2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.
"""

# ----------------------------
# PIPELINE FUNCTION
# ----------------------------
def process_pipeline(image, user_question):
    if image is None or user_question.strip() == "":
        return "Please upload an image and enter a question.", None

    # Run YOLO inference and get segmented image
    segmented_image = inference(image)

    # Combine user's question with default
    combined_question = user_question.strip() + "\n\n" + DEFAULT_QUESTION

    message = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=combined_question)}
        ]
    }]

    # Prepare inputs for MedVLM
    text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(message)

    inputs = processor(
        text=text,
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    ).to(device)

    # Generate output from MedVLM
    generated_ids = model.generate(
        **inputs,
        use_cache=True,
        max_new_tokens=1024,
        do_sample=False,
        generation_config=temp_generation_config
    )

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )[0]


    # MAX_INPUT_CHARS = 50
    # if len(output_text) > MAX_INPUT_CHARS:
    #     output_text = output_text[:MAX_INPUT_CHARS] + "... [truncated]"

        
    # Send MedVLM output to DeepSeek R1
    prompt = f"""The following is a medical AI's answer to a visual question. 
The answer is about having tumour or not, focus on that mostly. 
Keep the answer precise but more structured, and helpful for a medical professional. 
If possible, make a table using the details from the original answer.

Original Answer:
{output_text}
"""

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }
    data = {
        "model": deepseek_model,
        "max_tokens": 4000,
        "messages": [
            {"role": "system", "content": "You are a highly skilled medical writer."},
            {"role": "user", "content": prompt}
        ]
    }

    response = requests.post(
        "https://openrouter.ai/api/v1/chat/completions",
        headers=headers,
        json=data
    )

    try:
        detailed_answer = response.json()["choices"][0]["message"]["content"]
    except Exception as e:
        return f"**Error from DeepSeek:** {str(e)}\n\n```\n{response.text}\n```", segmented_image

    return f"{detailed_answer}", segmented_image


# ----------------------------
# GRADIO UI
# ----------------------------
with gr.Blocks(title="Brain MRI QA") as demo:
    with gr.Row():
        # First column: input image and result image side by side
        with gr.Column():
            with gr.Row():
                image_input = gr.Image(type="filepath", label="Upload Medical Image")
                result_image = gr.Image(type="filepath", label="Upload Medical Image")  # next to input image

            question_box = gr.Textbox(
                label="Your Question about the Image",
                placeholder="Type your question here..."
            )
            with gr.Row():
                submit_btn = gr.Button("Submit")
                clear_btn = gr.Button("Clear")

        # Second column: LLM answer output
        with gr.Column():
            llm_output = gr.Markdown(label="Detailed LLM Answer")
    submit_btn.click(
        fn=process_pipeline,
        inputs=[image_input, question_box],
        outputs=[llm_output, result_image]
    )
    clear_btn.click(
        fn=lambda: ("", "", None),
        outputs=[question_box, llm_output, result_image]
    )


if __name__ == "__main__":
    demo.launch()