import os from PIL import Image import torchvision.transforms as transforms try: from torchvision.transforms import InterpolationMode bic = InterpolationMode.BICUBIC except ImportError: bic = Image.BICUBIC import numpy as np import torch import torch.nn as nn import functools IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".webp"] class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__( self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, ): """Construct a Unet generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True, ) # add the innermost layer for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock( ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout, ) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock( ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) unet_block = UnetSkipConnectionBlock( ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer ) self.model = UnetSkipConnectionBlock( output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer, ) # add the outermost layer def forward(self, input): """Standard forward""" return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. X -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__( self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False, ): """Construct a Unet submodule with skip connections. Parameters: outer_nc (int) -- the number of filters in the outer conv layer inner_nc (int) -- the number of filters in the inner conv layer input_nc (int) -- the number of channels in input images/features submodule (UnetSkipConnectionBlock) -- previously defined submodules outermost (bool) -- if this module is the outermost module innermost (bool) -- if this module is the innermost module norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d( input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d( inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 ) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d( inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias ) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d( inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias, ) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) class Anime2Sketch: def __init__( self, model_path: str = "./models/netG.pth", device: str = "cpu" ) -> None: norm_layer = functools.partial( nn.InstanceNorm2d, affine=False, track_running_stats=False ) net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) ckpt = torch.load(model_path) for key in list(ckpt.keys()): if "module." in key: ckpt[key.replace("module.", "")] = ckpt[key].half() del ckpt[key] net.load_state_dict(ckpt) self.model = net if torch.cuda.is_available() and device == "cuda": self.device = "cuda" self.model.to(device) else: self.device = "cpu" self.model.to("cpu") def predict(self, image: Image.Image, load_size: int = 512) -> Image: try: aus_resize = None if load_size > 0: aus_resize = image.size transform = self.get_transform(load_size=load_size) image = transform(image) img = image.unsqueeze(0) except: raise Exception("Error in reading image {}".format(image.filename)) aus_tensor = self.model(img.to(self.device)) aus_img = self.tensor_to_img(aus_tensor) image_pil = Image.fromarray(aus_img) if aus_resize: bic = Image.BICUBIC image_pil = image_pil.resize(aus_resize, bic) return image_pil def get_transform(self, load_size=0, grayscale=False, method=bic, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if load_size > 0: osize = [load_size, load_size] transform_list.append(transforms.Resize(osize, method)) if convert: transform_list += [transforms.ToTensor()] if grayscale: transform_list += [transforms.Normalize((0.5,), (0.5,))] else: transform_list += [ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] return transforms.Compose(transform_list) def tensor_to_img(self, input_image, imtype=np.uint8): """ "Converts a Tensor array into a numpy image array. Parameters: input_image (tensor) -- the input image tensor array imtype (type) -- the desired type of the converted numpy array """ if not isinstance(input_image, np.ndarray): if isinstance(input_image, torch.Tensor): # get the data from a variable image_tensor = input_image.data else: return input_image image_numpy = ( image_tensor[0].cpu().float().numpy() ) # convert it into a numpy array if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = ( (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 ) # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype)