LIA-X-fast / networks /encoder.py
YaohuiW's picture
Upload 19 files
c42db24 verified
import math
import torch
from torch import nn
from torch.nn import functional as F
from .ops import (EqualConv2d, EqualLinear, ConvLayer)
class ResBlock(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv1 = ConvLayer(in_channel, out_channel, 3)
self.conv2 = ConvLayer(out_channel, out_channel, 3, downsample=True)
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
def forward(self, x):
h = x
h = self.conv1(h)
h = self.conv2(h)
skip = self.skip(x)
h = (h + skip) / math.sqrt(2)
return h
class Encoder2R(nn.Module):
def __init__(self, latent_dim=512, scale=1):
super(Encoder2R, self).__init__()
channels = [64*scale, 128*scale, 256*scale, 512*scale]
# version1
self.block1 = ConvLayer(3, channels[0], 1) # 256, 3 -> 64
self.block2 = nn.Sequential(
ResBlock(channels[0], channels[1])
) # 64 -> 128
self.block3 = nn.Sequential(
ResBlock(channels[1], channels[2])
) # 128 -> 256
self.block4 = nn.Sequential(
ResBlock(channels[2], channels[3])
) # 256 -> 512
self.block5 = nn.Sequential(
ResBlock(channels[3], channels[3])
) # 512 -> 512
self.block6 = nn.Sequential(
ResBlock(channels[3], channels[3])
) # 512 -> 512
self.block7 = nn.Sequential(
ResBlock(channels[3], channels[3])
) # 512 -> 512
self.block_512 = ResBlock(channels[3], channels[3])
self.block8 = EqualConv2d(channels[3], latent_dim, 4, padding=0, bias=False)
def forward(self, x):
res = []
h = x
h = self.block1(h) # 256
res.append(h)
h = self.block2(h) # 128
res.append(h)
h = self.block3(h) # 64
res.append(h)
h = self.block4(h) # 32
res.append(h)
h = self.block5(h) # 16
res.append(h)
h = self.block6(h) # 8
res.append(h)
h = self.block7(h) # 4
res.append(h)
h = self.block_512(h)
h = self.block8(h) # 1
return h.squeeze(-1).squeeze(-1), res[::-1]
class Encoder(nn.Module):
def __init__(self, dim=512, dim_motion=20, scale=1):
super(Encoder, self).__init__()
# 2R netmork
self.enc_2r = Encoder2R(dim, scale)
# R2T
self.enc_r2t = nn.Sequential(
EqualLinear(dim, dim_motion)
)
def enc_motion(self, x):
z_t2r, _ = self.enc_2r(x)
alpha_r2t = self.enc_r2t(z_t2r)
return alpha_r2t
def enc_transfer_img(self, z_s2r, d_l, s_l):
alpha_r2s = self.enc_r2t(z_s2r)
alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(s_l).unsqueeze(0).to('cuda')
alpha = [alpha_r2s]
return alpha
def enc_transfer_vid(self, alpha_r2s, input_target, alpha_start):
z_t2r, _ = self.enc_2r(input_target)
alpha_r2t = self.enc_r2t(z_t2r)
alpha = [alpha_r2t, alpha_r2s, alpha_start]
return alpha
def forward(self, input_source, input_target, alpha_start=None):
if input_target is not None:
z_s2r, feats = self.enc_2r(input_source)
z_t2r, _ = self.enc_2r(input_target)
alpha_r2t = self.enc_r2t(z_t2r)
if alpha_start is not None:
alpha_r2s = self.enc_r2t(z_s2r)
alpha = [alpha_r2t, alpha_r2s, alpha_start]
else:
alpha = [alpha_r2t]
return z_s2r, alpha, feats
else:
z_s2r, feats = self.enc_2r(input_source)
return z_s2r, None, feats