Spaces:
Running
on
Zero
Running
on
Zero
| import copy | |
| import random | |
| import glob | |
| import json | |
| import logging | |
| import os | |
| import torch | |
| from mmengine import print_log | |
| from mmengine.config import Config, ConfigDict | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from pycocotools.coco import COCO | |
| from pycocotools import mask as mask_utils | |
| from xtuner.registry import BUILDER | |
| from xtuner.dataset.utils import encode_fn | |
| from xtuner.dataset.map_fns import llava_map_fn | |
| from projects.glamm.datasets.utils.utils import expand2square | |
| from projects.glamm.datasets.utils.utils import ANSWER_LIST, REGION_QUESTIONS | |
| from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN | |
| class RegionDataset(Dataset): | |
| def __init__(self, | |
| image_folder, | |
| image_processor, | |
| data_path=None, | |
| tokenizer=None, | |
| template_map_fn=None, | |
| max_length=2048, | |
| pad_image_to_square=False, | |
| repeats=1, | |
| num_classes_per_sample=3, | |
| extra_image_processor=None): | |
| super().__init__() | |
| self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" | |
| self.question_templates = REGION_QUESTIONS | |
| if extra_image_processor is not None: | |
| self.extra_image_processor = BUILDER.build(extra_image_processor) | |
| self.num_classes_per_sample = num_classes_per_sample | |
| self.tokenizer = BUILDER.build(tokenizer) | |
| self.tokenizer.add_tokens( | |
| [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True | |
| ) | |
| reg_tokens = ['<bbox>', '<point>'] | |
| segmentation_tokens = ['[SEG]'] | |
| phrase_tokens = ['<p>', '</p>'] | |
| special_tokens = reg_tokens + segmentation_tokens + phrase_tokens | |
| self.tokenizer.add_tokens(special_tokens, special_tokens=True) | |
| self.max_length = max_length | |
| self.template_map_fn = BUILDER.build(template_map_fn) | |
| self.text_data = self._load_annotations(data_path, image_folder) | |
| self.image_folder = image_folder | |
| self.image_processor = BUILDER.build(image_processor) | |
| size = self.image_processor.crop_size | |
| if isinstance(size, dict): | |
| self.image_w, self.image_h = size['width'], size['height'] | |
| elif isinstance(size, int): | |
| self.image_h, self.image_w = size, size | |
| else: | |
| self.image_w, self.image_h = size | |
| self.pad_image_to_square = pad_image_to_square | |
| self.repeats = repeats | |
| def _load_annotations(self, data_path, image_folder=None): | |
| self.coco = COCO(data_path) | |
| img_ids = self.coco.getImgIds() | |
| data_infos = [] | |
| for img_id in img_ids: | |
| info = self.coco.loadImgs([img_id])[0] | |
| info['filename'] = info['file_name'].split('_')[-1] | |
| info['height'] = int(info['height']) | |
| info['width'] = int(info['width']) | |
| if min(info['height'], info['width']) < 32: | |
| continue | |
| data_infos.append(info) | |
| return data_infos | |
| def modality_length(self): | |
| length_list = [] | |
| for data_dict in self.text_data: | |
| cur_len = 100 | |
| length_list.append(cur_len) | |
| return length_list * self.repeats | |
| def __len__(self): | |
| return len(self.text_data) * self.repeats | |
| def real_len(self): | |
| return len(self.text_data) | |
| def region_processor(self, orig_size, post_size, bboxes, labels): | |
| orig_h, orig_w = orig_size | |
| post_h, post_w = post_size | |
| y_scale = post_h / orig_h | |
| x_scale = post_w / orig_w | |
| shuffle_ids = torch.randperm(len(labels))[:self.num_classes_per_sample] | |
| selected_bboxes = bboxes[shuffle_ids] | |
| # Ensure selected_bboxes is two-dimensional | |
| if len(selected_bboxes.shape) == 1: | |
| selected_bboxes = np.expand_dims(selected_bboxes, axis=0) | |
| selected_labels = [labels[i] for i in shuffle_ids] | |
| selected_bboxes[:, [0, 2]] *= x_scale | |
| selected_bboxes[:, [1, 3]] *= y_scale | |
| selected_bboxes = torch.tensor( | |
| selected_bboxes, dtype=torch.float32) / post_h | |
| return selected_bboxes, selected_labels | |
| def _parse_annotations(self, img_info): | |
| data_dict = {} | |
| bboxes, captions = [], [] | |
| ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id'])) | |
| image_path = os.path.join(self.image_folder, img_info['file_name']) | |
| image = Image.open(image_path).convert('RGB') | |
| if hasattr(self, 'extra_image_processor'): | |
| g_image = np.array(image) # for grounding | |
| g_image = self.extra_image_processor.apply_image(g_image) | |
| g_pixel_values = torch.from_numpy( | |
| g_image).permute(2, 0, 1).contiguous() | |
| data_dict['g_pixel_values'] = g_pixel_values | |
| orig_w, orig_h = image.size | |
| if self.pad_image_to_square: | |
| image = expand2square( | |
| image, tuple(int(x * 255) for x in self.image_processor.image_mean)) | |
| image = self.image_processor.preprocess( | |
| image, return_tensors='pt')['pixel_values'][0] | |
| post_h, post_w = image.shape[1:3] | |
| data_dict['pixel_values'] = image | |
| for ann in ann_info: | |
| if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1: | |
| continue | |
| x1, y1, w, h = ann['bbox'] | |
| inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0)) | |
| inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0)) | |
| if inter_w * inter_h == 0: | |
| continue | |
| bbox = [x1, y1, x1 + w, y1 + h] | |
| if bbox: | |
| bboxes.append(bbox) | |
| captions.append(img_info['caption']) | |
| if len(bboxes) == 0: | |
| return self.__getitem__(0) | |
| bboxes = np.array(bboxes, dtype=np.float32) | |
| seg_map = img_info['file_name'].replace('jpg', 'png') | |
| bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions) | |
| data_dict['bboxes'] = bboxes | |
| data_dict['captions'] = captions | |
| data_dict['seg_map'] = seg_map | |
| return data_dict | |
| def create_conversation(self, captions): | |
| questions = [] | |
| answers = [] | |
| for i, label in enumerate(captions): | |
| question = random.choice(self.question_templates).strip().replace('<region>', f'region{i + 1} <bbox>') | |
| questions.append(question) | |
| answers.append(label) | |
| conversation = [] | |
| for i, (question, answer) in enumerate(zip(questions, answers)): | |
| if i == 0: | |
| question = self.begin_str + question | |
| conversation.append({'input': question, 'output': answer}) | |
| return conversation | |
| def __getitem__(self, index): | |
| index = index % self.real_len() | |
| data_dict = {} | |
| ann_info = copy.deepcopy(self.text_data[index]) | |
| ann_info = self._parse_annotations(ann_info) | |
| data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values', None) | |
| data_dict['pixel_values'] = ann_info.pop('pixel_values') | |
| data_dict['bboxes'] = ann_info.pop('bboxes', None) | |
| conversation = self.create_conversation(ann_info['captions']) | |
| data_dict['conversation'] = conversation | |
| result = self.template_map_fn(data_dict) | |
| data_dict.update(result) | |
| result = encode_fn(data_dict, tokenizer=self.tokenizer, | |
| max_length=self.max_length, with_image_token=True) | |
| data_dict.update(result) | |
| return data_dict | |
| class RefCocoGRegionDataset(RegionDataset): | |
| pass | |
| class VisualGenomeRegionDataset(RegionDataset): | |
| def _parse_annotations(self, img_info): | |
| data_dict = {} | |
| bboxes, captions = [], [] | |
| ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id'])) | |
| image_path = os.path.join(self.image_folder, img_info['file_name']) | |
| image = Image.open(image_path).convert('RGB') | |
| if hasattr(self, 'extra_image_processor'): | |
| g_image = np.array(image) # for grounding | |
| g_image = self.extra_image_processor.apply_image(g_image) | |
| g_pixel_values = torch.from_numpy( | |
| g_image).permute(2, 0, 1).contiguous() | |
| data_dict['g_pixel_values'] = g_pixel_values | |
| orig_w, orig_h = image.size | |
| if self.pad_image_to_square: | |
| image = expand2square( | |
| image, tuple(int(x * 255) for x in self.image_processor.image_mean)) | |
| image = self.image_processor.preprocess( | |
| image, return_tensors='pt')['pixel_values'][0] | |
| post_h, post_w = image.shape[1:3] | |
| data_dict['pixel_values'] = image | |
| for ann in ann_info: | |
| if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1: | |
| continue | |
| x1, y1, w, h = ann['bbox'] | |
| inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0)) | |
| inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0)) | |
| if inter_w * inter_h == 0: | |
| continue | |
| bbox = [x1, y1, x1 + w, y1 + h] | |
| if bbox: | |
| bboxes.append(bbox) | |
| captions.append(ann['caption'].strip()) | |
| if len(bboxes) == 0: | |
| return self.__getitem__(0) | |
| bboxes = np.array(bboxes, dtype=np.float32) | |
| seg_map = img_info['file_name'].replace('jpg', 'png') | |
| bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions) | |
| data_dict['bboxes'] = bboxes | |
| data_dict['captions'] = captions | |
| data_dict['seg_map'] = seg_map | |
| return data_dict | |
| if __name__ == '__main__': | |
| from transformers import CLIPImageProcessor, AutoTokenizer | |
| from third_parts.segment_anything.utils.transforms import ResizeLongestSide | |
| pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained' | |
| llm_name_or_path = 'lmsys/vicuna-7b-v1.5' | |
| tokenizer = dict( | |
| type=AutoTokenizer.from_pretrained, | |
| pretrained_model_name_or_path=llm_name_or_path) | |
| image_processor = dict( | |
| type=CLIPImageProcessor.from_pretrained, | |
| pretrained_model_name_or_path='openai/clip-vit-large-patch14-336') | |
| extra_image_processor = dict( | |
| type=ResizeLongestSide, | |
| target_length=1024, | |
| ) | |
| from xtuner.utils.templates import PROMPT_TEMPLATE | |
| prompt_template = PROMPT_TEMPLATE.vicuna | |
| from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn | |
| from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn | |
| dataset = VisualGenomeRegionDataset( | |
| image_folder='./data/visual_genome/images', | |
| image_processor=image_processor, | |
| data_path='data/visual_genome/train.json', | |
| tokenizer=tokenizer, | |
| template_map_fn=dict( | |
| type=template_map_fn_factory, template=prompt_template), | |
| max_length=2048, | |
| pad_image_to_square=False, | |
| repeats=1, | |
| num_classes_per_sample=3, | |
| extra_image_processor=None) | |
| for i in range(1000): | |
| print(dataset[i]) | |