Spaces:
Running
Running
| """ | |
| An image-caption dataset dataloader. | |
| Luke Melas-Kyriazi, 2021 | |
| """ | |
| import warnings | |
| from typing import Optional, Callable | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import pandas as pd | |
| from torch.utils.data import Dataset | |
| from torchvision.datasets.folder import default_loader | |
| from PIL import ImageFile | |
| from PIL.Image import DecompressionBombWarning | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| warnings.filterwarnings("ignore", category=DecompressionBombWarning) | |
| class CaptionDataset(Dataset): | |
| """ | |
| A PyTorch Dataset class for (image, texts) tasks. Note that this dataset | |
| returns the raw text rather than tokens. This is done on purpose, because | |
| it's easy to tokenize a batch of text after loading it from this dataset. | |
| """ | |
| def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, | |
| image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision', | |
| include_captions: bool = True): | |
| """ | |
| :param images_root: folder where images are stored | |
| :param captions_path: path to csv that maps image filenames to captions | |
| :param image_transform: image transform pipeline | |
| :param text_transform: image transform pipeline | |
| :param image_transform_type: image transform type, either `torchvision` or `albumentations` | |
| :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images. | |
| """ | |
| # Base path for images | |
| self.images_root = Path(images_root) | |
| # Load captions as DataFrame | |
| self.captions = pd.read_csv(captions_path, delimiter='\t', header=0) | |
| self.captions['image_file'] = self.captions['image_file'].astype(str) | |
| # PyTorch transformation pipeline for the image (normalizing, etc.) | |
| self.text_transform = text_transform | |
| self.image_transform = image_transform | |
| self.image_transform_type = image_transform_type.lower() | |
| assert self.image_transform_type in ['torchvision', 'albumentations'] | |
| # Total number of datapoints | |
| self.size = len(self.captions) | |
| # Return image+captions or just images | |
| self.include_captions = include_captions | |
| def verify_that_all_images_exist(self): | |
| for image_file in self.captions['image_file']: | |
| p = self.images_root / image_file | |
| if not p.is_file(): | |
| print(f'file does not exist: {p}') | |
| def _get_raw_image(self, i): | |
| image_file = self.captions.iloc[i]['image_file'] | |
| image_path = self.images_root / image_file | |
| image = default_loader(image_path) | |
| return image | |
| def _get_raw_text(self, i): | |
| return self.captions.iloc[i]['caption'] | |
| def __getitem__(self, i): | |
| image = self._get_raw_image(i) | |
| caption = self._get_raw_text(i) | |
| if self.image_transform is not None: | |
| if self.image_transform_type == 'torchvision': | |
| image = self.image_transform(image) | |
| elif self.image_transform_type == 'albumentations': | |
| image = self.image_transform(image=np.array(image))['image'] | |
| else: | |
| raise NotImplementedError(f"{self.image_transform_type=}") | |
| return {'image': image, 'text': caption} if self.include_captions else image | |
| def __len__(self): | |
| return self.size | |
| if __name__ == "__main__": | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from transformers import AutoTokenizer | |
| # Paths | |
| images_root = './images' | |
| captions_path = './images-list-clean.tsv' | |
| # Create transforms | |
| tokenizer = AutoTokenizer.from_pretrained('distilroberta-base') | |
| def tokenize(text): | |
| return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length') | |
| image_transform = A.Compose([ | |
| A.Resize(256, 256), A.CenterCrop(256, 256), | |
| A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()]) | |
| # Create dataset | |
| dataset = CaptionDataset( | |
| images_root=images_root, | |
| captions_path=captions_path, | |
| image_transform=image_transform, | |
| text_transform=tokenize, | |
| image_transform_type='albumentations') | |
| # Create dataloader | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=2) | |
| batch = next(iter(dataloader)) | |
| print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}) | |
| # # (Optional) Check that all the images exist | |
| # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path) | |
| # dataset.verify_that_all_images_exist() | |
| # print('Done') | |