Spaces:
Running
on
L40S
Running
on
L40S
import math | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from .ops import (EqualConv2d, EqualLinear, ConvLayer) | |
class ResBlock(nn.Module): | |
def __init__(self, in_channel, out_channel): | |
super().__init__() | |
self.conv1 = ConvLayer(in_channel, out_channel, 3) | |
self.conv2 = ConvLayer(out_channel, out_channel, 3, downsample=True) | |
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) | |
def forward(self, x): | |
h = x | |
h = self.conv1(h) | |
h = self.conv2(h) | |
skip = self.skip(x) | |
h = (h + skip) / math.sqrt(2) | |
return h | |
class Encoder2R(nn.Module): | |
def __init__(self, latent_dim=512, scale=1): | |
super(Encoder2R, self).__init__() | |
channels = [64*scale, 128*scale, 256*scale, 512*scale] | |
# version1 | |
self.block1 = ConvLayer(3, channels[0], 1) # 256, 3 -> 64 | |
self.block2 = nn.Sequential( | |
ResBlock(channels[0], channels[1]) | |
) # 64 -> 128 | |
self.block3 = nn.Sequential( | |
ResBlock(channels[1], channels[2]) | |
) # 128 -> 256 | |
self.block4 = nn.Sequential( | |
ResBlock(channels[2], channels[3]) | |
) # 256 -> 512 | |
self.block5 = nn.Sequential( | |
ResBlock(channels[3], channels[3]) | |
) # 512 -> 512 | |
self.block6 = nn.Sequential( | |
ResBlock(channels[3], channels[3]) | |
) # 512 -> 512 | |
self.block7 = nn.Sequential( | |
ResBlock(channels[3], channels[3]) | |
) # 512 -> 512 | |
self.block_512 = ResBlock(channels[3], channels[3]) | |
self.block8 = EqualConv2d(channels[3], latent_dim, 4, padding=0, bias=False) | |
def forward(self, x): | |
res = [] | |
h = x | |
h = self.block1(h) # 256 | |
res.append(h) | |
h = self.block2(h) # 128 | |
res.append(h) | |
h = self.block3(h) # 64 | |
res.append(h) | |
h = self.block4(h) # 32 | |
res.append(h) | |
h = self.block5(h) # 16 | |
res.append(h) | |
h = self.block6(h) # 8 | |
res.append(h) | |
h = self.block7(h) # 4 | |
res.append(h) | |
h = self.block_512(h) | |
h = self.block8(h) # 1 | |
return h.squeeze(-1).squeeze(-1), res[::-1] | |
class Encoder(nn.Module): | |
def __init__(self, dim=512, dim_motion=20, scale=1): | |
super(Encoder, self).__init__() | |
# 2R netmork | |
self.enc_2r = Encoder2R(dim, scale) | |
# R2T | |
self.enc_r2t = nn.Sequential( | |
EqualLinear(dim, dim_motion) | |
) | |
def enc_motion(self, x): | |
z_t2r, _ = self.enc_2r(x) | |
alpha_r2t = self.enc_r2t(z_t2r) | |
return alpha_r2t | |
def enc_transfer_img(self, z_s2r, d_l, s_l): | |
alpha_r2s = self.enc_r2t(z_s2r) | |
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(s_l).unsqueeze(0).to('cuda') | |
alpha = [alpha_r2s] | |
return alpha | |
def enc_transfer_vid(self, alpha_r2s, input_target, alpha_start): | |
z_t2r, _ = self.enc_2r(input_target) | |
alpha_r2t = self.enc_r2t(z_t2r) | |
alpha = [alpha_r2t, alpha_r2s, alpha_start] | |
return alpha | |
def forward(self, input_source, input_target, alpha_start=None): | |
if input_target is not None: | |
z_s2r, feats = self.enc_2r(input_source) | |
z_t2r, _ = self.enc_2r(input_target) | |
alpha_r2t = self.enc_r2t(z_t2r) | |
if alpha_start is not None: | |
alpha_r2s = self.enc_r2t(z_s2r) | |
alpha = [alpha_r2t, alpha_r2s, alpha_start] | |
else: | |
alpha = [alpha_r2t] | |
return z_s2r, alpha, feats | |
else: | |
z_s2r, feats = self.enc_2r(input_source) | |
return z_s2r, None, feats | |