Spaces:
Paused
Paused
| import math | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .ops import (EqualConv2d, EqualLinear, ConvLayer) | |
| class ResBlock(nn.Module): | |
| def __init__(self, in_channel, out_channel): | |
| super().__init__() | |
| self.conv1 = ConvLayer(in_channel, out_channel, 3) | |
| self.conv2 = ConvLayer(out_channel, out_channel, 3, downsample=True) | |
| self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) | |
| def forward(self, x): | |
| h = x | |
| h = self.conv1(h) | |
| h = self.conv2(h) | |
| skip = self.skip(x) | |
| h = (h + skip) / math.sqrt(2) | |
| return h | |
| class Encoder2R(nn.Module): | |
| def __init__(self, latent_dim=512, scale=1): | |
| super(Encoder2R, self).__init__() | |
| channels = [64*scale, 128*scale, 256*scale, 512*scale] | |
| # version1 | |
| self.block1 = ConvLayer(3, channels[0], 1) # 256, 3 -> 64 | |
| self.block2 = nn.Sequential( | |
| ResBlock(channels[0], channels[1]) | |
| ) # 64 -> 128 | |
| self.block3 = nn.Sequential( | |
| ResBlock(channels[1], channels[2]) | |
| ) # 128 -> 256 | |
| self.block4 = nn.Sequential( | |
| ResBlock(channels[2], channels[3]) | |
| ) # 256 -> 512 | |
| self.block5 = nn.Sequential( | |
| ResBlock(channels[3], channels[3]) | |
| ) # 512 -> 512 | |
| self.block6 = nn.Sequential( | |
| ResBlock(channels[3], channels[3]) | |
| ) # 512 -> 512 | |
| self.block7 = nn.Sequential( | |
| ResBlock(channels[3], channels[3]) | |
| ) # 512 -> 512 | |
| self.block_512 = ResBlock(channels[3], channels[3]) | |
| self.block8 = EqualConv2d(channels[3], latent_dim, 4, padding=0, bias=False) | |
| def forward(self, x): | |
| res = [] | |
| h = x | |
| h = self.block1(h) # 256 | |
| res.append(h) | |
| h = self.block2(h) # 128 | |
| res.append(h) | |
| h = self.block3(h) # 64 | |
| res.append(h) | |
| h = self.block4(h) # 32 | |
| res.append(h) | |
| h = self.block5(h) # 16 | |
| res.append(h) | |
| h = self.block6(h) # 8 | |
| res.append(h) | |
| h = self.block7(h) # 4 | |
| res.append(h) | |
| h = self.block_512(h) | |
| h = self.block8(h) # 1 | |
| return h.squeeze(-1).squeeze(-1), res[::-1] | |
| class Encoder(nn.Module): | |
| def __init__(self, dim=512, dim_motion=20, scale=1): | |
| super(Encoder, self).__init__() | |
| # 2R netmork | |
| self.enc_2r = Encoder2R(dim, scale) | |
| # R2T | |
| self.enc_r2t = nn.Sequential( | |
| EqualLinear(dim, dim_motion) | |
| ) | |
| def enc_motion(self, x): | |
| z_t2r, _ = self.enc_2r(x) | |
| alpha_r2t = self.enc_r2t(z_t2r) | |
| return alpha_r2t | |
| def enc_transfer_img(self, z_s2r, d_l, s_l): | |
| alpha_r2s = self.enc_r2t(z_s2r) | |
| alpha_r2s[:, d_l] = alpha_r2s[:, d_l] + torch.FloatTensor(s_l).unsqueeze(0).to('cuda') | |
| alpha = [alpha_r2s] | |
| return alpha | |
| def enc_transfer_vid(self, alpha_r2s, input_target, alpha_start): | |
| z_t2r, _ = self.enc_2r(input_target) | |
| alpha_r2t = self.enc_r2t(z_t2r) | |
| alpha = [alpha_r2t, alpha_r2s, alpha_start] | |
| return alpha | |
| def forward(self, input_source, input_target, alpha_start=None): | |
| if input_target is not None: | |
| z_s2r, feats = self.enc_2r(input_source) | |
| z_t2r, _ = self.enc_2r(input_target) | |
| alpha_r2t = self.enc_r2t(z_t2r) | |
| if alpha_start is not None: | |
| alpha_r2s = self.enc_r2t(z_s2r) | |
| alpha = [alpha_r2t, alpha_r2s, alpha_start] | |
| else: | |
| alpha = [alpha_r2t] | |
| return z_s2r, alpha, feats | |
| else: | |
| z_s2r, feats = self.enc_2r(input_source) | |
| return z_s2r, None, feats | |