|
""" |
|
Dataset functions for applying Normalized Cut. |
|
Code adapted from SelfMask: https://github.com/NoelShin/selfmask |
|
""" |
|
|
|
import os |
|
from typing import Optional, Tuple, Union |
|
|
|
from pycocotools.coco import COCO |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision import transforms as T |
|
|
|
from datasets.utils import unnormalize |
|
from datasets.geometric_transforms import resize |
|
from datasets.VOC import get_voc_detection_gt, create_gt_masks_if_voc, create_VOC_loader |
|
from datasets.augmentations import geometric_augmentations, photometric_augmentations |
|
|
|
from datasets.uod_datasets import UODDataset |
|
|
|
NORMALIZE = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) |
|
|
|
def set_dataset_dir(dataset_name, root_dir): |
|
if dataset_name == "ECSSD": |
|
dataset_dir = os.path.join(root_dir, "ECSSD") |
|
img_dir = os.path.join(dataset_dir, "images") |
|
gt_dir = os.path.join(dataset_dir, "ground_truth_mask") |
|
|
|
elif dataset_name == "DUTS-TEST": |
|
dataset_dir = os.path.join(root_dir, "DUTS") |
|
img_dir = os.path.join(dataset_dir, "DUTS-TE-Image") |
|
gt_dir = os.path.join(dataset_dir, "DUTS-TE-Mask") |
|
|
|
elif dataset_name == "DUTS-TR": |
|
dataset_dir = os.path.join(root_dir, "DUTS") |
|
img_dir = os.path.join(dataset_dir, "DUTS-TR-Image") |
|
gt_dir = os.path.join(dataset_dir, "DUTS-TR-Mask") |
|
|
|
elif dataset_name == "DUT-OMRON": |
|
dataset_dir = os.path.join(root_dir, "DUT-OMRON") |
|
img_dir = os.path.join(dataset_dir, "DUT-OMRON-image") |
|
gt_dir = os.path.join(dataset_dir, "pixelwiseGT-new-PNG") |
|
|
|
elif dataset_name == "VOC07": |
|
dataset_dir = os.path.join(root_dir, "VOC2007") |
|
img_dir = dataset_dir |
|
gt_dir = dataset_dir |
|
|
|
elif dataset_name == "VOC12": |
|
dataset_dir = os.path.join('/datasets_local/osimeoni', "VOC2012") |
|
img_dir = dataset_dir |
|
gt_dir = dataset_dir |
|
|
|
elif dataset_name == "COCO17": |
|
dataset_dir = os.path.join(root_dir, "COCO") |
|
img_dir = dataset_dir |
|
gt_dir = dataset_dir |
|
|
|
elif dataset_name == "ImageNet": |
|
dataset_dir = os.path.join(root_dir, "ImageNet") |
|
img_dir = dataset_dir |
|
gt_dir = dataset_dir |
|
|
|
else: |
|
raise ValueError(f"Unknown dataset {dataset_name}") |
|
|
|
return img_dir, gt_dir |
|
|
|
|
|
def build_dataset( |
|
root_dir: str, |
|
dataset_name: str, |
|
dataset_set: Optional[str] = None, |
|
for_eval: bool = False, |
|
config=None, |
|
evaluation_type="saliency", |
|
): |
|
""" |
|
Build dataset |
|
""" |
|
|
|
if evaluation_type == "saliency": |
|
img_dir, gt_dir = set_dataset_dir(dataset_name, root_dir) |
|
|
|
dataset = FoundDataset( |
|
name=dataset_name, |
|
img_dir=img_dir, |
|
gt_dir=gt_dir, |
|
dataset_set=dataset_set, |
|
config=config, |
|
for_eval=for_eval, |
|
evaluation_type=evaluation_type, |
|
) |
|
|
|
elif evaluation_type == "uod": |
|
assert dataset_name in ["VOC07", "VOC12", "COCO20k"] |
|
dataset_set = "trainval" if dataset_name in ["VOC07", "VOC12"] else "train" |
|
no_hards = False |
|
dataset = UODDataset( |
|
dataset_name, |
|
dataset_set, |
|
root_dir=root_dir, |
|
remove_hards=no_hards, |
|
) |
|
|
|
return dataset |
|
|
|
|
|
class FoundDataset(Dataset): |
|
def __init__( |
|
self, |
|
name: str, |
|
img_dir: str, |
|
gt_dir: str, |
|
dataset_set: Optional[str] = None, |
|
config=None, |
|
for_eval:bool = False, |
|
evaluation_type:str = "saliency", |
|
) -> None: |
|
""" |
|
Args: |
|
root_dir (string): Directory with all the images. |
|
transform (callable, optional): Optional transform to be applied |
|
on a sample. |
|
""" |
|
self.for_eval = for_eval |
|
self.use_aug = not for_eval |
|
self.evaluation_type = evaluation_type |
|
|
|
assert evaluation_type in ["saliency"] |
|
|
|
self.name = name |
|
self.dataset_set = dataset_set |
|
self.img_dir = img_dir |
|
self.gt_dir = gt_dir |
|
|
|
|
|
self.loader = None |
|
self.cocoGt = None |
|
|
|
self.config = config |
|
|
|
if "VOC" in self.name: |
|
self.loader = create_VOC_loader(self.img_dir, dataset_set, evaluation_type) |
|
|
|
|
|
elif "ImageNet" in self.name: |
|
self.loader = torchvision.datasets.ImageNet( |
|
self.img_dir, |
|
split=dataset_set, |
|
transform=None, |
|
target_transform=None, |
|
) |
|
|
|
elif "COCO" in self.name: |
|
year = int("20"+self.name[-2:]) |
|
annFile=f'/datasets_local/COCO/annotations/instances_{dataset_set}{str(year)}.json' |
|
self.cocoGt=COCO(annFile) |
|
self.img_ids = list(sorted(self.cocoGt.getImgIds())) |
|
self.img_dir = f'/datasets_local/COCO/images/{dataset_set}{str(year)}/' |
|
|
|
|
|
if self.for_eval: |
|
full_img_transform, no_norm_full_img_transform = self.get_init_transformation( |
|
isVOC="VOC" in name |
|
) |
|
self.full_img_transform = full_img_transform |
|
self.no_norm_full_img_transform = no_norm_full_img_transform |
|
|
|
|
|
self.list_images = None |
|
if not "VOC" in self.name and not "COCO" in self.name: |
|
self.list_images = [ |
|
os.path.join(img_dir, i) for i in sorted(os.listdir(img_dir)) |
|
] |
|
|
|
self.ignore_index = -1 |
|
self.mean = NORMALIZE.mean |
|
self.std = NORMALIZE.std |
|
self.to_tensor_and_normalize = T.Compose([T.ToTensor(), NORMALIZE]) |
|
self.normalize = NORMALIZE |
|
|
|
if config is not None and self.use_aug: |
|
self._set_aug(config) |
|
|
|
|
|
def get_init_transformation(self, isVOC: bool = False): |
|
if isVOC: |
|
t = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float), NORMALIZE]) |
|
t_nonorm = T.Compose([T.PILToTensor(), T.ConvertImageDtype(torch.float)]) |
|
return t, t_nonorm |
|
|
|
else: |
|
t = T.Compose([T.ToTensor(), NORMALIZE]) |
|
t_nonorm = T.Compose([T.ToTensor()]) |
|
return t, t_nonorm |
|
|
|
def _set_aug(self, config): |
|
""" |
|
Set augmentation based on config. |
|
""" |
|
|
|
photometric_aug = config.training["photometric_aug"] |
|
|
|
self.cropping_strategy = config.training["cropping_strategy"] |
|
if self.cropping_strategy == "center_crop": |
|
self.use_aug = False |
|
self.scale_range = config.training["scale_range"] |
|
self.crop_size = config.training["crop_size"] |
|
self.center_crop_transforms = T.Compose( |
|
[ |
|
T.CenterCrop((self.crop_size, self.crop_size)), |
|
T.ToTensor(), |
|
] |
|
) |
|
self.center_crop_only_transforms = T.Compose( |
|
[T.CenterCrop((self.crop_size, self.crop_size)), T.PILToTensor()] |
|
) |
|
|
|
self.proba_photometric_aug = config.training["proba_photometric_aug"] |
|
|
|
self.random_color_jitter = False |
|
self.random_grayscale = False |
|
self.random_gaussian_blur = False |
|
if photometric_aug == "color_jitter": |
|
self.random_color_jitter = True |
|
elif photometric_aug == "grayscale": |
|
self.random_grayscale = True |
|
elif photometric_aug == "gaussian_blur": |
|
self.random_gaussian_blur = True |
|
|
|
def _preprocess_data_aug( |
|
self, |
|
image: Image.Image, |
|
mask: Image.Image, |
|
ignore_index: Optional[int] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Prepare data in a proper form for either training (data augmentation) or validation.""" |
|
|
|
|
|
image = resize( |
|
image, |
|
size=self.crop_size, |
|
edge="shorter", |
|
interpolation="bilinear", |
|
) |
|
mask = resize( |
|
mask, |
|
size=self.crop_size, |
|
edge="shorter", |
|
interpolation="bilinear", |
|
) |
|
|
|
if not isinstance(mask, torch.Tensor): |
|
mask: torch.Tensor = torch.tensor(np.array(mask)) |
|
|
|
random_scale_range = None |
|
random_crop_size = None |
|
random_hflip_p = None |
|
if self.cropping_strategy == "random_scale": |
|
random_scale_range = self.scale_range |
|
elif self.cropping_strategy == "random_crop": |
|
random_crop_size = self.crop_size |
|
elif self.cropping_strategy == "random_hflip": |
|
random_hflip_p = 0.5 |
|
elif self.cropping_strategy == "random_crop_and_hflip": |
|
random_hflip_p = 0.5 |
|
random_crop_size = self.crop_size |
|
|
|
if random_crop_size or random_hflip_p or random_scale_range: |
|
image, mask = geometric_augmentations( |
|
image=image, |
|
mask=mask, |
|
random_scale_range=random_scale_range, |
|
random_crop_size=random_crop_size, |
|
ignore_index=ignore_index, |
|
random_hflip_p=random_hflip_p, |
|
) |
|
|
|
if random_scale_range: |
|
|
|
image = resize( |
|
image, |
|
size=self.crop_size, |
|
interpolation="bilinear", |
|
) |
|
mask = resize( |
|
mask, |
|
size=(self.crop_size, self.crop_size), |
|
interpolation="bilinear", |
|
) |
|
|
|
image = photometric_augmentations( |
|
image, |
|
random_color_jitter=self.random_color_jitter, |
|
random_grayscale=self.random_grayscale, |
|
random_gaussian_blur=self.random_gaussian_blur, |
|
proba_photometric_aug=self.proba_photometric_aug, |
|
) |
|
|
|
|
|
image = self.to_tensor_and_normalize(image) |
|
|
|
return image, mask |
|
|
|
def __len__(self) -> int: |
|
if "VOC" in self.name: |
|
return len(self.loader) |
|
elif "ImageNet" in self.name: |
|
return len(self.loader) |
|
elif "COCO" in self.name: |
|
return len(self.img_ids) |
|
return len(self.list_images) |
|
|
|
def _apply_center_crop( |
|
self, image: Image.Image, mask: Union[Image.Image, np.ndarray, torch.Tensor] |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
img_t = self.center_crop_transforms(image) |
|
|
|
img_t = self.normalize(img_t) |
|
mask_gt = self.center_crop_transforms(mask).squeeze() |
|
return img_t, mask_gt |
|
|
|
|
|
def __getitem__(self, idx, get_mask_gt=True): |
|
if "VOC" in self.name: |
|
img, gt_labels = self.loader[idx] |
|
if self.evaluation_type == "uod": |
|
gt_labels, _ = get_voc_detection_gt( |
|
gt_labels, remove_hards=False |
|
) |
|
elif self.evaluation_type == "saliency": |
|
mask_gt = create_gt_masks_if_voc(gt_labels) |
|
img_path = self.loader.images[idx] |
|
|
|
elif "ImageNet" in self.name: |
|
img, _ = self.loader[idx] |
|
img_path = self.loader.imgs[idx][0] |
|
|
|
zeros = np.zeros(np.array(img).shape[:2]) |
|
mask_gt = Image.fromarray(zeros) |
|
|
|
elif "COCO" in self.name: |
|
img_id = self.img_ids[idx] |
|
|
|
path = self.cocoGt.loadImgs(img_id)[0]["file_name"] |
|
img = Image.open(os.path.join(self.img_dir, path)).convert("RGB") |
|
_ = self.cocoGt.loadAnns(self.cocoGt.getAnnIds(id)) |
|
img_path = self.img_ids[idx] |
|
|
|
|
|
zeros = np.zeros(np.array(img).shape[:2]) |
|
mask_gt = Image.fromarray(zeros) |
|
|
|
|
|
else: |
|
img_path = self.list_images[idx] |
|
with open(img_path, "rb") as f: |
|
img = Image.open(f) |
|
img = img.convert("RGB") |
|
im_name = img_path.split("/")[-1] |
|
mask_gt = Image.open( |
|
os.path.join(self.gt_dir, im_name.replace(".jpg", ".png")) |
|
).convert("L") |
|
|
|
if self.for_eval: |
|
img_t = self.full_img_transform(img) |
|
img_init = self.no_norm_full_img_transform(img) |
|
|
|
if self.evaluation_type == "saliency": |
|
mask_gt = torch.tensor(np.array(mask_gt)).squeeze() |
|
mask_gt = np.array(mask_gt) |
|
mask_gt = mask_gt == 255 |
|
mask_gt = torch.tensor(mask_gt) |
|
else: |
|
if self.use_aug: |
|
img_t, mask_gt = self._preprocess_data_aug( |
|
image=img, mask=mask_gt, ignore_index=self.ignore_index |
|
) |
|
mask_gt = np.array(mask_gt) |
|
mask_gt = mask_gt == 255 |
|
mask_gt = torch.tensor(mask_gt) |
|
else: |
|
|
|
img_t, mask_gt = self._apply_center_crop(image=img, mask=mask_gt) |
|
gt_labels = self.center_crop_only_transforms(gt_labels).squeeze() |
|
mask_gt = np.asarray(mask_gt, np.int64) |
|
mask_gt = mask_gt == 1 |
|
mask_gt = torch.tensor(mask_gt) |
|
|
|
img_init = unnormalize(img_t) |
|
|
|
if not get_mask_gt: |
|
mask_gt = None |
|
|
|
if self.evaluation_type == "uod": |
|
gt_labels = torch.tensor(gt_labels) |
|
mask_gt = gt_labels |
|
|
|
return img_t, img_init, mask_gt, img_path |
|
|
|
def fullimg_mode(self): |
|
self.val_full_image = True |
|
|
|
def training_mode(self): |
|
self.val_full_image = False |