HairSwapModel / models /Encoders.py
miguelmuzo's picture
Upload 426 files
3de0e37 verified
import argparse
import clip
import torch
import torch.nn as nn
from torch.nn import Linear, LayerNorm, LeakyReLU, Sequential
from torchvision import transforms as T
from models.Net import FeatureEncoderMult, IBasicBlock, conv1x1
from models.stylegan2.model import PixelNorm
class ModulationModule(nn.Module):
def __init__(self, layernum, last=False, inp=512, middle=512):
super().__init__()
self.layernum = layernum
self.last = last
self.fc = Linear(512, 512)
self.norm = LayerNorm([self.layernum, 512], elementwise_affine=False)
self.gamma_function = Sequential(Linear(inp, middle), LayerNorm([middle]), LeakyReLU(), Linear(middle, 512))
self.beta_function = Sequential(Linear(inp, middle), LayerNorm([middle]), LeakyReLU(), Linear(middle, 512))
self.leakyrelu = LeakyReLU()
def forward(self, x, embedding):
x = self.fc(x)
x = self.norm(x)
gamma = self.gamma_function(embedding)
beta = self.beta_function(embedding)
out = x * (1 + gamma) + beta
if not self.last:
out = self.leakyrelu(out)
return out
class FeatureiResnet(nn.Module):
def __init__(self, blocks, inplanes=1024):
super().__init__()
self.res_blocks = {}
for n, block in enumerate(blocks, start=1):
planes, num_blocks = block
for k in range(1, num_blocks + 1):
downsample = None
if inplanes != planes:
downsample = nn.Sequential(conv1x1(inplanes, planes, 1), nn.BatchNorm2d(planes, eps=1e-05, ), )
self.res_blocks[f'res_block_{n}_{k}'] = IBasicBlock(inplanes, planes, 1, downsample, 1, 64, 1)
inplanes = planes
self.res_blocks = nn.ModuleDict(self.res_blocks)
def forward(self, x):
for module in self.res_blocks.values():
x = module(x)
return x
class RotateModel(nn.Module):
def __init__(self):
super().__init__()
self.pixelnorm = PixelNorm()
self.modulation_module_list = nn.ModuleList([ModulationModule(6, i == 4) for i in range(5)])
def forward(self, latent_from, latent_to):
dt_latent = self.pixelnorm(latent_from)
for modulation_module in self.modulation_module_list:
dt_latent = modulation_module(dt_latent, latent_to)
output = latent_from + 0.1 * dt_latent
return output
class ClipBlendingModel(nn.Module):
def __init__(self, clip_model="ViT-B/32"):
super().__init__()
self.pixelnorm = PixelNorm()
self.clip_model, _ = clip.load(clip_model, device="cuda")
self.transform = T.Compose(
[T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
self.modulation_module_list = nn.ModuleList(
[ModulationModule(12, i == 4, inp=512 * 3, middle=1024) for i in range(5)]
)
for param in self.clip_model.parameters():
param.requires_grad = False
def get_image_embed(self, image_tensor):
resized_tensor = self.face_pool(image_tensor)
renormed_tensor = self.transform(resized_tensor * 0.5 + 0.5)
return self.clip_model.encode_image(renormed_tensor)
def forward(self, latent_face, latent_color, target_face, hair_color):
embed_face = self.get_image_embed(target_face).unsqueeze(1).expand(-1, 12, -1)
embed_color = self.get_image_embed(hair_color).unsqueeze(1).expand(-1, 12, -1)
latent_in = torch.cat((latent_color, embed_face, embed_color), dim=-1)
dt_latent = self.pixelnorm(latent_face)
for modulation_module in self.modulation_module_list:
dt_latent = modulation_module(dt_latent, latent_in)
output = latent_face + 0.1 * dt_latent
return output
class PostProcessModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder_face = FeatureEncoderMult(fs_layers=[9], opts=argparse.Namespace(
**{'arcface_model_path': "pretrained_models/ArcFace/backbone_ir50.pth"}))
self.latent_avg = torch.load('pretrained_models/PostProcess/latent_avg.pt', map_location=torch.device('cuda'))
self.to_feature = FeatureiResnet([[1024, 2], [768, 2], [512, 2]])
self.to_latent_1 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)])
self.to_latent_2 = nn.ModuleList([ModulationModule(18, i == 4) for i in range(5)])
self.pixelnorm = PixelNorm()
def forward(self, source, target):
s_face, [f_face] = self.encoder_face(source)
s_hair, [f_hair] = self.encoder_face(target)
dt_latent_face = self.pixelnorm(s_face)
dt_latent_hair = self.pixelnorm(s_hair)
for mod_module in self.to_latent_1:
dt_latent_face = mod_module(dt_latent_face, s_hair)
for mod_module in self.to_latent_2:
dt_latent_hair = mod_module(dt_latent_hair, s_face)
finall_s = self.latent_avg + 0.1 * (dt_latent_face + dt_latent_hair)
cat_f = torch.cat((f_face, f_hair), dim=1)
finall_f = self.to_feature(cat_f)
return finall_s, finall_f
class ClipModel(nn.Module):
def __init__(self):
super().__init__()
self.clip_model, _ = clip.load("ViT-B/32", device="cuda")
self.transform = T.Compose(
[T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]
)
self.face_pool = torch.nn.AdaptiveAvgPool2d((224, 224))
for param in self.clip_model.parameters():
param.requires_grad = False
def forward(self, image_tensor):
if not image_tensor.is_cuda:
image_tensor = image_tensor.to("cuda")
if image_tensor.dtype == torch.uint8:
image_tensor = image_tensor / 255
resized_tensor = self.face_pool(image_tensor)
renormed_tensor = self.transform(resized_tensor)
return self.clip_model.encode_image(renormed_tensor)