Spaces:
Sleeping
Sleeping
#imports | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
from math import floor, ceil | |
from random import randint | |
from sklearn.neighbors import KDTree | |
from skimage.util.shape import view_as_windows | |
from skimage import io | |
from PIL import Image, ImageDraw | |
from IPython.display import clear_output | |
class patchBasedTextureSynthesis: | |
def __init__(self, exampleMapPath, in_outputPath, in_outputSize, in_patchSize, in_overlapSize, in_windowStep = 5, in_mirror_hor = True, in_mirror_vert = True, in_shapshots = True): | |
self.exampleMap = self.loadExampleMap(exampleMapPath) | |
self.snapshots = in_shapshots | |
self.outputPath = in_outputPath | |
self.outputSize = in_outputSize | |
self.patchSize = in_patchSize | |
self.overlapSize = in_overlapSize | |
self.mirror_hor = in_mirror_hor | |
self.mirror_vert = in_mirror_vert | |
self.total_patches_count = 0 #excluding mirrored versions | |
self.windowStep = 5 | |
self.iter = 0 | |
self.checkIfDirectoryExists() #check if output directory exists | |
self.examplePatches = self.prepareExamplePatches() | |
self.canvas, self.filledMap, self.idMap = self.initCanvas() | |
self.initFirstPatch() #place random block to start with | |
self.kdtree_topOverlap, self.kdtree_leftOverlap, self.kdtree_combined = self.initKDtrees() | |
self.PARM_truncation = 0.8 | |
self.PARM_attenuation = 2 | |
def checkIfDirectoryExists(self): | |
if not os.path.exists(self.outputPath): | |
os.makedirs(self.outputPath) | |
def resolveAll(self): | |
self.saveParams() | |
#resolve all unresolved patches | |
for i in range(np.sum(1-self.filledMap).astype(int)): | |
self.resolveNext() | |
if not self.snapshots: | |
img = Image.fromarray(np.uint8(self.canvas*255)) | |
img = img.resize((self.outputSize[0], self.outputSize[1]), resample=0, box=None) | |
img.save(self.outputPath + 'out.jpg') | |
# else: | |
# self.visualize([0,0], [], [], showCandidates=False) | |
return img | |
def saveParams(self): | |
#write | |
text_file = open(self.outputPath + 'params.txt', "w") | |
text_file.write("PatchSize: %d \nOverlapSize: %d \nMirror Vert: %d \nMirror Hor: %d" % (self.patchSize, self.overlapSize, self.mirror_vert, self.mirror_hor)) | |
text_file.close() | |
def resolveNext(self): | |
#coordinate of the next one to resolve | |
coord = self.idCoordTo2DCoord(np.sum(self.filledMap), np.shape(self.filledMap)) #get 2D coordinate of next to resolve patch | |
#get overlap areas of the patch we want to resolve | |
overlapArea_Top = self.getOverlapAreaTop(coord) | |
overlapArea_Left = self.getOverlapAreaLeft(coord) | |
#find most similar patch from the examples | |
dist, ind = self.findMostSimilarPatches(overlapArea_Top, overlapArea_Left, coord) | |
if self.mirror_hor or self.mirror_vert: | |
#check that top and left neighbours are not mirrors | |
dist, ind = self.checkForMirrors(dist, ind, coord) | |
#choose random valid patch | |
probabilities = self.distances2probability(dist, self.PARM_truncation, self.PARM_attenuation) | |
chosenPatchId = np.random.choice(ind, 1, p=probabilities) | |
#update canvas | |
blend_top = (overlapArea_Top is not None) | |
blend_left = (overlapArea_Left is not None) | |
self.updateCanvas(chosenPatchId, coord[0], coord[1], blend_top, blend_left) | |
#update filledMap and id map ;) | |
self.filledMap[coord[0], coord[1]] = 1 | |
self.idMap[coord[0], coord[1]] = chosenPatchId | |
#visualize | |
# self.visualize(coord, chosenPatchId, ind) | |
self.iter += 1 | |
def visualize(self, coord, chosenPatchId, nonchosenPatchId, showCandidates = True): | |
#full visualization includes both example and generated img | |
canvasSize = np.shape(self.canvas) | |
#insert generated image | |
vis = np.zeros((canvasSize[0], canvasSize[1] * 2, 3)) + 0.2 | |
vis[:, 0:canvasSize[1]] = self.canvas | |
#insert example | |
exampleHighlited = np.copy(self.exampleMap) | |
if showCandidates: | |
exampleHighlited = self.hightlightPatchCandidates(chosenPatchId, nonchosenPatchId) | |
h = floor(canvasSize[0] / 2) | |
w = floor(canvasSize[1] / 2) | |
exampleResized = self.resize(exampleHighlited, [h, w]) | |
offset_h = floor((canvasSize[0] - h) / 2) | |
offset_w = floor((canvasSize[1] - w) / 2) | |
vis[offset_h:offset_h+h, canvasSize[1]+offset_w:canvasSize[1]+offset_w+w] = exampleResized | |
#show live update | |
plt.imshow(vis) | |
clear_output(wait=True) | |
display(plt.show()) | |
if self.snapshots: | |
img = Image.fromarray(np.uint8(vis*255)) | |
img = img.resize((self.outputSize[0]*2, self.outputSize[1]), resample=0, box=None) | |
img.save(self.outputPath + 'out' + str(self.iter) + '.jpg') | |
def hightlightPatchCandidates(self, chosenPatchId, nonchosenPatchId): | |
result = np.copy(self.exampleMap) | |
#mod patch ID | |
chosenPatchId = chosenPatchId[0] % self.total_patches_count | |
if len(nonchosenPatchId)>0: | |
nonchosenPatchId = nonchosenPatchId % self.total_patches_count | |
#exlcude chosen from nonchosen | |
nonchosenPatchId = np.delete(nonchosenPatchId, np.where(nonchosenPatchId == chosenPatchId)) | |
#highlight non chosen candidates | |
c = [0.25, 0.9 ,0.45] | |
self.highlightPatches(result, nonchosenPatchId, color=c, highlight_width = 4, alpha = 0.5) | |
#hightlight chosen | |
c = [1.0, 0.25, 0.15] | |
self.highlightPatches(result, [chosenPatchId], color=c, highlight_width = 4, alpha = 1) | |
return result | |
def highlightPatches(self, writeResult, patchesIDs, color, highlight_width = 2, solid = False, alpha = 0.1): | |
searchWindow = self.patchSize + 2*self.overlapSize | |
#number of possible steps | |
row_steps = floor((np.shape(writeResult)[0] - searchWindow) / self.windowStep) + 1 | |
col_steps = floor((np.shape(writeResult)[1] - searchWindow) / self.windowStep) + 1 | |
for i in range(len(patchesIDs)): | |
chosenPatchId = patchesIDs[i] | |
#patch Id to step | |
patch_row = floor(chosenPatchId / col_steps) | |
patch_col = chosenPatchId - patch_row * col_steps | |
#highlight chosen patch (below are boundaries of the example patch) | |
row_start = self.windowStep* patch_row | |
row_end = self.windowStep * patch_row + searchWindow | |
col_start = self.windowStep * patch_col | |
col_end = self.windowStep * patch_col + searchWindow | |
if not solid: | |
w = highlight_width | |
overlap = np.copy(writeResult[row_start:row_start+w, col_start:col_end]) | |
writeResult[row_start:row_start+w, col_start:col_end] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #top | |
overlap = np.copy(writeResult[row_end-w:row_end, col_start:col_end]) | |
writeResult[row_end-w:row_end, col_start:col_end] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #bot | |
overlap = np.copy( writeResult[row_start:row_end, col_start:col_start+w]) | |
writeResult[row_start:row_end, col_start:col_start+w] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #left | |
overlap = np.copy(writeResult[row_start:row_end, col_end-w:col_end]) | |
writeResult[row_start:row_end, col_end-w:col_end] = overlap * (1-alpha) + (np.zeros(np.shape(overlap))+color) * alpha #end | |
else: | |
a = alpha | |
writeResult[row_start:row_end, col_start:col_end] = writeResult[row_start:row_end, col_start:col_end] * (1-a) + (np.zeros(np.shape(writeResult[row_start:row_end, col_start:col_end]))+color) * a | |
def resize(self, imgArray, targetSize): | |
img = Image.fromarray(np.uint8(imgArray*255)) | |
img = img.resize((targetSize[0], targetSize[1]), resample=0, box=None) | |
return np.array(img)/255 | |
def findMostSimilarPatches(self, overlapArea_Top, overlapArea_Left, coord, in_k=5): | |
#check which KD tree we need to use | |
if (overlapArea_Top is not None) and (overlapArea_Left is not None): | |
combined = self.getCombinedOverlap(overlapArea_Top.reshape(-1), overlapArea_Left.reshape(-1)) | |
dist, ind = self.kdtree_combined.query([combined], k=in_k) | |
elif overlapArea_Top is not None: | |
dist, ind = self.kdtree_topOverlap.query([overlapArea_Top.reshape(-1)], k=in_k) | |
elif overlapArea_Left is not None: | |
dist, ind = self.kdtree_leftOverlap.query([overlapArea_Left.reshape(-1)], k=in_k) | |
else: | |
raise Exception("ERROR: no valid overlap area is passed to -findMostSimilarPatch-") | |
dist = dist[0] | |
ind = ind[0] | |
return dist, ind | |
#disallow visually similar blocks to be placed next to each other | |
def checkForMirrors(self, dist, ind, coord, thres = 3): | |
remove_i = [] | |
#do I have a top or left neighbour | |
if coord[0]-1>-1: | |
top_neigh = int(self.idMap[coord[0]-1, coord[1]]) | |
for i in range(len(ind)): | |
if (abs(ind[i]%self.total_patches_count - top_neigh%self.total_patches_count) < thres): | |
remove_i.append(i) | |
if coord[1]-1>-1: | |
left_neigh = int(self.idMap[coord[0], coord[1]-1]) | |
for i in range(len(ind)): | |
if (abs(ind[i]%self.total_patches_count - left_neigh%self.total_patches_count) < thres): | |
remove_i.append(i) | |
dist = np.delete(dist, remove_i) | |
ind = np.delete(ind, remove_i) | |
return dist, ind | |
def distances2probability(self, distances, PARM_truncation, PARM_attenuation): | |
probabilities = 1 - distances / np.max(distances) | |
probabilities *= (probabilities > PARM_truncation) | |
probabilities = pow(probabilities, PARM_attenuation) #attenuate the values | |
#check if we didn't truncate everything! | |
if np.sum(probabilities) == 0: | |
#then just revert it | |
probabilities = 1 - distances / np.max(distances) | |
probabilities *= (probabilities > PARM_truncation*np.max(probabilities)) # truncate the values (we want top truncate%) | |
probabilities = pow(probabilities, PARM_attenuation) | |
probabilities /= np.sum(probabilities) #normalize so they add up to one | |
return probabilities | |
def getOverlapAreaTop(self, coord): | |
#do I have a top neighbour | |
if coord[0]-1>-1: | |
canvasPatch = self.patchCoord2canvasPatch(coord) | |
return canvasPatch[0:self.overlapSize, :, :] | |
else: | |
return None | |
def getOverlapAreaLeft(self, coord): | |
#do I have a left neighbour | |
if coord[1]-1>-1: | |
canvasPatch = self.patchCoord2canvasPatch(coord) | |
return canvasPatch[:, 0:self.overlapSize, :] | |
else: | |
return None | |
def initKDtrees(self): | |
#prepate overlap patches | |
topOverlap = self.examplePatches[:, 0:self.overlapSize, :, :] | |
leftOverlap = self.examplePatches[:, :, 0:self.overlapSize, :] | |
shape_top = np.shape(topOverlap) | |
shape_left = np.shape(leftOverlap) | |
flatten_top = topOverlap.reshape(shape_top[0], -1) | |
flatten_left = leftOverlap.reshape(shape_left[0], -1) | |
flatten_combined = self.getCombinedOverlap(flatten_top, flatten_left) | |
tree_top = KDTree(flatten_top) | |
tree_left = KDTree(flatten_left) | |
tree_combined = KDTree(flatten_combined) | |
return tree_top, tree_left, tree_combined | |
#the corner of 2 overlaps is counted double | |
def getCombinedOverlap(self, top, left): | |
shape = np.shape(top) | |
if len(shape) > 1: | |
combined = np.zeros((shape[0], shape[1]*2)) | |
combined[0:shape[0], 0:shape[1]] = top | |
combined[0:shape[0], shape[1]:shape[1]*2] = left | |
else: | |
combined = np.zeros((shape[0]*2)) | |
combined[0:shape[0]] = top | |
combined[shape[0]:shape[0]*2] = left | |
return combined | |
def initFirstPatch(self): | |
#grab a random block | |
patchId = randint(0, np.shape(self.examplePatches)[0]) | |
#mark out fill map | |
self.filledMap[0, 0] = 1 | |
self.idMap[0, 0] = patchId % self.total_patches_count | |
#update canvas | |
self.updateCanvas(patchId, 0, 0, False, False) | |
#visualize | |
# self.visualize([0,0], [patchId], []) | |
def prepareExamplePatches(self): | |
searchKernelSize = self.patchSize + 2 * self.overlapSize | |
result = view_as_windows(self.exampleMap, [searchKernelSize, searchKernelSize, 3] , self.windowStep) | |
shape = np.shape(result) | |
result = result.reshape(shape[0]*shape[1], searchKernelSize, searchKernelSize, 3) | |
self.total_patches_count = shape[0]*shape[1] | |
if self.mirror_hor: | |
#flip along horizonal axis | |
hor_result = np.zeros(np.shape(result)) | |
for i in range(self.total_patches_count): | |
hor_result[i] = result[i][::-1, :, :] | |
result = np.concatenate((result, hor_result)) | |
if self.mirror_vert: | |
vert_result = np.zeros((shape[0]*shape[1], searchKernelSize, searchKernelSize, 3)) | |
for i in range(self.total_patches_count): | |
vert_result[i] = result[i][:, ::-1, :] | |
result = np.concatenate((result, vert_result)) | |
return result | |
def initCanvas(self): | |
#check whether the outputSize adheres to patch+overlap size | |
num_patches_X = ceil((self.outputSize[0]-self.overlapSize)/(self.patchSize+self.overlapSize)) | |
num_patches_Y = ceil((self.outputSize[1]-self.overlapSize)/(self.patchSize+self.overlapSize)) | |
#calc needed output image size | |
required_size_X = num_patches_X*self.patchSize + (num_patches_X+1)*self.overlapSize | |
required_size_Y = num_patches_Y*self.patchSize + (num_patches_X+1)*self.overlapSize | |
#create empty canvas | |
canvas = np.zeros((required_size_X, required_size_Y, 3)) | |
filledMap = np.zeros((num_patches_X, num_patches_Y)) #map showing which patches have been resolved | |
idMap = np.zeros((num_patches_X, num_patches_Y)) - 1 #stores patches id | |
print("modified output size: ", np.shape(canvas)) | |
print("number of patches: ", np.shape(filledMap)[0]) | |
return canvas, filledMap, idMap | |
def idCoordTo2DCoord(self, idCoord, imgSize): | |
row = int(floor(idCoord / imgSize[0])) | |
col = int(idCoord - row * imgSize[1]) | |
return [row, col] | |
def updateCanvas(self, inputPatchId, coord_X, coord_Y, blendTop = False, blendLeft = False): | |
#translate Patch coordinate into Canvas coordinate | |
x_range = self.patchCoord2canvasCoord(coord_X) | |
y_range = self.patchCoord2canvasCoord(coord_Y) | |
examplePatch = self.examplePatches[inputPatchId] | |
if blendLeft: | |
canvasOverlap = self.canvas[x_range[0]:x_range[1], y_range[0]:y_range[0]+self.overlapSize] | |
examplePatchOverlap = np.copy(examplePatch[0][:, 0:self.overlapSize]) | |
examplePatch[0][:, 0:self.overlapSize] = self.linearBlendOverlaps(canvasOverlap, examplePatchOverlap, 'left') | |
if blendTop: | |
canvasOverlap = self.canvas[x_range[0]:x_range[0]+self.overlapSize, y_range[0]:y_range[1]] | |
examplePatchOverlap = np.copy(examplePatch[0][0:self.overlapSize, :]) | |
examplePatch[0][0:self.overlapSize, :] = self.linearBlendOverlaps(canvasOverlap, examplePatchOverlap, 'top') | |
self.canvas[x_range[0]:x_range[1], y_range[0]:y_range[1]] = examplePatch | |
def linearBlendOverlaps(self, canvasOverlap, examplePatchOverlap, mode): | |
if mode == 'left': | |
mask = np.repeat(np.arange(self.overlapSize)[np.newaxis, :], np.shape(canvasOverlap)[0], axis=0) / self.overlapSize | |
elif mode == 'top': | |
mask = np.repeat(np.arange(self.overlapSize)[:, np.newaxis], np.shape(canvasOverlap)[1], axis=1) / self.overlapSize | |
mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) #cast to 3d array | |
return canvasOverlap * (1 - mask) + examplePatchOverlap * mask | |
#def minimumBoundaryError(self, canvasOverlap, examplePatchOverlap, mode) | |
def patchCoord2canvasCoord(self, coord): | |
return [(self.patchSize+self.overlapSize)*coord, (self.patchSize+self.overlapSize)*(coord+1) + self.overlapSize] | |
def patchCoord2canvasPatch(self, coord): | |
x_range = self.patchCoord2canvasCoord(coord[0]) | |
y_range = self.patchCoord2canvasCoord(coord[1]) | |
return np.copy(self.canvas[x_range[0]:x_range[1], y_range[0]:y_range[1]]) | |
def loadExampleMap(self, exampleMapPath): | |
exampleMap = io.imread(exampleMapPath) #returns an MxNx3 array | |
exampleMap = exampleMap / 255.0 #normalize | |
#make sure it is 3channel RGB | |
if (np.shape(exampleMap)[-1] > 3): | |
exampleMap = exampleMap[:,:,:3] #remove Alpha Channel | |
elif (len(np.shape(exampleMap)) == 2): | |
exampleMap = np.repeat(exampleMap[np.newaxis, :, :], 3, axis=0) #convert from Grayscale to RGB | |
return exampleMap | |