Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	| import math | |
| import torch | |
| from torch import nn | |
| from typing import Optional, Any | |
| from torch import Tensor | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import torchaudio.functional as audio_F | |
| import random | |
| random.seed(0) | |
| def _get_activation_fn(activ): | |
| if activ == 'relu': | |
| return nn.ReLU() | |
| elif activ == 'lrelu': | |
| return nn.LeakyReLU(0.2) | |
| elif activ == 'swish': | |
| return lambda x: x*torch.sigmoid(x) | |
| else: | |
| raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) | |
| class LinearNorm(torch.nn.Module): | |
| def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): | |
| super(LinearNorm, self).__init__() | |
| self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.linear_layer.weight, | |
| gain=torch.nn.init.calculate_gain(w_init_gain)) | |
| def forward(self, x): | |
| return self.linear_layer(x) | |
| class ConvNorm(torch.nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, | |
| padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): | |
| super(ConvNorm, self).__init__() | |
| if padding is None: | |
| assert(kernel_size % 2 == 1) | |
| padding = int(dilation * (kernel_size - 1) / 2) | |
| self.conv = torch.nn.Conv1d(in_channels, out_channels, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=padding, dilation=dilation, | |
| bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) | |
| def forward(self, signal): | |
| conv_signal = self.conv(signal) | |
| return conv_signal | |
| class CausualConv(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): | |
| super(CausualConv, self).__init__() | |
| if padding is None: | |
| assert(kernel_size % 2 == 1) | |
| padding = int(dilation * (kernel_size - 1) / 2) * 2 | |
| else: | |
| self.padding = padding * 2 | |
| self.conv = nn.Conv1d(in_channels, out_channels, | |
| kernel_size=kernel_size, stride=stride, | |
| padding=self.padding, | |
| dilation=dilation, | |
| bias=bias) | |
| torch.nn.init.xavier_uniform_( | |
| self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| x = x[:, :, :-self.padding] | |
| return x | |
| class CausualBlock(nn.Module): | |
| def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): | |
| super(CausualBlock, self).__init__() | |
| self.blocks = nn.ModuleList([ | |
| self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) | |
| for i in range(n_conv)]) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| res = x | |
| x = block(x) | |
| x += res | |
| return x | |
| def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): | |
| layers = [ | |
| CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), | |
| _get_activation_fn(activ), | |
| nn.BatchNorm1d(hidden_dim), | |
| nn.Dropout(p=dropout_p), | |
| CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), | |
| _get_activation_fn(activ), | |
| nn.Dropout(p=dropout_p) | |
| ] | |
| return nn.Sequential(*layers) | |
| class ConvBlock(nn.Module): | |
| def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): | |
| super().__init__() | |
| self._n_groups = 8 | |
| self.blocks = nn.ModuleList([ | |
| self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) | |
| for i in range(n_conv)]) | |
| def forward(self, x): | |
| for block in self.blocks: | |
| res = x | |
| x = block(x) | |
| x += res | |
| return x | |
| def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): | |
| layers = [ | |
| ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), | |
| _get_activation_fn(activ), | |
| nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), | |
| nn.Dropout(p=dropout_p), | |
| ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), | |
| _get_activation_fn(activ), | |
| nn.Dropout(p=dropout_p) | |
| ] | |
| return nn.Sequential(*layers) | |
| class LocationLayer(nn.Module): | |
| def __init__(self, attention_n_filters, attention_kernel_size, | |
| attention_dim): | |
| super(LocationLayer, self).__init__() | |
| padding = int((attention_kernel_size - 1) / 2) | |
| self.location_conv = ConvNorm(2, attention_n_filters, | |
| kernel_size=attention_kernel_size, | |
| padding=padding, bias=False, stride=1, | |
| dilation=1) | |
| self.location_dense = LinearNorm(attention_n_filters, attention_dim, | |
| bias=False, w_init_gain='tanh') | |
| def forward(self, attention_weights_cat): | |
| processed_attention = self.location_conv(attention_weights_cat) | |
| processed_attention = processed_attention.transpose(1, 2) | |
| processed_attention = self.location_dense(processed_attention) | |
| return processed_attention | |
| class Attention(nn.Module): | |
| def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, | |
| attention_location_n_filters, attention_location_kernel_size): | |
| super(Attention, self).__init__() | |
| self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, | |
| bias=False, w_init_gain='tanh') | |
| self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, | |
| w_init_gain='tanh') | |
| self.v = LinearNorm(attention_dim, 1, bias=False) | |
| self.location_layer = LocationLayer(attention_location_n_filters, | |
| attention_location_kernel_size, | |
| attention_dim) | |
| self.score_mask_value = -float("inf") | |
| def get_alignment_energies(self, query, processed_memory, | |
| attention_weights_cat): | |
| """ | |
| PARAMS | |
| ------ | |
| query: decoder output (batch, n_mel_channels * n_frames_per_step) | |
| processed_memory: processed encoder outputs (B, T_in, attention_dim) | |
| attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) | |
| RETURNS | |
| ------- | |
| alignment (batch, max_time) | |
| """ | |
| processed_query = self.query_layer(query.unsqueeze(1)) | |
| processed_attention_weights = self.location_layer(attention_weights_cat) | |
| energies = self.v(torch.tanh( | |
| processed_query + processed_attention_weights + processed_memory)) | |
| energies = energies.squeeze(-1) | |
| return energies | |
| def forward(self, attention_hidden_state, memory, processed_memory, | |
| attention_weights_cat, mask): | |
| """ | |
| PARAMS | |
| ------ | |
| attention_hidden_state: attention rnn last output | |
| memory: encoder outputs | |
| processed_memory: processed encoder outputs | |
| attention_weights_cat: previous and cummulative attention weights | |
| mask: binary mask for padded data | |
| """ | |
| alignment = self.get_alignment_energies( | |
| attention_hidden_state, processed_memory, attention_weights_cat) | |
| if mask is not None: | |
| alignment.data.masked_fill_(mask, self.score_mask_value) | |
| attention_weights = F.softmax(alignment, dim=1) | |
| attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) | |
| attention_context = attention_context.squeeze(1) | |
| return attention_context, attention_weights | |
| class ForwardAttentionV2(nn.Module): | |
| def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, | |
| attention_location_n_filters, attention_location_kernel_size): | |
| super(ForwardAttentionV2, self).__init__() | |
| self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, | |
| bias=False, w_init_gain='tanh') | |
| self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, | |
| w_init_gain='tanh') | |
| self.v = LinearNorm(attention_dim, 1, bias=False) | |
| self.location_layer = LocationLayer(attention_location_n_filters, | |
| attention_location_kernel_size, | |
| attention_dim) | |
| self.score_mask_value = -float(1e20) | |
| def get_alignment_energies(self, query, processed_memory, | |
| attention_weights_cat): | |
| """ | |
| PARAMS | |
| ------ | |
| query: decoder output (batch, n_mel_channels * n_frames_per_step) | |
| processed_memory: processed encoder outputs (B, T_in, attention_dim) | |
| attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) | |
| RETURNS | |
| ------- | |
| alignment (batch, max_time) | |
| """ | |
| processed_query = self.query_layer(query.unsqueeze(1)) | |
| processed_attention_weights = self.location_layer(attention_weights_cat) | |
| energies = self.v(torch.tanh( | |
| processed_query + processed_attention_weights + processed_memory)) | |
| energies = energies.squeeze(-1) | |
| return energies | |
| def forward(self, attention_hidden_state, memory, processed_memory, | |
| attention_weights_cat, mask, log_alpha): | |
| """ | |
| PARAMS | |
| ------ | |
| attention_hidden_state: attention rnn last output | |
| memory: encoder outputs | |
| processed_memory: processed encoder outputs | |
| attention_weights_cat: previous and cummulative attention weights | |
| mask: binary mask for padded data | |
| """ | |
| log_energy = self.get_alignment_energies( | |
| attention_hidden_state, processed_memory, attention_weights_cat) | |
| #log_energy = | |
| if mask is not None: | |
| log_energy.data.masked_fill_(mask, self.score_mask_value) | |
| #attention_weights = F.softmax(alignment, dim=1) | |
| #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] | |
| #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] | |
| #log_total_score = log_alpha + content_score | |
| #previous_attention_weights = attention_weights_cat[:,0,:] | |
| log_alpha_shift_padded = [] | |
| max_time = log_energy.size(1) | |
| for sft in range(2): | |
| shifted = log_alpha[:,:max_time-sft] | |
| shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) | |
| log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) | |
| biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) | |
| log_alpha_new = biased + log_energy | |
| attention_weights = F.softmax(log_alpha_new, dim=1) | |
| attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) | |
| attention_context = attention_context.squeeze(1) | |
| return attention_context, attention_weights, log_alpha_new | |
| class PhaseShuffle2d(nn.Module): | |
| def __init__(self, n=2): | |
| super(PhaseShuffle2d, self).__init__() | |
| self.n = n | |
| self.random = random.Random(1) | |
| def forward(self, x, move=None): | |
| # x.size = (B, C, M, L) | |
| if move is None: | |
| move = self.random.randint(-self.n, self.n) | |
| if move == 0: | |
| return x | |
| else: | |
| left = x[:, :, :, :move] | |
| right = x[:, :, :, move:] | |
| shuffled = torch.cat([right, left], dim=3) | |
| return shuffled | |
| class PhaseShuffle1d(nn.Module): | |
| def __init__(self, n=2): | |
| super(PhaseShuffle1d, self).__init__() | |
| self.n = n | |
| self.random = random.Random(1) | |
| def forward(self, x, move=None): | |
| # x.size = (B, C, M, L) | |
| if move is None: | |
| move = self.random.randint(-self.n, self.n) | |
| if move == 0: | |
| return x | |
| else: | |
| left = x[:, :, :move] | |
| right = x[:, :, move:] | |
| shuffled = torch.cat([right, left], dim=2) | |
| return shuffled | |
| class MFCC(nn.Module): | |
| def __init__(self, n_mfcc=40, n_mels=80): | |
| super(MFCC, self).__init__() | |
| self.n_mfcc = n_mfcc | |
| self.n_mels = n_mels | |
| self.norm = 'ortho' | |
| dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) | |
| self.register_buffer('dct_mat', dct_mat) | |
| def forward(self, mel_specgram): | |
| if len(mel_specgram.shape) == 2: | |
| mel_specgram = mel_specgram.unsqueeze(0) | |
| unsqueezed = True | |
| else: | |
| unsqueezed = False | |
| # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) | |
| # -> (channel, time, n_mfcc).tranpose(...) | |
| mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) | |
| # unpack batch | |
| if unsqueezed: | |
| mfcc = mfcc.squeeze(0) | |
| return mfcc | |
