Spaces:
Paused
Paused
import torch | |
import torch.nn.functional as F | |
import torchvision.transforms as T | |
from torch import nn | |
from models.CtrlHair.shape_branch.config import cfg as cfg_mask | |
from models.CtrlHair.shape_branch.solver import get_hair_face_code, get_new_shape, Solver as SolverMask | |
from models.Encoders import RotateModel | |
from models.Net import Net, get_segmentation | |
from models.sean_codes.models.pix2pix_model import Pix2PixModel, SEAN_OPT, encode_sean, decode_sean | |
from utils.image_utils import DilateErosion | |
from utils.save_utils import save_vis_mask, save_gen_image, save_latents | |
class Alignment(nn.Module): | |
""" | |
Module for transferring the desired hair shape | |
""" | |
def __init__(self, opts, latent_encoder=None, net=None): | |
super().__init__() | |
self.opts = opts | |
self.latent_encoder = latent_encoder | |
if not net: | |
self.net = Net(self.opts) | |
else: | |
self.net = net | |
self.sean_model = Pix2PixModel(SEAN_OPT) | |
self.sean_model.eval() | |
solver_mask = SolverMask(cfg_mask, device=self.opts.device, local_rank=-1, training=False) | |
self.mask_generator = solver_mask.gen | |
self.mask_generator.load_state_dict(torch.load('pretrained_models/ShapeAdaptor/mask_generator.pth')) | |
self.rotate_model = RotateModel() | |
self.rotate_model.load_state_dict(torch.load(self.opts.rotate_checkpoint)['model_state_dict']) | |
self.rotate_model.to(self.opts.device).eval() | |
self.dilate_erosion = DilateErosion(dilate_erosion=self.opts.smooth, device=self.opts.device) | |
self.to_bisenet = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | |
def shape_module(self, im_name1: str, im_name2: str, name_to_embed, only_target=True, **kwargs): | |
device = self.opts.device | |
# load images | |
img1_in = name_to_embed[im_name1]['image_256'] | |
img2_in = name_to_embed[im_name2]['image_256'] | |
# load latents | |
latent_W_1 = name_to_embed[im_name1]["W"] | |
latent_W_2 = name_to_embed[im_name2]["W"] | |
# load masks | |
inp_mask1 = name_to_embed[im_name1]['mask'] | |
inp_mask2 = name_to_embed[im_name2]['mask'] | |
# Rotate stage | |
if img1_in is not img2_in: | |
rotate_to = self.rotate_model(latent_W_2[:, :6], latent_W_1[:, :6]) | |
rotate_to = torch.cat((rotate_to, latent_W_2[:, 6:]), dim=1) | |
I_rot, _ = self.net.generator([rotate_to], input_is_latent=True, return_latents=False) | |
I_rot_to_seg = ((I_rot + 1) / 2).clip(0, 1) | |
I_rot_to_seg = self.to_bisenet(I_rot_to_seg) | |
rot_mask = get_segmentation(I_rot_to_seg) | |
else: | |
I_rot = None | |
rot_mask = inp_mask2 | |
# Shape Adaptor | |
if img1_in is not img2_in: | |
face_1, hair_1 = get_hair_face_code(self.mask_generator, inp_mask1[0, 0, ...]) | |
face_2, hair_2 = get_hair_face_code(self.mask_generator, rot_mask[0, 0, ...]) | |
target_mask = get_new_shape(self.mask_generator, face_1, hair_2)[None, None] | |
else: | |
target_mask = inp_mask1 | |
# Hair mask | |
hair_mask_target = torch.where(target_mask == 13, torch.ones_like(target_mask, device=device), | |
torch.zeros_like(target_mask, device=device)) | |
if self.opts.save_all: | |
exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" | |
output_dir = self.opts.save_all_dir / exp_name | |
if I_rot is not None: | |
save_gen_image(output_dir, 'Shape', f'{im_name2}_rotate_to_{im_name1}.png', I_rot) | |
save_vis_mask(output_dir, 'Shape', f'mask_{im_name1}.png', inp_mask1) | |
save_vis_mask(output_dir, 'Shape', f'mask_{im_name2}.png', inp_mask2) | |
save_vis_mask(output_dir, 'Shape', f'mask_{im_name2}_rotate_to_{im_name1}.png', rot_mask) | |
save_vis_mask(output_dir, 'Shape', f'mask_{im_name1}_{im_name2}_target.png', target_mask) | |
if only_target: | |
return {'HM_X': hair_mask_target} | |
else: | |
hair_mask1 = torch.where(inp_mask1 == 13, torch.ones_like(inp_mask1, device=device), | |
torch.zeros_like(inp_mask1, device=device)) | |
hair_mask2 = torch.where(inp_mask2 == 13, torch.ones_like(inp_mask2, device=device), | |
torch.zeros_like(inp_mask2, device=device)) | |
return inp_mask1, hair_mask1, inp_mask2, hair_mask2, target_mask, hair_mask_target | |
def align_images(self, im_name1, im_name2, name_to_embed, **kwargs): | |
# load images | |
img1_in = name_to_embed[im_name1]['image_256'] | |
img2_in = name_to_embed[im_name2]['image_256'] | |
# load latents | |
latent_S_1, latent_F_1 = name_to_embed[im_name1]["S"], name_to_embed[im_name1]["F"] | |
latent_S_2, latent_F_2 = name_to_embed[im_name2]["S"], name_to_embed[im_name2]["F"] | |
# Shape Module | |
if img1_in is img2_in: | |
hair_mask_target = self.shape_module(im_name1, im_name2, name_to_embed, only_target=True, **kwargs)['HM_X'] | |
return {'latent_F_align': latent_F_1, 'HM_X': hair_mask_target} | |
inp_mask1, hair_mask1, inp_mask2, hair_mask2, target_mask, hair_mask_target = ( | |
self.shape_module(im_name1, im_name2, name_to_embed, only_target=False, **kwargs) | |
) | |
images = torch.cat([img1_in, img2_in], dim=0) | |
labels = torch.cat([inp_mask1, inp_mask2], dim=0) | |
# SEAN for inpaint | |
img1_code, img2_code = encode_sean(self.sean_model, images, labels) | |
gen1_sean = decode_sean(self.sean_model, img1_code.unsqueeze(0), target_mask) | |
gen2_sean = decode_sean(self.sean_model, img2_code.unsqueeze(0), target_mask) | |
# Encoding result in F from E4E | |
enc_imgs = self.latent_encoder([gen1_sean, gen2_sean]) | |
intermediate_align, latent_inter = enc_imgs["F"][0].unsqueeze(0), enc_imgs["W"][0].unsqueeze(0) | |
latent_F_out_new, latent_out = enc_imgs["F"][1].unsqueeze(0), enc_imgs["W"][1].unsqueeze(0) | |
# Alignment of F space | |
masks = [ | |
1 - (1 - hair_mask1) * (1 - hair_mask_target), | |
hair_mask_target, | |
hair_mask2 * hair_mask_target | |
] | |
masks = torch.cat(masks, dim=0) | |
# masks = T.functional.resize(masks, (1024, 1024), interpolation=T.InterpolationMode.NEAREST) | |
dilate, erosion = self.dilate_erosion.mask(masks) | |
free_mask = [ | |
dilate[0], | |
erosion[1], | |
erosion[2] | |
] | |
free_mask = torch.stack(free_mask, dim=0) | |
free_mask_down_32 = F.interpolate(free_mask.float(), size=(32, 32), mode='bicubic') | |
interpolation_low = 1 - free_mask_down_32 | |
latent_F_align = intermediate_align + interpolation_low[0] * (latent_F_1 - intermediate_align) | |
latent_F_align = latent_F_out_new + interpolation_low[1] * (latent_F_align - latent_F_out_new) | |
latent_F_align = latent_F_2 + interpolation_low[2] * (latent_F_align - latent_F_2) | |
if self.opts.save_all: | |
exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" | |
output_dir = self.opts.save_all_dir / exp_name | |
save_gen_image(output_dir, 'Align', f'{im_name1}_{im_name2}_SEAN.png', gen1_sean) | |
save_gen_image(output_dir, 'Align', f'{im_name2}_{im_name1}_SEAN.png', gen2_sean) | |
img1_e4e = self.net.generator([latent_inter], input_is_latent=True, return_latents=False, start_layer=4, | |
end_layer=8, layer_in=intermediate_align)[0] | |
img2_e4e = self.net.generator([latent_out], input_is_latent=True, return_latents=False, start_layer=4, | |
end_layer=8, layer_in=latent_F_out_new)[0] | |
save_gen_image(output_dir, 'Align', f'{im_name1}_{im_name2}_e4e.png', img1_e4e) | |
save_gen_image(output_dir, 'Align', f'{im_name2}_{im_name1}_e4e.png', img2_e4e) | |
gen_im, _ = self.net.generator([latent_S_1], input_is_latent=True, return_latents=False, start_layer=4, | |
end_layer=8, layer_in=latent_F_align) | |
save_gen_image(output_dir, 'Align', f'{im_name1}_{im_name2}_output.png', gen_im) | |
save_latents(output_dir, 'Align', f'{im_name1}_{im_name2}_F.npz', latent_F_align=latent_F_align) | |
return {'latent_F_align': latent_F_align, 'HM_X': hair_mask_target} | |