Faisal commited on
Commit
90e84e7
Β·
1 Parent(s): af5d735

Restore GPU version - remove CPU optimizations and restore GPU-compatible dependencies

Browse files
Files changed (3) hide show
  1. app.py +46 -12
  2. best.pt +3 -0
  3. requirements.txt +17 -0
app.py CHANGED
@@ -3,6 +3,11 @@ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, Generat
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
  import requests
 
 
 
 
 
6
 
7
  # ----------------------------
8
  # MODEL LOADING (MedVLM-R1) - CPU Compatible
@@ -31,6 +36,30 @@ temp_generation_config = GenerationConfig(
31
  pad_token_id=151643,
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ----------------------------
35
  # API SETTINGS (DeepSeek R1)
36
  # ----------------------------
@@ -54,7 +83,10 @@ Your task:
54
  # ----------------------------
55
  def process_pipeline(image, user_question):
56
  if image is None or user_question.strip() == "":
57
- return "Please upload an image and enter a question."
 
 
 
58
 
59
  # Combine user's question with default
60
  combined_question = user_question.strip() + "\n\n" + DEFAULT_QUESTION
@@ -135,9 +167,9 @@ Original Answer:
135
  try:
136
  detailed_answer = response.json()["choices"][0]["message"]["content"]
137
  except Exception as e:
138
- return f"**Error from DeepSeek:** {str(e)}\n\n```\n{response.text}\n```"
139
 
140
- return f"{detailed_answer}"
141
 
142
 
143
  # ----------------------------
@@ -145,31 +177,33 @@ Original Answer:
145
  # ----------------------------
146
  with gr.Blocks(title="Brain MRI QA") as demo:
147
  with gr.Row():
148
- # Left column
149
  with gr.Column():
150
- image_input = gr.Image(type="filepath", label="Upload Medical Image")
 
 
 
151
  question_box = gr.Textbox(
152
  label="Your Question about the Image",
153
  placeholder="Type your question here..."
154
  )
155
- # Buttons side by side
156
  with gr.Row():
157
  submit_btn = gr.Button("Submit")
158
  clear_btn = gr.Button("Clear")
159
 
160
- # Right column
161
  with gr.Column():
162
  llm_output = gr.Markdown(label="Detailed LLM Answer")
163
-
164
  submit_btn.click(
165
  fn=process_pipeline,
166
  inputs=[image_input, question_box],
167
- outputs=llm_output
168
  )
169
  clear_btn.click(
170
- fn=lambda: ("", ""),
171
- outputs=[question_box, llm_output]
172
  )
173
 
 
174
  if __name__ == "__main__":
175
- demo.launch()
 
3
  from qwen_vl_utils import process_vision_info
4
  import torch
5
  import requests
6
+ from ultralytics import YOLO
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import io
11
 
12
  # ----------------------------
13
  # MODEL LOADING (MedVLM-R1) - CPU Compatible
 
36
  pad_token_id=151643,
37
  )
38
 
39
+ # ----------------------------
40
+ # YOLO MODEL LOADING
41
+ # ----------------------------
42
+ yolo_model = YOLO("MedSegVLM_with_DeepSeek/best.pt") # replace with your segmentation weights
43
+
44
+ def inference(image_path: str):
45
+ """Runs YOLO segmentation on an image and returns the annotated image."""
46
+ # Load image
47
+ img = Image.open(image_path).convert("RGB")
48
+
49
+ # Run inference
50
+ results = yolo_model(img)
51
+
52
+ # Plot with masks and bounding boxes
53
+ annotated = results[0].plot() # NumPy array (BGR)
54
+
55
+ # Convert from BGR (OpenCV default) to RGB for matplotlib
56
+ annotated_rgb = annotated[:, :, ::-1]
57
+
58
+ # Convert numpy array to PIL Image
59
+ annotated_image = Image.fromarray(annotated_rgb)
60
+
61
+ return annotated_image
62
+
63
  # ----------------------------
64
  # API SETTINGS (DeepSeek R1)
65
  # ----------------------------
 
83
  # ----------------------------
84
  def process_pipeline(image, user_question):
85
  if image is None or user_question.strip() == "":
86
+ return "Please upload an image and enter a question.", None
87
+
88
+ # Run YOLO inference and get segmented image
89
+ segmented_image = inference(image)
90
 
91
  # Combine user's question with default
92
  combined_question = user_question.strip() + "\n\n" + DEFAULT_QUESTION
 
167
  try:
168
  detailed_answer = response.json()["choices"][0]["message"]["content"]
169
  except Exception as e:
170
+ return f"**Error from DeepSeek:** {str(e)}\n\n```\n{response.text}\n```", segmented_image
171
 
172
+ return f"{detailed_answer}", segmented_image
173
 
174
 
175
  # ----------------------------
 
177
  # ----------------------------
178
  with gr.Blocks(title="Brain MRI QA") as demo:
179
  with gr.Row():
180
+ # First column: input image and result image side by side
181
  with gr.Column():
182
+ with gr.Row():
183
+ image_input = gr.Image(type="filepath", label="Upload Medical Image")
184
+ result_image = gr.Image(type="filepath", label="Upload Medical Image") # next to input image
185
+
186
  question_box = gr.Textbox(
187
  label="Your Question about the Image",
188
  placeholder="Type your question here..."
189
  )
 
190
  with gr.Row():
191
  submit_btn = gr.Button("Submit")
192
  clear_btn = gr.Button("Clear")
193
 
194
+ # Second column: LLM answer output
195
  with gr.Column():
196
  llm_output = gr.Markdown(label="Detailed LLM Answer")
 
197
  submit_btn.click(
198
  fn=process_pipeline,
199
  inputs=[image_input, question_box],
200
+ outputs=[llm_output, result_image]
201
  )
202
  clear_btn.click(
203
+ fn=lambda: ("", "", None),
204
+ outputs=[question_box, llm_output, result_image]
205
  )
206
 
207
+
208
  if __name__ == "__main__":
209
+ demo.launch()
best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:148e1be81643a7f594d2dd626a7a27bab6b4fd03f197dde4e1b2005a578e37a3
3
+ size 5468627
requirements.txt CHANGED
@@ -1,3 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  gradio==5.42.0
2
  transformers>=4.40.0
3
  torch>=2.0.0
@@ -11,3 +25,6 @@ numpy>=1.24.0
11
  scipy>=1.10.0
12
  qwen-vl-utils
13
  ipython>=8.0.0
 
 
 
 
1
+ # gradio==5.42.0
2
+ # transformers>=4.40.0
3
+ # torch>=2.0.0
4
+ # torchvision>=0.15.0
5
+ # requests>=2.31.0
6
+ # Pillow>=10.0.0
7
+ # accelerate>=0.20.0
8
+ # safetensors>=0.3.0
9
+ # tokenizers>=0.15.0
10
+ # numpy>=1.24.0
11
+ # scipy>=1.10.0
12
+ # qwen-vl-utils
13
+ # ipython>=8.0.0
14
+
15
  gradio==5.42.0
16
  transformers>=4.40.0
17
  torch>=2.0.0
 
25
  scipy>=1.10.0
26
  qwen-vl-utils
27
  ipython>=8.0.0
28
+ ultralytics>=8.0.0
29
+ matplotlib>=3.5.0
30
+ opencv-python>=4.5.0