Real3DPortrait-Unofficial-Working-Code
/
modules
/img2plane
/simple_encoders
/high_resolution_encoder.py
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class HighResoEncoder(nn.Module): | |
def __init__(self, | |
in_dim=5, # 3 for rgb and 2 for coordinate | |
out_dim=96, | |
): | |
super().__init__() | |
self.first = nn.Conv2d(in_channels=in_dim, out_channels=64, kernel_size=7, stride=2, padding=3) | |
self.activation = nn.LeakyReLU(negative_slope=0.01) | |
self.conv_layers = nn.Sequential(*[ | |
nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(negative_slope=0.01), | |
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(negative_slope=0.01), | |
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(negative_slope=0.01), | |
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), | |
nn.LeakyReLU(negative_slope=0.01), | |
]) | |
self.final = nn.Conv2d(in_channels=96, out_channels=out_dim, kernel_size=3, stride=1, padding=1) | |
def forward(self, x): | |
""" | |
x: [B, C=5, 256, 256] | |
return: [B, C=96, 256, 256] | |
""" | |
h = self.first(x) | |
h = self.conv_layers(h) | |
h = self.final(h) | |
return h | |