asadi / dataset /base_dataset.py
smjfas's picture
initial commit
16a0f31
"""
Base class for our zero-shot anomaly detection dataset
"""
import json
import os
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import cv2
from config import DATA_ROOT
class DataSolver:
def __init__(self, root, clsnames):
self.root = root
self.clsnames = clsnames
self.path = os.path.join(root, 'meta.json')
def run(self):
with open(self.path, 'r') as f:
info = json.load(f)
info_required = dict(train={}, test={})
for cls in self.clsnames:
for k in info.keys():
info_required[k][cls] = info[k][cls]
return info_required
class BaseDataset(data.Dataset):
def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.aug_rate = aug_rate
self.training = training
self.data_all = []
self.cls_names = clsnames
solver = DataSolver(root, clsnames)
meta_info = solver.run()
self.meta_info = meta_info['test'] # Only utilize the test dataset for both training and testing
for cls_name in self.cls_names:
self.data_all.extend(self.meta_info[cls_name])
self.length = len(self.data_all)
def __len__(self):
return self.length
def combine_img(self, cls_name):
"""
From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN
Here we combine four images into a single image for data augmentation.
"""
img_info = random.sample(self.meta_info[cls_name], 4)
img_ls = []
mask_ls = []
for data in img_info:
img_path = os.path.join(self.root, data['img_path'])
mask_path = os.path.join(self.root, data['mask_path'])
img = Image.open(img_path).convert('RGB')
img_ls.append(img)
if not data['anomaly']:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(mask_path).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
mask_ls.append(img_mask)
# Image
image_width, image_height = img_ls[0].size
result_image = Image.new("RGB", (2 * image_width, 2 * image_height))
for i, img in enumerate(img_ls):
row = i // 2
col = i % 2
x = col * image_width
y = row * image_height
result_image.paste(img, (x, y))
# Mask
result_mask = Image.new("L", (2 * image_width, 2 * image_height))
for i, img in enumerate(mask_ls):
row = i // 2
col = i % 2
x = col * image_width
y = row * image_height
result_mask.paste(img, (x, y))
return result_image, result_mask
def __getitem__(self, index):
data = self.data_all[index]
img_path = os.path.join(self.root, data['img_path'])
mask_path = os.path.join(self.root, data['mask_path'])
cls_name = data['cls_name']
anomaly = data['anomaly']
random_number = random.random()
if self.training and random_number < self.aug_rate:
img, img_mask = self.combine_img(cls_name)
else:
if img_path.endswith('.tif'):
img = cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
else:
img = Image.open(img_path).convert('RGB')
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
if data['mask_path']:
img_mask = np.array(Image.open(mask_path).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
else:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
# Transforms
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and img_mask is not None:
img_mask = self.target_transform(img_mask)
if img_mask is None:
img_mask = []
return {
'img': img,
'img_mask': img_mask,
'cls_name': cls_name,
'anomaly': anomaly,
'img_path': img_path
}