drlon commited on
Commit
1f2a2c3
·
1 Parent(s): 1fd1cad
Files changed (1) hide show
  1. app.py +271 -3
app.py CHANGED
@@ -3,14 +3,54 @@ import logging
3
  from typing import Optional
4
  import spaces
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  logger = logging.getLogger(__name__)
8
  logger.setLevel(logging.WARNING)
9
  handler = logging.StreamHandler()
10
  logger.addHandler(handler)
11
-
12
  print("here")
13
- logger.warning(f"The repository is downloading to:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  MARKDOWN = """
16
  <div align="center">
@@ -40,10 +80,238 @@ This demo is powered by [Gradio](https://gradio.app/) and uses OmniParserv2 to g
40
  </div>
41
  """
42
 
43
- logger.warning("Starting App.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  with gr.Blocks() as demo:
46
  gr.Markdown(MARKDOWN)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # demo.launch(debug=True, show_error=True, share=True)
49
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')
 
3
  from typing import Optional
4
  import spaces
5
  import gradio as gr
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+ import io
10
+ import re
11
+
12
+ import base64, os
13
+ from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
14
+ from util.som import MarkHelper, plot_boxes_with_marks, plot_circles_with_marks
15
+ from util.process_utils import pred_2_point, extract_bbox, extract_mark_id
16
+
17
+ import torch
18
+ from PIL import Image
19
+
20
+ from huggingface_hub import snapshot_download
21
+ import torch
22
+ from transformers import AutoModelForCausalLM
23
+ from transformers import AutoProcessor
24
 
25
  logger = logging.getLogger(__name__)
26
  logger.setLevel(logging.WARNING)
27
  handler = logging.StreamHandler()
28
  logger.addHandler(handler)
 
29
  print("here")
30
+
31
+ # Define repository and local directory
32
+ repo_id = "microsoft/OmniParser-v2.0" # HF repo
33
+ local_dir = "weights" # Target local directory
34
+
35
+ som_generator = MarkHelper()
36
+ magma_som_prompt = "<image>\nIn this view I need to click a button to \"{}\"? Provide the coordinates and the mark index of the containing bounding box if applicable."
37
+ magma_qa_prompt = "<image>\n{} Answer the question briefly."
38
+ magma_model_id = "microsoft/Magma-8B"
39
+ magam_model = AutoModelForCausalLM.from_pretrained(magma_model_id, trust_remote_code=True)
40
+ magma_processor = AutoProcessor.from_pretrained(magma_model_id, trust_remote_code=True)
41
+ magam_model.to("cuda")
42
+
43
+ logger.warning(f"The repository is downloading to: {local_dir}")
44
+
45
+ # Download the entire repository
46
+ snapshot_download(repo_id=repo_id, local_dir=local_dir)
47
+
48
+ logger.warning(f"Repository downloaded to: {local_dir}")
49
+
50
+
51
+ yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
52
+ caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption")
53
+ # caption_model_processor = get_caption_model_processor(model_name="blip2", model_name_or_path="weights/icon_caption_blip2")
54
 
55
  MARKDOWN = """
56
  <div align="center">
 
80
  </div>
81
  """
82
 
83
+ DEVICE = torch.device('cuda')
84
+
85
+ @spaces.GPU
86
+ @torch.inference_mode()
87
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
88
+ def get_som_response(instruction, image_som):
89
+ prompt = magma_som_prompt.format(instruction)
90
+ if magam_model.config.mm_use_image_start_end:
91
+ qs = prompt.replace('<image>', '<image_start><image><image_end>')
92
+ else:
93
+ qs = prompt
94
+ convs = [{"role": "user", "content": qs}]
95
+ convs = [{"role": "system", "content": "You are agent that can see, talk and act."}] + convs
96
+ prompt = magma_processor.tokenizer.apply_chat_template(
97
+ convs,
98
+ tokenize=False,
99
+ add_generation_prompt=True
100
+ )
101
+
102
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
103
+ inputs = magma_processor(images=[image_som], texts=prompt, return_tensors="pt")
104
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
105
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
106
+ # inputs = inputs.to("cuda")
107
+ inputs = inputs.to("cuda", dtype=torch.bfloat16)
108
+
109
+ magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
110
+ with torch.inference_mode():
111
+ output_ids = magam_model.generate(
112
+ **inputs,
113
+ temperature=0.0,
114
+ do_sample=False,
115
+ num_beams=1,
116
+ max_new_tokens=128,
117
+ use_cache=True
118
+ )
119
+
120
+ prompt_decoded = magma_processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
121
+ response = magma_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
122
+ response = response.replace(prompt_decoded, '').strip()
123
+ return response
124
 
125
+ @spaces.GPU
126
+ @torch.inference_mode()
127
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
128
+ def get_qa_response(instruction, image):
129
+ prompt = magma_qa_prompt.format(instruction)
130
+ if magam_model.config.mm_use_image_start_end:
131
+ qs = prompt.replace('<image>', '<image_start><image><image_end>')
132
+ else:
133
+ qs = prompt
134
+ convs = [{"role": "user", "content": qs}]
135
+ convs = [{"role": "system", "content": "You are agent that can see, talk and act."}] + convs
136
+ prompt = magma_processor.tokenizer.apply_chat_template(
137
+ convs,
138
+ tokenize=False,
139
+ add_generation_prompt=True
140
+ )
141
+
142
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
143
+ inputs = magma_processor(images=[image], texts=prompt, return_tensors="pt")
144
+ inputs['pixel_values'] = inputs['pixel_values'].unsqueeze(0).to(torch.bfloat16) # Add .to(torch.bfloat16) here for explicit casting
145
+ inputs['image_sizes'] = inputs['image_sizes'].unsqueeze(0)
146
+ # inputs = inputs.to("cuda")
147
+ inputs = inputs.to("cuda", dtype=torch.bfloat16)
148
+
149
+ magam_model.generation_config.pad_token_id = magma_processor.tokenizer.pad_token_id
150
+ with torch.inference_mode():
151
+ output_ids = magam_model.generate(
152
+ **inputs,
153
+ temperature=0.0,
154
+ do_sample=False,
155
+ num_beams=1,
156
+ max_new_tokens=128,
157
+ use_cache=True
158
+ )
159
+
160
+ prompt_decoded = magma_processor.batch_decode(inputs['input_ids'], skip_special_tokens=True)[0]
161
+ response = magma_processor.batch_decode(output_ids, skip_special_tokens=True)[0]
162
+ response = response.replace(prompt_decoded, '').strip()
163
+ return response
164
+
165
+ @spaces.GPU
166
+ @torch.inference_mode()
167
+ @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
168
+ def process(
169
+ image_input,
170
+ box_threshold,
171
+ iou_threshold,
172
+ use_paddleocr,
173
+ imgsz,
174
+ instruction,
175
+ ) -> Optional[Image.Image]:
176
+
177
+ logger.warning("Starting processing.")
178
+ try:
179
+ # image_save_path = 'imgs/saved_image_demo.png'
180
+ # image_input.save(image_save_path)
181
+ # image = Image.open(image_save_path)
182
+ box_overlay_ratio = image_input.size[0] / 3200
183
+ draw_bbox_config = {
184
+ 'text_scale': 0.8 * box_overlay_ratio,
185
+ 'text_thickness': max(int(2 * box_overlay_ratio), 1),
186
+ 'text_padding': max(int(3 * box_overlay_ratio), 1),
187
+ 'thickness': max(int(3 * box_overlay_ratio), 1),
188
+ }
189
+
190
+ ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr)
191
+ text, ocr_bbox = ocr_bbox_rslt
192
+ dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,)
193
+ parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)])
194
+
195
+ if len(instruction) == 0:
196
+ logger.warning('finish processing')
197
+ image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
198
+ return image, str(parsed_content_list)
199
+
200
+ elif instruction.startswith('Q:'):
201
+ response = get_qa_response(instruction, image_input)
202
+ return image_input, response
203
+
204
+ # parsed_content_list = str(parsed_content_list)
205
+ # convert xywh to yxhw
206
+ label_coordinates_yxhw = {}
207
+ for key, val in label_coordinates.items():
208
+ if val[2] < 0 or val[3] < 0:
209
+ continue
210
+ label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]]
211
+ image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False)
212
+
213
+ # convert xywh to xyxy
214
+ for key, val in label_coordinates.items():
215
+ label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]]
216
+
217
+ # normalize label_coordinates
218
+ for key, val in label_coordinates.items():
219
+ label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]]
220
+
221
+ magma_response = get_som_response(instruction, image_som)
222
+ logger.warning("magma repsonse: ", magma_response)
223
+
224
+ # map magma_response into the mark id
225
+ mark_id = extract_mark_id(magma_response)
226
+ if mark_id is not None:
227
+ if str(mark_id) in label_coordinates:
228
+ bbox_for_mark = label_coordinates[str(mark_id)]
229
+ else:
230
+ bbox_for_mark = None
231
+ else:
232
+ bbox_for_mark = None
233
+
234
+ if bbox_for_mark:
235
+ # draw bbox_for_mark on the image
236
+ image_som = plot_boxes_with_marks(
237
+ image_input,
238
+ [label_coordinates_yxhw[str(mark_id)]],
239
+ som_generator,
240
+ edgecolor=(255,127,111),
241
+ alpha=30,
242
+ fn_save=None,
243
+ normalized_to_pixel=False,
244
+ add_mark=False
245
+ )
246
+ else:
247
+ try:
248
+ if 'box' in magma_response:
249
+ pred_bbox = extract_bbox(magma_response)
250
+ click_point = [(pred_bbox[0][0] + pred_bbox[1][0]) / 2, (pred_bbox[0][1] + pred_bbox[1][1]) / 2]
251
+ click_point = [item / 1000 for item in click_point]
252
+ else:
253
+ click_point = pred_2_point(magma_response)
254
+ # de-normalize click_point (width, height)
255
+ click_point = [click_point[0] * image_input.size[0], click_point[1] * image_input.size[1]]
256
+
257
+ image_som = plot_circles_with_marks(
258
+ image_input,
259
+ [click_point],
260
+ som_generator,
261
+ edgecolor=(255,127,111),
262
+ linewidth=3,
263
+ fn_save=None,
264
+ normalized_to_pixel=False,
265
+ add_mark=False
266
+ )
267
+ except:
268
+ image_som = image_input
269
+
270
+ logger.warning("finish processing")
271
+ return image_som, str(parsed_content_list)
272
+ except Exception as e:
273
+ error_message = traceback.format_exc()
274
+ logger.warning(error_message)
275
+ return image_input, error_message
276
+
277
+ logger.warning("Starting App.")
278
  with gr.Blocks() as demo:
279
  gr.Markdown(MARKDOWN)
280
+ with gr.Row():
281
+ with gr.Column():
282
+ image_input_component = gr.Image(
283
+ type='pil', label='Upload image')
284
+ # set the threshold for removing the bounding boxes with low confidence, default is 0.05
285
+ with gr.Accordion("Parameters", open=False) as parameter_row:
286
+ box_threshold_component = gr.Slider(
287
+ label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05)
288
+ # set the threshold for removing the bounding boxes with large overlap, default is 0.1
289
+ iou_threshold_component = gr.Slider(
290
+ label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1)
291
+ use_paddleocr_component = gr.Checkbox(
292
+ label='Use PaddleOCR', value=True)
293
+ imgsz_component = gr.Slider(
294
+ label='Icon Detect Image Size', minimum=640, maximum=1920, step=32, value=640)
295
+ # text box
296
+ text_input_component = gr.Textbox(label='Text Input', placeholder='Text Input')
297
+ submit_button_component = gr.Button(
298
+ value='Submit', variant='primary')
299
+ with gr.Column():
300
+ image_output_component = gr.Image(type='pil', label='Image Output')
301
+ text_output_component = gr.Textbox(label='Parsed screen elements', placeholder='Text Output')
302
+
303
+ submit_button_component.click(
304
+ fn=process,
305
+ inputs=[
306
+ image_input_component,
307
+ box_threshold_component,
308
+ iou_threshold_component,
309
+ use_paddleocr_component,
310
+ imgsz_component,
311
+ text_input_component
312
+ ],
313
+ outputs=[image_output_component, text_output_component]
314
+ )
315
 
316
  # demo.launch(debug=True, show_error=True, share=True)
317
  # demo.launch(share=True, server_port=7861, server_name='0.0.0.0')