# ----------------------------------------------------------------------------- # Do Not Alter This File! # ----------------------------------------------------------------------------- # The following code is part of the logic used for loading and evaluating your # output scores. Please DO NOT modify this section, as upon your submission, # the whole evaluation logic will be overwritten by the original code. # ----------------------------------------------------------------------------- # If you'd like to make modifications, you can create a completely new Dataset # class or a child class that inherits from this one and use that with your # data loader. # ----------------------------------------------------------------------------- import os from enum import Enum import PIL import torch from torchvision import transforms IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class RayanDataset(torch.utils.data.Dataset): def __init__( self, source, classname, input_size=518, output_size=224, split=DatasetSplit.TEST, external_transform=None, **kwargs, ): super().__init__() self.source = source self.split = split self.classnames_to_use = [classname] self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() if external_transform is None: self.transform_img = [ transforms.Resize((input_size, input_size)), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] self.transform_img = transforms.Compose(self.transform_img) else: self.transform_img = external_transform # Output size of the mask has to be of shape: 1×224×224 self.transform_mask = [ transforms.Resize((output_size, output_size)), transforms.CenterCrop(output_size), transforms.ToTensor(), ] self.transform_mask = transforms.Compose(self.transform_mask) self.output_shape = (1, output_size, output_size) def __getitem__(self, idx): classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] image = PIL.Image.open(image_path).convert("RGB") image = self.transform_img(image) if self.split == DatasetSplit.TEST and mask_path is not None: mask = PIL.Image.open(mask_path).convert("L") mask = self.transform_mask(mask) > 0 else: mask = torch.zeros([*self.output_shape]) return { "image": image, "mask": mask, "is_anomaly": int(anomaly != "good"), "image_path": image_path, } def __len__(self): return len(self.data_to_iterate) def get_image_data(self): imgpaths_per_class = {} maskpaths_per_class = {} for classname in self.classnames_to_use: classpath = os.path.join(self.source, classname, self.split.value) maskpath = os.path.join(self.source, classname, "ground_truth") anomaly_types = os.listdir(classpath) imgpaths_per_class[classname] = {} maskpaths_per_class[classname] = {} for anomaly in anomaly_types: anomaly_path = os.path.join(classpath, anomaly) anomaly_files = sorted(os.listdir(anomaly_path)) imgpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_path, x) for x in anomaly_files ] if self.split == DatasetSplit.TEST and anomaly != "good": anomaly_mask_path = os.path.join(maskpath, anomaly) anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) maskpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files ] else: maskpaths_per_class[classname]["good"] = None data_to_iterate = [] for classname in sorted(imgpaths_per_class.keys()): for anomaly in sorted(imgpaths_per_class[classname].keys()): for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): data_tuple = [classname, anomaly, image_path] if self.split == DatasetSplit.TEST and anomaly != "good": data_tuple.append(maskpaths_per_class[classname][anomaly][i]) else: data_tuple.append(None) data_to_iterate.append(data_tuple) return imgpaths_per_class, data_to_iterate