Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import random | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| from PIL import ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| Image.MAX_IMAGE_PIXELS = None | |
| from data.utils import pre_caption | |
| import os,glob | |
| class pretrain_dataset(Dataset): | |
| def __init__(self, ann_file, laion_path, transform): | |
| self.ann_pretrain = [] | |
| for f in ann_file: | |
| print('loading '+f) | |
| ann = json.load(open(f,'r')) | |
| self.ann_pretrain += ann | |
| self.laion_path = laion_path | |
| if self.laion_path: | |
| self.laion_files = glob.glob(os.path.join(laion_path,'*.json')) | |
| print('loading '+self.laion_files[0]) | |
| with open(self.laion_files[0],'r') as f: | |
| self.ann_laion = json.load(f) | |
| self.annotation = self.ann_pretrain + self.ann_laion | |
| else: | |
| self.annotation = self.ann_pretrain | |
| self.transform = transform | |
| def reload_laion(self, epoch): | |
| n = epoch%len(self.laion_files) | |
| print('loading '+self.laion_files[n]) | |
| with open(self.laion_files[n],'r') as f: | |
| self.ann_laion = json.load(f) | |
| self.annotation = self.ann_pretrain + self.ann_laion | |
| def __len__(self): | |
| return len(self.annotation) | |
| def __getitem__(self, index): | |
| ann = self.annotation[index] | |
| image = Image.open(ann['image']).convert('RGB') | |
| image = self.transform(image) | |
| caption = pre_caption(ann['caption'],30) | |
| return image, caption |