File size: 4,609 Bytes
16a0f31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
Base class for our zero-shot anomaly detection dataset
"""
import json
import os
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import cv2
from config import DATA_ROOT
class DataSolver:
def __init__(self, root, clsnames):
self.root = root
self.clsnames = clsnames
self.path = os.path.join(root, 'meta.json')
def run(self):
with open(self.path, 'r') as f:
info = json.load(f)
info_required = dict(train={}, test={})
for cls in self.clsnames:
for k in info.keys():
info_required[k][cls] = info[k][cls]
return info_required
class BaseDataset(data.Dataset):
def __init__(self, clsnames, transform, target_transform, root, aug_rate=0., training=True):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.aug_rate = aug_rate
self.training = training
self.data_all = []
self.cls_names = clsnames
solver = DataSolver(root, clsnames)
meta_info = solver.run()
self.meta_info = meta_info['test'] # Only utilize the test dataset for both training and testing
for cls_name in self.cls_names:
self.data_all.extend(self.meta_info[cls_name])
self.length = len(self.data_all)
def __len__(self):
return self.length
def combine_img(self, cls_name):
"""
From April-GAN: https://github.com/ByChelsea/VAND-APRIL-GAN
Here we combine four images into a single image for data augmentation.
"""
img_info = random.sample(self.meta_info[cls_name], 4)
img_ls = []
mask_ls = []
for data in img_info:
img_path = os.path.join(self.root, data['img_path'])
mask_path = os.path.join(self.root, data['mask_path'])
img = Image.open(img_path).convert('RGB')
img_ls.append(img)
if not data['anomaly']:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
img_mask = np.array(Image.open(mask_path).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
mask_ls.append(img_mask)
# Image
image_width, image_height = img_ls[0].size
result_image = Image.new("RGB", (2 * image_width, 2 * image_height))
for i, img in enumerate(img_ls):
row = i // 2
col = i % 2
x = col * image_width
y = row * image_height
result_image.paste(img, (x, y))
# Mask
result_mask = Image.new("L", (2 * image_width, 2 * image_height))
for i, img in enumerate(mask_ls):
row = i // 2
col = i % 2
x = col * image_width
y = row * image_height
result_mask.paste(img, (x, y))
return result_image, result_mask
def __getitem__(self, index):
data = self.data_all[index]
img_path = os.path.join(self.root, data['img_path'])
mask_path = os.path.join(self.root, data['mask_path'])
cls_name = data['cls_name']
anomaly = data['anomaly']
random_number = random.random()
if self.training and random_number < self.aug_rate:
img, img_mask = self.combine_img(cls_name)
else:
if img_path.endswith('.tif'):
img = cv2.imread(img_path)
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
else:
img = Image.open(img_path).convert('RGB')
if anomaly == 0:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
else:
if data['mask_path']:
img_mask = np.array(Image.open(mask_path).convert('L')) > 0
img_mask = Image.fromarray(img_mask.astype(np.uint8) * 255, mode='L')
else:
img_mask = Image.fromarray(np.zeros((img.size[0], img.size[1])), mode='L')
# Transforms
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and img_mask is not None:
img_mask = self.target_transform(img_mask)
if img_mask is None:
img_mask = []
return {
'img': img,
'img_mask': img_mask,
'cls_name': cls_name,
'anomaly': anomaly,
'img_path': img_path
}
|