drlon commited on
Commit
9aefa45
·
1 Parent(s): 86938ce

new app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -145
app.py CHANGED
@@ -28,21 +28,20 @@ if not logger.handlers:
28
  handler = logging.StreamHandler()
29
  handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
30
  logger.addHandler(handler)
31
- logger.warning("here")
32
 
33
  # Define repository and local directory
34
  repo_id = "microsoft/OmniParser-v2.0" # HF repo
35
  local_dir = "weights" # Target local directory
 
 
36
 
37
  som_generator = MarkHelper()
38
  magma_som_prompt = "<image>\nIn this view I need to click a button to \"{}\"? Provide the coordinates and the mark index of the containing bounding box if applicable."
39
  magma_qa_prompt = "<image>\n{} Answer the question briefly."
40
  magma_model_id = "microsoft/Magma-8B"
41
- magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
42
- magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=torch.bfloat16)
43
- magam_model.to("cuda")
44
-
45
- logger.warning(f"The repository is downloading to: {local_dir}")
46
 
47
  # Download the entire repository
48
  snapshot_download(repo_id=repo_id, local_dir=local_dir)
@@ -58,27 +57,14 @@ MARKDOWN = """
58
  <div align="center">
59
  <h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
60
 
61
- [Jianwei Yang](https://jwyang.github.io/)<sup>*</sup><sup>1</sup><sup>†</sup>&nbsp;
62
- [Reuben Tan](https://cs-people.bu.edu/rxtan/)<sup>1</sup><sup>†</sup>&nbsp;
63
- [Qianhui Wu](https://qianhuiwu.github.io/)<sup>1</sup><sup>†</sup>&nbsp;
64
- [Ruijie Zheng](https://ruijiezheng.com/)<sup>2</sup><sup>‡</sup>&nbsp;
65
- [Baolin Peng](https://scholar.google.com/citations?user=u1CNjgwAAAAJ&hl=en&oi=ao)<sup>1</sup><sup>‡</sup>&nbsp;
66
- [Yongyuan Liang](https://cheryyunl.github.io)<sup>2</sup><sup>‡</sup>
67
- [Yu Gu](https://users.umiacs.umd.edu/~hal/)<sup>1</sup>&nbsp;
68
- [Mu Cai](https://pages.cs.wisc.edu/~mucai/)<sup>3</sup>&nbsp;
69
- [Seonghyeon Ye](https://seonghyeonye.github.io/)<sup>4</sup>&nbsp;
70
- [Joel Jang](https://joeljang.github.io/)<sup>5</sup>&nbsp;
71
- [Yuquan Deng](https://scholar.google.com/citations?user=LTC0Q6YAAAAJ&hl=en)<sup>5</sup>&nbsp;
72
- [Lars Liden](https://sites.google.com/site/larsliden)<sup>1</sup>&nbsp;
73
- [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/)<sup>1</sup><sup>▽</sup>
74
-
75
- <sup>1</sup> Microsoft Research; <sup>2</sup> University of Maryland; <sup>3</sup> University of Wisconsin-Madison; <sup>4</sup> KAIST; <sup>5</sup> University of Washington
76
-
77
- <sup>*</sup> Project lead <sup>†</sup> First authors <sup>‡</sup> Second authors <sup>▽</sup> Leadership
78
-
79
  \[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] &nbsp; \[[Project Page](https://microsoft.github.io/Magma/)\] &nbsp; \[[Github Repo](https://github.com/microsoft/Magma)\] &nbsp; \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\] &nbsp;
80
 
81
- This demo is powered by [Gradio](https://gradio.app/) and uses OmniParserv2 to generate Set-of-Mark prompts.
 
 
 
 
 
82
  </div>
83
  """
84
 
@@ -86,7 +72,6 @@ DEVICE = torch.device('cuda')
86
 
87
  @spaces.GPU
88
  @torch.inference_mode()
89
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
90
  def get_som_response(instruction, image_som):
91
  prompt = magma_som_prompt.format(instruction)
92
  if magam_model.config.mm_use_image_start_end:
@@ -101,23 +86,10 @@ def get_som_response(instruction, image_som):
101
  add_generation_prompt=True
102
  )
103
 
104
- # inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
105
- # # with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
106
- # # inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
107
- # # inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
108
- # # logger.warning(inputs['pixel_values'].dtype)
109
- # # # inputs = inputs.to("cuda")
110
- # inputs = inputs.to("cuda", dtype=torch.bfloat16)
111
-
112
  inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
113
- inputs['pixel_values'] = inputs['pixel_values'].to("cuda", dtype=torch.bfloat16)
114
  inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
115
- inputs['image_sizes'] = inputs['image_sizes'].to("cuda")
116
-
117
- # 处理其他可能的输入
118
- for key in inputs:
119
- if key not in ['pixel_values', 'image_sizes'] and torch.is_tensor(inputs[key]):
120
- inputs[key] = inputs[key].to("cuda")
121
 
122
  magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
123
  with torch.inference_mode():
@@ -137,7 +109,6 @@ def get_som_response(instruction, image_som):
137
 
138
  @spaces.GPU
139
  @torch.inference_mode()
140
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
141
  def get_qa_response(instruction, image):
142
  prompt = magma_qa_prompt.format(instruction)
143
  if magam_model.config.mm_use_image_start_end:
@@ -152,12 +123,10 @@ def get_qa_response(instruction, image):
152
  add_generation_prompt=True
153
  )
154
 
155
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
156
- inputs = magma_processor(images=[image], texts=prompt, return_tensors="pt")
157
- inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
158
- inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
159
- # inputs = inputs.to("cuda")
160
- inputs = inputs.to("cuda", dtype=torch.bfloat16)
161
 
162
  magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
163
  with torch.inference_mode():
@@ -177,7 +146,7 @@ def get_qa_response(instruction, image):
177
 
178
  @spaces.GPU
179
  @torch.inference_mode()
180
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
181
  def process(
182
  image_input,
183
  box_threshold,
@@ -187,107 +156,99 @@ def process(
187
  instruction,
188
  ) -> Optional[Image.Image]:
189
 
190
- logger.warning("Starting processing.")
191
- try:
192
- # image_save_path = 'imgs/saved_image_demo.png'
193
- # image_input.save(image_save_path)
194
- # image = Image.open(image_save_path)
195
- box_overlay_ratio = image_input.size[0] / 3200
196
- draw_bbox_config = {
197
- 'text_scale': 0.8 * box_overlay_ratio,
198
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
199
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
200
- 'thickness': max(int(3 * box_overlay_ratio), 1),
201
- }
202
-
203
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
204
- text, ocr_bbox = ocr_bbox_rslt
205
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
206
- parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
207
-
208
- if len(instruction) == 0:
209
- logger.warning('finish processing')
210
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
211
- return image, str(parsed_content_list)
212
-
213
- elif instruction.startswith('Q:'):
214
- response = get_qa_response(instruction, image_input)
215
- return image_input, response
216
-
217
- # parsed_content_list = str(parsed_content_list)
218
- # convert xywh to yxhw
219
- label_coordinates_yxhw = {}
220
- for key, val in label_coordinates.items():
221
- if val[2] < 0 or val[3] < 0:
222
- continue
223
- label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
224
- image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
225
-
226
- # convert xywh to xyxy
227
- for key, val in label_coordinates.items():
228
- label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
229
-
230
- # normalize label_coordinates
231
- for key, val in label_coordinates.items():
232
- label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
233
-
234
- magma_response = get_som_response(instruction, image_som)
235
- logger.warning("magma repsonse: ", magma_response)
236
-
237
- # map magma_response into the mark id
238
- mark_id = extract_mark_id(magma_response)
239
- if mark_id is not None:
240
- if str(mark_id) in label_coordinates:
241
- bbox_for_mark = label_coordinates[str(mark_id)]
242
- else:
243
- bbox_for_mark = None
244
  else:
245
  bbox_for_mark = None
246
-
247
- if bbox_for_mark:
248
- # draw bbox_for_mark on the image
249
- image_som = plot_boxes_with_marks(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  image_input,
251
- [label_coordinates_yxhw[str(mark_id)]],
252
- som_generator,
253
  edgecolor=(255,127,111),
254
- alpha=30,
255
- fn_save=None,
256
  normalized_to_pixel=False,
257
  add_mark=False
258
  )
259
- else:
260
- try:
261
- if 'box' in magma_response:
262
- pred_bbox = extract_bbox(magma_response)
263
- click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
264
- click_point = [item / 1000 for item in click_point]
265
- else:
266
- click_point = pred_2_point(magma_response)
267
- # de-normalize click_point (width, height)
268
- click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
269
-
270
- image_som = plot_circles_with_marks(
271
- image_input,
272
- [click_point],
273
- som_generator,
274
- edgecolor=(255,127,111),
275
- linewidth=3,
276
- fn_save=None,
277
- normalized_to_pixel=False,
278
- add_mark=False
279
- )
280
- except:
281
- image_som = image_input
282
-
283
- logger.warning("finish processing")
284
- return image_som, str(parsed_content_list)
285
- except Exception as e:
286
- error_message = traceback.format_exc()
287
- logger.warning(error_message)
288
- return image_input, error_message
289
-
290
- logger.warning("Starting App.")
291
  with gr.Blocks() as demo:
292
  gr.Markdown(MARKDOWN)
293
  with gr.Row():
@@ -326,6 +287,6 @@ with gr.Blocks() as demo:
326
  outputs=[image_output_component, text_output_component]
327
  )
328
 
329
- # demo.launch(debug=True, show_error=True, share=True)
330
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
331
- demo.queue().launch(share=False)
 
28
  handler = logging.StreamHandler()
29
  handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
30
  logger.addHandler(handler)
 
31
 
32
  # Define repository and local directory
33
  repo_id = "microsoft/OmniParser-v2.0" # HF repo
34
  local_dir = "weights" # Target local directory
35
+ dtype = torch.bfloat16
36
+ DEVICE = torch.device('cuda')
37
 
38
  som_generator = MarkHelper()
39
  magma_som_prompt = "<image>\nIn this view I need to click a button to \"{}\"? Provide the coordinates and the mark index of the containing bounding box if applicable."
40
  magma_qa_prompt = "<image>\n{} Answer the question briefly."
41
  magma_model_id = "microsoft/Magma-8B"
42
+ magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True, torch_dtype=dtype)
43
+ magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
44
+ magam_model.to(DEVICE)
 
 
45
 
46
  # Download the entire repository
47
  snapshot_download(repo_id=repo_id, local_dir=local_dir)
 
57
  <div align="center">
58
  <h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  \[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] &nbsp; \[[Project Page](https://microsoft.github.io/Magma/)\] &nbsp; \[[Github Repo](https://github.com/microsoft/Magma)\] &nbsp; \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\] &nbsp;
61
 
62
+ This demo is powered by [Gradio](https://gradio.app/) and uses [OmniParserv2](https://github.com/microsoft/OmniParser) to generate [Set-of-Mark prompts](https://github.com/microsoft/SoM).
63
+
64
+ The demo supports three modes:
65
+ 1. Empty text inut: it downgrades to an OmniParser demo.
66
+ 2. Text input starting with "Q:": it leads to a visual question answering demo.
67
+ 3. Text input for UI navigation: it leads to a UI navigation demo.
68
  </div>
69
  """
70
 
 
72
 
73
  @spaces.GPU
74
  @torch.inference_mode()
 
75
  def get_som_response(instruction, image_som):
76
  prompt = magma_som_prompt.format(instruction)
77
  if magam_model.config.mm_use_image_start_end:
 
86
  add_generation_prompt=True
87
  )
88
 
 
 
 
 
 
 
 
 
89
  inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
90
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
91
  inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
92
+ inputs = inputs.to(dtype).to(DEVICE)
 
 
 
 
 
93
 
94
  magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
95
  with torch.inference_mode():
 
109
 
110
  @spaces.GPU
111
  @torch.inference_mode()
 
112
  def get_qa_response(instruction, image):
113
  prompt = magma_qa_prompt.format(instruction)
114
  if magam_model.config.mm_use_image_start_end:
 
123
  add_generation_prompt=True
124
  )
125
 
126
+ inputs = magma_processor(images=[image], texts=prompt, return_tensors="pt")
127
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0)
128
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
129
+ inputs = inputs.to(dtype).to(DEVICE)
 
 
130
 
131
  magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
132
  with torch.inference_mode():
 
146
 
147
  @spaces.GPU
148
  @torch.inference_mode()
149
+ # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
150
  def process(
151
  image_input,
152
  box_threshold,
 
156
  instruction,
157
  ) -> Optional[Image.Image]:
158
 
159
+ # image_save_path = 'imgs/saved_image_demo.png'
160
+ # image_input.save(image_save_path)
161
+ # image = Image.open(image_save_path)
162
+ box_overlay_ratio = image_input.size[0] / 3200
163
+ draw_bbox_config = {
164
+ 'text_scale': 0.8 * box_overlay_ratio,
165
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
166
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
167
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
168
+ }
169
+
170
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
171
+ text, ocr_bbox = ocr_bbox_rslt
172
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
173
+ parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
174
+
175
+ if len(instruction) == 0:
176
+ logger.warning('finish processing')
177
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
178
+ return image, str(parsed_content_list)
179
+
180
+ elif instruction.startswith('Q:'):
181
+ response = get_qa_response(instruction, image_input)
182
+ return image_input, response
183
+
184
+ # parsed_content_list = str(parsed_content_list)
185
+ # convert xywh to yxhw
186
+ label_coordinates_yxhw = {}
187
+ for key, val in label_coordinates.items():
188
+ if val[2] < 0 or val[3] < 0:
189
+ continue
190
+ label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
191
+ image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
192
+
193
+ # convert xywh to xyxy
194
+ for key, val in label_coordinates.items():
195
+ label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
196
+
197
+ # normalize label_coordinates
198
+ for key, val in label_coordinates.items():
199
+ label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
200
+
201
+ magma_response = get_som_response(instruction, image_som)
202
+ logger.warning("magma repsonse: ", magma_response)
203
+
204
+ # map magma_response into the mark id
205
+ mark_id = extract_mark_id(magma_response)
206
+ if mark_id is not None:
207
+ if str(mark_id) in label_coordinates:
208
+ bbox_for_mark = label_coordinates[str(mark_id)]
 
 
 
 
209
  else:
210
  bbox_for_mark = None
211
+ else:
212
+ bbox_for_mark = None
213
+
214
+ if bbox_for_mark:
215
+ # draw bbox_for_mark on the image
216
+ image_som = plot_boxes_with_marks(
217
+ image_input,
218
+ [label_coordinates_yxhw[str(mark_id)]],
219
+ som_generator,
220
+ edgecolor=(255,127,111),
221
+ alpha=30,
222
+ fn_save=None,
223
+ normalized_to_pixel=False,
224
+ add_mark=False
225
+ )
226
+ else:
227
+ try:
228
+ if 'box' in magma_response:
229
+ pred_bbox = extract_bbox(magma_response)
230
+ click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
231
+ click_point = [item / 1000 for item in click_point]
232
+ else:
233
+ click_point = pred_2_point(magma_response)
234
+ # de-normalize click_point (width, height)
235
+ click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
236
+
237
+ image_som = plot_circles_with_marks(
238
  image_input,
239
+ [click_point],
240
+ som_generator,
241
  edgecolor=(255,127,111),
242
+ linewidth=3,
243
+ fn_save=None,
244
  normalized_to_pixel=False,
245
  add_mark=False
246
  )
247
+ except:
248
+ image_som = image_input
249
+
250
+ return image_som, str(parsed_content_list)
251
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  with gr.Blocks() as demo:
253
  gr.Markdown(MARKDOWN)
254
  with gr.Row():
 
287
  outputs=[image_output_component, text_output_component]
288
  )
289
 
290
+ # demo.launch(debug=False, show_error=True, share=True)
291
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
292
+ demo.queue().launch(share=False)