HairSwapModel / models /Blending.py
miguelmuzo's picture
Upload 426 files
3de0e37 verified
import torch
from torch import nn
from models.Encoders import ClipBlendingModel, PostProcessModel
from models.Net import Net
from utils.bicubic import BicubicDownSample
from utils.image_utils import DilateErosion
from utils.save_utils import save_gen_image, save_latents
class Blending(nn.Module):
"""
Module for transferring the desired hair color and post processing
"""
def __init__(self, opts, net=None):
super().__init__()
self.opts = opts
if net is None:
self.net = Net(self.opts)
else:
self.net = net
blending_checkpoint = torch.load(self.opts.blending_checkpoint)
self.blending_encoder = ClipBlendingModel(blending_checkpoint.get('clip', "ViT-B/32"))
self.blending_encoder.load_state_dict(blending_checkpoint['model_state_dict'], strict=False)
self.blending_encoder.to(self.opts.device).eval()
self.post_process = PostProcessModel().to(self.opts.device).eval()
self.post_process.load_state_dict(torch.load(self.opts.pp_checkpoint)['model_state_dict'])
self.dilate_erosion = DilateErosion(dilate_erosion=self.opts.smooth, device=self.opts.device)
self.downsample_256 = BicubicDownSample(factor=4)
@torch.inference_mode()
def blend_images(self, align_shape, align_color, name_to_embed, **kwargs):
I_1 = name_to_embed['face']['image_norm_256']
I_2 = name_to_embed['shape']['image_norm_256']
I_3 = name_to_embed['color']['image_norm_256']
mask_de = self.dilate_erosion.hair_from_mask(
torch.cat([name_to_embed[x]['mask'] for x in ['face', 'color']], dim=0)
)
HM_1D, _ = mask_de[0][0].unsqueeze(0), mask_de[1][0].unsqueeze(0)
HM_3D, HM_3E = mask_de[0][1].unsqueeze(0), mask_de[1][1].unsqueeze(0)
latent_S_1, latent_F_align = name_to_embed['face']['S'], align_shape['latent_F_align']
HM_X = align_color['HM_X']
latent_S_3 = name_to_embed['color']["S"]
HM_XD, _ = self.dilate_erosion.mask(HM_X)
target_mask = (1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD)
# Blending
if I_1 is not I_3 or I_1 is not I_2:
S_blend_6_18 = self.blending_encoder(latent_S_1[:, 6:], latent_S_3[:, 6:], I_1 * target_mask, I_3 * HM_3E)
S_blend = torch.cat((latent_S_1[:, :6], S_blend_6_18), dim=1)
else:
S_blend = latent_S_1
I_blend, _ = self.net.generator([S_blend], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=latent_F_align)
I_blend_256 = self.downsample_256(I_blend)
# Post Process
S_final, F_final = self.post_process(I_1, I_blend_256)
I_final, _ = self.net.generator([S_final], input_is_latent=True, return_latents=False,
start_layer=5, end_layer=8, layer_in=F_final)
if self.opts.save_all:
exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else ""
output_dir = self.opts.save_all_dir / exp_name
save_gen_image(output_dir, 'Blending', 'blending.png', I_blend)
save_latents(output_dir, 'Blending', 'blending.npz', S_blend=S_blend)
save_gen_image(output_dir, 'Final', 'final.png', I_final)
save_latents(output_dir, 'Final', 'final.npz', S_final=S_final, F_final=F_final)
final_image = ((I_final[0] + 1) / 2).clip(0, 1)
return final_image