#imports import numpy as np import matplotlib.pyplot as plt import os from math import floor, ceil from random import randint from sklearn.neighbors import KDTree from skimage.util.shape import view_as_windows from skimage import io from PIL import Image, ImageDraw from IPython.display import clear_output class patchBasedTextureSynthesis: def __init__(self, exampleMapPath, in_outputPath, in_outputSize, in_patchSize, in_overlapSize, in_windowStep = 5, in_mirror_hor = True, in_mirror_vert = True, in_shapshots = True): self.exampleMap = self.loadExampleMap(exampleMapPath) self.snapshots = in_shapshots self.outputPath = in_outputPath self.outputSize = in_outputSize self.patchSize = in_patchSize self.overlapSize = in_overlapSize self.mirror_hor = in_mirror_hor self.mirror_vert = in_mirror_vert self.total_patches_count = 0 #excluding mirrored versions self.windowStep = 5 self.iter = 0 self.checkIfDirectoryExists() #check if output directory exists self.examplePatches = self.prepareExamplePatches() self.canvas, self.filledMap, self.idMap = self.initCanvas() self.initFirstPatch() #place random block to start with self.kdtree_topOverlap, self.kdtree_leftOverlap, self.kdtree_combined = self.initKDtrees() self.PARM_truncation = 0.8 self.PARM_attenuation = 2 def checkIfDirectoryExists(self): if not os.path.exists(self.outputPath): os.makedirs(self.outputPath) def resolveAll(self): self.saveParams() #resolve all unresolved patches for i in range(np.sum(1-self.filledMap).astype(int)): self.resolveNext() if not self.snapshots: img = Image.fromarray(np.uint8(self.canvas*255)) img = img.resize((self.outputSize[0], self.outputSize[1]), resample=0, box=None) img.save(self.outputPath + 'out.jpg') # else: # self.visualize([0,0], [], [], showCandidates=False) return img def saveParams(self): #write text_file = open(self.outputPath + 'params.txt', "w") text_file.write("PatchSize: %d \nOverlapSize: %d \nMirror Vert: %d \nMirror Hor: %d" % (self.patchSize, self.overlapSize, self.mirror_vert, self.mirror_hor)) text_file.close() def resolveNext(self): #coordinate of the next one to resolve coord = self.idCoordTo2DCoord(np.sum(self.filledMap), np.shape(self.filledMap)) #get 2D coordinate of next to resolve patch #get overlap areas of the patch we want to resolve overlapArea_Top = self.getOverlapAreaTop(coord) overlapArea_Left = self.getOverlapAreaLeft(coord) #find most similar patch from the examples dist, ind = self.findMostSimilarPatches(overlapArea_Top, overlapArea_Left, coord) if self.mirror_hor or self.mirror_vert: #check that top and left neighbours are not mirrors dist, ind = self.checkForMirrors(dist, ind, coord) #choose random valid patch probabilities = self.distances2probability(dist, self.PARM_truncation, self.PARM_attenuation) chosenPatchId = np.random.choice(ind, 1, p=probabilities) #update canvas blend_top = (overlapArea_Top is not None) blend_left = (overlapArea_Left is not None) self.updateCanvas(chosenPatchId, coord[0], coord[1], blend_top, blend_left) #update filledMap and id map ;) self.filledMap[coord[0], coord[1]] = 1 self.idMap[coord[0], coord[1]] = chosenPatchId #visualize # self.visualize(coord, chosenPatchId, ind) self.iter += 1 def visualize(self, coord, chosenPatchId, nonchosenPatchId, showCandidates = True): #full visualization includes both example and generated img canvasSize = np.shape(self.canvas) #insert generated image vis = np.zeros((canvasSize[0], canvasSize[1] * 2, 3)) + 0.2 vis[:, 0:canvasSize[1]] = self.canvas #insert example exampleHighlited = np.copy(self.exampleMap) if showCandidates: exampleHighlited = self.hightlightPatchCandidates(chosenPatchId, nonchosenPatchId) h = floor(canvasSize[0] / 2) w = floor(canvasSize[1] / 2) exampleResized = self.resize(exampleHighlited, [h, w]) offset_h = floor((canvasSize[0] - h) / 2) offset_w = floor((canvasSize[1] - w) / 2) vis[offset_h:offset_h+h, canvasSize[1]+offset_w:canvasSize[1]+offset_w+w] = exampleResized #show live update plt.imshow(vis) clear_output(wait=True) display(plt.show()) if self.snapshots: img = Image.fromarray(np.uint8(vis*255)) img = img.resize((self.outputSize[0]*2, self.outputSize[1]), resample=0, box=None) img.save(self.outputPath + 'out' + str(self.iter) + '.jpg') def hightlightPatchCandidates(self, chosenPatchId, nonchosenPatchId): result = np.copy(self.exampleMap) #mod patch ID chosenPatchId = chosenPatchId[0] % self.total_patches_count if len(nonchosenPatchId)>0: nonchosenPatchId = nonchosenPatchId % self.total_patches_count #exlcude chosen from nonchosen nonchosenPatchId = np.delete(nonchosenPatchId, np.where(nonchosenPatchId == chosenPatchId)) #highlight non chosen candidates c = [0.25, 0.9 ,0.45] self.highlightPatches(result, nonchosenPatchId, color=c, highlight_width = 4, alpha = 0.5) #hightlight chosen c = [1.0, 0.25, 0.15] self.highlightPatches(result, [chosenPatchId], color=c, highlight_width = 4, alpha = 1) return result def highlightPatches(self, writeResult, patchesIDs, color, highlight_width = 2, solid = False, alpha = 0.1): searchWindow = self.patchSize + 2*self.overlapSize #number of possible steps row_steps = floor((np.shape(writeResult)[0] - searchWindow) / self.windowStep) + 1 col_steps = floor((np.shape(writeResult)[1] - searchWindow) / self.windowStep) + 1 for i in range(len(patchesIDs)): chosenPatchId = patchesIDs[i] #patch Id to step patch_row = floor(chosenPatchId / col_steps) patch_col = chosenPatchId - patch_row * col_steps #highlight chosen patch (below are boundaries of the example patch) row_start = self.windowStep* patch_row row_end = self.windowStep * patch_row + searchWindow col_start = self.windowStep * patch_col col_end = self.windowStep * patch_col + searchWindow if not solid: w = highlight_width overlap = np.copy(writeResult[row_start:row_start+w, col_start:col_end]) writeResult[row_start:row_start+w, col_start:col_end] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #top overlap = np.copy(writeResult[row_end-w:row_end, col_start:col_end]) writeResult[row_end-w:row_end, col_start:col_end] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #bot overlap = np.copy( writeResult[row_start:row_end, col_start:col_start+w]) writeResult[row_start:row_end, col_start:col_start+w] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #left overlap = np.copy(writeResult[row_start:row_end, col_end-w:col_end]) writeResult[row_start:row_end, col_end-w:col_end] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #end else: a = alpha writeResult[row_start:row_end, col_start:col_end] = writeResult[row_start:row_end, col_start:col_end] * (1-a) + (np.zeros(np.shape(writeResult[row_start:row_end, col_start:col_end]))+color) * a def resize(self, imgArray, targetSize): img = Image.fromarray(np.uint8(imgArray*255)) img = img.resize((targetSize[0], targetSize[1]), resample=0, box=None) return np.array(img)/255 def findMostSimilarPatches(self, overlapArea_Top, overlapArea_Left, coord, in_k=5): #check which KD tree we need to use if (overlapArea_Top is not None) and (overlapArea_Left is not None): combined = self.getCombinedOverlap(overlapArea_Top.reshape(-1), overlapArea_Left.reshape(-1)) dist, ind = self.kdtree_combined.query([combined], k=in_k) elif overlapArea_Top is not None: dist, ind = self.kdtree_topOverlap.query([overlapArea_Top.reshape(-1)], k=in_k) elif overlapArea_Left is not None: dist, ind = self.kdtree_leftOverlap.query([overlapArea_Left.reshape(-1)], k=in_k) else: raise Exception("ERROR: no valid overlap area is passed to -findMostSimilarPatch-") dist = dist[0] ind = ind[0] return dist, ind #disallow visually similar blocks to be placed next to each other def checkForMirrors(self, dist, ind, coord, thres = 3): remove_i = [] #do I have a top or left neighbour if coord[0]-1>-1: top_neigh = int(self.idMap[coord[0]-1, coord[1]]) for i in range(len(ind)): if (abs(ind[i]%self.total_patches_count - top_neigh%self.total_patches_count) < thres): remove_i.append(i) if coord[1]-1>-1: left_neigh = int(self.idMap[coord[0], coord[1]-1]) for i in range(len(ind)): if (abs(ind[i]%self.total_patches_count - left_neigh%self.total_patches_count) < thres): remove_i.append(i) dist = np.delete(dist, remove_i) ind = np.delete(ind, remove_i) return dist, ind def distances2probability(self, distances, PARM_truncation, PARM_attenuation): probabilities = 1 - distances / np.max(distances) probabilities *= (probabilities > PARM_truncation) probabilities = pow(probabilities, PARM_attenuation) #attenuate the values #check if we didn't truncate everything! if np.sum(probabilities) == 0: #then just revert it probabilities = 1 - distances / np.max(distances) probabilities *= (probabilities > PARM_truncation*np.max(probabilities)) # truncate the values (we want top truncate%) probabilities = pow(probabilities, PARM_attenuation) probabilities /= np.sum(probabilities) #normalize so they add up to one return probabilities def getOverlapAreaTop(self, coord): #do I have a top neighbour if coord[0]-1>-1: canvasPatch = self.patchCoord2canvasPatch(coord) return canvasPatch[0:self.overlapSize, :, :] else: return None def getOverlapAreaLeft(self, coord): #do I have a left neighbour if coord[1]-1>-1: canvasPatch = self.patchCoord2canvasPatch(coord) return canvasPatch[:, 0:self.overlapSize, :] else: return None def initKDtrees(self): #prepate overlap patches topOverlap = self.examplePatches[:, 0:self.overlapSize, :, :] leftOverlap = self.examplePatches[:, :, 0:self.overlapSize, :] shape_top = np.shape(topOverlap) shape_left = np.shape(leftOverlap) flatten_top = topOverlap.reshape(shape_top[0], -1) flatten_left = leftOverlap.reshape(shape_left[0], -1) flatten_combined = self.getCombinedOverlap(flatten_top, flatten_left) tree_top = KDTree(flatten_top) tree_left = KDTree(flatten_left) tree_combined = KDTree(flatten_combined) return tree_top, tree_left, tree_combined #the corner of 2 overlaps is counted double def getCombinedOverlap(self, top, left): shape = np.shape(top) if len(shape) > 1: combined = np.zeros((shape[0], shape[1]*2)) combined[0:shape[0], 0:shape[1]] = top combined[0:shape[0], shape[1]:shape[1]*2] = left else: combined = np.zeros((shape[0]*2)) combined[0:shape[0]] = top combined[shape[0]:shape[0]*2] = left return combined def initFirstPatch(self): #grab a random block patchId = randint(0, np.shape(self.examplePatches)[0]) #mark out fill map self.filledMap[0, 0] = 1 self.idMap[0, 0] = patchId % self.total_patches_count #update canvas self.updateCanvas(patchId, 0, 0, False, False) #visualize # self.visualize([0,0], [patchId], []) def prepareExamplePatches(self): searchKernelSize = self.patchSize + 2 * self.overlapSize result = view_as_windows(self.exampleMap, [searchKernelSize, searchKernelSize, 3] , self.windowStep) shape = np.shape(result) result = result.reshape(shape[0]*shape[1], searchKernelSize, searchKernelSize, 3) self.total_patches_count = shape[0]*shape[1] if self.mirror_hor: #flip along horizonal axis hor_result = np.zeros(np.shape(result)) for i in range(self.total_patches_count): hor_result[i] = result[i][::-1, :, :] result = np.concatenate((result, hor_result)) if self.mirror_vert: vert_result = np.zeros((shape[0]*shape[1], searchKernelSize, searchKernelSize, 3)) for i in range(self.total_patches_count): vert_result[i] = result[i][:, ::-1, :] result = np.concatenate((result, vert_result)) return result def initCanvas(self): #check whether the outputSize adheres to patch+overlap size num_patches_X = ceil((self.outputSize[0]-self.overlapSize)/(self.patchSize+self.overlapSize)) num_patches_Y = ceil((self.outputSize[1]-self.overlapSize)/(self.patchSize+self.overlapSize)) #calc needed output image size required_size_X = num_patches_X*self.patchSize + (num_patches_X+1)*self.overlapSize required_size_Y = num_patches_Y*self.patchSize + (num_patches_X+1)*self.overlapSize #create empty canvas canvas = np.zeros((required_size_X, required_size_Y, 3)) filledMap = np.zeros((num_patches_X, num_patches_Y)) #map showing which patches have been resolved idMap = np.zeros((num_patches_X, num_patches_Y)) - 1 #stores patches id print("modified output size: ", np.shape(canvas)) print("number of patches: ", np.shape(filledMap)[0]) return canvas, filledMap, idMap def idCoordTo2DCoord(self, idCoord, imgSize): row = int(floor(idCoord / imgSize[0])) col = int(idCoord - row * imgSize[1]) return [row, col] def updateCanvas(self, inputPatchId, coord_X, coord_Y, blendTop = False, blendLeft = False): #translate Patch coordinate into Canvas coordinate x_range = self.patchCoord2canvasCoord(coord_X) y_range = self.patchCoord2canvasCoord(coord_Y) examplePatch = self.examplePatches[inputPatchId] if blendLeft: canvasOverlap = self.canvas[x_range[0]:x_range[1], y_range[0]:y_range[0]+self.overlapSize] examplePatchOverlap = np.copy(examplePatch[0][:, 0:self.overlapSize]) examplePatch[0][:, 0:self.overlapSize] = self.linearBlendOverlaps(canvasOverlap, examplePatchOverlap, 'left') if blendTop: canvasOverlap = self.canvas[x_range[0]:x_range[0]+self.overlapSize, y_range[0]:y_range[1]] examplePatchOverlap = np.copy(examplePatch[0][0:self.overlapSize, :]) examplePatch[0][0:self.overlapSize, :] = self.linearBlendOverlaps(canvasOverlap, examplePatchOverlap, 'top') self.canvas[x_range[0]:x_range[1], y_range[0]:y_range[1]] = examplePatch def linearBlendOverlaps(self, canvasOverlap, examplePatchOverlap, mode): if mode == 'left': mask = np.repeat(np.arange(self.overlapSize)[np.newaxis, :], np.shape(canvasOverlap)[0], axis=0) / self.overlapSize elif mode == 'top': mask = np.repeat(np.arange(self.overlapSize)[:, np.newaxis], np.shape(canvasOverlap)[1], axis=1) / self.overlapSize mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) #cast to 3d array return canvasOverlap * (1 - mask) + examplePatchOverlap * mask #def minimumBoundaryError(self, canvasOverlap, examplePatchOverlap, mode) def patchCoord2canvasCoord(self, coord): return [(self.patchSize+self.overlapSize)*coord, (self.patchSize+self.overlapSize)*(coord+1) + self.overlapSize] def patchCoord2canvasPatch(self, coord): x_range = self.patchCoord2canvasCoord(coord[0]) y_range = self.patchCoord2canvasCoord(coord[1]) return np.copy(self.canvas[x_range[0]:x_range[1], y_range[0]:y_range[1]]) def loadExampleMap(self, exampleMapPath): exampleMap = io.imread(exampleMapPath) #returns an MxNx3 array exampleMap = exampleMap / 255.0 #normalize #make sure it is 3channel RGB if (np.shape(exampleMap)[-1] > 3): exampleMap = exampleMap[:,:,:3] #remove Alpha Channel elif (len(np.shape(exampleMap)) == 2): exampleMap = np.repeat(exampleMap[np.newaxis, :, :], 3, axis=0) #convert from Grayscale to RGB return exampleMap