Spaces:
Paused
Paused
import torch | |
from torch import nn | |
from models.Encoders import ClipBlendingModel, PostProcessModel | |
from models.Net import Net | |
from utils.bicubic import BicubicDownSample | |
from utils.image_utils import DilateErosion | |
from utils.save_utils import save_gen_image, save_latents | |
class Blending(nn.Module): | |
""" | |
Module for transferring the desired hair color and post processing | |
""" | |
def __init__(self, opts, net=None): | |
super().__init__() | |
self.opts = opts | |
if net is None: | |
self.net = Net(self.opts) | |
else: | |
self.net = net | |
blending_checkpoint = torch.load(self.opts.blending_checkpoint) | |
self.blending_encoder = ClipBlendingModel(blending_checkpoint.get('clip', "ViT-B/32")) | |
self.blending_encoder.load_state_dict(blending_checkpoint['model_state_dict'], strict=False) | |
self.blending_encoder.to(self.opts.device).eval() | |
self.post_process = PostProcessModel().to(self.opts.device).eval() | |
self.post_process.load_state_dict(torch.load(self.opts.pp_checkpoint)['model_state_dict']) | |
self.dilate_erosion = DilateErosion(dilate_erosion=self.opts.smooth, device=self.opts.device) | |
self.downsample_256 = BicubicDownSample(factor=4) | |
def blend_images(self, align_shape, align_color, name_to_embed, **kwargs): | |
I_1 = name_to_embed['face']['image_norm_256'] | |
I_2 = name_to_embed['shape']['image_norm_256'] | |
I_3 = name_to_embed['color']['image_norm_256'] | |
mask_de = self.dilate_erosion.hair_from_mask( | |
torch.cat([name_to_embed[x]['mask'] for x in ['face', 'color']], dim=0) | |
) | |
HM_1D, _ = mask_de[0][0].unsqueeze(0), mask_de[1][0].unsqueeze(0) | |
HM_3D, HM_3E = mask_de[0][1].unsqueeze(0), mask_de[1][1].unsqueeze(0) | |
latent_S_1, latent_F_align = name_to_embed['face']['S'], align_shape['latent_F_align'] | |
HM_X = align_color['HM_X'] | |
latent_S_3 = name_to_embed['color']["S"] | |
HM_XD, _ = self.dilate_erosion.mask(HM_X) | |
target_mask = (1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD) | |
# Blending | |
if I_1 is not I_3 or I_1 is not I_2: | |
S_blend_6_18 = self.blending_encoder(latent_S_1[:, 6:], latent_S_3[:, 6:], I_1 * target_mask, I_3 * HM_3E) | |
S_blend = torch.cat((latent_S_1[:, :6], S_blend_6_18), dim=1) | |
else: | |
S_blend = latent_S_1 | |
I_blend, _ = self.net.generator([S_blend], input_is_latent=True, return_latents=False, start_layer=4, | |
end_layer=8, layer_in=latent_F_align) | |
I_blend_256 = self.downsample_256(I_blend) | |
# Post Process | |
S_final, F_final = self.post_process(I_1, I_blend_256) | |
I_final, _ = self.net.generator([S_final], input_is_latent=True, return_latents=False, | |
start_layer=5, end_layer=8, layer_in=F_final) | |
if self.opts.save_all: | |
exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" | |
output_dir = self.opts.save_all_dir / exp_name | |
save_gen_image(output_dir, 'Blending', 'blending.png', I_blend) | |
save_latents(output_dir, 'Blending', 'blending.npz', S_blend=S_blend) | |
save_gen_image(output_dir, 'Final', 'final.png', I_final) | |
save_latents(output_dir, 'Final', 'final.npz', S_final=S_final, F_final=F_final) | |
final_image = ((I_final[0] + 1) / 2).clip(0, 1) | |
return final_image | |