Spaces:
Runtime error
Runtime error
| import torch | |
| import more_itertools | |
| from tqdm import tqdm | |
| import json | |
| import time | |
| import os | |
| import numpy as np | |
| from transformers import LogitsProcessor, MinNewTokensLengthLogitsProcessor, ForcedEOSTokenLogitsProcessor | |
| from PIL import Image | |
| import cv2 | |
| class VisualLogitsProcessor(LogitsProcessor): | |
| def __init__(self, tokenizer): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.object_token_id = self.tokenizer("<|#object#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.prebox_token_id = self.tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.box_token_id = self.tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.previsual_token_id = self.tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.visual_token_id = self.tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.eos_token_id = self.tokenizer.encode(self.tokenizer.eos_token)[-1] | |
| self.endofobject_token_id = self.tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
| self.topk = 2 | |
| def __call__(self, input_ids, scores): | |
| # print("decoding===>", self.tokenizer.decode(scores.sort(descending=True).indices.tolist()[0][:self.topk])) | |
| # import pdb; pdb.set_trace() | |
| if self.object_token_id in scores.sort(descending=True).indices.tolist()[0][1:self.topk] and self.eos_token_id not in scores.sort(descending=True).indices.tolist()[0][:self.topk] and (input_ids == self.object_token_id).sum() * 2 == (input_ids == self.endofobject_token_id).sum(): | |
| scores[0, self.object_token_id] = 1000 | |
| if input_ids[0, -1] == self.object_token_id and input_ids[0, -2] != self.prebox_token_id: | |
| if (input_ids[0, :-1] == self.object_token_id).sum() != 0: | |
| # print("generate a previsual token next") | |
| scores[0, self.previsual_token_id] = 1000 | |
| elif input_ids[0, -1] == self.previsual_token_id or input_ids[0, -1] == self.visual_token_id: | |
| # print("stop to run bbox generation for " + "previsual" if input_ids[0, -1] == self.previsual_token_id else "visual") | |
| scores[0, self.eos_token_id] = 1000 | |
| elif input_ids[0, -1] == self.endofobject_token_id and input_ids[0, -2] != self.box_token_id: | |
| # print("generate a visual token next") | |
| scores[0, self.visual_token_id] = 1000 | |
| return scores | |
| def prepare_batch_images(batch, image_processor): | |
| batch_images = None | |
| for b in batch: | |
| b_image = image_processor(b["image"]).unsqueeze(0).unsqueeze(1).unsqueeze(0) | |
| if batch_images is None: | |
| batch_images = b_image | |
| else: | |
| batch_images = torch.cat([batch_images, b_image], dim=0) | |
| return batch_images | |
| # def captioner( | |
| # model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums, | |
| # added_bbox_list, debug=True): | |
| # """Evaluate a model on COCO dataset. | |
| # Returns: | |
| # float: CIDEr score | |
| # | |
| # """ | |
| # visual_logits_processor = VisualLogitsProcessor(tokenizer) | |
| # model.eval() | |
| # # model.eval().cuda() | |
| # lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
| # media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| # endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| # pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| # bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| # previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| # visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| # box_token = "<|#box#|>" | |
| # prebox_token = "<|#prebox#|>" | |
| # endofobject_token = "<|#endofobject#|>" | |
| # object_token = "<|#object#|>" | |
| # ori_prompt_length = len(input_ids[0]) | |
| # have_prebox = False | |
| # prompt = None | |
| # out_image = None | |
| # no_end = True | |
| # for i in range(500): | |
| # if no_end: | |
| # batch_images = batch_images | |
| # if prompt == None: | |
| # input_ids = input_ids | |
| # attention_mask = attention_mask | |
| # else: | |
| # encodings = tokenizer( | |
| # [prompt], | |
| # padding="longest", | |
| # truncation=True, | |
| # return_tensors="pt", | |
| # max_length=2000, | |
| # ) | |
| # attention_mask = encodings["attention_mask"] | |
| # input_ids = encodings["input_ids"] | |
| # image_start_index_list = image_start_index_list | |
| # image_nums = image_nums | |
| # if debug: | |
| # print("input--->", tokenizer.decode(input_ids[0])) | |
| # p1 = MinNewTokensLengthLogitsProcessor( | |
| # prompt_length_to_skip=input_ids.shape[-1], | |
| # min_new_tokens=5, | |
| # eos_token_id=bos_token_id, | |
| # ) | |
| # with torch.inference_mode(): | |
| # outputs = model.generate( | |
| # batch_images, | |
| # input_ids, | |
| # attention_mask=attention_mask, | |
| # max_new_tokens=20, | |
| # # min_new_tokens=8, | |
| # num_beams=1, | |
| # # length_penalty=0, | |
| # image_start_index_list=image_start_index_list, | |
| # image_nums=image_nums, | |
| # added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
| # logits_processor_list=[p1, visual_logits_processor], | |
| # ) | |
| # if debug: | |
| # print("outputs--->", tokenizer.decode(outputs[0])) | |
| # input_ids = encodings["input_ids"] | |
| # attention_mask = encodings["attention_mask"] | |
| # image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| # image_start_index_list = [[x] for x in image_start_index_list] | |
| # image_nums = [1] * len(input_ids) | |
| # if debug: | |
| # print("get the visual bbox--->", tokenizer.decode(input_ids[0])) | |
| # with torch.no_grad(): | |
| # outputs = model( | |
| # vision_x=batch_images, | |
| # lang_x=input_ids, | |
| # attention_mask=attention_mask, | |
| # image_nums=image_nums, | |
| # image_start_index_list=image_start_index_list, | |
| # added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
| # add_box=added_bbox_list is not None and len(added_bbox_list) != 0, | |
| # ) | |
| # boxes = outputs["boxes"] | |
| # scores = outputs["scores"] | |
| # if debug: | |
| # print("box num---->", len(boxes)) | |
| # # if not model.valid: | |
| # # import pdb; pdb.set_trace() | |
| # if boxes is not None: | |
| # if is_visual: | |
| # if have_prebox: | |
| # added_bbox_list.pop() | |
| # prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") | |
| # have_prebox = False | |
| # if debug: | |
| # print("find previsual and remove it--->", prompt) | |
| # first_box = boxes[scores.argmax()] | |
| # added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224] | |
| # prompt = prompt[:-len(tokenizer.eos_token)] | |
| # prompt += box_token + endofobject_token | |
| # if debug: | |
| # print("after inserting visual---->", prompt) | |
| # | |
| # else: | |
| # import numpy as np | |
| # import cv2 | |
| # | |
| # # exit() | |
| # pre_box = boxes[scores.argmax()] | |
| # added_bbox_list += [torch.tensor(pre_box).unsqueeze(0) / 224] | |
| # prompt = prompt[:-len(tokenizer.eos_token)] | |
| # prompt += prebox_token + object_token | |
| # have_prebox = True | |
| # if debug: | |
| # print("after inserting previsual---->", prompt) | |
| # else: | |
| # # if debug: | |
| # # import pdb;pdb.set_trace() | |
| # prompt = tokenizer.decode(outputs.clone()[0]) | |
| # if debug: | |
| # print("before else---->", prompt) | |
| # prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) | |
| # if debug: | |
| # print("after else---->", prompt) | |
| # | |
| # else: | |
| # no_end = False | |
| # # break | |
| # # print("outputs--->", tokenizer.decode(outputs[0])) | |
| # outputs = outputs[:, ori_prompt_length:] | |
| # outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "") | |
| # open_cv_image = np.array(image_ori) | |
| # open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| # width = image_ori.width | |
| # height = image_ori.height | |
| # for i, pre_box in enumerate(added_bbox_list): | |
| # open_cv_image = cv2.rectangle(open_cv_image, np.array(pre_box[0][:2]*[width,height]).astype(int), np.array(pre_box[0][2:]*[width,height]).astype(int), | |
| # (0, 255, 0), i + 1) | |
| # out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) | |
| # # new_predictions = [ | |
| # # postprocess_captioning_generation(out).replace('"', "") | |
| # # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| # # ] | |
| # # import pdb; pdb.set_trace() | |
| # | |
| # return outputs, out_image | |
| def captioner( | |
| model, tokenizer, image_ori, batch_images, input_ids, attention_mask, image_start_index_list, image_nums, | |
| added_bbox_list, debug=True): | |
| """Evaluate a model on COCO dataset. | |
| Returns: | |
| float: CIDEr score | |
| """ | |
| visual_logits_processor = VisualLogitsProcessor(tokenizer) | |
| model.eval() | |
| # model.eval().cuda() | |
| lang_encoder_name = model.lang_encoder.__class__.__name__.lower() | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| pad_token_id = tokenizer(tokenizer.pad_token, add_special_tokens=False)["input_ids"][-1] | |
| bos_token_id = tokenizer(tokenizer.bos_token, add_special_tokens=False)["input_ids"][-1] | |
| previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| box_token = "<|#box#|>" | |
| prebox_token = "<|#prebox#|>" | |
| endofobject_token = "<|#endofobject#|>" | |
| object_token = "<|#object#|>" | |
| ori_prompt_length = len(input_ids[0]) | |
| have_prebox = False | |
| prompt = None | |
| out_image = None | |
| no_end = True | |
| for i in range(100): | |
| if no_end: | |
| batch_images = batch_images | |
| if prompt == None: | |
| input_ids = input_ids | |
| attention_mask = attention_mask | |
| else: | |
| encodings = tokenizer( | |
| [prompt], | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| attention_mask = encodings["attention_mask"] | |
| input_ids = encodings["input_ids"] | |
| image_start_index_list = image_start_index_list | |
| image_nums = image_nums | |
| if debug: | |
| print("input--->", tokenizer.decode(input_ids[0])) | |
| p1 = MinNewTokensLengthLogitsProcessor( | |
| prompt_length_to_skip=input_ids.shape[-1], | |
| min_new_tokens=5, | |
| eos_token_id=bos_token_id, | |
| ) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| batch_images, | |
| input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=20, | |
| # min_new_tokens=8, | |
| num_beams=1, | |
| # length_penalty=0, | |
| image_start_index_list=image_start_index_list, | |
| image_nums=image_nums, | |
| added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
| logits_processor_list=[p1, visual_logits_processor], | |
| ) | |
| if debug: | |
| print("outputs--->", tokenizer.decode(outputs[0])) | |
| if outputs[0, -2] in [previsual_token_id, visual_token_id] and outputs[0, -1] == bos_token_id: | |
| prompt = tokenizer.decode(outputs.clone()[0]) | |
| is_visual = (outputs[0, -2] == visual_token_id) | |
| batch_text = tokenizer.batch_decode(outputs[:, :-1]) | |
| encodings = tokenizer( | |
| batch_text, | |
| padding="longest", | |
| truncation=True, | |
| return_tensors="pt", | |
| max_length=2000, | |
| ) | |
| input_ids = encodings["input_ids"] | |
| attention_mask = encodings["attention_mask"] | |
| image_start_index_list = ((input_ids == media_token_id).nonzero(as_tuple=True)[-1] + 1).tolist() | |
| image_start_index_list = [[x] for x in image_start_index_list] | |
| image_nums = [1] * len(input_ids) | |
| if debug: | |
| print("get the visual bbox--->", tokenizer.decode(input_ids[0])) | |
| with torch.no_grad(): | |
| outputs = model( | |
| vision_x=batch_images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=added_bbox_list if len(added_bbox_list) != 0 else None, | |
| add_box=added_bbox_list is not None and len(added_bbox_list) != 0, | |
| ) | |
| boxes = outputs["boxes"] | |
| scores = outputs["scores"] | |
| if debug: | |
| print("box num---->", len(boxes)) | |
| # if not model.valid: | |
| # import pdb; pdb.set_trace() | |
| if boxes is not None: | |
| if is_visual: | |
| if have_prebox: | |
| added_bbox_list.pop() | |
| prompt = prompt.replace("<|#previsual#|><|#prebox#|><|#object#|>", "") | |
| have_prebox = False | |
| if debug: | |
| print("find previsual and remove it--->", prompt) | |
| first_box = boxes[scores.argmax()] | |
| added_bbox_list += [torch.tensor(first_box).unsqueeze(0) / 224] | |
| prompt = prompt[:-len(tokenizer.eos_token)] | |
| prompt += box_token + endofobject_token | |
| if debug: | |
| print("after inserting visual---->", prompt) | |
| else: | |
| import numpy as np | |
| import cv2 | |
| # exit() | |
| pre_box = boxes[scores.argmax()] | |
| added_bbox_list += [torch.tensor(pre_box).unsqueeze(0) / 224] | |
| prompt = prompt[:-len(tokenizer.eos_token)] | |
| prompt += prebox_token + object_token | |
| have_prebox = True | |
| if debug: | |
| print("after inserting previsual---->", prompt) | |
| else: | |
| # if debug: | |
| # import pdb;pdb.set_trace() | |
| prompt = tokenizer.decode(outputs.clone()[0]) | |
| if debug: | |
| print("before else---->", prompt) | |
| prompt = tokenizer.decode(outputs[0, :-2].clone()[0]) | |
| if debug: | |
| print("after else---->", prompt) | |
| else: | |
| no_end = False | |
| outputs = outputs[:, ori_prompt_length:] | |
| outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].replace('"', "") | |
| open_cv_image = np.array(image_ori) | |
| open_cv_image = open_cv_image[:, :, ::-1].copy() | |
| width = image_ori.width | |
| height = image_ori.height | |
| for i, pre_box in enumerate(added_bbox_list): | |
| print(pre_box) | |
| open_cv_image = cv2.rectangle(open_cv_image, (np.array(pre_box[0][:2]) * [width, height]).astype(int), | |
| (np.array(pre_box[0][2:]) * [width, height]).astype(int), | |
| (0, 255, 0), i + 1) | |
| out_image = Image.fromarray(cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)) | |
| # new_predictions = [ | |
| # postprocess_captioning_generation(out).replace('"', "") | |
| # for out in tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| # ] | |
| # import pdb; pdb.set_trace() | |
| return outputs, out_image |