Spaces:
Runtime error
Runtime error
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| import torch | |
| import json | |
| import h5py | |
| import bisect | |
| CAPTION_LENGTH = 25 | |
| SIMPLE_PREFIX = "This image shows " | |
| def prep_strings(text, tokenizer, template=None, retrieved_caps=None, k=None, is_test=False, max_length=None): | |
| if is_test: | |
| padding = False | |
| truncation = False | |
| else: | |
| padding = True | |
| truncation = True | |
| if retrieved_caps is not None: | |
| infix = '\n\n'.join(retrieved_caps[:k]) + '.' | |
| prefix = template.replace('||', infix) | |
| else: | |
| prefix = SIMPLE_PREFIX | |
| prefix_ids = tokenizer.encode(prefix) | |
| len_prefix = len(prefix_ids) | |
| text_ids = tokenizer.encode(text, add_special_tokens=False) | |
| if truncation: | |
| text_ids = text_ids[:CAPTION_LENGTH] | |
| input_ids = prefix_ids + text_ids if not is_test else prefix_ids | |
| # we ignore the prefix (minus one as the first subtoken in the prefix is not predicted) | |
| label_ids = [-100] * (len_prefix - 1) + text_ids + [tokenizer.eos_token_id] | |
| if padding: | |
| input_ids += [tokenizer.pad_token_id] * (max_length - len(input_ids)) | |
| label_ids += [-100] * (max_length - len(label_ids)) | |
| if is_test: | |
| return input_ids | |
| else: | |
| return input_ids, label_ids | |
| def postprocess_preds(pred, tokenizer): | |
| pred = pred.split(SIMPLE_PREFIX)[-1] | |
| pred = pred.replace(tokenizer.pad_token, '') | |
| if pred.startswith(tokenizer.bos_token): | |
| pred = pred[len(tokenizer.bos_token):] | |
| if pred.endswith(tokenizer.eos_token): | |
| pred = pred[:-len(tokenizer.eos_token)] | |
| return pred | |
| class TrainDataset(Dataset): | |
| def __init__(self, df, features_path, tokenizer, rag=False, template_path=None, k=None, max_caption_length=25): | |
| self.df = df | |
| self.tokenizer = tokenizer | |
| self.features = h5py.File(features_path, 'r') | |
| if rag: | |
| self.template = open(template_path).read().strip() + ' ' | |
| self.max_target_length = (max_caption_length # target caption | |
| + max_caption_length * k # retrieved captions | |
| + len(tokenizer.encode(self.template)) # template | |
| + len(tokenizer.encode('\n\n')) * (k-1) # separator between captions | |
| ) | |
| assert k is not None | |
| self.k = k | |
| self.rag = rag | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| text = self.df['text'][idx] | |
| if self.rag: | |
| caps = self.df['caps'][idx] | |
| decoder_input_ids, labels = prep_strings(text, self.tokenizer, template=self.template, | |
| retrieved_caps=caps, k=self.k, max_length=self.max_target_length) | |
| else: | |
| decoder_input_ids, labels = prep_strings(text, self.tokenizer, max_length=self.max_target_length) | |
| # load precomputed features | |
| encoder_outputs = self.features[self.df['cocoid'][idx]][()] | |
| encoding = {"encoder_outputs": torch.tensor(encoder_outputs), | |
| "decoder_input_ids": torch.tensor(decoder_input_ids), | |
| "labels": torch.tensor(labels)} | |
| return encoding | |
| def load_data_for_training(annot_path, caps_path=None): | |
| annotations = json.load(open(annot_path))['images'] | |
| if caps_path is not None: | |
| retrieved_caps = json.load(open(caps_path)) | |
| data = {'train': [], 'val': []} | |
| for item in annotations: | |
| file_name = item['filename'].split('_')[-1] | |
| caps = retrieved_caps[str(item['cocoid'])] | |
| samples = [] | |
| for sentence in item['sentences']: | |
| print("how are the retrieved caps", caps + ' '.join(sentence['tokens'])) | |
| samples.append({'file_name': file_name, 'cocoid': str(item['cocoid']), 'caps': None, 'text': " ".join(caps) + ' '.join(sentence['tokens'])}) | |
| if item['split'] == 'train' or item['split'] == 'restval': | |
| data['train'] += samples | |
| elif item['split'] == 'val': | |
| data['val'] += samples | |
| return data | |
| def load_data_for_inference(annot_path, caps_path=None): | |
| annotations = json.load(open(annot_path))['images'] | |
| if caps_path is not None: | |
| retrieved_caps = json.load(open(caps_path)) | |
| data = {'test': [], 'val': []} | |
| for item in annotations: | |
| file_name = item['filename'].split('_')[-1] | |
| if caps_path is not None: | |
| caps = retrieved_caps[str(item['cocoid'])] | |
| else: | |
| caps = None | |
| image = {'file_name': file_name, 'caps': caps, 'image_id': str(item['cocoid'])} | |
| if item['split'] == 'test': | |
| data['test'].append(image) | |
| elif item['split'] == 'val': | |
| data['val'].append(image) | |
| return data | |