Upload MyLLaMa
Browse files
llama.py
CHANGED
@@ -92,9 +92,7 @@ class RMSNorm(nn.Module):
|
|
92 |
super().__init__()
|
93 |
|
94 |
self.dim = dim
|
95 |
-
self.trainable = nn.Parameter(
|
96 |
-
data=torch.nn.init.normal_(torch.zeros((dim,))), requires_grad=True
|
97 |
-
)
|
98 |
self.eps = eps
|
99 |
|
100 |
def forward(self, x: Tensor):
|
|
|
92 |
super().__init__()
|
93 |
|
94 |
self.dim = dim
|
95 |
+
self.trainable = nn.Parameter(data=torch.ones((dim,)), requires_grad=True)
|
|
|
|
|
96 |
self.eps = eps
|
97 |
|
98 |
def forward(self, x: Tensor):
|