Spaces:
Build error
Build error
| """ | |
| This file defines the core research contribution | |
| """ | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import math | |
| import torch | |
| from torch import nn | |
| from models.encoders import psp_encoders | |
| from models.stylegan2.model import Generator | |
| from configs.paths_config import model_paths | |
| import torch.nn.functional as F | |
| def get_keys(d, name): | |
| if 'state_dict' in d: | |
| d = d['state_dict'] | |
| d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} | |
| return d_filt | |
| class pSp(nn.Module): | |
| def __init__(self, opts, ckpt=None): | |
| super(pSp, self).__init__() | |
| self.set_opts(opts) | |
| # compute number of style inputs based on the output resolution | |
| self.opts.n_styles = int(math.log(self.opts.output_size, 2)) * 2 - 2 | |
| # Define architecture | |
| self.encoder = self.set_encoder() | |
| self.decoder = Generator(self.opts.output_size, 512, 8) | |
| self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
| # Load weights if needed | |
| self.load_weights(ckpt) | |
| def set_encoder(self): | |
| if self.opts.encoder_type == 'GradualStyleEncoder': | |
| encoder = psp_encoders.GradualStyleEncoder(50, 'ir_se', self.opts) | |
| elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoW': | |
| encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoW(50, 'ir_se', self.opts) | |
| elif self.opts.encoder_type == 'BackboneEncoderUsingLastLayerIntoWPlus': | |
| encoder = psp_encoders.BackboneEncoderUsingLastLayerIntoWPlus(50, 'ir_se', self.opts) | |
| else: | |
| raise Exception('{} is not a valid encoders'.format(self.opts.encoder_type)) | |
| return encoder | |
| def load_weights(self, ckpt=None): | |
| if self.opts.checkpoint_path is not None: | |
| print('Loading pSp from checkpoint: {}'.format(self.opts.checkpoint_path)) | |
| if ckpt is None: | |
| ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') | |
| self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=False) | |
| self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=False) | |
| self.__load_latent_avg(ckpt) | |
| else: | |
| print('Loading encoders weights from irse50!') | |
| encoder_ckpt = torch.load(model_paths['ir_se50']) | |
| # if input to encoder is not an RGB image, do not load the input layer weights | |
| if self.opts.label_nc != 0: | |
| encoder_ckpt = {k: v for k, v in encoder_ckpt.items() if "input_layer" not in k} | |
| self.encoder.load_state_dict(encoder_ckpt, strict=False) | |
| print('Loading decoder weights from pretrained!') | |
| ckpt = torch.load(self.opts.stylegan_weights) | |
| self.decoder.load_state_dict(ckpt['g_ema'], strict=False) | |
| if self.opts.learn_in_w: | |
| self.__load_latent_avg(ckpt, repeat=1) | |
| else: | |
| self.__load_latent_avg(ckpt, repeat=self.opts.n_styles) | |
| # for video toonification, we load G0' model | |
| if self.opts.toonify_weights is not None: ##### modified | |
| ckpt = torch.load(self.opts.toonify_weights) | |
| self.decoder.load_state_dict(ckpt['g_ema'], strict=False) | |
| self.opts.toonify_weights = None | |
| # x1: image for first-layer feature f. | |
| # x2: image for style latent code w+. If not specified, x2=x1. | |
| # inject_latent: for sketch/mask-to-face translation, another latent code to fuse with w+ | |
| # latent_mask: fuse w+ and inject_latent with the mask (1~7 use w+ and 8~18 use inject_latent) | |
| # use_feature: use f. Otherwise, use the orginal StyleGAN first-layer constant 4*4 feature | |
| # first_layer_feature_ind: always=0, means the 1st layer of G accept f | |
| # use_skip: use skip connection. | |
| # zero_noise: use zero noises. | |
| # editing_w: the editing vector v for video face editing | |
| def forward(self, x1, x2=None, resize=True, latent_mask=None, randomize_noise=True, | |
| inject_latent=None, return_latents=False, alpha=None, use_feature=True, | |
| first_layer_feature_ind=0, use_skip=False, zero_noise=False, editing_w=None): ##### modified | |
| feats = None # f and the skipped encoder features | |
| codes, feats = self.encoder(x1, return_feat=True, return_full=use_skip) ##### modified | |
| if x2 is not None: ##### modified | |
| codes = self.encoder(x2) ##### modified | |
| # normalize with respect to the center of an average face | |
| if self.opts.start_from_latent_avg: | |
| if self.opts.learn_in_w: | |
| codes = codes + self.latent_avg.repeat(codes.shape[0], 1) | |
| else: | |
| codes = codes + self.latent_avg.repeat(codes.shape[0], 1, 1) | |
| # E_W^{1:7}(T(x1)) concatenate E_W^{8:18}(w~) | |
| if latent_mask is not None: | |
| for i in latent_mask: | |
| if inject_latent is not None: | |
| if alpha is not None: | |
| codes[:, i] = alpha * inject_latent[:, i] + (1 - alpha) * codes[:, i] | |
| else: | |
| codes[:, i] = inject_latent[:, i] | |
| else: | |
| codes[:, i] = 0 | |
| first_layer_feats, skip_layer_feats, fusion = None, None, None ##### modified | |
| if use_feature: ##### modified | |
| first_layer_feats = feats[0:2] # use f | |
| if use_skip: ##### modified | |
| skip_layer_feats = feats[2:] # use skipped encoder feature | |
| fusion = self.encoder.fusion # use fusion layer to fuse encoder feature and decoder feature. | |
| images, result_latent = self.decoder([codes], | |
| input_is_latent=True, | |
| randomize_noise=randomize_noise, | |
| return_latents=return_latents, | |
| first_layer_feature=first_layer_feats, | |
| first_layer_feature_ind=first_layer_feature_ind, | |
| skip_layer_feature=skip_layer_feats, | |
| fusion_block=fusion, | |
| zero_noise=zero_noise, | |
| editing_w=editing_w) ##### modified | |
| if resize: | |
| if self.opts.output_size == 1024: ##### modified | |
| images = F.adaptive_avg_pool2d(images, (images.shape[2]//4, images.shape[3]//4)) ##### modified | |
| else: | |
| images = self.face_pool(images) | |
| if return_latents: | |
| return images, result_latent | |
| else: | |
| return images | |
| def set_opts(self, opts): | |
| self.opts = opts | |
| def __load_latent_avg(self, ckpt, repeat=None): | |
| if 'latent_avg' in ckpt: | |
| self.latent_avg = ckpt['latent_avg'].to(self.opts.device) | |
| if repeat is not None: | |
| self.latent_avg = self.latent_avg.repeat(repeat, 1) | |
| else: | |
| self.latent_avg = None | |