Sanket17 commited on
Commit
9ffa530
·
verified ·
1 Parent(s): 597b56a

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +412 -0
utils.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from ultralytics import YOLO
2
+ import os
3
+ import io
4
+ import base64
5
+ import time
6
+ from PIL import Image, ImageDraw, ImageFont
7
+ import json
8
+ import requests
9
+ # utility function
10
+ import os
11
+ # from openai import AzureOpenAI
12
+
13
+ import json
14
+ import sys
15
+ import os
16
+ import cv2
17
+ import numpy as np
18
+ # %matplotlib inline
19
+ from matplotlib import pyplot as plt
20
+ import easyocr
21
+ from paddleocr import PaddleOCR
22
+ reader = easyocr.Reader(['en'])
23
+ paddle_ocr = PaddleOCR(
24
+ lang='en', # other lang also available
25
+ use_angle_cls=False,
26
+ use_gpu=False, # using cuda will conflict with pytorch in the same process
27
+ show_log=False,
28
+ max_batch_size=1024,
29
+ use_dilation=True, # improves accuracy
30
+ det_db_score_mode='slow', # improves accuracy
31
+ rec_batch_num=1024)
32
+ import time
33
+ import base64
34
+
35
+ import os
36
+ import ast
37
+ import torch
38
+ from typing import Tuple, List
39
+ from torchvision.ops import box_convert
40
+ import re
41
+ from torchvision.transforms import ToPILImage
42
+ import supervision as sv
43
+ import torchvision.transforms as T
44
+
45
+
46
+ def get_caption_model_processor(model_name, model_name_or_path="Salesforce/blip2-opt-2.7b", device=None):
47
+ if not device:
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ if model_name == "blip2":
50
+ from transformers import Blip2Processor, Blip2ForConditionalGeneration
51
+ processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
52
+ if device == 'cpu':
53
+ model = Blip2ForConditionalGeneration.from_pretrained(
54
+ model_name_or_path, device_map=None, torch_dtype=torch.float32
55
+ )
56
+ else:
57
+ model = Blip2ForConditionalGeneration.from_pretrained(
58
+ model_name_or_path, device_map=None, torch_dtype=torch.float16
59
+ ).to(device)
60
+ elif model_name == "florence2":
61
+ from transformers import AutoProcessor, AutoModelForCausalLM
62
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
63
+ if device == 'cpu':
64
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float32, trust_remote_code=True)
65
+ else:
66
+ model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
67
+ return {'model': model.to(device), 'processor': processor}
68
+
69
+
70
+ def get_yolo_model(model_path):
71
+ from ultralytics import YOLO
72
+ # Load the model.
73
+ model = YOLO(model_path)
74
+ return model
75
+
76
+
77
+ @torch.inference_mode()
78
+ def get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=None):
79
+ to_pil = ToPILImage()
80
+ if ocr_bbox:
81
+ non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
82
+ else:
83
+ non_ocr_boxes = filtered_boxes
84
+ croped_pil_image = []
85
+ for i, coord in enumerate(non_ocr_boxes):
86
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
87
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
88
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
89
+ croped_pil_image.append(to_pil(cropped_image))
90
+
91
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
92
+ if not prompt:
93
+ if 'florence' in model.config.name_or_path:
94
+ prompt = "<CAPTION>"
95
+ else:
96
+ prompt = "The image shows"
97
+
98
+ batch_size = 10 # Number of samples per batch
99
+ generated_texts = []
100
+ device = model.device
101
+
102
+ for i in range(0, len(croped_pil_image), batch_size):
103
+ batch = croped_pil_image[i:i+batch_size]
104
+ if model.device.type == 'cuda':
105
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device, dtype=torch.float16)
106
+ else:
107
+ inputs = processor(images=batch, text=[prompt]*len(batch), return_tensors="pt").to(device=device)
108
+ if 'florence' in model.config.name_or_path:
109
+ generated_ids = model.generate(input_ids=inputs["input_ids"],pixel_values=inputs["pixel_values"],max_new_tokens=1024,num_beams=3, do_sample=False)
110
+ else:
111
+ generated_ids = model.generate(**inputs, max_length=100, num_beams=5, no_repeat_ngram_size=2, early_stopping=True, num_return_sequences=1) # temperature=0.01, do_sample=True,
112
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
113
+ generated_text = [gen.strip() for gen in generated_text]
114
+ generated_texts.extend(generated_text)
115
+
116
+ return generated_texts
117
+
118
+
119
+
120
+ def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor):
121
+ to_pil = ToPILImage()
122
+ if ocr_bbox:
123
+ non_ocr_boxes = filtered_boxes[len(ocr_bbox):]
124
+ else:
125
+ non_ocr_boxes = filtered_boxes
126
+ croped_pil_image = []
127
+ for i, coord in enumerate(non_ocr_boxes):
128
+ xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1])
129
+ ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0])
130
+ cropped_image = image_source[ymin:ymax, xmin:xmax, :]
131
+ croped_pil_image.append(to_pil(cropped_image))
132
+
133
+ model, processor = caption_model_processor['model'], caption_model_processor['processor']
134
+ device = model.device
135
+ messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}]
136
+ prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
137
+
138
+ batch_size = 5 # Number of samples per batch
139
+ generated_texts = []
140
+
141
+ for i in range(0, len(croped_pil_image), batch_size):
142
+ images = croped_pil_image[i:i+batch_size]
143
+ image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images]
144
+ inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []}
145
+ texts = [prompt] * len(images)
146
+ for i, txt in enumerate(texts):
147
+ input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt")
148
+ inputs['input_ids'].append(input['input_ids'])
149
+ inputs['attention_mask'].append(input['attention_mask'])
150
+ inputs['pixel_values'].append(input['pixel_values'])
151
+ inputs['image_sizes'].append(input['image_sizes'])
152
+ max_len = max([x.shape[1] for x in inputs['input_ids']])
153
+ for i, v in enumerate(inputs['input_ids']):
154
+ inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1)
155
+ inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1)
156
+ inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()}
157
+
158
+ generation_args = {
159
+ "max_new_tokens": 25,
160
+ "temperature": 0.01,
161
+ "do_sample": False,
162
+ }
163
+ generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args)
164
+ # # remove input tokens
165
+ generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:]
166
+ response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
167
+ response = [res.strip('\n').strip() for res in response]
168
+ generated_texts.extend(response)
169
+
170
+ return generated_texts
171
+
172
+ def remove_overlap(boxes, iou_threshold, ocr_bbox=None):
173
+ assert ocr_bbox is None or isinstance(ocr_bbox, List)
174
+
175
+ def box_area(box):
176
+ return (box[2] - box[0]) * (box[3] - box[1])
177
+
178
+ def intersection_area(box1, box2):
179
+ x1 = max(box1[0], box2[0])
180
+ y1 = max(box1[1], box2[1])
181
+ x2 = min(box1[2], box2[2])
182
+ y2 = min(box1[3], box2[3])
183
+ return max(0, x2 - x1) * max(0, y2 - y1)
184
+
185
+ def IoU(box1, box2):
186
+ intersection = intersection_area(box1, box2)
187
+ union = box_area(box1) + box_area(box2) - intersection + 1e-6
188
+ if box_area(box1) > 0 and box_area(box2) > 0:
189
+ ratio1 = intersection / box_area(box1)
190
+ ratio2 = intersection / box_area(box2)
191
+ else:
192
+ ratio1, ratio2 = 0, 0
193
+ return max(intersection / union, ratio1, ratio2)
194
+
195
+ boxes = boxes.tolist()
196
+ filtered_boxes = []
197
+ if ocr_bbox:
198
+ filtered_boxes.extend(ocr_bbox)
199
+ # print('ocr_bbox!!!', ocr_bbox)
200
+ for i, box1 in enumerate(boxes):
201
+ # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j):
202
+ is_valid_box = True
203
+ for j, box2 in enumerate(boxes):
204
+ if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2):
205
+ is_valid_box = False
206
+ break
207
+ if is_valid_box:
208
+ # add the following 2 lines to include ocr bbox
209
+ if ocr_bbox:
210
+ if not any(IoU(box1, box3) > iou_threshold for k, box3 in enumerate(ocr_bbox)):
211
+ filtered_boxes.append(box1)
212
+ else:
213
+ filtered_boxes.append(box1)
214
+ return torch.tensor(filtered_boxes)
215
+
216
+ def load_image(image_path: str) -> Tuple[np.array, torch.Tensor]:
217
+ transform = T.Compose(
218
+ [
219
+ T.RandomResize([800], max_size=1333),
220
+ T.ToTensor(),
221
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
222
+ ]
223
+ )
224
+ image_source = Image.open(image_path).convert("RGB")
225
+ image = np.asarray(image_source)
226
+ image_transformed, _ = transform(image_source, None)
227
+ return image, image_transformed
228
+
229
+
230
+ def annotate(image_source: np.ndarray, boxes: torch.Tensor, logits: torch.Tensor, phrases: List[str], text_scale: float,
231
+ text_padding=5, text_thickness=2, thickness=3) -> np.ndarray:
232
+ """
233
+ This function annotates an image with bounding boxes and labels.
234
+ Parameters:
235
+ image_source (np.ndarray): The source image to be annotated.
236
+ boxes (torch.Tensor): A tensor containing bounding box coordinates. in cxcywh format, pixel scale
237
+ logits (torch.Tensor): A tensor containing confidence scores for each bounding box.
238
+ phrases (List[str]): A list of labels for each bounding box.
239
+ text_scale (float): The scale of the text to be displayed. 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
240
+ Returns:
241
+ np.ndarray: The annotated image.
242
+ """
243
+ h, w, _ = image_source.shape
244
+ boxes = boxes * torch.Tensor([w, h, w, h])
245
+ xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
246
+ xywh = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xywh").numpy()
247
+ detections = sv.Detections(xyxy=xyxy)
248
+
249
+ labels = [f"{phrase}" for phrase in range(boxes.shape[0])]
250
+
251
+ from util.box_annotator import BoxAnnotator
252
+ box_annotator = BoxAnnotator(text_scale=text_scale, text_padding=text_padding,text_thickness=text_thickness,thickness=thickness) # 0.8 for mobile/web, 0.3 for desktop # 0.4 for mind2web
253
+ annotated_frame = image_source.copy()
254
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels, image_size=(w,h))
255
+
256
+ label_coordinates = {f"{phrase}": v for phrase, v in zip(phrases, xywh)}
257
+ return annotated_frame, label_coordinates
258
+
259
+
260
+ def predict(model, image, caption, box_threshold, text_threshold):
261
+ """ Use huggingface model to replace the original model
262
+ """
263
+ model, processor = model['model'], model['processor']
264
+ device = model.device
265
+
266
+ inputs = processor(images=image, text=caption, return_tensors="pt").to(device)
267
+ with torch.no_grad():
268
+ outputs = model(**inputs)
269
+
270
+ results = processor.post_process_grounded_object_detection(
271
+ outputs,
272
+ inputs.input_ids,
273
+ box_threshold=box_threshold, # 0.4,
274
+ text_threshold=text_threshold, # 0.3,
275
+ target_sizes=[image.size[::-1]]
276
+ )[0]
277
+ boxes, logits, phrases = results["boxes"], results["scores"], results["labels"]
278
+ return boxes, logits, phrases
279
+
280
+
281
+ def predict_yolo(model, image_path, box_threshold):
282
+ """ Use huggingface model to replace the original model
283
+ """
284
+ # model = model['model']
285
+
286
+ result = model.predict(
287
+ source=image_path,
288
+ conf=box_threshold,
289
+ # iou=0.5, # default 0.7
290
+ )
291
+ boxes = result[0].boxes.xyxy#.tolist() # in pixel space
292
+ conf = result[0].boxes.conf
293
+ phrases = [str(i) for i in range(len(boxes))]
294
+
295
+ return boxes, conf, phrases
296
+
297
+
298
+ def get_som_labeled_img(img_path, model=None, BOX_TRESHOLD = 0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None):
299
+ """ ocr_bbox: list of xyxy format bbox
300
+ """
301
+ TEXT_PROMPT = "clickable buttons on the screen"
302
+ # BOX_TRESHOLD = 0.02 # 0.05/0.02 for web and 0.1 for mobile
303
+ TEXT_TRESHOLD = 0.01 # 0.9 # 0.01
304
+ image_source = Image.open(img_path).convert("RGB")
305
+ w, h = image_source.size
306
+ # import pdb; pdb.set_trace()
307
+ if False: # TODO
308
+ xyxy, logits, phrases = predict(model=model, image=image_source, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD)
309
+ else:
310
+ xyxy, logits, phrases = predict_yolo(model=model, image_path=img_path, box_threshold=BOX_TRESHOLD)
311
+ xyxy = xyxy / torch.Tensor([w, h, w, h]).to(xyxy.device)
312
+ image_source = np.asarray(image_source)
313
+ phrases = [str(i) for i in range(len(phrases))]
314
+
315
+ # annotate the image with labels
316
+ h, w, _ = image_source.shape
317
+ if ocr_bbox:
318
+ ocr_bbox = torch.tensor(ocr_bbox) / torch.Tensor([w, h, w, h])
319
+ ocr_bbox=ocr_bbox.tolist()
320
+ else:
321
+ print('no ocr bbox!!!')
322
+ ocr_bbox = None
323
+ filtered_boxes = remove_overlap(boxes=xyxy, iou_threshold=iou_threshold, ocr_bbox=ocr_bbox)
324
+
325
+ # get parsed icon local semantics
326
+ if use_local_semantics:
327
+ caption_model = caption_model_processor['model']
328
+ if 'phi3_v' in caption_model.config.model_type:
329
+ parsed_content_icon = get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
330
+ else:
331
+ parsed_content_icon = get_parsed_content_icon(filtered_boxes, ocr_bbox, image_source, caption_model_processor, prompt=prompt)
332
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
333
+ icon_start = len(ocr_text)
334
+ parsed_content_icon_ls = []
335
+ for i, txt in enumerate(parsed_content_icon):
336
+ parsed_content_icon_ls.append(f"Icon Box ID {str(i+icon_start)}: {txt}")
337
+ parsed_content_merged = ocr_text + parsed_content_icon_ls
338
+ else:
339
+ ocr_text = [f"Text Box ID {i}: {txt}" for i, txt in enumerate(ocr_text)]
340
+ parsed_content_merged = ocr_text
341
+
342
+ filtered_boxes = box_convert(boxes=filtered_boxes, in_fmt="xyxy", out_fmt="cxcywh")
343
+
344
+ phrases = [i for i in range(len(filtered_boxes))]
345
+
346
+ # draw boxes
347
+ if draw_bbox_config:
348
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, **draw_bbox_config)
349
+ else:
350
+ annotated_frame, label_coordinates = annotate(image_source=image_source, boxes=filtered_boxes, logits=logits, phrases=phrases, text_scale=text_scale, text_padding=text_padding)
351
+
352
+ pil_img = Image.fromarray(annotated_frame)
353
+ buffered = io.BytesIO()
354
+ pil_img.save(buffered, format="PNG")
355
+ encoded_image = base64.b64encode(buffered.getvalue()).decode('ascii')
356
+ if output_coord_in_ratio:
357
+ # h, w, _ = image_source.shape
358
+ label_coordinates = {k: [v[0]/w, v[1]/h, v[2]/w, v[3]/h] for k, v in label_coordinates.items()}
359
+ assert w == annotated_frame.shape[1] and h == annotated_frame.shape[0]
360
+
361
+ return encoded_image, label_coordinates, parsed_content_merged
362
+
363
+
364
+ def get_xywh(input):
365
+ x, y, w, h = input[0][0], input[0][1], input[2][0] - input[0][0], input[2][1] - input[0][1]
366
+ x, y, w, h = int(x), int(y), int(w), int(h)
367
+ return x, y, w, h
368
+
369
+ def get_xyxy(input):
370
+ x, y, xp, yp = input[0][0], input[0][1], input[2][0], input[2][1]
371
+ x, y, xp, yp = int(x), int(y), int(xp), int(yp)
372
+ return x, y, xp, yp
373
+
374
+ def get_xywh_yolo(input):
375
+ x, y, w, h = input[0], input[1], input[2] - input[0], input[3] - input[1]
376
+ x, y, w, h = int(x), int(y), int(w), int(h)
377
+ return x, y, w, h
378
+
379
+
380
+
381
+ def check_ocr_box(image_path, display_img = True, output_bb_format='xywh', goal_filtering=None, easyocr_args=None, use_paddleocr=False):
382
+ if use_paddleocr:
383
+ result = paddle_ocr.ocr(image_path, cls=False)[0]
384
+ coord = [item[0] for item in result]
385
+ text = [item[1][0] for item in result]
386
+ else: # EasyOCR
387
+ if easyocr_args is None:
388
+ easyocr_args = {}
389
+ result = reader.readtext(image_path, **easyocr_args)
390
+ # print('goal filtering pred:', result[-5:])
391
+ coord = [item[0] for item in result]
392
+ text = [item[1] for item in result]
393
+ # read the image using cv2
394
+ if display_img:
395
+ opencv_img = cv2.imread(image_path)
396
+ opencv_img = cv2.cvtColor(opencv_img, cv2.COLOR_RGB2BGR)
397
+ bb = []
398
+ for item in coord:
399
+ x, y, a, b = get_xywh(item)
400
+ # print(x, y, a, b)
401
+ bb.append((x, y, a, b))
402
+ cv2.rectangle(opencv_img, (x, y), (x+a, y+b), (0, 255, 0), 2)
403
+
404
+ # Display the image
405
+ plt.imshow(opencv_img)
406
+ else:
407
+ if output_bb_format == 'xywh':
408
+ bb = [get_xywh(item) for item in coord]
409
+ elif output_bb_format == 'xyxy':
410
+ bb = [get_xyxy(item) for item in coord]
411
+ # print('bounding box!!!', bb)
412
+ return (text, bb), goal_filtering