Spaces:
Build error
Build error
import torch.nn as nn | |
class ResNormLayer(nn.Module): | |
def __init__(self, linear_size,): | |
super(ResNormLayer, self).__init__() | |
self.l_size = linear_size | |
self.nonlin1 = nn.ReLU(inplace=True) | |
self.nonlin2 = nn.ReLU(inplace=True) | |
self.norm_fn1 = nn.LayerNorm(self.l_size) | |
self.norm_fn2 = nn.LayerNorm(self.l_size) | |
self.w1 = nn.Linear(self.l_size, self.l_size) | |
self.w2 = nn.Linear(self.l_size, self.l_size) | |
def forward(self, x): | |
y = self.w1(x) | |
y = self.nonlin1(y) | |
y = self.norm_fn1(y) | |
y = self.w2(y) | |
y = self.nonlin2(y) | |
y = self.norm_fn2(y) | |
out = x + y | |
return out | |