import torch.nn as nn class ResNormLayer(nn.Module): def __init__(self, linear_size,): super(ResNormLayer, self).__init__() self.l_size = linear_size self.nonlin1 = nn.ReLU(inplace=True) self.nonlin2 = nn.ReLU(inplace=True) self.norm_fn1 = nn.LayerNorm(self.l_size) self.norm_fn2 = nn.LayerNorm(self.l_size) self.w1 = nn.Linear(self.l_size, self.l_size) self.w2 = nn.Linear(self.l_size, self.l_size) def forward(self, x): y = self.w1(x) y = self.nonlin1(y) y = self.norm_fn1(y) y = self.w2(y) y = self.nonlin2(y) y = self.norm_fn2(y) out = x + y return out