ameerazam08's picture
Upload folder using huggingface_hub
e34aada verified
raw
history blame
1.37 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class HighResoEncoder(nn.Module):
def __init__(self,
in_dim=5, # 3 for rgb and 2 for coordinate
out_dim=96,
):
super().__init__()
self.first = nn.Conv2d(in_channels=in_dim, out_channels=64, kernel_size=7, stride=2, padding=3)
self.activation = nn.LeakyReLU(negative_slope=0.01)
self.conv_layers = nn.Sequential(*[
nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01),
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01),
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01),
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(negative_slope=0.01),
])
self.final = nn.Conv2d(in_channels=96, out_channels=out_dim, kernel_size=3, stride=1, padding=1)
def forward(self, x):
"""
x: [B, C=5, 256, 256]
return: [B, C=96, 256, 256]
"""
h = self.first(x)
h = self.conv_layers(h)
h = self.final(h)
return h