Spaces:
Runtime error
Runtime error
Faisal
Restore GPU version - remove CPU optimizations and restore GPU-compatible dependencies
b0dbfb2
import gradio as gr | |
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, GenerationConfig | |
from qwen_vl_utils import process_vision_info | |
import torch | |
import requests | |
from ultralytics import YOLO | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import io | |
# ---------------------------- | |
# MODEL LOADING (MedVLM-R1) - CPU Compatible | |
# ---------------------------- | |
MODEL_PATH = 'JZPeterPan/MedVLM-R1' | |
# Check if CUDA is available, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
model = Qwen2VLForConditionalGeneration.from_pretrained( | |
MODEL_PATH, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
) | |
processor = AutoProcessor.from_pretrained(MODEL_PATH) | |
# Move model to device | |
model = model.to(device) | |
temp_generation_config = GenerationConfig( | |
max_new_tokens=1024, | |
do_sample=False, | |
temperature=1, | |
num_return_sequences=1, | |
pad_token_id=151643, | |
) | |
# ---------------------------- | |
# YOLO MODEL LOADING | |
# ---------------------------- | |
yolo_model = YOLO("best.pt") # replace with your segmentation weights | |
def inference(image_path: str): | |
"""Runs YOLO segmentation on an image and returns the annotated image.""" | |
# Load image | |
img = Image.open(image_path).convert("RGB") | |
# Run inference | |
results = yolo_model(img) | |
# Plot with masks and bounding boxes | |
annotated = results[0].plot() # NumPy array (BGR) | |
# Convert from BGR (OpenCV default) to RGB for matplotlib | |
annotated_rgb = annotated[:, :, ::-1] | |
# Convert numpy array to PIL Image | |
annotated_image = Image.fromarray(annotated_rgb) | |
return annotated_image | |
# ---------------------------- | |
# API SETTINGS (DeepSeek R1) | |
# ---------------------------- | |
api_key = "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" | |
deepseek_model = "deepseek/deepseek-r1" | |
# ---------------------------- | |
# DEFAULT QUESTION | |
# ---------------------------- | |
DEFAULT_QUESTION = "What abnormality is in the brain MRI and what is the location?\nA) Tumour\nB) No tumour \nC) Other" | |
QUESTION_TEMPLATE = """ | |
{Question} | |
Your task: | |
1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags. | |
2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags. | |
""" | |
# ---------------------------- | |
# PIPELINE FUNCTION | |
# ---------------------------- | |
def process_pipeline(image, user_question): | |
if image is None or user_question.strip() == "": | |
return "Please upload an image and enter a question.", None | |
# Run YOLO inference and get segmented image | |
segmented_image = inference(image) | |
# Combine user's question with default | |
combined_question = user_question.strip() + "\n\n" + DEFAULT_QUESTION | |
message = [{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=combined_question)} | |
] | |
}] | |
# Prepare inputs for MedVLM | |
text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) | |
image_inputs, video_inputs = process_vision_info(message) | |
inputs = processor( | |
text=text, | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(device) | |
# Generate output from MedVLM | |
generated_ids = model.generate( | |
**inputs, | |
use_cache=True, | |
max_new_tokens=1024, | |
do_sample=False, | |
generation_config=temp_generation_config | |
) | |
generated_ids_trimmed = [ | |
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
] | |
output_text = processor.batch_decode( | |
generated_ids_trimmed, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
# MAX_INPUT_CHARS = 50 | |
# if len(output_text) > MAX_INPUT_CHARS: | |
# output_text = output_text[:MAX_INPUT_CHARS] + "... [truncated]" | |
# Send MedVLM output to DeepSeek R1 | |
prompt = f"""The following is a medical AI's answer to a visual question. | |
The answer is about having tumour or not, focus on that mostly. | |
Keep the answer precise but more structured, and helpful for a medical professional. | |
If possible, make a table using the details from the original answer. | |
Original Answer: | |
{output_text} | |
""" | |
headers = { | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": deepseek_model, | |
"max_tokens": 4000, | |
"messages": [ | |
{"role": "system", "content": "You are a highly skilled medical writer."}, | |
{"role": "user", "content": prompt} | |
] | |
} | |
response = requests.post( | |
"https://openrouter.ai/api/v1/chat/completions", | |
headers=headers, | |
json=data | |
) | |
try: | |
detailed_answer = response.json()["choices"][0]["message"]["content"] | |
except Exception as e: | |
return f"**Error from DeepSeek:** {str(e)}\n\n```\n{response.text}\n```", segmented_image | |
return f"{detailed_answer}", segmented_image | |
# ---------------------------- | |
# GRADIO UI | |
# ---------------------------- | |
with gr.Blocks(title="Brain MRI QA") as demo: | |
with gr.Row(): | |
# First column: input image and result image side by side | |
with gr.Column(): | |
with gr.Row(): | |
image_input = gr.Image(type="filepath", label="Upload Medical Image") | |
result_image = gr.Image(type="filepath", label="Upload Medical Image") # next to input image | |
question_box = gr.Textbox( | |
label="Your Question about the Image", | |
placeholder="Type your question here..." | |
) | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear") | |
# Second column: LLM answer output | |
with gr.Column(): | |
llm_output = gr.Markdown(label="Detailed LLM Answer") | |
submit_btn.click( | |
fn=process_pipeline, | |
inputs=[image_input, question_box], | |
outputs=[llm_output, result_image] | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", None), | |
outputs=[question_box, llm_output, result_image] | |
) | |
if __name__ == "__main__": | |
demo.launch() |