Spaces:
Runtime error
Runtime error
| import spacy | |
| import torch | |
| from tqdm import tqdm | |
| import numpy as np | |
| import itertools | |
| nlp = spacy.load('en_core_web_md') | |
| def get_iou(box1, box2): | |
| # box1 and box2 should be in the format [x1, y1, x2, y2] | |
| intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ | |
| max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) | |
| area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| union = area_box1 + area_box2 - intersection | |
| iou = intersection / union if union > 0 else 0 | |
| return iou | |
| # def find_root(token): | |
| # if token.pos_ == "VERB": | |
| # return token | |
| # while token.dep_ not in ["pobj", "nsubj", "ROOT", "npadvmod", "dobj", "det", "prep", "punct", "cc", "conj", "acl", "dep", "appos", "relcl", "advmod", "nmod", "attr"]: | |
| # token = token.head | |
| # return token | |
| def find_root(token): | |
| if token.pos_ == "VERB": | |
| return token | |
| while token.dep_ in ["compound", "amod"]: | |
| token = token.head | |
| return token | |
| def get_object_from_text(text, verbose=False): | |
| if len(text.split(" ")) == 3: | |
| text = text.split(" ") | |
| return [text[0], text[-1]] | |
| doc = nlp(text) | |
| if verbose: | |
| for TT in doc: | |
| print(TT.text, TT.pos_, TT.dep_, TT.head) | |
| roots = set() | |
| for i, token in enumerate(doc): | |
| roots.add(find_root(token)) | |
| exprs = [] | |
| roots = sorted(list(roots), key=lambda token: token.idx) | |
| first_nsubj = True | |
| if verbose: | |
| print(roots) | |
| for root in roots: | |
| if root.pos_ not in ["NOUN", "PROPN"]: | |
| continue | |
| if root.dep_ not in ["pobj", "nsubj"]: | |
| continue | |
| if not first_nsubj and root.dep_ in ["nsubj"]: | |
| continue | |
| exprs.append([]) | |
| for token in doc: | |
| if find_root(token) == root: | |
| exprs[-1].append(token.text) | |
| exprs[-1] = " ".join(exprs[-1]).replace(" '", "'") | |
| if exprs[-1] not in text: | |
| if verbose: | |
| print("not in text error:", exprs[-1], "#",text) | |
| # for TT in doc: | |
| # print(TT.text, TT.pos_, TT.dep_, TT.head) | |
| # import pdb; pdb.set_trace() | |
| exprs.pop() | |
| if first_nsubj and root.dep_ in ["nsubj"]: | |
| first_nsubj = False | |
| if len(exprs) <= 1: | |
| if verbose: | |
| print("not enough exprs error:", exprs, "#",text) | |
| return [] | |
| return exprs | |
| def is_correct(input_ids, logits, tokenizer, object: str, topk=5, N=10): | |
| answer_id = torch.tensor(tokenizer(f" {object}", add_special_tokens=False)["input_ids"]).to(input_ids.device) | |
| answer_begin_idx = (input_ids == answer_id[0]).nonzero() | |
| answer_idx = None | |
| for (batch_idx, IDX) in answer_begin_idx: | |
| try: | |
| if (input_ids[batch_idx, IDX:IDX+len(answer_id)] == answer_id).all(): | |
| answer_idx = list(range(IDX-1, IDX+len(answer_id)-1)) | |
| except: | |
| pass | |
| if answer_idx is None: | |
| return np.inf, False, False | |
| res = logits[0, answer_idx].softmax(-1).sort(descending=True) | |
| values = res.values | |
| indices = res.indices | |
| chosen_ids = list(itertools.product(*([list(range(N))]*len(answer_idx)))) | |
| probs = [] | |
| for ids in chosen_ids: | |
| prob = 1.0 | |
| for i, id in enumerate(ids): | |
| prob *= values[i, id] | |
| probs.append((prob.item(), ids)) | |
| probs.sort(reverse=True) | |
| answer_pos = tuple([id_array.tolist().index(idx) for id_array, idx in zip(indices, answer_id)]) | |
| ranking = [p[1] for p in probs] | |
| # if len(answer_idx) > 1: | |
| # import pdb; pdb.set_trace() | |
| try: | |
| r = ranking.index(answer_pos) | |
| return r, r < 1, r < 5 | |
| except: | |
| return np.inf, False, False | |
| def get_bbox(visual_box_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, debug=False, return_all=False): | |
| assert isinstance(prompt, list) and len(prompt) == 1 and isinstance(prompt[0], str) | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| model.debug_id = 0 | |
| with torch.inference_mode() and torch.cuda.amp.autocast(dtype=torch.float16): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| labels=None, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=visual_box_list, | |
| add_box=visual_box_list is not None, | |
| relations=None, | |
| debug_mode=False, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if debug: | |
| import pdb; pdb.set_trace() | |
| if return_all: | |
| return boxes, scores | |
| if len(scores) == 0: | |
| return None, None | |
| else: | |
| return boxes[scores.argmax()], scores.max() | |
| def _eval_text_image(text, image, model, tokenizer, image_processor, vis_embed_size, media_token_id, prebox_token_id, debug=False, objects=None): | |
| batch_images = image_processor(image).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| if objects is None: | |
| objects = get_object_from_text(text) | |
| if len(objects) == 0: | |
| return None, None, None | |
| if debug: | |
| tqdm.write(text) | |
| tqdm.write(f"{objects}") | |
| first_idx = text.find(objects[0]) | |
| if first_idx == 0: | |
| first_text = f"<|#object#|>{objects[0]}<|#endofobject#|><|#visual#|>" | |
| else: | |
| first_text = text[:first_idx-1] + f"<|#object#|> {objects[0]}<|#endofobject#|><|#visual#|>" | |
| if debug: | |
| tqdm.write(first_text) | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{first_text}"] | |
| # import pdb; pdb.set_trace() | |
| # print("do first get_bbox |", first_text) | |
| first_box, first_score = get_bbox(None, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False) | |
| if not model.valid and debug: | |
| import pdb; pdb.set_trace() | |
| if first_box is not None: | |
| added_bbox_list = [torch.tensor(first_box).unsqueeze(0).cuda() / 224] | |
| text = first_text + "<|#box#|><|#endofobject#|>" + text[first_idx+len(objects[0]):] | |
| else: | |
| added_bbox_list = [] | |
| final_ranks = [] | |
| is_top1_list = [] | |
| is_top5_list = [] | |
| for kk, object in enumerate(objects): | |
| if kk == 0: | |
| continue | |
| idx = text.find(objects[0]) | |
| for t_i, temp in enumerate(objects[1:kk+1]): | |
| # t_i is actually the previous one. This is not a bug | |
| idx = text.find(temp, idx + len(objects[t_i])) | |
| while idx+len(temp) != len(text) and (text[idx-1] == "#" or text[idx+len(temp)] == "#"): | |
| # in case temp is box or object or visual or something like that | |
| idx = text.find(temp, idx + len(temp)) | |
| this_text = text[:idx-1] + "<|#object#|><|#previsual#|>" | |
| # if this_text == "<|#object#|><|#previsual#|>": | |
| # import pdb; pdb.set_trace() | |
| if debug: | |
| tqdm.write(this_text) | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] | |
| # import pdb; pdb.set_trace() | |
| # print("do pre get_bbox |", this_text) | |
| pre_boxes, pre_scores = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, | |
| prebox_token_id, return_all=True) | |
| if not model.valid and debug: | |
| import pdb; pdb.set_trace() | |
| logits_list = [] | |
| # pre_boxes = [pre_boxes[0]] | |
| # pre_scores = [pre_scores[0]] | |
| this_text = this_text + f"<|#prebox#|><|#object#|> {object}<|#endofobject#|>" | |
| for pre_box, pre_score in zip(pre_boxes, pre_scores): | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] | |
| encodings = tokenizer( | |
| prompt, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=512, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| vision_x = batch_images.cuda() | |
| lang_x = input_ids.cuda() | |
| attention_mask = attention_mask.cuda() | |
| this_added_bbox_list = added_bbox_list + [torch.tensor(pre_box).unsqueeze(0).cuda() / 224] | |
| with torch.cuda.amp.autocast(dtype=torch.float16) and torch.no_grad(): | |
| outputs = model( | |
| vision_x=vision_x, | |
| lang_x=lang_x, | |
| attention_mask=attention_mask, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=this_added_bbox_list, | |
| add_box=this_added_bbox_list is not None and len(this_added_bbox_list) != 0, | |
| relations=None, | |
| ) | |
| if not model.valid and debug: | |
| import pdb; pdb.set_trace() | |
| logits_list.append([pre_score, outputs.logits]) | |
| if debug: | |
| answer_start_idx = (lang_x == tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1]).nonzero()[-1][1] | |
| logits = outputs["logits"][0, answer_start_idx:] | |
| tqdm.write(tokenizer.decode(logits[0].sort(descending=True).indices.tolist()[:10])) | |
| # if debug: | |
| # image.save("Atest.png") | |
| # open_cv_image = np.array(image) | |
| # open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # if first_box is not None: | |
| # open_cv_image = cv2.rectangle(open_cv_image, first_box[:2].astype(int), first_box[2:].astype(int), (255, 0, 0), 2) | |
| # if pre_box is not None: | |
| # open_cv_image = cv2.rectangle(open_cv_image, pre_box[:2].astype(int), pre_box[2:].astype(int), (0, 255, 0), 2) | |
| # cv2.imwrite(f"Atest.png", open_cv_image) | |
| # import pdb; pdb.set_trace() | |
| pre_scores = np.array([x[0] for x in logits_list]) | |
| final_probs = 0.0 | |
| for score, (_, logits) in zip(pre_scores, logits_list): | |
| final_probs += score * logits.softmax(-1) | |
| assert input_ids.shape[:2] == final_probs.shape[:2] | |
| _rank, is_top1, is_top5 = is_correct(input_ids, final_probs, tokenizer, object, topk=5) | |
| final_ranks.append(_rank) | |
| is_top1_list.append(is_top1) | |
| is_top5_list.append(is_top5) | |
| this_text = text[:idx-1] + f"<|#object#|> {object}<|#endofobject#|><|#visual#|>" | |
| if debug: | |
| tqdm.write(this_text) | |
| prompt = [f"{tokenizer.bos_token}<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|>{this_text}"] | |
| # print("do this get_bbox |", this_text) | |
| this_box, this_score = get_bbox(added_bbox_list, batch_images, prompt, model, tokenizer, media_token_id, prebox_token_id, return_all=False) | |
| if not model.valid and debug: | |
| import pdb; pdb.set_trace() | |
| if this_box is not None: | |
| added_bbox_list += [torch.tensor(this_box).unsqueeze(0).cuda() / 224] | |
| text = this_text + "<|#box#|><|#endofobject#|>" + text[idx+len(object):] | |
| return final_ranks, is_top1_list, is_top5_list | |
| if __name__ == "__main__": | |
| # print(get_object_from_text("there is a cookie. there is a bear. white orio cookie is next to the teddy bear. car runs on the traffic road. there is a tree.", verbose=False)) | |
| print(get_object_from_text("President speaks to an American at a business office",verbose=True)) | |