Spaces:
Running
Running
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| import torch.nn as nn | |
| def normalization(channels: int, groups: int = 32): | |
| r"""Make a standard normalization layer, i.e. GroupNorm. | |
| Args: | |
| channels: number of input channels. | |
| groups: number of groups for group normalization. | |
| Returns: | |
| a ``nn.Module`` for normalization. | |
| """ | |
| assert groups > 0, f"invalid number of groups: {groups}" | |
| return nn.GroupNorm(groups, channels) | |
| def Linear(*args, **kwargs): | |
| r"""Wrapper of ``nn.Linear`` with kaiming_normal_ initialization.""" | |
| layer = nn.Linear(*args, **kwargs) | |
| nn.init.kaiming_normal_(layer.weight) | |
| return layer | |
| def Conv1d(*args, **kwargs): | |
| r"""Wrapper of ``nn.Conv1d`` with kaiming_normal_ initialization.""" | |
| layer = nn.Conv1d(*args, **kwargs) | |
| nn.init.kaiming_normal_(layer.weight) | |
| return layer | |
| def Conv2d(*args, **kwargs): | |
| r"""Wrapper of ``nn.Conv2d`` with kaiming_normal_ initialization.""" | |
| layer = nn.Conv2d(*args, **kwargs) | |
| nn.init.kaiming_normal_(layer.weight) | |
| return layer | |
| def ConvNd(dims: int = 1, *args, **kwargs): | |
| r"""Wrapper of N-dimension convolution with kaiming_normal_ initialization. | |
| Args: | |
| dims: number of dimensions of the convolution. | |
| """ | |
| if dims == 1: | |
| return Conv1d(*args, **kwargs) | |
| elif dims == 2: | |
| return Conv2d(*args, **kwargs) | |
| else: | |
| raise ValueError(f"invalid number of dimensions: {dims}") | |
| def zero_module(module: nn.Module): | |
| r"""Zero out the parameters of a module and return it.""" | |
| nn.init.zeros_(module.weight) | |
| nn.init.zeros_(module.bias) | |
| return module | |
| def scale_module(module: nn.Module, scale): | |
| r"""Scale the parameters of a module and return it.""" | |
| for p in module.parameters(): | |
| p.detach().mul_(scale) | |
| return module | |
| def mean_flat(tensor: torch.Tensor): | |
| r"""Take the mean over all non-batch dimensions.""" | |
| return tensor.mean(dim=tuple(range(1, tensor.dim()))) | |
| def append_dims(x, target_dims): | |
| r"""Appends dimensions to the end of a tensor until | |
| it has target_dims dimensions. | |
| """ | |
| dims_to_append = target_dims - x.dim() | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.dim()} dims but target_dims is {target_dims}, which is less" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def append_zero(x, count=1): | |
| r"""Appends ``count`` zeros to the end of a tensor along the last dimension.""" | |
| assert count > 0, f"invalid count: {count}" | |
| return torch.cat([x, x.new_zeros((*x.size()[:-1], count))], dim=-1) | |
| class Transpose(nn.Identity): | |
| """(N, T, D) -> (N, D, T)""" | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return input.transpose(1, 2) | |