File size: 2,094 Bytes
d9e90e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import torch.nn as nn


class RMSNorm(nn.Module):
    def __init__(

        self,

        dim: int,

        elementwise_affine=True,

        eps: float = 1e-6,

        device=None,

        dtype=None,

    ):
        """

        Initialize the RMSNorm normalization layer.



        Args:

            dim (int): The dimension of the input tensor.

            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.



        Attributes:

            eps (float): A small value added to the denominator for numerical stability.

            weight (nn.Parameter): Learnable scaling parameter.



        """
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.eps = eps
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))

    def _norm(self, x):
        """

        Apply the RMSNorm normalization to the input tensor.



        Args:

            x (torch.Tensor): The input tensor.



        Returns:

            torch.Tensor: The normalized tensor.



        """
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        """

        Forward pass through the RMSNorm layer.



        Args:

            x (torch.Tensor): The input tensor.



        Returns:

            torch.Tensor: The output tensor after applying RMSNorm.



        """
        output = self._norm(x.float()).type_as(x)
        if hasattr(self, "weight"):
            output = output * self.weight
        return output


def get_norm_layer(norm_layer):
    """

    Get the normalization layer.



    Args:

        norm_layer (str): The type of normalization layer.



    Returns:

        norm_layer (nn.Module): The normalization layer.

    """
    if norm_layer == "layer":
        return nn.LayerNorm
    elif norm_layer == "rms":
        return RMSNorm
    else:
        raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")