Spaces:
Build error
Build error
from argparse import Namespace | |
import torch | |
import math | |
from typing import Union | |
from .Layer import Conv1d, LayerNorm, LinearAttention | |
from .Diffusion import Diffusion | |
class DiffSinger(torch.nn.Module): | |
def __init__(self, hyper_parameters: Namespace): | |
super().__init__() | |
self.hp = hyper_parameters | |
self.encoder = Encoder(self.hp) | |
self.diffusion = Diffusion(self.hp) | |
def forward( | |
self, | |
tokens: torch.LongTensor, | |
notes: torch.LongTensor, | |
durations: torch.LongTensor, | |
lengths: torch.LongTensor, | |
genres: torch.LongTensor, | |
singers: torch.LongTensor, | |
features: Union[torch.FloatTensor, None]= None, | |
ddim_steps: Union[int, None]= None | |
): | |
encodings, linear_predictions = self.encoder( | |
tokens= tokens, | |
notes= notes, | |
durations= durations, | |
lengths= lengths, | |
genres= genres, | |
singers= singers | |
) # [Batch, Enc_d, Feature_t] | |
encodings = torch.cat([encodings, linear_predictions], dim= 1) # [Batch, Enc_d + Feature_d, Feature_t] | |
if not features is None or ddim_steps is None or ddim_steps == self.hp.Diffusion.Max_Step: | |
diffusion_predictions, noises, epsilons = self.diffusion( | |
encodings= encodings, | |
features= features, | |
) | |
else: | |
noises, epsilons = None, None | |
diffusion_predictions = self.diffusion.DDIM( | |
encodings= encodings, | |
ddim_steps= ddim_steps | |
) | |
return linear_predictions, diffusion_predictions, noises, epsilons | |
class Encoder(torch.nn.Module): | |
def __init__( | |
self, | |
hyper_parameters: Namespace | |
): | |
super().__init__() | |
self.hp = hyper_parameters | |
if self.hp.Feature_Type == 'Mel': | |
self.feature_size = self.hp.Sound.Mel_Dim | |
elif self.hp.Feature_Type == 'Spectrogram': | |
self.feature_size = self.hp.Sound.N_FFT // 2 + 1 | |
self.token_embedding = torch.nn.Embedding( | |
num_embeddings= self.hp.Tokens, | |
embedding_dim= self.hp.Encoder.Size | |
) | |
self.note_embedding = torch.nn.Embedding( | |
num_embeddings= self.hp.Notes, | |
embedding_dim= self.hp.Encoder.Size | |
) | |
self.duration_embedding = Duration_Positional_Encoding( | |
num_embeddings= self.hp.Durations, | |
embedding_dim= self.hp.Encoder.Size | |
) | |
self.genre_embedding = torch.nn.Embedding( | |
num_embeddings= self.hp.Genres, | |
embedding_dim= self.hp.Encoder.Size, | |
) | |
self.singer_embedding = torch.nn.Embedding( | |
num_embeddings= self.hp.Singers, | |
embedding_dim= self.hp.Encoder.Size, | |
) | |
torch.nn.init.xavier_uniform_(self.token_embedding.weight) | |
torch.nn.init.xavier_uniform_(self.note_embedding.weight) | |
torch.nn.init.xavier_uniform_(self.genre_embedding.weight) | |
torch.nn.init.xavier_uniform_(self.singer_embedding.weight) | |
self.fft_blocks = torch.nn.ModuleList([ | |
FFT_Block( | |
channels= self.hp.Encoder.Size, | |
num_head= self.hp.Encoder.ConvFFT.Head, | |
ffn_kernel_size= self.hp.Encoder.ConvFFT.FFN.Kernel_Size, | |
dropout_rate= self.hp.Encoder.ConvFFT.Dropout_Rate | |
) | |
for _ in range(self.hp.Encoder.ConvFFT.Stack) | |
]) | |
self.linear_projection = Conv1d( | |
in_channels= self.hp.Encoder.Size, | |
out_channels= self.feature_size, | |
kernel_size= 1, | |
bias= True, | |
w_init_gain= 'linear' | |
) | |
def forward( | |
self, | |
tokens: torch.Tensor, | |
notes: torch.Tensor, | |
durations: torch.Tensor, | |
lengths: torch.Tensor, | |
genres: torch.Tensor, | |
singers: torch.Tensor | |
): | |
x = \ | |
self.token_embedding(tokens) + \ | |
self.note_embedding(notes) + \ | |
self.duration_embedding(durations) + \ | |
self.genre_embedding(genres).unsqueeze(1) + \ | |
self.singer_embedding(singers).unsqueeze(1) | |
x = x.permute(0, 2, 1) # [Batch, Enc_d, Enc_t] | |
for block in self.fft_blocks: | |
x = block(x, lengths) # [Batch, Enc_d, Enc_t] | |
linear_predictions = self.linear_projection(x) # [Batch, Feature_d, Enc_t] | |
return x, linear_predictions | |
class FFT_Block(torch.nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
num_head: int, | |
ffn_kernel_size: int, | |
dropout_rate: float= 0.1, | |
) -> None: | |
super().__init__() | |
self.attention = LinearAttention( | |
channels= channels, | |
calc_channels= channels, | |
num_heads= num_head, | |
dropout_rate= dropout_rate | |
) | |
self.ffn = FFN( | |
channels= channels, | |
kernel_size= ffn_kernel_size, | |
dropout_rate= dropout_rate | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
lengths: torch.Tensor | |
) -> torch.Tensor: | |
''' | |
x: [Batch, Dim, Time] | |
''' | |
masks = (~Mask_Generate(lengths= lengths, max_length= torch.ones_like(x[0, 0]).sum())).unsqueeze(1).float() # float mask | |
# Attention + Dropout + LayerNorm | |
x = self.attention(x) | |
# FFN + Dropout + LayerNorm | |
x = self.ffn(x, masks) | |
return x * masks | |
class FFN(torch.nn.Module): | |
def __init__( | |
self, | |
channels: int, | |
kernel_size: int, | |
dropout_rate: float= 0.1, | |
) -> None: | |
super().__init__() | |
self.conv_0 = Conv1d( | |
in_channels= channels, | |
out_channels= channels, | |
kernel_size= kernel_size, | |
padding= (kernel_size - 1) // 2, | |
w_init_gain= 'relu' | |
) | |
self.relu = torch.nn.ReLU() | |
self.dropout = torch.nn.Dropout(p= dropout_rate) | |
self.conv_1 = Conv1d( | |
in_channels= channels, | |
out_channels= channels, | |
kernel_size= kernel_size, | |
padding= (kernel_size - 1) // 2, | |
w_init_gain= 'linear' | |
) | |
self.norm = LayerNorm( | |
num_features= channels, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
masks: torch.Tensor | |
) -> torch.Tensor: | |
''' | |
x: [Batch, Dim, Time] | |
''' | |
residuals = x | |
x = self.conv_0(x * masks) | |
x = self.relu(x) | |
x = self.dropout(x) | |
x = self.conv_1(x * masks) | |
x = self.dropout(x) | |
x = self.norm(x + residuals) | |
return x * masks | |
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html | |
# https://github.com/soobinseo/Transformer-TTS/blob/master/network.py | |
class Duration_Positional_Encoding(torch.nn.Embedding): | |
def __init__( | |
self, | |
num_embeddings: int, | |
embedding_dim: int, | |
): | |
positional_embedding = torch.zeros(num_embeddings, embedding_dim) | |
position = torch.arange(0, num_embeddings, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim)) | |
positional_embedding[:, 0::2] = torch.sin(position * div_term) | |
positional_embedding[:, 1::2] = torch.cos(position * div_term) | |
super().__init__( | |
num_embeddings= num_embeddings, | |
embedding_dim= embedding_dim, | |
_weight= positional_embedding | |
) | |
self.weight.requires_grad = False | |
self.alpha = torch.nn.Parameter( | |
data= torch.ones(1) * 0.01, | |
requires_grad= True | |
) | |
def forward(self, durations): | |
''' | |
durations: [Batch, Length] | |
''' | |
return self.alpha * super().forward(durations) # [Batch, Dim, Length] | |
def get_pe(x: torch.Tensor, pe: torch.Tensor): | |
pe = pe.repeat(1, 1, math.ceil(x.size(2) / pe.size(2))) | |
return pe[:, :, :x.size(2)] | |
def Mask_Generate(lengths: torch.Tensor, max_length: Union[torch.Tensor, int, None]= None): | |
''' | |
lengths: [Batch] | |
max_lengths: an int value. If None, max_lengths == max(lengths) | |
''' | |
max_length = max_length or torch.max(lengths) | |
sequence = torch.arange(max_length)[None, :].to(lengths.device) | |
return sequence >= lengths[:, None] # [Batch, Time] |