|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
class LayerNorm(torch.nn.LayerNorm): |
|
""" |
|
Layer normalization module. |
|
|
|
Args: |
|
nout (int): Output dim size. |
|
dim (int): Dimension to be normalized. |
|
""" |
|
|
|
def __init__(self, nout, dim=-1): |
|
""" |
|
Construct an LayerNorm object. |
|
""" |
|
super(LayerNorm, self).__init__(nout, eps=1e-12) |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
""" |
|
Apply layer normalization. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
|
|
Returns: |
|
torch.Tensor: Normalized tensor. |
|
""" |
|
if self.dim == -1: |
|
return super(LayerNorm, self).forward(x) |
|
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) |
|
|