Simon Thomine commited on
Commit
7973387
·
1 Parent(s): 7667296

Add application file

Browse files
.gitattributes CHANGED
@@ -33,3 +33,14 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ samples/*.png filter=lfs diff=lfs merge=lfs -text
38
+ samples/cable.png filter=lfs diff=lfs merge=lfs -text
39
+ samples/capsule.png filter=lfs diff=lfs merge=lfs -text
40
+ samples/carpet.png filter=lfs diff=lfs merge=lfs -text
41
+ samples/hazelnut.png filter=lfs diff=lfs merge=lfs -text
42
+ samples/leather.png filter=lfs diff=lfs merge=lfs -text
43
+ samples/tile.png filter=lfs diff=lfs merge=lfs -text
44
+ samples/toothbrush.png filter=lfs diff=lfs merge=lfs -text
45
+ samples/transistor.png filter=lfs diff=lfs merge=lfs -text
46
+ samples/wood.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.txt.user*
2
+ build/*
3
+ dbg_build/*
4
+ # .env is personal
5
+ .env
6
+
7
+ # various emacs ignore
8
+ *~
9
+ [#]*[#]
10
+ .\#*
11
+
12
+ # devcontainer vcode
13
+
14
+ .devcontainer/tmp/*
15
+
16
+ # Python
17
+ results/*
18
+ **/__pycache__
19
+ .idea/*
20
+ mv
21
+
22
+
23
+
24
+ #CMake Outputs
25
+ CMakeLists.txt.user
26
+ CMakeCache.txt
27
+ CMakeFiles
28
+ CMakeScripts
29
+ Testing
30
+ Makefile
31
+ cmake_install.cmake
32
+ install_manifest.txt
33
+ compile_commands.json
34
+ CTestTestfile.cmake
35
+ _deps
36
+ # C++ objects and libs
37
+ *.slo
38
+ *.lo
39
+ *.o
40
+ *.a
41
+ *.la
42
+ *.lai
43
+ *.so
44
+ *.so.*
45
+ *.dll
46
+ *.dylib
47
+
48
+ # Qt-es
49
+ object_script.*.Release
50
+ object_script.*.Debug
51
+ *_plugin_import.cpp
52
+ /.qmake.cache
53
+ /.qmake.stash
54
+ *.pro.user
55
+ *.pro.user.*
56
+ *.qbs.user
57
+ *.qbs.user.*
58
+ *.moc
59
+ moc_*.cpp
60
+ moc_*.h
61
+ qrc_*.cpp
62
+ ui_*.h
63
+ *.qmlc
64
+ *.jsc
65
+ Makefile*
66
+ *build-*
67
+ *.qm
68
+ *.prl
69
+
70
+ # Qt unit tests
71
+ target_wrapper.*
72
+
73
+ # QtCreator
74
+ *.autosave
75
+
76
+ # QtCreator Qml
77
+ *.qmlproject.user
78
+ *.qmlproject.user.*
79
+ # QtCreator CMake
80
+ CMakeLists.txt.user*
81
+
82
+ # QtCreator 4.8< compilation database
83
+ compile_commands.json
84
+
85
+ # QtCreator local machine specific files for imported projects
86
+ *creator.user*
87
+ *_qmlcache.qrc
88
+
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from source.defectGenerator import DefectGenerator
2
+ import matplotlib.pyplot as plt
3
+ from PIL import Image
4
+ import gradio as gr
5
+ import numpy as np
6
+
7
+ def generate_defect_image(image, defect_type,category):
8
+ defGen=DefectGenerator(image.size,dtd_path="samples/dtd/")
9
+ defect,msk=defGen.genDefect(image,[defect_type],category.lower())
10
+ defect=(defect.permute(1,2,0).numpy()*255.0).astype('uint8')
11
+ msk=(msk.permute(1,2,0).numpy()*255.0).astype('uint8')
12
+ msk = np.concatenate((msk, msk, msk), axis=2)
13
+ return defect, msk
14
+
15
+ images = {
16
+ "Bottle": Image.open('samples/bottle.png').convert('RGB').resize((1024, 1024)),
17
+ "Cable": Image.open('samples/cable.png').convert('RGB').resize((1024, 1024)),
18
+ "Capsule": Image.open('samples/capsule.png').convert('RGB').resize((1024, 1024)),
19
+ "Carpet": Image.open('samples/carpet.png').convert('RGB').resize((1024, 1024)),
20
+ "Grid": Image.open('samples/grid.png').convert('RGB').resize((1024, 1024)),
21
+ "Hazelnut": Image.open('samples/hazelnut.png').convert('RGB').resize((1024, 1024)),
22
+ "Leather": Image.open('samples/leather.png').convert('RGB').resize((1024, 1024)),
23
+ "Metal Nut": Image.open('samples/metal_nut.png').convert('RGB').resize((1024, 1024)),
24
+ "Pill": Image.open('samples/pill.png').convert('RGB').resize((1024, 1024)),
25
+ "Screw": Image.open('samples/screw.png').convert('RGB').resize((1024, 1024)),
26
+ "Tile": Image.open('samples/tile.png').convert('RGB').resize((1024, 1024)),
27
+ "Toothbrush": Image.open('samples/toothbrush.png').convert('RGB').resize((1024, 1024)),
28
+ "Transistor": Image.open('samples/transistor.png').convert('RGB').resize((1024, 1024)),
29
+ "Wood": Image.open('samples/wood.png').convert('RGB').resize((1024, 1024)),
30
+ "Zipper": Image.open('samples/zipper.png').convert('RGB').resize((1024, 1024))
31
+
32
+ }
33
+
34
+
35
+ def generate_and_display_images(category, defect_type):
36
+ base_image = images[category]
37
+ img_with_defect, defect_mask = generate_defect_image(base_image, defect_type,category)
38
+ return np.array(base_image), img_with_defect, defect_mask
39
+
40
+ # Components
41
+ with gr.Blocks(css="style.css") as demo:
42
+ gr.HTML(
43
+ "<h1><center> &#127981; MVTEC AD Defect Generator &#127981; </center></h1>" +
44
+ "<p><center><a href='https://github.com/SimonThomine/IndustrialDefectLib'>https://github.com/SimonThomine/IndustrialDefectLib</a></center></p>"
45
+ )
46
+ with gr.Group():
47
+ with gr.Row():
48
+ category_input = gr.Dropdown(label="Select object", choices=list(images.keys()),value="Bottle")
49
+ defect_type_input = gr.Dropdown(label="Select type of defect", choices=["blurred", "nsa","structural", "textural" ],value="nsa")
50
+ submit = gr.Button(
51
+ scale=1,
52
+ variant='primary'
53
+ )
54
+ with gr.Row():
55
+ with gr.Column(scale=1, min_width=400):
56
+ gr.HTML("<h1><center> Base </center></h1>")
57
+ base_image_output = gr.Image("Base", type="numpy")
58
+
59
+ with gr.Column(scale=1, min_width=400):
60
+ gr.HTML("<h1><center> Mask </center></h1>")
61
+ mask_output = gr.Image("Mask", type="numpy")
62
+
63
+ with gr.Column(scale=1, min_width=400):
64
+ gr.HTML("<h1><center> Defect </center></h1>")
65
+ defect_image_output = gr.Image("Defect", type="numpy")
66
+
67
+ submit.click(
68
+ fn=generate_and_display_images,
69
+ inputs=[category_input, defect_type_input],
70
+ outputs=[base_image_output, defect_image_output,mask_output],
71
+ )
72
+
73
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ opencv-python
4
+ pandas
5
+ pillow
6
+ scikit-image
7
+ scikit-learn
8
+ timm
9
+ torch
10
+ torchvision
11
+ tqdm
12
+ imgaug
samples/bottle.png ADDED

Git LFS Details

  • SHA256: 8a25eb07475a5102f4823fefd82662dd65f8f2a0a333bf958124ceeee38f0e38
  • Pointer size: 131 Bytes
  • Size of remote file: 533 kB
samples/cable.png ADDED

Git LFS Details

  • SHA256: 7457390d52de2e34f706abcbc1185da287a5dfd90573af23fef6816fa5d7a1aa
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
samples/capsule.png ADDED

Git LFS Details

  • SHA256: 96a77c1cca9d6efbd4fb487c9113fe775ac0559730131850c6c6523e944611d3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
samples/carpet.png ADDED

Git LFS Details

  • SHA256: 9fd2420a125701bb31dc8025cbf62cc7152dc04505fe5f151b782493e303f39e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.92 MB
samples/dtd/banded/banded_0004.jpg ADDED
samples/dtd/blotchy/blotchy_0003.jpg ADDED
samples/dtd/braided/braided_0050.jpg ADDED
samples/dtd/bubbly/bubbly_0038.jpg ADDED
samples/dtd/bumpy/bumpy_0014.jpg ADDED
samples/dtd/chequered/chequered_0017.jpg ADDED
samples/grid.png ADDED

Git LFS Details

  • SHA256: 94e29cc6c4c46225590138559736c8649959c4fa8b25fb5569755e624af1b137
  • Pointer size: 131 Bytes
  • Size of remote file: 450 kB
samples/hazelnut.png ADDED

Git LFS Details

  • SHA256: 5e28a714fa36ef5198c683058b607435b792974cdd23f3e0810b887dbdfe7112
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
samples/leather.png ADDED

Git LFS Details

  • SHA256: b9b62731f4a804dbf941ae1c1f99506c784f256f634f5b3f28ac429e246e0d6a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB
samples/metal_nut.png ADDED

Git LFS Details

  • SHA256: b0309325bc305ae8b19246884f20397b951285786aecca0ebf1d009e9799e350
  • Pointer size: 131 Bytes
  • Size of remote file: 499 kB
samples/pill.png ADDED

Git LFS Details

  • SHA256: 5f01a9be11029dcdac12a84308f345babe6bca33ccee4c50fb54e656fb42611a
  • Pointer size: 131 Bytes
  • Size of remote file: 614 kB
samples/screw.png ADDED

Git LFS Details

  • SHA256: 44b81b1b5779d5bc1695e8fdb1ca361eeda32b771ce91b0cc5468cc1f8aa735b
  • Pointer size: 131 Bytes
  • Size of remote file: 408 kB
samples/tile.png ADDED

Git LFS Details

  • SHA256: 522535bee0cc7e38a853dbdab79302f99f2f0f7fe66533e386f1b0c558ce5d0b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
samples/toothbrush.png ADDED

Git LFS Details

  • SHA256: 91c3ac02df7e987a1571aa5a2eb5b7e1f6b291cf61450aa949267465b6a9473e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
samples/transistor.png ADDED

Git LFS Details

  • SHA256: d755cd3f9334acc338db228fbccfe531397add1794181c09dafafa35c9014741
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
samples/wood.png ADDED

Git LFS Details

  • SHA256: d6d87826ab78b00bca17872acf7d06fc9d21b0e81567bed2ccebb7ba57b85e76
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
samples/zipper.png ADDED

Git LFS Details

  • SHA256: 42e15bf4d9672ad7b0fa08e99110aa9b9715b9c1889a820220db101cf76fec29
  • Pointer size: 131 Bytes
  • Size of remote file: 406 kB
source/defectGenerator.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import cv2
4
+ import imgaug.augmenters as iaa
5
+ import random
6
+ import torchvision.transforms as T
7
+ import glob
8
+ from source.perlin import rand_perlin_2d_np
9
+ import matplotlib.pyplot as plt
10
+ from source.nsa import backGroundMask,patch_ex
11
+
12
+ class TexturalAnomalyGenerator():
13
+ def __init__(self, resize_shape=None,dtd_path="../../datasets/dtd/images"):
14
+
15
+
16
+ self.resize_shape=resize_shape
17
+ self.anomaly_source_paths = sorted(glob.glob(dtd_path+"/*/*.jpg"))
18
+ self.augmenters = [iaa.GammaContrast((0.5,2.0),per_channel=True),
19
+ iaa.MultiplyAndAddToBrightness(mul=(0.8,1.2),add=(-30,30)),
20
+ iaa.pillike.EnhanceSharpness(),
21
+ iaa.AddToHueAndSaturation((-10,10),per_channel=True),
22
+ iaa.Solarize(0.5, threshold=(32,128)),
23
+ iaa.Posterize(),
24
+ iaa.Invert(),
25
+ iaa.pillike.Autocontrast(),
26
+ iaa.pillike.Equalize(),
27
+ ]
28
+
29
+
30
+ def randAugmenter(self):
31
+ aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
32
+ aug = iaa.Sequential([self.augmenters[aug_ind[0]],
33
+ self.augmenters[aug_ind[1]],
34
+ self.augmenters[aug_ind[2]]]
35
+ )
36
+ return aug
37
+ def getDtdImage(self):
38
+ randIndex=random.randint(0, len(self.anomaly_source_paths)-1)
39
+ image=cv2.imread(self.anomaly_source_paths[randIndex])
40
+ image=cv2.resize(image, dsize=(self.resize_shape[0], self.resize_shape[1]))
41
+ aug=self.randAugmenter()
42
+ image=aug(image=image)
43
+ return image
44
+
45
+
46
+ class StructuralAnomalyGenerator():
47
+ def __init__(self,resize_shape=None):
48
+
49
+ self.resize_shape=resize_shape
50
+ self.augmenters = [iaa.Fliplr(0.5),
51
+ iaa.Affine(rotate=(-45, 45)),
52
+ iaa.Multiply((0.8, 1.2)),
53
+ iaa.MultiplySaturation((0.5, 1.5)),
54
+ iaa.MultiplyHue((0.5, 1.5))
55
+ ]
56
+
57
+ def randAugmenter(self):
58
+ aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
59
+ aug = iaa.Sequential([self.augmenters[aug_ind[0]],
60
+ self.augmenters[aug_ind[1]],
61
+ self.augmenters[aug_ind[2]]]
62
+ )
63
+ return aug
64
+
65
+ def generateStructuralDefect(self,image):
66
+ aug=self.randAugmenter()
67
+ image_array=(image.permute(1,2,0).numpy()*255).astype(np.uint8)# # *
68
+
69
+
70
+ image_array=aug(image=image_array)
71
+
72
+ height, width, _ = image_array.shape
73
+ grid_size = 8
74
+ cell_height = height // grid_size
75
+ cell_width = width // grid_size
76
+
77
+ grid = []
78
+ for i in range(grid_size):
79
+ for j in range(grid_size):
80
+ cell = image_array[i * cell_height: (i + 1) * cell_height,
81
+ j * cell_width: (j + 1) * cell_width, :]
82
+ grid.append(cell)
83
+
84
+ np.random.shuffle(grid)
85
+
86
+ reconstructed_image = np.zeros_like(image_array)
87
+ for i in range(grid_size):
88
+ for j in range(grid_size):
89
+ reconstructed_image[i * cell_height: (i + 1) * cell_height,
90
+ j * cell_width: (j + 1) * cell_width, :] = grid[i * grid_size + j]
91
+ return reconstructed_image
92
+
93
+
94
+
95
+ class DefectGenerator():
96
+
97
+ def __init__(self, resize_shape=None,dtd_path="../../datasets/dtd/images"):
98
+
99
+
100
+ self.texturalAnomalyGenerator=TexturalAnomalyGenerator(resize_shape,dtd_path)
101
+ self.structuralAnomalyGenerator=StructuralAnomalyGenerator(resize_shape)
102
+
103
+ self.resize_shape=resize_shape
104
+ self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
105
+ self.toTensor=T.ToTensor()
106
+
107
+ def generateMask(self,bMask):
108
+ perlin_scale = 6
109
+ min_perlin_scale = 0
110
+ perlin_scalex = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
111
+ perlin_scaley = 2 ** (torch.randint(min_perlin_scale, perlin_scale, (1,)).numpy()[0])
112
+ perlin_noise = rand_perlin_2d_np((self.resize_shape[0], self.resize_shape[1]), (perlin_scalex, perlin_scaley))
113
+ perlin_noise = self.rot(image=perlin_noise)
114
+ threshold = 0.5
115
+ perlin_thr = np.where(perlin_noise > threshold, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
116
+ perlin_thr = np.expand_dims(perlin_thr, axis=2)
117
+ msk = (perlin_thr).astype(np.float32)
118
+ msk=torch.from_numpy(msk).permute(2,0,1)
119
+ if (len(bMask)>0):
120
+ msk=bMask*msk
121
+ return msk
122
+
123
+ def generateTexturalDefect(self, image,bMask=[]):
124
+ msk=torch.zeros((self.resize_shape[0], self.resize_shape[1]))
125
+ while (torch.count_nonzero(msk)<100):
126
+ msk=self.generateMask(bMask)*255.0
127
+ texturalImg=self.texturalAnomalyGenerator.getDtdImage()
128
+ texturalImg=torch.from_numpy(texturalImg).permute(2,0,1)/255.0
129
+ mskDtd=texturalImg*(msk)
130
+
131
+ image = image * (1 - msk)+ (mskDtd)
132
+ return image ,msk
133
+
134
+ def generateStructuralDefect(self, image,bMask=[]):
135
+ msk=torch.zeros((self.resize_shape[0], self.resize_shape[1]))
136
+ while (torch.count_nonzero(msk)<100):
137
+ msk=self.generateMask(bMask)*255.0
138
+ structuralImg=self.structuralAnomalyGenerator.generateStructuralDefect(image)/255.0
139
+ structuralImg=torch.from_numpy(structuralImg).permute(2,0,1)
140
+
141
+ mskDtd=structuralImg*(msk)
142
+ image = image * (1 - msk)+ (mskDtd)
143
+ return image ,msk
144
+
145
+
146
+ def generateBlurredDefectiveImage(self, image,bMask=[]):
147
+ msk=torch.zeros((self.resize_shape[0], self.resize_shape[1]))
148
+ while (torch.count_nonzero(msk)<100):
149
+ msk=self.generateMask(bMask)*255.0
150
+ randGaussianValue = random.randint(0, 5)*2+21
151
+ transform = T.GaussianBlur(kernel_size=(randGaussianValue, randGaussianValue), sigma=11.0)
152
+ imageBlurred = transform(image)
153
+ imageBlurred=imageBlurred*(msk)
154
+ image=image*(1-msk)
155
+
156
+ image=image+imageBlurred
157
+
158
+ return image,msk
159
+
160
+ def generateNsaDefect(self, image,bMask):
161
+ image = np.expand_dims(np.array(image),2) if len(np.array(image).shape)==2 else np.array(image)
162
+ image,msk=patch_ex(image,backgroundMask=bMask)
163
+ transform=T.ToTensor()
164
+ image = transform(image)
165
+ msk = transform(msk)*255.0
166
+ return image,msk
167
+
168
+
169
+
170
+ def genSingleDefect(self,image,label,mskbg):
171
+ if label.lower() not in ["textural","structural","blurred","nsa"]:
172
+ raise ValueError("The defect type should be in ['textural','structural','blurred','nsa']")
173
+
174
+ if (label.lower()=="textural" or label.lower()=="structural" or label.lower()=="blurred"):
175
+ imageT=self.toTensor(image)
176
+ bmask=self.toTensor(mskbg)
177
+ if (label.lower()=="textural"):
178
+ return self.generateTexturalDefect(imageT,bmask)
179
+ elif (label.lower()=="structural"):
180
+ return self.generateStructuralDefect(imageT,bmask)
181
+ elif (label.lower()=="blurred"):
182
+ return self.generateBlurredDefectiveImage(imageT,bmask)
183
+ elif (label.lower()=="nsa"):
184
+ return self.generateNsaDefect(image,mskbg)
185
+
186
+ def genDefect(self,image,defectType,category="",return_list=False):
187
+ mskbg=backGroundMask(image,obj=category)
188
+ if not return_list:
189
+ if (len(defectType)>1):
190
+ index=np.random.randint(0,len(defectType))
191
+ label=defectType[index]
192
+ else:
193
+ label=defectType[0]
194
+ return self.genSingleDefect(image,label,mskbg)
195
+ if return_list:
196
+ defectImages=[]
197
+ defectMasks=[]
198
+ for label in defectType:
199
+ defectImage,defectMask=self.genSingleDefect(image,label,mskbg)
200
+ defectImages.append(defectImage)
201
+ defectMasks.append(defectMask)
202
+ return defectImages,defectMasks
source/nsa.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import sys
4
+ from skimage.morphology import disk
5
+ from skimage.filters import median
6
+ import torch
7
+ import torchvision.transforms as T
8
+ import random
9
+
10
+ import matplotlib.pyplot as plt
11
+ from PIL import Image
12
+
13
+ BACKGROUND = {'bottle':(200, 60), 'screw':(200, 60), 'capsule':(200, 60), 'zipper':(200, 60),
14
+ 'hazelnut':(20, 20), 'pill':(20, 20), 'toothbrush':(20, 20), 'metal_nut':(20, 20)}
15
+
16
+
17
+ def backGroundMask(image,obj=""):
18
+ image = np.expand_dims(np.array(image),2) if len(np.array(image).shape)==2 else np.array(image)
19
+ #if obj=="":
20
+ if obj not in BACKGROUND.keys():
21
+ return np.ones_like(image[...,0:1])
22
+ else:
23
+ skip_background=BACKGROUND[obj]
24
+ if isinstance(skip_background, tuple):
25
+ skip_background = [skip_background]
26
+ object_mask = np.ones_like(image[...,0:1])
27
+ for background, threshold in skip_background:
28
+ object_mask &= np.uint8(np.abs(image.mean(axis=-1, keepdims=True) - background) > threshold)
29
+ object_mask[...,0] = cv2.medianBlur(object_mask[...,0], 7) # remove grain from threshold choice
30
+
31
+ return object_mask
32
+
33
+
34
+ def patch_ex(ima_dest, ima_src=None, same=False, num_patches=1,
35
+ mode=cv2.NORMAL_CLONE, width_bounds_pct=((0.05,0.2),(0.05,0.2)), min_object_pct=0.25,
36
+ min_overlap_pct=0.25, shift=True, label_mode='binary', backgroundMask=None, tol=1, resize=True,
37
+ gamma_params=None, intensity_logistic_params=(1/6, 20),
38
+ resize_bounds=(0.7, 1.3), num_ellipses=None, verbose=True, cutpaste_patch_generation=False):
39
+ """
40
+ Create a synthetic training example from the given images by pasting/blending random patches.
41
+ Args:
42
+ ima_dest (uint8 numpy array): image with shape (W,H,3) or (W,H,1) where patch should be changed
43
+ ima_src (uint8 numpy array): optional, otherwise use ima_dest as source
44
+ same (bool): use ima_dest as source even if ima_src given
45
+ mode: 'uniform', 'swap', 'mix', cv2.NORMAL_CLONE, or cv2.MIXED_CLONE what blending method to use
46
+ ('mix' is flip a coin between normal and mixed clone)
47
+ num_patches (int): how many patches to add. the method will always attempt to add the first patch,
48
+ for each subsequent patch it flips a coin
49
+ width_bounds_pct ((float, float), (float, float)): min half-width of patch ((min_dim1, max_dim1), (min_dim2, max_dim2))
50
+ shift (bool): if false, patches in src and dest image have same coords. otherwise random shift
51
+ resize (bool): if true, patch is resampled at random size (within bounds and keeping aspect ratio the same) before blending
52
+ skip_background (int, int) or [(int, int),]: optional, assume background color is first and only interpolate patches
53
+ in areas where dest or src patch has pixelwise MAD < second from background.
54
+ tol (int): mean abs intensity change required to get positive label
55
+ gamma_params (float, float, float): optional, (shape, scale, left offset) of gamma dist to sample half-width of patch from,
56
+ otherwise use uniform dist between 0.05 and 0.95
57
+ intensity_logistic_params (float, float): k, x0 of logitistc map for intensity based label
58
+ num_ellipses (int): optional, if set, the rectangular patch mask is filled with random ellipses
59
+ label_mode: 'binary',
60
+ 'continuous' -- use interpolation factor as label (only when mode is 'uniform'),
61
+ 'intensity' -- use median filtered mean absolute pixelwise intensity difference as label,
62
+ 'logistic-intensity' -- use logistic median filtered of mean absolute pixelwise intensity difference as label,
63
+ cutpaste_patch_generation (bool): optional, if set, width_bounds_pct, resize, skip_background, min_overlap_pct, min_object_pct,
64
+ num_patches and gamma_params are ignored. A single patch is sampled as in the CutPaste paper:
65
+ 1. sampling the area ratio between the patch and the full image from (0.02, 0.15)
66
+ 2. determine the aspect ratio by sampling from (0.3, 1) union (1, 3.3)
67
+ 3. sample location such that patch is contained entirely within the image
68
+ """
69
+ if mode == 'mix':
70
+ mode = (cv2.NORMAL_CLONE, cv2.MIXED_CLONE)[np.random.randint(2)]
71
+
72
+ if cutpaste_patch_generation:
73
+ width_bounds_pct = None
74
+ resize = False
75
+ min_overlap_pct = None
76
+ min_object_pct = None
77
+ gamma_params = None
78
+ num_patches = 1
79
+
80
+ ima_src = ima_dest.copy() if same or (ima_src is None) else ima_src
81
+
82
+ src_object_mask = backgroundMask
83
+ dest_object_mask = backgroundMask
84
+
85
+
86
+ mask = np.zeros_like(ima_dest[..., 0:1])
87
+ patchex = ima_dest.copy()
88
+ coor_min_dim1, coor_max_dim1, coor_min_dim2, coor_max_dim2 = mask.shape[0] - 1, 0, mask.shape[1] - 1, 0
89
+ if label_mode == 'continuous':
90
+ factor = np.random.uniform(0.05, 0.95)
91
+ else:
92
+ factor = 1
93
+ for i in range(num_patches):
94
+ if i == 0 or np.random.randint(2) > 0:
95
+ patchex, ((_coor_min_dim1, _coor_max_dim1), (_coor_min_dim2, _coor_max_dim2)), patch_mask = _patch_ex(
96
+ patchex, ima_src, dest_object_mask, src_object_mask, mode, label_mode, shift, resize, width_bounds_pct,
97
+ gamma_params, min_object_pct, min_overlap_pct, factor, resize_bounds, num_ellipses, verbose, cutpaste_patch_generation)
98
+ if patch_mask is not None:
99
+ mask[_coor_min_dim1:_coor_max_dim1,_coor_min_dim2:_coor_max_dim2] = patch_mask
100
+ coor_min_dim1 = min(coor_min_dim1, _coor_min_dim1)
101
+ coor_max_dim1 = max(coor_max_dim1, _coor_max_dim1)
102
+ coor_min_dim2 = min(coor_min_dim2, _coor_min_dim2)
103
+ coor_max_dim2 = max(coor_max_dim2, _coor_max_dim2)
104
+
105
+ # create label
106
+ label_mask = np.uint8(np.mean(np.abs(1.0 * mask*ima_dest - 1.0 * mask*patchex), axis=-1, keepdims=True) > tol)
107
+ label_mask[...,0] = cv2.medianBlur(label_mask[...,0], 5)
108
+
109
+ if label_mode == 'continuous':
110
+ label = label_mask * factor
111
+ elif label_mode in ['logistic-intensity', 'intensity']:
112
+ k, x0 = intensity_logistic_params
113
+ label = np.mean(np.abs(label_mask * ima_dest * 1.0 - label_mask * patchex * 1.0), axis=-1, keepdims=True)
114
+ label[...,0] = median(label[...,0], disk(5))
115
+ if label_mode == 'logistic-intensity':
116
+ label = label_mask / (1 + np.exp(-k * (label - x0)))
117
+ elif label_mode == 'binary':
118
+ label = label_mask
119
+ else:
120
+ raise ValueError("label_mode not supported" + str(label_mode))
121
+
122
+ return patchex, label
123
+
124
+
125
+ def _patch_ex(ima_dest, ima_src, dest_object_mask, src_object_mask, mode, label_mode, shift, resize, width_bounds_pct,
126
+ gamma_params, min_object_pct, min_overlap_pct, factor, resize_bounds, num_ellipses, verbose, cutpaste_patch_generation):
127
+ if cutpaste_patch_generation:
128
+ skip_background = False
129
+ dims = np.array(ima_dest.shape)
130
+ if dims[0] != dims[1]:
131
+ raise ValueError("CutPaste patch generation only works for square images")
132
+ # 1. sampling the area ratio between the patch and the full image from (0.02, 0.15)
133
+ # (divide by 4 as patch-widths below are actually half-widths)
134
+ area_ratio = np.random.uniform(0.02, 0.15) / 4.0
135
+ # 2. determine the aspect ratio by sampling from (0.3, 1) union (1, 3.3)
136
+ if np.random.randint(2) > 0:
137
+ aspect_ratio = np.random.uniform(0.3, 1)
138
+ else:
139
+ aspect_ratio = np.random.uniform(1, 3.3)
140
+
141
+ patch_width_dim1 = int(np.rint(np.clip(np.sqrt(area_ratio * aspect_ratio * dims[0]**2), 0, dims[0])))
142
+ patch_width_dim2 = int(np.rint(np.clip(area_ratio * dims[0]**2 / patch_width_dim1, 0, dims[1])))
143
+ # 3. sample location such that patch is contained entirely within the image
144
+ center_dim1 = np.random.randint(patch_width_dim1, dims[0] - patch_width_dim1)
145
+ center_dim2 = np.random.randint(patch_width_dim2, dims[1] - patch_width_dim2)
146
+
147
+ coor_min_dim1 = np.clip(center_dim1 - patch_width_dim1, 0, dims[0])
148
+ coor_min_dim2 = np.clip(center_dim2 - patch_width_dim2, 0, dims[1])
149
+ coor_max_dim1 = np.clip(center_dim1 + patch_width_dim1, 0, dims[0])
150
+ coor_max_dim2 = np.clip(center_dim2 + patch_width_dim2, 0, dims[1])
151
+
152
+ patch_mask = np.ones((coor_max_dim1 - coor_min_dim1, coor_max_dim2 - coor_min_dim2, 1), dtype=np.uint8)
153
+ else:
154
+ skip_background = (src_object_mask is not None) and (dest_object_mask is not None)
155
+ dims = np.array(ima_dest.shape)
156
+ min_width_dim1 = (width_bounds_pct[0][0]*dims[0]).round().astype(int)
157
+ max_width_dim1 = (width_bounds_pct[0][1]*dims[0]).round().astype(int)
158
+ min_width_dim2 = (width_bounds_pct[1][0]*dims[1]).round().astype(int)
159
+ max_width_dim2 = (width_bounds_pct[1][1]*dims[1]).round().astype(int)
160
+
161
+ if gamma_params is not None:
162
+ shape, scale, lower_bound = gamma_params
163
+ patch_width_dim1 = int(np.clip((lower_bound + np.random.gamma(shape, scale)) * dims[0], min_width_dim1, max_width_dim1))
164
+ patch_width_dim2 = int(np.clip((lower_bound + np.random.gamma(shape, scale)) * dims[1], min_width_dim2, max_width_dim2))
165
+ else:
166
+ patch_width_dim1 = np.random.randint(min_width_dim1, max_width_dim1)
167
+ patch_width_dim2 = np.random.randint(min_width_dim2, max_width_dim2)
168
+
169
+ found_patch = False
170
+ attempts = 0
171
+ while not found_patch:
172
+ center_dim1 = np.random.randint(min_width_dim1, dims[0]-min_width_dim1)
173
+ center_dim2 = np.random.randint(min_width_dim2, dims[1]-min_width_dim2)
174
+
175
+ coor_min_dim1 = np.clip(center_dim1 - patch_width_dim1, 0, dims[0])
176
+ coor_min_dim2 = np.clip(center_dim2 - patch_width_dim2, 0, dims[1])
177
+ coor_max_dim1 = np.clip(center_dim1 + patch_width_dim1, 0, dims[0])
178
+ coor_max_dim2 = np.clip(center_dim2 + patch_width_dim2, 0, dims[1])
179
+
180
+ if num_ellipses is not None:
181
+ ellipse_min_dim1 = min_width_dim1
182
+ ellipse_min_dim2 = min_width_dim2
183
+ ellipse_max_dim1 = max(min_width_dim1 + 1, patch_width_dim1 // 2)
184
+ ellipse_max_dim2 = max(min_width_dim2 + 1, patch_width_dim2 // 2)
185
+ patch_mask = np.zeros((coor_max_dim1 - coor_min_dim1, coor_max_dim2 - coor_min_dim2), dtype=np.uint8)
186
+ x = np.arange(patch_mask.shape[0]).reshape(-1, 1)
187
+ y = np.arange(patch_mask.shape[1]).reshape(1, -1)
188
+ for _ in range(num_ellipses):
189
+ theta = np.random.uniform(0, np.pi)
190
+ x0 = np.random.randint(0, patch_mask.shape[0])
191
+ y0 = np.random.randint(0, patch_mask.shape[1])
192
+ a = np.random.randint(ellipse_min_dim1, ellipse_max_dim1)
193
+ b = np.random.randint(ellipse_min_dim2, ellipse_max_dim2)
194
+ ellipse = (((x-x0)*np.cos(theta) + (y-y0)*np.sin(theta))/a)**2 + (((x-x0)*np.sin(theta) + (y-y0)*np.cos(theta))/b)**2 <= 1 # True for points inside the ellipse
195
+ patch_mask |= ellipse
196
+ patch_mask = patch_mask[...,None]
197
+ else:
198
+ patch_mask = np.ones((coor_max_dim1 - coor_min_dim1, coor_max_dim2 - coor_min_dim2, 1), dtype=np.uint8)
199
+
200
+ if skip_background:
201
+ background_area = np.sum(patch_mask & src_object_mask[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2])
202
+ if num_ellipses is not None:
203
+ patch_area = np.sum(patch_mask)
204
+ else:
205
+ patch_area = patch_mask.shape[0] * patch_mask.shape[1]
206
+ found_patch = (background_area / patch_area > min_object_pct)
207
+ else:
208
+ found_patch = True
209
+ attempts += 1
210
+ if attempts == 200:
211
+ if verbose:
212
+ print('No suitable patch found.')
213
+ return ima_dest.copy(), ((0,0),(0,0)), None
214
+ src = ima_src[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2]
215
+ height, width, _ = src.shape
216
+ if resize:
217
+ lb, ub = resize_bounds
218
+ scale = np.clip(np.random.normal(1, 0.5), lb, ub)
219
+ new_height = np.clip(scale * height, min_width_dim1, max_width_dim1)
220
+ new_width = np.clip(int(new_height / height * width), min_width_dim2, max_width_dim2)
221
+ new_height = np.clip(int(new_width / width * height), min_width_dim1, max_width_dim1) # in case there was clipping
222
+
223
+ if src.shape[2] == 1: # grayscale
224
+ src = cv2.resize(src[..., 0], (new_width, new_height))
225
+ src = src[...,None]
226
+ else:
227
+ src = cv2.resize(src, (new_width, new_height))
228
+ height, width, _ = src.shape
229
+ patch_mask = cv2.resize(patch_mask[...,0], (width, height))
230
+ patch_mask = patch_mask[...,None]
231
+
232
+ if skip_background:
233
+ src_object_mask = cv2.resize(src_object_mask[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2, 0], (width, height))
234
+ src_object_mask = src_object_mask[...,None]
235
+
236
+ # sample destination location and size
237
+ if shift:
238
+ found_center = False
239
+ attempts = 0
240
+ while not found_center:
241
+ center_dim1 = np.random.randint(height//2 + 1, ima_dest.shape[0] - height//2 - 1)
242
+ center_dim2 = np.random.randint(width//2 + 1, ima_dest.shape[1] - width//2 - 1)
243
+ coor_min_dim1, coor_max_dim1 = center_dim1 - height//2, center_dim1 + (height+1)//2
244
+ coor_min_dim2, coor_max_dim2 = center_dim2 - width//2, center_dim2 + (width+1)//2
245
+
246
+ if skip_background:
247
+ src_and_dest = dest_object_mask[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2] & src_object_mask & patch_mask
248
+ src_or_dest = (dest_object_mask[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2] | src_object_mask) & patch_mask
249
+ found_center = (np.sum(src_object_mask) / (patch_mask.shape[0] * patch_mask.shape[1]) > min_object_pct and # contains object
250
+ np.sum(src_and_dest) / np.sum(src_object_mask) > min_overlap_pct) # object overlaps src object
251
+ else:
252
+ found_center = True
253
+ attempts += 1
254
+ if attempts == 200:
255
+ if verbose:
256
+ print('No suitable center found. Dims were:', width, height)
257
+ return ima_dest.copy(), ((0,0),(0,0)), None
258
+
259
+ # blend
260
+ if skip_background:
261
+ patch_mask &= src_object_mask | dest_object_mask[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2]
262
+
263
+ if mode == 'swap':
264
+ patchex = ima_dest.copy()
265
+ before = patchex[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2]
266
+ patchex[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2] -= patch_mask * before
267
+ patchex[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2] += patch_mask * src
268
+ elif mode == 'uniform':
269
+ patchex = 1.0 * ima_dest
270
+ before = patchex[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2]
271
+ patchex[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2] -= factor * patch_mask * before
272
+ patchex[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2] += factor * patch_mask * src
273
+ patchex = np.uint8(np.floor(patchex))
274
+ elif mode in [cv2.NORMAL_CLONE, cv2.MIXED_CLONE]: # poisson interpolation
275
+ int_factor = np.uint8(np.ceil(factor * 255))
276
+ # add background to patchmask to avoid artefacts
277
+ if skip_background:
278
+ patch_mask_scaled = int_factor * (patch_mask | ((1 - src_object_mask) & (1 - dest_object_mask[coor_min_dim1:coor_max_dim1, coor_min_dim2:coor_max_dim2])))
279
+ else:
280
+ patch_mask_scaled = int_factor * patch_mask
281
+ patch_mask_scaled[0], patch_mask_scaled[-1], patch_mask_scaled[:,0], patch_mask_scaled[:,-1] = 0, 0, 0, 0 # zero border to avoid artefacts
282
+ center = (coor_max_dim2 - (coor_max_dim2 - coor_min_dim2) // 2, coor_min_dim1 + (coor_max_dim1 - coor_min_dim1) // 2) # height dim first
283
+ if np.sum(patch_mask_scaled > 0) < 50: # cv2 seamlessClone will fail if positive mask area is too small
284
+ return ima_dest.copy(), ((0,0),(0,0)), None
285
+ try:
286
+ if ima_dest.shape[2] == 1: # grayscale
287
+ # pad to 3 channels as that's what OpenCV expects
288
+ src_3 = np.concatenate((src, np.zeros_like(src), np.zeros_like(src)), axis=2)
289
+ ima_dest_3 = np.concatenate((ima_dest, np.zeros_like(ima_dest), np.zeros_like(ima_dest)), axis=2)
290
+ patchex = cv2.seamlessClone(src_3, ima_dest_3, patch_mask_scaled, center, mode)
291
+ patchex = patchex[...,0:1] # extract first channel
292
+ else: # RGB
293
+ patchex = cv2.seamlessClone(src, ima_dest, patch_mask_scaled, center, mode)
294
+ except cv2.error as e:
295
+ print('WARNING, tried bad interpolation mask and got:', e)
296
+ return ima_dest.copy(), ((0,0),(0,0)), None
297
+ else:
298
+ raise ValueError("mode not supported" + str(mode))
299
+
300
+ return patchex, ((coor_min_dim1, coor_max_dim1), (coor_min_dim2, coor_max_dim2)), patch_mask
source/perlin.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ def lerp_np(x,y,w):
6
+ fin_out = (y-x)*w + x
7
+ return fin_out
8
+
9
+ def generate_fractal_noise_2d(shape, res, octaves=1, persistence=0.5):
10
+ noise = np.zeros(shape)
11
+ frequency = 1
12
+ amplitude = 1
13
+ for _ in range(octaves):
14
+ noise += amplitude * generate_perlin_noise_2d(shape, (frequency*res[0], frequency*res[1]))
15
+ frequency *= 2
16
+ amplitude *= persistence
17
+ return noise
18
+
19
+
20
+ def generate_perlin_noise_2d(shape, res):
21
+ def f(t):
22
+ return 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3
23
+
24
+ delta = (res[0] / shape[0], res[1] / shape[1])
25
+ d = (shape[0] // res[0], shape[1] // res[1])
26
+ grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
27
+ # Gradients
28
+ angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
29
+ gradients = np.dstack((np.cos(angles), np.sin(angles)))
30
+ g00 = gradients[0:-1, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
31
+ g10 = gradients[1:, 0:-1].repeat(d[0], 0).repeat(d[1], 1)
32
+ g01 = gradients[0:-1, 1:].repeat(d[0], 0).repeat(d[1], 1)
33
+ g11 = gradients[1:, 1:].repeat(d[0], 0).repeat(d[1], 1)
34
+ # Ramps
35
+ n00 = np.sum(grid * g00, 2)
36
+ n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1])) * g10, 2)
37
+ n01 = np.sum(np.dstack((grid[:, :, 0], grid[:, :, 1] - 1)) * g01, 2)
38
+ n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] - 1)) * g11, 2)
39
+ # Interpolation
40
+ t = f(grid)
41
+ n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
42
+ n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
43
+ return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
44
+
45
+
46
+ def rand_perlin_2d_np(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
47
+ delta = (res[0] / shape[0], res[1] / shape[1])
48
+ d = (shape[0] // res[0], shape[1] // res[1])
49
+ grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
50
+
51
+ angles = 2 * math.pi * np.random.rand(res[0] + 1, res[1] + 1)
52
+ gradients = np.stack((np.cos(angles), np.sin(angles)), axis=-1)
53
+ tt = np.repeat(np.repeat(gradients,d[0],axis=0),d[1],axis=1)
54
+
55
+ tile_grads = lambda slice1, slice2: np.repeat(np.repeat(gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]],d[0],axis=0),d[1],axis=1)
56
+ dot = lambda grad, shift: (
57
+ np.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
58
+ axis=-1) * grad[:shape[0], :shape[1]]).sum(axis=-1)
59
+
60
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
61
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
62
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
63
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
64
+ t = fade(grid[:shape[0], :shape[1]])
65
+ return math.sqrt(2) * lerp_np(lerp_np(n00, n10, t[..., 0]), lerp_np(n01, n11, t[..., 0]), t[..., 1])
66
+
67
+
68
+ def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
69
+ delta = (res[0] / shape[0], res[1] / shape[1])
70
+ d = (shape[0] // res[0], shape[1] // res[1])
71
+
72
+ grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
73
+ angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
74
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
75
+
76
+ tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
77
+ 0).repeat_interleave(
78
+ d[1], 1)
79
+ dot = lambda grad, shift: (
80
+ torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
81
+ dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
82
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
83
+
84
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
85
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
86
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
87
+ t = fade(grid[:shape[0], :shape[1]])
88
+ return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
89
+
90
+
91
+ def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
92
+ noise = torch.zeros(shape)
93
+ frequency = 1
94
+ amplitude = 1
95
+ for _ in range(octaves):
96
+ noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
97
+ frequency *= 2
98
+ amplitude *= persistence
99
+ return noise