IndustrialDefects / source /defectGenerator.py
Simon Thomine
add cutpaste defect
94e4abf
import torch
import numpy as np
import cv2
import imgaug.augmenters as iaa
import random
import torchvision.transforms as T
import glob
from source.perlin import rand_perlin_2d_np
import matplotlib.pyplot as plt
from source.nsa import backGroundMask,patch_ex
from source.cutpaste import CutPaste
class TexturalAnomalyGenerator():
def __init__(self, resize_shape=None,dtd_path="../../datasets/dtd/images"):
self.resize_shape=resize_shape
self.anomaly_source_paths = sorted(glob.glob(dtd_path+"/*/*.jpg"))
self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True),
iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
iaa.pillike.EnhanceSharpness(),
iaa.AddToHueAndSaturation((-10,10),per_channel=True),
iaa.Solarize(0.5, threshold=(32,128)),
iaa.Posterize(),
iaa.Invert(),
iaa.pillike.Autocontrast(),
iaa.pillike.Equalize(),
]
def randAugmenter(self):
aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
aug = iaa.Sequential([self.augmenters[aug_ind[0]],
self.augmenters[aug_ind[1]],
self.augmenters[aug_ind[2]]]
)
return aug
def getDtdImage(self):
randIndex=random.randint(0, len(self.anomaly_source_paths)-1)
image=cv2.imread(self.anomaly_source_paths[randIndex])
image=cv2.resize(image, dsize=(self.resize_shape[0], self.resize_shape[1]))
aug=self.randAugmenter()
image=aug(image=image)
return image
class StructuralAnomalyGenerator():
def __init__(self,resize_shape=None):
self.resize_shape=resize_shape
self.augmenters = [iaa.Fliplr(0.5),
iaa.Affine(rotate=(-45, 45)),
iaa.Multiply((0.8, 1.2)),
iaa.MultiplySaturation((0.5, 1.5)),
iaa.MultiplyHue((0.5, 1.5))
]
def randAugmenter(self):
aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
aug = iaa.Sequential([self.augmenters[aug_ind[0]],
self.augmenters[aug_ind[1]],
self.augmenters[aug_ind[2]]]
)
return aug
def generateStructuralDefect(self,image):
aug=self.randAugmenter()
image_array=(image.permute(1,2,0).numpy()*255).astype(np.uint8)# # *
image_array=aug(image=image_array)
height, width, _ = image_array.shape
grid_size = 8
cell_height = height // grid_size
cell_width = width // grid_size
grid = []
for i in range(grid_size):
for j in range(grid_size):
cell = image_array[i * cell_height: (i + 1) * cell_height,
j * cell_width: (j + 1) * cell_width, :]
grid.append(cell)
np.random.shuffle(grid)
reconstructed_image = np.zeros_like(image_array)
for i in range(grid_size):
for j in range(grid_size):
reconstructed_image[i * cell_height: (i + 1) * cell_height,
j * cell_width: (j + 1) * cell_width, :] = grid[i * grid_size + j]
return reconstructed_image
class DefectGenerator():
def __init__(self, resize_shape=None,dtd_path="../../datasets/dtd/images"):
self.texturalAnomalyGenerator=TexturalAnomalyGenerator(resize_shape,dtd_path)
self.structuralAnomalyGenerator=StructuralAnomalyGenerator(resize_shape)
self.cutpaste=CutPaste()
self.resize_shape=resize_shape
self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
self.toTensor=T.ToTensor()
def generateMask(self,bMask):
perlin_scale = 6
min_perlin_scale = 0
perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
perlin_noise = rand_perlin_2d_np((self.resize_shape[0], self.resize_shape[1]), (perlin_scalex, perlin_scaley))
perlin_noise = self.rot(image=perlin_noise)
threshold = 0.5
perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
perlin_thr = np.expand_dims(perlin_thr, axis=2)
msk = (perlin_thr).astype(np.float32)
msk=torch.from_numpy(msk).permute(2,0,1)
if (len(bMask)>0):
msk=bMask*msk
return msk
def generateTexturalDefect(self, image,bMask=[]):
msk=torch.zeros((self.resize_shape[0], self.resize_shape[1]))
while (torch.count_nonzero(msk)<100):
msk=self.generateMask(bMask)*255.0
texturalImg=self.texturalAnomalyGenerator.getDtdImage()
texturalImg=torch.from_numpy(texturalImg).permute(2,0,1)/255.0
mskDtd=texturalImg*(msk)
image = image * (1 - msk)+ (mskDtd)
return image ,msk
def generateStructuralDefect(self, image,bMask=[]):
msk=torch.zeros((self.resize_shape[0], self.resize_shape[1]))
while (torch.count_nonzero(msk)<100):
msk=self.generateMask(bMask)*255.0
structuralImg=self.structuralAnomalyGenerator.generateStructuralDefect(image)/255.0
structuralImg=torch.from_numpy(structuralImg).permute(2,0,1)
mskDtd=structuralImg*(msk)
image = image * (1 - msk)+ (mskDtd)
return image ,msk
def generateBlurredDefectiveImage(self, image,bMask=[]):
msk=torch.zeros((self.resize_shape[0], self.resize_shape[1]))
while (torch.count_nonzero(msk)<100):
msk=self.generateMask(bMask)*255.0
randGaussianValue = random.randint(0, 5)*2+21
transform = T.GaussianBlur(kernel_size=(randGaussianValue, randGaussianValue), sigma=11.0)
imageBlurred = transform(image)
imageBlurred=imageBlurred*(msk)
image=image*(1-msk)
image=image+imageBlurred
return image,msk
def generateNsaDefect(self, image,bMask):
image = np.expand_dims(np.array(image),2) if len(np.array(image).shape)==2 else np.array(image)
image,msk=patch_ex(image,backgroundMask=bMask)
transform=T.ToTensor()
image = transform(image)
msk = transform(msk)*255.0
return image,msk
def generateCutPasteDefect(self, image,bMask):
msk=np.zeros((self.resize_shape[0], self.resize_shape[1]))
while (np.count_nonzero(msk)<100):
defect,cpmsk=self.cutpaste.cutpaste(image)
msk=bMask*np.expand_dims(np.array(cpmsk),axis=2)
image=np.array(defect)*bMask + np.array(image)*(1-bMask)
transform=T.ToTensor()
image = transform(image)
msk = transform(msk)
return image,msk
def genSingleDefect(self,image,label,mskbg):
if label.lower() not in ["textural","structural","blurred","nsa","cutpaste"]:
raise ValueError("The defect type should be in ['textural','structural','blurred','nsa','cutpaste']")
if (label.lower()=="textural" or label.lower()=="structural" or label.lower()=="blurred"):
imageT=self.toTensor(image)
bmask=self.toTensor(mskbg)
if (label.lower()=="textural"):
return self.generateTexturalDefect(imageT,bmask)
elif (label.lower()=="structural"):
return self.generateStructuralDefect(imageT,bmask)
elif (label.lower()=="blurred"):
return self.generateBlurredDefectiveImage(imageT,bmask)
elif (label.lower()=="nsa"):
return self.generateNsaDefect(image,mskbg)
elif (label.lower()=="cutpaste"):
return self.generateCutPasteDefect(image,mskbg)
def genDefect(self,image,defectType,category="",return_list=False):
mskbg=backGroundMask(image,obj=category)
if not return_list:
if (len(defectType)>1):
index=np.random.randint(0,len(defectType))
label=defectType[index]
else:
label=defectType[0]
return self.genSingleDefect(image,label,mskbg)
if return_list:
defectImages=[]
defectMasks=[]
for label in defectType:
defectImage,defectMask=self.genSingleDefect(image,label,mskbg)
defectImages.append(defectImage)
defectMasks.append(defectMask)
return defectImages,defectMasks