ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.spectral_norm as spectral_norm
from modules.eg3ds.models.networks_stylegan2 import SynthesisBlock
from modules.eg3ds.models.superresolution import SynthesisBlockNoUp
from modules.eg3ds.models.superresolution import SuperresolutionHybrid8XDC
from modules.real3d.facev2v_warp.model import WarpBasedTorsoModelMediaPipe as torso_model_v1
from modules.real3d.facev2v_warp.model2 import WarpBasedTorsoModelMediaPipe as torso_model_v2
from utils.commons.hparams import hparams
from utils.commons.image_utils import dilate, erode
class SuperresolutionHybrid8XDC_Warp(SuperresolutionHybrid8XDC):
def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, **block_kwargs):
super().__init__(channels, img_resolution, sr_num_fp16_res, sr_antialias, **block_kwargs)
if hparams.get("torso_model_version", "v1") == 'v1':
self.torso_model = torso_model_v1('standard')
elif hparams.get("torso_model_version", "v1") == 'v2':
self.torso_model = torso_model_v2('standard')
else: raise NotImplementedError()
# self.torso_model = WarpBasedTorsoModelMediaPipe('small')
self.torso_encoder = nn.Sequential(*[
nn.Conv2d(64, 256, 1, 1, padding=0),
])
self.bg_encoder = nn.Sequential(*[
nn.Conv2d(3, 64, 3, 1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(64, 256, 3, 1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(256, 256, 3, 1, padding=1),
])
if hparams.get("weight_fuse", True):
if hparams['htbsr_head_weight_fuse_mode'] in ['v1']:
fuse_in_dim = 512
# elif hparams['htbsr_head_weight_fuse_mode'] in ['v2']:
else:
fuse_in_dim = 512
self.head_torso_alpha_predictor = nn.Sequential(*[
nn.Conv2d(3+1+3, 32, 3, 1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(32, 32, 3, 1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(32, 1, 3, 1, padding=1),
nn.Sigmoid(),
])
self.fuse_head_torso_convs = nn.Sequential(*[
nn.Conv2d(256+256, 256, 3, 1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(256, 256, 3, 1, padding=1),
])
self.head_torso_block = SynthesisBlockNoUp(256, 256, w_dim=512, resolution=256,
img_channels=3, is_last=False, use_fp16=False, conv_clamp=None, **block_kwargs)
else:
fuse_in_dim = 768
self.fuse_fg_bg_convs = nn.Sequential(*[
nn.Conv2d(fuse_in_dim, 64, 1, 1, padding=0),
nn.LeakyReLU(),
nn.Conv2d(64, 256, 3, 1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(256, 256, 3, 1, padding=1),
])
def forward(self, rgb, x, ws, ref_torso_rgb, ref_bg_rgb, weights_img, segmap, kp_s, kp_d, target_torso_mask=None, **block_kwargs):
weights_img = weights_img.detach()
ws = ws[:, -1:, :].expand([rgb.shape[0], 3, -1])
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
mode='bilinear', align_corners=False, antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
mode='bilinear', align_corners=False, antialias=self.sr_antialias)
rgb_256 = torch.nn.functional.interpolate(rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
weights_256 = torch.nn.functional.interpolate(weights_img, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
ref_torso_rgb_256 = torch.nn.functional.interpolate(ref_torso_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
ref_bg_rgb_256 = torch.nn.functional.interpolate(ref_bg_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs) # sr branch, 128x128 head img ==> 256x256 head img
if hparams.get("torso_model_version", "v1") == 'v1':
rgb_torso, facev2v_ret = self.torso_model.forward(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), cal_loss=True, target_torso_mask=target_torso_mask)
elif hparams.get("torso_model_version", "v1") == 'v2':
rgb_torso, facev2v_ret = self.torso_model.forward(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), weights_256.detach(), cal_loss=True, target_torso_mask=target_torso_mask)
x_torso = self.torso_encoder(facev2v_ret['deformed_torso_hid'])
x_bg = self.bg_encoder(ref_bg_rgb_256)
if hparams.get("weight_fuse", True):
if hparams['htbsr_head_weight_fuse_mode'] == 'v1':
rgb = rgb * weights_256 + rgb_torso * (1-weights_256) # get person img
x = x * weights_256 + x_torso * (1-weights_256) # get person img
head_occlusion = weights_256.clone()
htbsr_head_threshold = hparams['htbsr_head_threshold']
head_occlusion[head_occlusion > htbsr_head_threshold] = 1.
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1)
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6
x = self.fuse_fg_bg_convs(x)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
elif hparams['htbsr_head_weight_fuse_mode'] == 'v2':
# 用alpha-cat实现head torso的x的融合;替代了之前的直接alpha相加
head_torso_alpha = weights_256.clone()
head_torso_alpha[head_torso_alpha>weights_256] = weights_256[head_torso_alpha>weights_256]
rgb = rgb * head_torso_alpha + rgb_torso * (1-head_torso_alpha) # get person img
x = torch.cat([x * head_torso_alpha, x_torso * (1-head_torso_alpha)], dim=1)
x = self.fuse_head_torso_convs(x)
x, rgb = self.head_torso_block(x, rgb, ws, **block_kwargs)
head_occlusion = head_torso_alpha.clone()
# 鼓励weights与mask逼近后,不再需要手动修改head weights threshold到很小的值了,0.7就行
htbsr_head_threshold = hparams['htbsr_head_threshold']
head_occlusion[head_occlusion > htbsr_head_threshold] = 1.
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1)
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6
x = self.fuse_fg_bg_convs(x)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
elif hparams['htbsr_head_weight_fuse_mode'] == 'v3':
# v2:用alpha-cat实现head torso的x的融合;替代了之前的直接alpha相加
# v3: 用nn额外后处理head mask
head_torso_alpha_inp = torch.cat([rgb.clamp(-1,1)/2+0.5, weights_256, rgb_torso.clamp(-1,1)/2+0.5], dim=1)
head_torso_alpha_ = self.head_torso_alpha_predictor(head_torso_alpha_inp)
head_torso_alpha = head_torso_alpha_.clone()
head_torso_alpha[head_torso_alpha>weights_256] = weights_256[head_torso_alpha>weights_256]
rgb = rgb * head_torso_alpha + rgb_torso * (1-head_torso_alpha) # get person img
x = torch.cat([x * head_torso_alpha, x_torso * (1-head_torso_alpha)], dim=1) # run6
x = self.fuse_head_torso_convs(x)
x, rgb = self.head_torso_block(x, rgb, ws, **block_kwargs)
head_occlusion = head_torso_alpha.clone()
htbsr_head_threshold = hparams['htbsr_head_threshold']
if not self.training:
head_occlusion_ = head_occlusion[head_occlusion>0.05]
htbsr_head_threshold = max(head_occlusion_.quantile(0.05), htbsr_head_threshold) # 过滤掉比0.05大的最后5% voxels
head_occlusion[head_occlusion > htbsr_head_threshold] = 1.
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1)
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6
# Todo: 修改这里,把cat的occlusion去掉?或者把occlusion截断一下。
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6
x = self.fuse_fg_bg_convs(x)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
else:
# v4 尝试直接用cat处理head-torso的hid的融合,发现不好
# v5 try1处理x的时候也把cat里的alpha去掉了,但是try1发现导致occlusion直接变1.所以去掉
# v5 try2给torso也加了threshold让他算rgb的时候更加sharp, 会导致torso周围黑边?
raise NotImplementedError()
else:
x = torch.cat([x, x_torso, x_bg], dim=1) # run6
x = self.fuse_fg_bg_convs(x)
x, rgb = self.block1(x, None, ws, **block_kwargs)
return rgb, facev2v_ret
@torch.no_grad()
def infer_forward_stage1(self, rgb, x, ws, ref_torso_rgb, ref_bg_rgb, weights_img, segmap, kp_s, kp_d, **block_kwargs):
weights_img = weights_img.detach()
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
mode='bilinear', align_corners=False, antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
mode='bilinear', align_corners=False, antialias=self.sr_antialias)
rgb_256 = torch.nn.functional.interpolate(rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
weights_256 = torch.nn.functional.interpolate(weights_img, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
ref_torso_rgb_256 = torch.nn.functional.interpolate(ref_torso_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
ref_bg_rgb_256 = torch.nn.functional.interpolate(ref_bg_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
facev2v_ret = self.torso_model.infer_forward_stage1(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), cal_loss=True)
facev2v_ret['ref_bg_rgb_256'] = ref_bg_rgb_256
facev2v_ret['weights_256'] = weights_256
facev2v_ret['x'] = x
facev2v_ret['ws'] = ws
facev2v_ret['rgb'] = rgb
return facev2v_ret
@torch.no_grad()
def infer_forward_stage2(self, facev2v_ret, **block_kwargs):
x = facev2v_ret['x']
ws = facev2v_ret['ws']
rgb = facev2v_ret['rgb']
ref_bg_rgb_256 = facev2v_ret['ref_bg_rgb_256']
weights_256 = facev2v_ret['weights_256']
rgb_torso = self.torso_model.infer_forward_stage2(facev2v_ret)
x_torso = self.torso_encoder(facev2v_ret['deformed_torso_hid'])
x_bg = self.bg_encoder(ref_bg_rgb_256)
if hparams.get("weight_fuse", True):
rgb = rgb * weights_256 + rgb_torso * (1-weights_256) # get person img
x = x * weights_256 + x_torso * (1-weights_256) # get person img
head_occlusion = weights_256.clone()
head_occlusion[head_occlusion > 0.5] = 1.
torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias)
person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1)
rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6
x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6
x = self.fuse_fg_bg_convs(x)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
else:
x = torch.cat([x, x_torso, x_bg], dim=1) # run6
x = self.fuse_fg_bg_convs(x)
x, rgb = self.block1(x, None, ws, **block_kwargs)
return rgb, facev2v_ret
if __name__ == '__main__':
model = SuperresolutionHybrid8XDC_Warp(32,512,0, False)
model.cuda()
rgb = torch.randn([4, 3, 128, 128]).cuda()
x = torch.randn([4, 32, 128, 128]).cuda()
ws = torch.randn([4, 14, 512]).cuda()
ref_rgb = torch.randn([4, 3, 128, 128]).cuda()
ref_torso_rgb = torch.randn([4, 3, 128, 128]).cuda()
y = model(rgb, x, ws, ref_rgb, ref_torso_rgb)
print(" ")