Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import copy | |
| import numbers | |
| from typing import Any, List, Tuple, Union | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.nn import functional as F | |
| from modules.general.scaling import ActivationBalancer | |
| from modules.general.scaling import BasicNorm as _BasicNorm | |
| _shape_t = Union[int, List[int], torch.Size] | |
| class LayerNorm(nn.Module): | |
| __constants__ = ["normalized_shape", "eps", "elementwise_affine"] | |
| normalized_shape: Tuple[int, ...] | |
| eps: float | |
| elementwise_affine: bool | |
| def __init__( | |
| self, | |
| normalized_shape: _shape_t, | |
| eps: float = 1e-5, | |
| elementwise_affine: bool = True, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super(LayerNorm, self).__init__() | |
| if isinstance(normalized_shape, numbers.Integral): | |
| normalized_shape = (normalized_shape,) | |
| self.normalized_shape = tuple(normalized_shape) | |
| self.eps = eps | |
| self.elementwise_affine = elementwise_affine | |
| if self.elementwise_affine: | |
| self.weight = nn.Parameter( | |
| torch.empty(self.normalized_shape, **factory_kwargs) | |
| ) | |
| self.bias = nn.Parameter( | |
| torch.empty(self.normalized_shape, **factory_kwargs) | |
| ) | |
| else: | |
| self.register_parameter("weight", None) | |
| self.register_parameter("bias", None) | |
| self.reset_parameters() | |
| def reset_parameters(self) -> None: | |
| if self.elementwise_affine: | |
| nn.init.ones_(self.weight) | |
| nn.init.zeros_(self.bias) | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| output = F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) | |
| return output, embedding | |
| assert embedding is None | |
| return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) | |
| def extra_repr(self) -> str: | |
| return ( | |
| "{normalized_shape}, eps={eps}, " | |
| "elementwise_affine={elementwise_affine}".format(**self.__dict__) | |
| ) | |
| class AdaptiveLayerNorm(nn.Module): | |
| r"""Adaptive Layer Normalization""" | |
| def __init__(self, d_model, norm) -> None: | |
| super(AdaptiveLayerNorm, self).__init__() | |
| self.project_layer = nn.Linear(d_model, 2 * d_model) | |
| self.norm = norm | |
| self.d_model = d_model | |
| self.eps = self.norm.eps | |
| def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| weight, bias = torch.split( | |
| self.project_layer(embedding), | |
| split_size_or_sections=self.d_model, | |
| dim=-1, | |
| ) | |
| return (weight * self.norm(input) + bias, embedding) | |
| weight, bias = torch.split( | |
| self.project_layer(embedding), | |
| split_size_or_sections=self.d_model, | |
| dim=-1, | |
| ) | |
| return weight * self.norm(input) + bias | |
| class BasicNorm(_BasicNorm): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| eps: float = 1e-5, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super(BasicNorm, self).__init__(d_model, eps=eps) | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| return ( | |
| super(BasicNorm, self).forward(input), | |
| embedding, | |
| ) | |
| assert embedding is None | |
| return super(BasicNorm, self).forward(input) | |
| class BalancedBasicNorm(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| eps: float = 1e-5, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super(BalancedBasicNorm, self).__init__() | |
| self.balancer = ActivationBalancer( | |
| d_model, | |
| channel_dim=-1, | |
| min_positive=0.45, | |
| max_positive=0.55, | |
| max_abs=6.0, | |
| ) | |
| self.norm = BasicNorm(d_model, eps, device=device, dtype=dtype) | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| input, embedding = input | |
| return self.norm((self.balancer(input), embedding)) | |
| assert embedding is None | |
| return self.norm(self.balancer(input)) | |
| class IdentityNorm(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| eps: float = 1e-5, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| super(IdentityNorm, self).__init__() | |
| def forward(self, input: Tensor, embedding: Any = None) -> Tensor: | |
| if isinstance(input, tuple): | |
| return input | |
| assert embedding is None | |
| return input | |