drlon commited on
Commit
f998db7
·
1 Parent(s): 2adceb1
Files changed (2) hide show
  1. app.py +1 -265
  2. app_1.py +317 -0
app.py CHANGED
@@ -3,54 +3,18 @@ import logging
3
  from typing import Optional
4
  import spaces
5
  import gradio as gr
6
- import numpy as np
7
- import torch
8
- from PIL import Image
9
- import io
10
- import re
11
-
12
- import base64, os
13
- from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
14
- from util.som import MarkHelper, plot_boxes_with_marks, plot_circles_with_marks
15
- from util.process_utils import pred_2_point, extract_bbox, extract_mark_id
16
-
17
- import torch
18
- from PIL import Image
19
-
20
- from huggingface_hub import snapshot_download
21
- import torch
22
- from transformers import AutoModelForCausalLM
23
- from transformers import AutoProcessor
24
 
25
  logger = logging.getLogger(__name__)
26
  logger.setLevel(logging.WARNING)
27
  handler = logging.StreamHandler()
28
  logger.addHandler(handler)
29
 
30
- # Define repository and local directory
31
- repo_id = "microsoft/OmniParser-v2.0" # HF repo
32
- local_dir = "weights" # Target local directory
33
-
34
- som_generator = MarkHelper()
35
- magma_som_prompt = "<image>\nIn this view I need to click a button to \"{}\"? Provide the coordinates and the mark index of the containing bounding box if applicable."
36
- magma_qa_prompt = "<image>\n{} Answer the question briefly."
37
- magma_model_id = "microsoft/Magma-8B"
38
- magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True)
39
- magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
40
- magam_model.to("cuda")
41
 
42
  logger.warning(f"The repository is downloading to: {local_dir}")
43
 
44
- # Download the entire repository
45
- snapshot_download(repo_id=repo_id, local_dir=local_dir)
46
-
47
  logger.warning(f"Repository downloaded to: {local_dir}")
48
 
49
 
50
- yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
51
- caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
52
- # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
53
-
54
  MARKDOWN = """
55
  <div align="center">
56
  <h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
@@ -79,238 +43,10 @@ This demo is powered by [Gradio](https://gradio.app/) and uses OmniParserv2 to g
79
  </div>
80
  """
81
 
82
- DEVICE = torch.device('cuda')
83
-
84
- @spaces.GPU
85
- @torch.inference_mode()
86
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
87
- def get_som_response(instruction, image_som):
88
- prompt = magma_som_prompt.format(instruction)
89
- if magam_model.config.mm_use_image_start_end:
90
- qs = prompt.replace('<image>', '<image_start><image><image_end>')
91
- else:
92
- qs = prompt
93
- convs = [{"role": "user", "content": qs}]
94
- convs = [{"role": "system", "content": "You are agent that can see, talk and act."}] + convs
95
- prompt = magma_processor.tokenizer.apply_chat_template(
96
- convs,
97
- tokenize=False,
98
- add_generation_prompt=True
99
- )
100
-
101
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
102
- inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
103
- inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
104
- inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
105
- # inputs = inputs.to("cuda")
106
- inputs = inputs.to("cuda", dtype=torch.bfloat16)
107
-
108
- magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
109
- with torch.inference_mode():
110
- output_ids = magam_model.generate(
111
- **inputs,
112
- temperature=0.0,
113
- do_sample=False,
114
- num_beams=1,
115
- max_new_tokens=128,
116
- use_cache=True
117
- )
118
-
119
- prompt_decoded = magma_processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
120
- response = magma_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
121
- response = response.replace(prompt_decoded, '').strip()
122
- return response
123
-
124
- @spaces.GPU
125
- @torch.inference_mode()
126
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
127
- def get_qa_response(instruction, image):
128
- prompt = magma_qa_prompt.format(instruction)
129
- if magam_model.config.mm_use_image_start_end:
130
- qs = prompt.replace('<image>', '<image_start><image><image_end>')
131
- else:
132
- qs = prompt
133
- convs = [{"role": "user", "content": qs}]
134
- convs = [{"role": "system", "content": "You are agent that can see, talk and act."}] + convs
135
- prompt = magma_processor.tokenizer.apply_chat_template(
136
- convs,
137
- tokenize=False,
138
- add_generation_prompt=True
139
- )
140
-
141
- with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
142
- inputs = magma_processor(images=[image], texts=prompt, return_tensors="pt")
143
- inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
144
- inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
145
- # inputs = inputs.to("cuda")
146
- inputs = inputs.to("cuda", dtype=torch.bfloat16)
147
-
148
- magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
149
- with torch.inference_mode():
150
- output_ids = magam_model.generate(
151
- **inputs,
152
- temperature=0.0,
153
- do_sample=False,
154
- num_beams=1,
155
- max_new_tokens=128,
156
- use_cache=True
157
- )
158
-
159
- prompt_decoded = magma_processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
160
- response = magma_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
161
- response = response.replace(prompt_decoded, '').strip()
162
- return response
163
-
164
- @spaces.GPU
165
- @torch.inference_mode()
166
- @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
167
- def process(
168
- image_input,
169
- box_threshold,
170
- iou_threshold,
171
- use_paddleocr,
172
- imgsz,
173
- instruction,
174
- ) -> Optional[Image.Image]:
175
-
176
- logger.warning("Starting processing.")
177
- try:
178
- # image_save_path = 'imgs/saved_image_demo.png'
179
- # image_input.save(image_save_path)
180
- # image = Image.open(image_save_path)
181
- box_overlay_ratio = image_input.size[0] / 3200
182
- draw_bbox_config = {
183
- 'text_scale': 0.8 * box_overlay_ratio,
184
- 'text_thickness': max(int(2 * box_overlay_ratio), 1),
185
- 'text_padding': max(int(3 * box_overlay_ratio), 1),
186
- 'thickness': max(int(3 * box_overlay_ratio), 1),
187
- }
188
-
189
- ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
190
- text, ocr_bbox = ocr_bbox_rslt
191
- dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
192
- parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
193
-
194
- if len(instruction) == 0:
195
- logger.warning('finish processing')
196
- image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
197
- return image, str(parsed_content_list)
198
-
199
- elif instruction.startswith('Q:'):
200
- response = get_qa_response(instruction, image_input)
201
- return image_input, response
202
-
203
- # parsed_content_list = str(parsed_content_list)
204
- # convert xywh to yxhw
205
- label_coordinates_yxhw = {}
206
- for key, val in label_coordinates.items():
207
- if val[2] < 0 or val[3] < 0:
208
- continue
209
- label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
210
- image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
211
-
212
- # convert xywh to xyxy
213
- for key, val in label_coordinates.items():
214
- label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
215
-
216
- # normalize label_coordinates
217
- for key, val in label_coordinates.items():
218
- label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
219
-
220
- magma_response = get_som_response(instruction, image_som)
221
- logger.warning("magma repsonse: ", magma_response)
222
-
223
- # map magma_response into the mark id
224
- mark_id = extract_mark_id(magma_response)
225
- if mark_id is not None:
226
- if str(mark_id) in label_coordinates:
227
- bbox_for_mark = label_coordinates[str(mark_id)]
228
- else:
229
- bbox_for_mark = None
230
- else:
231
- bbox_for_mark = None
232
-
233
- if bbox_for_mark:
234
- # draw bbox_for_mark on the image
235
- image_som = plot_boxes_with_marks(
236
- image_input,
237
- [label_coordinates_yxhw[str(mark_id)]],
238
- som_generator,
239
- edgecolor=(255,127,111),
240
- alpha=30,
241
- fn_save=None,
242
- normalized_to_pixel=False,
243
- add_mark=False
244
- )
245
- else:
246
- try:
247
- if 'box' in magma_response:
248
- pred_bbox = extract_bbox(magma_response)
249
- click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
250
- click_point = [item / 1000 for item in click_point]
251
- else:
252
- click_point = pred_2_point(magma_response)
253
- # de-normalize click_point (width, height)
254
- click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
255
-
256
- image_som = plot_circles_with_marks(
257
- image_input,
258
- [click_point],
259
- som_generator,
260
- edgecolor=(255,127,111),
261
- linewidth=3,
262
- fn_save=None,
263
- normalized_to_pixel=False,
264
- add_mark=False
265
- )
266
- except:
267
- image_som = image_input
268
-
269
- logger.warning("finish processing")
270
- return image_som, str(parsed_content_list)
271
- except Exception as e:
272
- error_message = traceback.format_exc()
273
- logger.warning(error_message)
274
- return image_input, error_message
275
-
276
  logger.warning("Starting App.")
 
277
  with gr.Blocks() as demo:
278
  gr.Markdown(MARKDOWN)
279
- with gr.Row():
280
- with gr.Column():
281
- image_input_component = gr.Image(
282
- type='pil', label='Upload image')
283
- # set the threshold for removing the bounding boxes with low confidence, default is 0.05
284
- with gr.Accordion("Parameters", open=False) as parameter_row:
285
- box_threshold_component = gr.Slider(
286
- label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
287
- # set the threshold for removing the bounding boxes with large overlap, default is 0.1
288
- iou_threshold_component = gr.Slider(
289
- label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
290
- use_paddleocr_component = gr.Checkbox(
291
- label='Use PaddleOCR', value=True)
292
- imgsz_component = gr.Slider(
293
- label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
294
- # text box
295
- text_input_component = gr.Textbox(label='Text Input', placeholder='Text Input')
296
- submit_button_component = gr.Button(
297
- value='Submit', variant='primary')
298
- with gr.Column():
299
- image_output_component = gr.Image(type='pil', label='Image Output')
300
- text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
301
-
302
- submit_button_component.click(
303
- fn=process,
304
- inputs=[
305
- image_input_component,
306
- box_threshold_component,
307
- iou_threshold_component,
308
- use_paddleocr_component,
309
- imgsz_component,
310
- text_input_component
311
- ],
312
- outputs=[image_output_component, text_output_component]
313
- )
314
 
315
  demo.launch(debug=True, show_error=True, share=True)
316
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
 
3
  from typing import Optional
4
  import spaces
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  logger = logging.getLogger(__name__)
8
  logger.setLevel(logging.WARNING)
9
  handler = logging.StreamHandler()
10
  logger.addHandler(handler)
11
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  logger.warning(f"The repository is downloading to: {local_dir}")
14
 
 
 
 
15
  logger.warning(f"Repository downloaded to: {local_dir}")
16
 
17
 
 
 
 
 
18
  MARKDOWN = """
19
  <div align="center">
20
  <h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
 
43
  </div>
44
  """
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  logger.warning("Starting App.")
47
+
48
  with gr.Blocks() as demo:
49
  gr.Markdown(MARKDOWN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  demo.launch(debug=True, show_error=True, share=True)
52
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
app_1.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ import logging
3
+ from typing import Optional
4
+ import spaces
5
+ import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ import io
10
+ import re
11
+
12
+ import base64, os
13
+ from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
14
+ from util.som import MarkHelper, plot_boxes_with_marks, plot_circles_with_marks
15
+ from util.process_utils import pred_2_point, extract_bbox, extract_mark_id
16
+
17
+ import torch
18
+ from PIL import Image
19
+
20
+ from huggingface_hub import snapshot_download
21
+ import torch
22
+ from transformers import AutoModelForCausalLM
23
+ from transformers import AutoProcessor
24
+
25
+ logger = logging.getLogger(__name__)
26
+ logger.setLevel(logging.WARNING)
27
+ handler = logging.StreamHandler()
28
+ logger.addHandler(handler)
29
+
30
+ # Define repository and local directory
31
+ repo_id = "microsoft/OmniParser-v2.0" # HF repo
32
+ local_dir = "weights" # Target local directory
33
+
34
+ som_generator = MarkHelper()
35
+ magma_som_prompt = "<image>\nIn this view I need to click a button to \"{}\"? Provide the coordinates and the mark index of the containing bounding box if applicable."
36
+ magma_qa_prompt = "<image>\n{} Answer the question briefly."
37
+ magma_model_id = "microsoft/Magma-8B"
38
+ magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True)
39
+ magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
40
+ magam_model.to("cuda")
41
+
42
+ logger.warning(f"The repository is downloading to: {local_dir}")
43
+
44
+ # Download the entire repository
45
+ snapshot_download(repo_id=repo_id, local_dir=local_dir)
46
+
47
+ logger.warning(f"Repository downloaded to: {local_dir}")
48
+
49
+
50
+ yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
51
+ caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
52
+ # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
53
+
54
+ MARKDOWN = """
55
+ <div align="center">
56
+ <h2>Magma: A Foundation Model for Multimodal AI Agents</h2>
57
+
58
+ [Jianwei Yang](https://jwyang.github.io/)<sup>*</sup><sup>1</sup><sup>†</sup>&nbsp;
59
+ [Reuben Tan](https://cs-people.bu.edu/rxtan/)<sup>1</sup><sup>†</sup>&nbsp;
60
+ [Qianhui Wu](https://qianhuiwu.github.io/)<sup>1</sup><sup>†</sup>&nbsp;
61
+ [Ruijie Zheng](https://ruijiezheng.com/)<sup>2</sup><sup>‡</sup>&nbsp;
62
+ [Baolin Peng](https://scholar.google.com/citations?user=u1CNjgwAAAAJ&hl=en&oi=ao)<sup>1</sup><sup>‡</sup>&nbsp;
63
+ [Yongyuan Liang](https://cheryyunl.github.io)<sup>2</sup><sup>‡</sup>
64
+ [Yu Gu](https://users.umiacs.umd.edu/~hal/)<sup>1</sup>&nbsp;
65
+ [Mu Cai](https://pages.cs.wisc.edu/~mucai/)<sup>3</sup>&nbsp;
66
+ [Seonghyeon Ye](https://seonghyeonye.github.io/)<sup>4</sup>&nbsp;
67
+ [Joel Jang](https://joeljang.github.io/)<sup>5</sup>&nbsp;
68
+ [Yuquan Deng](https://scholar.google.com/citations?user=LTC0Q6YAAAAJ&hl=en)<sup>5</sup>&nbsp;
69
+ [Lars Liden](https://sites.google.com/site/larsliden)<sup>1</sup>&nbsp;
70
+ [Jianfeng Gao](https://www.microsoft.com/en-us/research/people/jfgao/)<sup>1</sup><sup>▽</sup>
71
+
72
+ <sup>1</sup> Microsoft Research; <sup>2</sup> University of Maryland; <sup>3</sup> University of Wisconsin-Madison; <sup>4</sup> KAIST; <sup>5</sup> University of Washington
73
+
74
+ <sup>*</sup> Project lead <sup>†</sup> First authors <sup>‡</sup> Second authors <sup>▽</sup> Leadership
75
+
76
+ \[[arXiv Paper](https://www.arxiv.org/pdf/2502.13130)\] &nbsp; \[[Project Page](https://microsoft.github.io/Magma/)\] &nbsp; \[[Github Repo](https://github.com/microsoft/Magma)\] &nbsp; \[[Hugging Face Model](https://huggingface.co/microsoft/Magma-8B)\] &nbsp;
77
+
78
+ This demo is powered by [Gradio](https://gradio.app/) and uses OmniParserv2 to generate Set-of-Mark prompts.
79
+ </div>
80
+ """
81
+
82
+ DEVICE = torch.device('cuda')
83
+
84
+ @spaces.GPU
85
+ @torch.inference_mode()
86
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
87
+ def get_som_response(instruction, image_som):
88
+ prompt = magma_som_prompt.format(instruction)
89
+ if magam_model.config.mm_use_image_start_end:
90
+ qs = prompt.replace('<image>', '<image_start><image><image_end>')
91
+ else:
92
+ qs = prompt
93
+ convs = [{"role": "user", "content": qs}]
94
+ convs = [{"role": "system", "content": "You are agent that can see, talk and act."}] + convs
95
+ prompt = magma_processor.tokenizer.apply_chat_template(
96
+ convs,
97
+ tokenize=False,
98
+ add_generation_prompt=True
99
+ )
100
+
101
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
102
+ inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
103
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
104
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
105
+ # inputs = inputs.to("cuda")
106
+ inputs = inputs.to("cuda", dtype=torch.bfloat16)
107
+
108
+ magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
109
+ with torch.inference_mode():
110
+ output_ids = magam_model.generate(
111
+ **inputs,
112
+ temperature=0.0,
113
+ do_sample=False,
114
+ num_beams=1,
115
+ max_new_tokens=128,
116
+ use_cache=True
117
+ )
118
+
119
+ prompt_decoded = magma_processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
120
+ response = magma_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
121
+ response = response.replace(prompt_decoded, '').strip()
122
+ return response
123
+
124
+ @spaces.GPU
125
+ @torch.inference_mode()
126
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
127
+ def get_qa_response(instruction, image):
128
+ prompt = magma_qa_prompt.format(instruction)
129
+ if magam_model.config.mm_use_image_start_end:
130
+ qs = prompt.replace('<image>', '<image_start><image><image_end>')
131
+ else:
132
+ qs = prompt
133
+ convs = [{"role": "user", "content": qs}]
134
+ convs = [{"role": "system", "content": "You are agent that can see, talk and act."}] + convs
135
+ prompt = magma_processor.tokenizer.apply_chat_template(
136
+ convs,
137
+ tokenize=False,
138
+ add_generation_prompt=True
139
+ )
140
+
141
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
142
+ inputs = magma_processor(images=[image], texts=prompt, return_tensors="pt")
143
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
144
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
145
+ # inputs = inputs.to("cuda")
146
+ inputs = inputs.to("cuda", dtype=torch.bfloat16)
147
+
148
+ magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
149
+ with torch.inference_mode():
150
+ output_ids = magam_model.generate(
151
+ **inputs,
152
+ temperature=0.0,
153
+ do_sample=False,
154
+ num_beams=1,
155
+ max_new_tokens=128,
156
+ use_cache=True
157
+ )
158
+
159
+ prompt_decoded = magma_processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
160
+ response = magma_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
161
+ response = response.replace(prompt_decoded, '').strip()
162
+ return response
163
+
164
+ @spaces.GPU
165
+ @torch.inference_mode()
166
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
167
+ def process(
168
+ image_input,
169
+ box_threshold,
170
+ iou_threshold,
171
+ use_paddleocr,
172
+ imgsz,
173
+ instruction,
174
+ ) -> Optional[Image.Image]:
175
+
176
+ logger.warning("Starting processing.")
177
+ try:
178
+ # image_save_path = 'imgs/saved_image_demo.png'
179
+ # image_input.save(image_save_path)
180
+ # image = Image.open(image_save_path)
181
+ box_overlay_ratio = image_input.size[0] / 3200
182
+ draw_bbox_config = {
183
+ 'text_scale': 0.8 * box_overlay_ratio,
184
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
185
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
186
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
187
+ }
188
+
189
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
190
+ text, ocr_bbox = ocr_bbox_rslt
191
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
192
+ parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
193
+
194
+ if len(instruction) == 0:
195
+ logger.warning('finish processing')
196
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
197
+ return image, str(parsed_content_list)
198
+
199
+ elif instruction.startswith('Q:'):
200
+ response = get_qa_response(instruction, image_input)
201
+ return image_input, response
202
+
203
+ # parsed_content_list = str(parsed_content_list)
204
+ # convert xywh to yxhw
205
+ label_coordinates_yxhw = {}
206
+ for key, val in label_coordinates.items():
207
+ if val[2] < 0 or val[3] < 0:
208
+ continue
209
+ label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
210
+ image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
211
+
212
+ # convert xywh to xyxy
213
+ for key, val in label_coordinates.items():
214
+ label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
215
+
216
+ # normalize label_coordinates
217
+ for key, val in label_coordinates.items():
218
+ label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
219
+
220
+ magma_response = get_som_response(instruction, image_som)
221
+ logger.warning("magma repsonse: ", magma_response)
222
+
223
+ # map magma_response into the mark id
224
+ mark_id = extract_mark_id(magma_response)
225
+ if mark_id is not None:
226
+ if str(mark_id) in label_coordinates:
227
+ bbox_for_mark = label_coordinates[str(mark_id)]
228
+ else:
229
+ bbox_for_mark = None
230
+ else:
231
+ bbox_for_mark = None
232
+
233
+ if bbox_for_mark:
234
+ # draw bbox_for_mark on the image
235
+ image_som = plot_boxes_with_marks(
236
+ image_input,
237
+ [label_coordinates_yxhw[str(mark_id)]],
238
+ som_generator,
239
+ edgecolor=(255,127,111),
240
+ alpha=30,
241
+ fn_save=None,
242
+ normalized_to_pixel=False,
243
+ add_mark=False
244
+ )
245
+ else:
246
+ try:
247
+ if 'box' in magma_response:
248
+ pred_bbox = extract_bbox(magma_response)
249
+ click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
250
+ click_point = [item / 1000 for item in click_point]
251
+ else:
252
+ click_point = pred_2_point(magma_response)
253
+ # de-normalize click_point (width, height)
254
+ click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
255
+
256
+ image_som = plot_circles_with_marks(
257
+ image_input,
258
+ [click_point],
259
+ som_generator,
260
+ edgecolor=(255,127,111),
261
+ linewidth=3,
262
+ fn_save=None,
263
+ normalized_to_pixel=False,
264
+ add_mark=False
265
+ )
266
+ except:
267
+ image_som = image_input
268
+
269
+ logger.warning("finish processing")
270
+ return image_som, str(parsed_content_list)
271
+ except Exception as e:
272
+ error_message = traceback.format_exc()
273
+ logger.warning(error_message)
274
+ return image_input, error_message
275
+
276
+ logger.warning("Starting App.")
277
+ with gr.Blocks() as demo:
278
+ gr.Markdown(MARKDOWN)
279
+ with gr.Row():
280
+ with gr.Column():
281
+ image_input_component = gr.Image(
282
+ type='pil', label='Upload image')
283
+ # set the threshold for removing the bounding boxes with low confidence, default is 0.05
284
+ with gr.Accordion("Parameters", open=False) as parameter_row:
285
+ box_threshold_component = gr.Slider(
286
+ label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
287
+ # set the threshold for removing the bounding boxes with large overlap, default is 0.1
288
+ iou_threshold_component = gr.Slider(
289
+ label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
290
+ use_paddleocr_component = gr.Checkbox(
291
+ label='Use PaddleOCR', value=True)
292
+ imgsz_component = gr.Slider(
293
+ label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
294
+ # text box
295
+ text_input_component = gr.Textbox(label='Text Input', placeholder='Text Input')
296
+ submit_button_component = gr.Button(
297
+ value='Submit', variant='primary')
298
+ with gr.Column():
299
+ image_output_component = gr.Image(type='pil', label='Image Output')
300
+ text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
301
+
302
+ submit_button_component.click(
303
+ fn=process,
304
+ inputs=[
305
+ image_input_component,
306
+ box_threshold_component,
307
+ iou_threshold_component,
308
+ use_paddleocr_component,
309
+ imgsz_component,
310
+ text_input_component
311
+ ],
312
+ outputs=[image_output_component, text_output_component]
313
+ )
314
+
315
+ demo.launch(debug=True, show_error=True, share=True)
316
+ # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
317
+ # demo.queue().launch(share=False)