""" Dataset functions for applying Normalized Cut. Code adapted from SelfMask: https://github.com/NoelShin/selfmask """ import os from typing import Optional, Tuple, Union from pycocotools.coco import COCO import numpy as np import torch import torchvision from PIL import Image from torch.utils.data import Dataset from torchvision import transforms as T from datasets.utils import unnormalize from datasets.geometric_transforms import resize from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader from datasets.augmentations import geometric_augmentations, photometric_augmentations from datasets.uod_datasets import UODDataset NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) def set_dataset_dir(dataset_name, root_dir): if dataset_name == "ECSSD": dataset_dir = os.path.join(root_dir, "ECSSD") img_dir = os.path.join(dataset_dir, "images") gt_dir = os.path.join(dataset_dir, "ground_truth_mask") elif dataset_name == "DUTS-TEST": dataset_dir = os.path.join(root_dir, "DUTS") img_dir = os.path.join(dataset_dir, "DUTS-TE-Image") gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask") elif dataset_name == "DUTS-TR": dataset_dir = os.path.join(root_dir, "DUTS") img_dir = os.path.join(dataset_dir, "DUTS-TR-Image") gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask") elif dataset_name == "DUT-OMRON": dataset_dir = os.path.join(root_dir, "DUT-OMRON") img_dir = os.path.join(dataset_dir, "DUT-OMRON-image") gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG") elif dataset_name == "VOC07": dataset_dir = os.path.join(root_dir, "VOC2007") img_dir = dataset_dir gt_dir = dataset_dir elif dataset_name == "VOC12": dataset_dir = os.path.join('/datasets_local/osimeoni', "VOC2012") img_dir = dataset_dir gt_dir = dataset_dir elif dataset_name == "COCO17": dataset_dir = os.path.join(root_dir, "COCO") img_dir = dataset_dir gt_dir = dataset_dir elif dataset_name == "ImageNet": dataset_dir = os.path.join(root_dir, "ImageNet") img_dir = dataset_dir gt_dir = dataset_dir else: raise ValueError(f"Unknown dataset {dataset_name}") return img_dir, gt_dir def build_dataset( root_dir: str, dataset_name: str, dataset_set: Optional[str] = None, for_eval: bool = False, config=None, evaluation_type="saliency", # uod, ): """ Build dataset """ if evaluation_type == "saliency": img_dir, gt_dir = set_dataset_dir(dataset_name, root_dir) dataset = FoundDataset( name=dataset_name, img_dir=img_dir, gt_dir=gt_dir, dataset_set=dataset_set, config=config, for_eval=for_eval, evaluation_type=evaluation_type, ) elif evaluation_type == "uod": assert dataset_name in ["VOC07", "VOC12", "COCO20k"] dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train" no_hards = False dataset = UODDataset( dataset_name, dataset_set, root_dir=root_dir, remove_hards=no_hards, ) return dataset class FoundDataset(Dataset): def __init__( self, name: str, img_dir: str, gt_dir: str, dataset_set: Optional[str] = None, config=None, for_eval:bool = False, evaluation_type:str = "saliency", ) -> None: """ Args: root_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.for_eval = for_eval self.use_aug = not for_eval self.evaluation_type = evaluation_type assert evaluation_type in ["saliency"] self.name = name self.dataset_set = dataset_set self.img_dir = img_dir self.gt_dir = gt_dir # if VOC dataset self.loader = None self.cocoGt = None self.config = config if "VOC" in self.name: self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type) # if ImageNet dataset elif "ImageNet" in self.name: self.loader = torchvision.datasets.ImageNet( self.img_dir, split=dataset_set, transform=None, target_transform=None, ) elif "COCO" in self.name: year = int("20"+self.name[-2:]) annFile=f'/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json' self.cocoGt=COCO(annFile) self.img_ids = list(sorted(self.cocoGt.getImgIds())) self.img_dir = f'/datasets_local/COCO/images/{dataset_set}{str(year)}/' # Transformations if self.for_eval: full_img_transform, no_norm_full_img_transform = self.get_init_transformation( isVOC="VOC" in name ) self.full_img_transform = full_img_transform self.no_norm_full_img_transform = no_norm_full_img_transform # Images self.list_images = None if not "VOC" in self.name and not "COCO" in self.name: self.list_images = [ os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir)) ] self.ignore_index = -1 self.mean = NORMALIZE.mean self.std = NORMALIZE.std self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE]) self.normalize = NORMALIZE if config is not None and self.use_aug: self._set_aug(config) def get_init_transformation(self, isVOC: bool = False): if isVOC: t = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE]) t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)]) return t, t_nonorm else: t = T.Compose([T.ToTensor(), NORMALIZE]) t_nonorm = T.Compose([T.ToTensor()]) return t, t_nonorm def _set_aug(self, config): """ Set augmentation based on config. """ photometric_aug = config.training["photometric_aug"] self.cropping_strategy = config.training["cropping_strategy"] if self.cropping_strategy == "center_crop": self.use_aug = False # default strategy, not considered to be a data aug self.scale_range = config.training["scale_range"] self.crop_size = config.training["crop_size"] self.center_crop_transforms = T.Compose( [ T.CenterCrop((self.crop_size, self.crop_size)), T.ToTensor(), ] ) self.center_crop_only_transforms = T.Compose( [T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()] ) self.proba_photometric_aug = config.training["proba_photometric_aug"] self.random_color_jitter = False self.random_grayscale = False self.random_gaussian_blur = False if photometric_aug == "color_jitter": self.random_color_jitter = True elif photometric_aug == "grayscale": self.random_grayscale = True elif photometric_aug == "gaussian_blur": self.random_gaussian_blur = True def _preprocess_data_aug( self, image: Image.Image, mask: Image.Image, ignore_index: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Prepare data in a proper form for either training (data augmentation) or validation.""" # resize to base size image = resize( image, size=self.crop_size, edge="shorter", interpolation="bilinear", ) mask = resize( mask, size=self.crop_size, edge="shorter", interpolation="bilinear", ) if not isinstance(mask, torch.Tensor): mask: torch.Tensor = torch.tensor(np.array(mask)) random_scale_range = None random_crop_size = None random_hflip_p = None if self.cropping_strategy == "random_scale": random_scale_range = self.scale_range elif self.cropping_strategy == "random_crop": random_crop_size = self.crop_size elif self.cropping_strategy == "random_hflip": random_hflip_p = 0.5 elif self.cropping_strategy == "random_crop_and_hflip": random_hflip_p = 0.5 random_crop_size = self.crop_size if random_crop_size or random_hflip_p or random_scale_range: image, mask = geometric_augmentations( image=image, mask=mask, random_scale_range=random_scale_range, random_crop_size=random_crop_size, ignore_index=ignore_index, random_hflip_p=random_hflip_p, ) if random_scale_range: # resize to (self.crop_size, self.crop_size) image = resize( image, size=self.crop_size, interpolation="bilinear", ) mask = resize( mask, size=(self.crop_size, self.crop_size), interpolation="bilinear", ) image = photometric_augmentations( image, random_color_jitter=self.random_color_jitter, random_grayscale=self.random_grayscale, random_gaussian_blur=self.random_gaussian_blur, proba_photometric_aug=self.proba_photometric_aug, ) # to tensor + normalize image image = self.to_tensor_and_normalize(image) return image, mask def __len__(self) -> int: if "VOC" in self.name: return len(self.loader) elif "ImageNet" in self.name: return len(self.loader) elif "COCO" in self.name: return len(self.img_ids) return len(self.list_images) def _apply_center_crop( self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: img_t = self.center_crop_transforms(image) # need to normalize image img_t = self.normalize(img_t) mask_gt = self.center_crop_transforms(mask).squeeze() return img_t, mask_gt def __getitem__(self, idx, get_mask_gt=True): if "VOC" in self.name: img, gt_labels = self.loader[idx] if self.evaluation_type == "uod": gt_labels, _ = get_voc_detection_gt( gt_labels, remove_hards=False ) elif self.evaluation_type == "saliency": mask_gt = create_gt_masks_if_voc(gt_labels) img_path = self.loader.images[idx] elif "ImageNet" in self.name: img, _ = self.loader[idx] img_path = self.loader.imgs[idx][0] # empty mask since no gt mask, only class label zeros = np.zeros(np.array(img).shape[:2]) mask_gt = Image.fromarray(zeros) elif "COCO" in self.name: img_id = self.img_ids[idx] path = self.cocoGt.loadImgs(img_id)[0]["file_name"] img = Image.open(os.path.join(self.img_dir, path)).convert("RGB") _ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id)) img_path = self.img_ids[idx] # What matters most is the id for eval # empty mask since no gt mask, only class label zeros = np.zeros(np.array(img).shape[:2]) mask_gt = Image.fromarray(zeros) # For all others else: img_path = self.list_images[idx] with open(img_path, "rb") as f: img = Image.open(f) img = img.convert("RGB") im_name = img_path.split("/")[-1] mask_gt = Image.open( os.path.join(self.gt_dir, im_name.replace(".jpg", ".png")) ).convert("L") if self.for_eval: img_t = self.full_img_transform(img) img_init = self.no_norm_full_img_transform(img) if self.evaluation_type == "saliency": mask_gt = torch.tensor(np.array(mask_gt)).squeeze() mask_gt = np.array(mask_gt) mask_gt = mask_gt == 255 mask_gt = torch.tensor(mask_gt) else: if self.use_aug: img_t, mask_gt = self._preprocess_data_aug( image=img, mask=mask_gt, ignore_index=self.ignore_index ) mask_gt = np.array(mask_gt) mask_gt = mask_gt == 255 mask_gt = torch.tensor(mask_gt) else: # no data aug img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt) gt_labels = self.center_crop_only_transforms(gt_labels).squeeze() mask_gt = np.asarray(mask_gt, np.int64) mask_gt = mask_gt == 1 mask_gt = torch.tensor(mask_gt) img_init = unnormalize(img_t) if not get_mask_gt: mask_gt = None if self.evaluation_type == "uod": gt_labels = torch.tensor(gt_labels) mask_gt = gt_labels return img_t, img_init, mask_gt, img_path def fullimg_mode(self): self.val_full_image = True def training_mode(self): self.val_full_image = False