Yehor's picture
Clean up
eac7684
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# 1x1InvertibleConv and WN based on implementation from WaveGlow https://github.com/NVIDIA/waveglow/blob/master/glow.py
# Original license:
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import ast
from splines import (
piecewise_linear_transform,
piecewise_linear_inverse_transform,
unbounded_piecewise_quadratic_transform,
)
from partialconv1d import PartialConv1d as pconv1d
from typing import Tuple
from torch_env import device
def get_mask_from_lengths(lengths):
"""Constructs binary mask from a 1D torch tensor of input lengths
Args:
lengths (torch.tensor): 1D tensor
Returns:
mask (torch.tensor): num_sequences x max_length x 1 binary tensor
"""
max_len = torch.max(lengths).item()
ids = torch.tensor(list(range(max_len)), dtype=torch.long, device=device)
mask = (ids < lengths.unsqueeze(1)).bool()
return mask
class ExponentialClass(torch.nn.Module):
def __init__(self):
super(ExponentialClass, self).__init__()
def forward(self, x):
return torch.exp(x)
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(
self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
use_partial_padding=False,
use_weight_norm=False,
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.kernel_size = kernel_size
self.dilation = dilation
self.use_partial_padding = use_partial_padding
self.use_weight_norm = use_weight_norm
conv_fn = torch.nn.Conv1d
if self.use_partial_padding:
conv_fn = pconv1d
self.conv = conv_fn(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
)
if self.use_weight_norm:
self.conv = torch.nn.utils.parametrizations.weight_norm(self.conv)
def forward(self, signal, mask=None):
if self.use_partial_padding:
conv_signal = self.conv(signal, mask)
else:
conv_signal = self.conv(signal)
if mask is not None:
# always re-zero output if mask is
# available to match zero-padding
conv_signal = conv_signal * mask
return conv_signal
class DenseLayer(nn.Module):
def __init__(self, in_dim=1024, sizes=[1024, 1024]):
super(DenseLayer, self).__init__()
in_sizes = [in_dim] + sizes[:-1]
self.layers = nn.ModuleList(
[
LinearNorm(in_size, out_size, bias=True)
for (in_size, out_size) in zip(in_sizes, sizes)
]
)
def forward(self, x):
for linear in self.layers:
x = torch.tanh(linear(x))
return x
class LengthRegulator(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, dur):
output = []
for x_i, dur_i in zip(x, dur):
expanded = self.expand(x_i, dur_i)
output.append(expanded)
output = self.pad(output)
return output
def expand(self, x, dur):
output = []
for i, frame in enumerate(x):
expanded_len = int(dur[i] + 0.5)
expanded = frame.expand(expanded_len, -1)
output.append(expanded)
output = torch.cat(output, 0)
return output
def pad(self, x):
output = []
max_len = max([x[i].size(0) for i in range(len(x))])
for i, seq in enumerate(x):
padded = F.pad(seq, [0, 0, 0, max_len - seq.size(0)], "constant", 0.0)
output.append(padded)
output = torch.stack(output)
return output
class ConvLSTMLinear(nn.Module):
def __init__(
self,
in_dim,
out_dim,
n_layers=2,
n_channels=256,
kernel_size=3,
p_dropout=0.1,
lstm_type="bilstm",
use_linear=True,
):
super(ConvLSTMLinear, self).__init__()
self.out_dim = out_dim
self.lstm_type = lstm_type
self.use_linear = use_linear
self.dropout = nn.Dropout(p=p_dropout)
convolutions = []
for i in range(n_layers):
conv_layer = ConvNorm(
in_dim if i == 0 else n_channels,
n_channels,
kernel_size=kernel_size,
stride=1,
padding=int((kernel_size - 1) / 2),
dilation=1,
w_init_gain="relu",
)
conv_layer = torch.nn.utils.parametrizations.weight_norm(
conv_layer.conv, name="weight"
)
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
if not self.use_linear:
n_channels = out_dim
if self.lstm_type != "":
use_bilstm = False
lstm_channels = n_channels
if self.lstm_type == "bilstm":
use_bilstm = True
lstm_channels = int(n_channels // 2)
self.bilstm = nn.LSTM(
n_channels, lstm_channels, 1, batch_first=True, bidirectional=use_bilstm
)
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0")
if self.lstm_type == "bilstm":
self.bilstm = lstm_norm_fn_pntr(self.bilstm, "weight_hh_l0_reverse")
if self.use_linear:
self.dense = nn.Linear(n_channels, out_dim)
def run_padded_sequence(self, context, lens):
context_embedded = []
for b_ind in range(context.size()[0]): # TODO: speed up
curr_context = context[b_ind : b_ind + 1, :, : lens[b_ind]].clone()
for conv in self.convolutions:
curr_context = self.dropout(F.relu(conv(curr_context)))
context_embedded.append(curr_context[0].transpose(0, 1))
context = torch.nn.utils.rnn.pad_sequence(context_embedded, batch_first=True)
return context
def run_unsorted_inputs(self, fn, context, lens):
lens_sorted, ids_sorted = torch.sort(lens, descending=True)
unsort_ids = [0] * lens.size(0)
for i in range(len(ids_sorted)):
unsort_ids[ids_sorted[i]] = i
lens_sorted = lens_sorted.long().cpu()
context = context[ids_sorted]
context = nn.utils.rnn.pack_padded_sequence(
context, lens_sorted, batch_first=True
)
context = fn(context)[0]
context = nn.utils.rnn.pad_packed_sequence(context, batch_first=True)[0]
# map back to original indices
context = context[unsort_ids]
return context
def forward(self, context, lens):
if context.size()[0] > 1:
context = self.run_padded_sequence(context, lens)
# to B, D, T
context = context.transpose(1, 2)
else:
for conv in self.convolutions:
context = self.dropout(F.relu(conv(context)))
if self.lstm_type != "":
context = context.transpose(1, 2)
self.bilstm.flatten_parameters()
if lens is not None:
context = self.run_unsorted_inputs(self.bilstm, context, lens)
else:
context = self.bilstm(context)[0]
context = context.transpose(1, 2)
x_hat = context
if self.use_linear:
x_hat = self.dense(context.transpose(1, 2)).transpose(1, 2)
return x_hat
def infer(self, z, txt_enc, spk_emb):
x_hat = self.forward(txt_enc, spk_emb)["x_hat"]
x_hat = self.feature_processing.denormalize(x_hat)
return x_hat
class Encoder(nn.Module):
"""Encoder module:
- Three 1-d convolution banks
- Bidirectional LSTM
"""
def __init__(
self,
encoder_n_convolutions=3,
encoder_embedding_dim=512,
encoder_kernel_size=5,
norm_fn=nn.BatchNorm1d,
lstm_norm_fn=None,
):
super(Encoder, self).__init__()
convolutions = []
for _ in range(encoder_n_convolutions):
conv_layer = nn.Sequential(
ConvNorm(
encoder_embedding_dim,
encoder_embedding_dim,
kernel_size=encoder_kernel_size,
stride=1,
padding=int((encoder_kernel_size - 1) / 2),
dilation=1,
w_init_gain="relu",
use_partial_padding=True,
),
norm_fn(encoder_embedding_dim, affine=True),
)
convolutions.append(conv_layer)
self.convolutions = nn.ModuleList(convolutions)
self.lstm = nn.LSTM(
encoder_embedding_dim,
int(encoder_embedding_dim / 2),
1,
batch_first=True,
bidirectional=True,
)
if lstm_norm_fn is not None:
if "spectral" in lstm_norm_fn:
print("Applying spectral norm to text encoder LSTM")
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.spectral_norm
elif "weight" in lstm_norm_fn:
print("Applying weight norm to text encoder LSTM")
lstm_norm_fn_pntr = torch.nn.utils.parametrizations.weight_norm
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0")
self.lstm = lstm_norm_fn_pntr(self.lstm, "weight_hh_l0_reverse")
@torch.autocast(device, enabled=False)
def forward(self, x, in_lens):
"""
Args:
x (torch.tensor): N x C x L padded input of text embeddings
in_lens (torch.tensor): 1D tensor of sequence lengths
"""
if x.size()[0] > 1:
x_embedded = []
for b_ind in range(x.size()[0]): # TODO: improve speed
curr_x = x[b_ind : b_ind + 1, :, : in_lens[b_ind]].clone()
for conv in self.convolutions:
curr_x = F.dropout(F.relu(conv(curr_x)), 0.5, self.training)
x_embedded.append(curr_x[0].transpose(0, 1))
x = torch.nn.utils.rnn.pad_sequence(x_embedded, batch_first=True)
else:
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
# recent amp change -- change in_lens to int
in_lens = in_lens.int().cpu()
x = nn.utils.rnn.pack_padded_sequence(x, in_lens, batch_first=True)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
return outputs
@torch.autocast(device, enabled=False)
def infer(self, x):
for conv in self.convolutions:
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
x = x.transpose(1, 2)
self.lstm.flatten_parameters()
outputs, _ = self.lstm(x)
return outputs
class Invertible1x1ConvLUS(torch.nn.Module):
def __init__(self, c, cache_inverse=False):
super(Invertible1x1ConvLUS, self).__init__()
# Sample a random orthonormal matrix to initialize weights
W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0]
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0]
p, lower, upper = torch.lu_unpack(*torch.linalg.lu_factor(W))
self.register_buffer("p", p)
# diagonals of lower will always be 1s anyway
lower = torch.tril(lower, -1)
lower_diag = torch.diag(torch.eye(c, c))
self.register_buffer("lower_diag", lower_diag)
self.lower = nn.Parameter(lower)
self.upper_diag = nn.Parameter(torch.diag(upper))
self.upper = nn.Parameter(torch.triu(upper, 1))
self.cache_inverse = cache_inverse
@torch.autocast(device, enabled=False)
def forward(self, z, inverse=False):
U = torch.triu(self.upper, 1) + torch.diag(self.upper_diag)
L = torch.tril(self.lower, -1) + torch.diag(self.lower_diag)
W = torch.mm(self.p, torch.mm(L, U))
if inverse:
if not hasattr(self, "W_inverse"):
# inverse computation
W_inverse = W.float().inverse()
if z.is_cuda and z.dtype == torch.float16:
W_inverse = W_inverse.half()
self.W_inverse = W_inverse[..., None]
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
if not self.cache_inverse:
delattr(self, "W_inverse")
return z
else:
W = W[..., None]
z = F.conv1d(z, W, bias=None, stride=1, padding=0)
log_det_W = torch.sum(torch.log(torch.abs(self.upper_diag)))
return z, log_det_W
class Invertible1x1Conv(torch.nn.Module):
"""
The layer outputs both the convolution, and the log determinant
of its weight matrix. If inverse=True it does convolution with
inverse
"""
def __init__(self, c, cache_inverse=False):
super(Invertible1x1Conv, self).__init__()
self.conv = torch.nn.Conv1d(
c, c, kernel_size=1, stride=1, padding=0, bias=False
)
# Sample a random orthonormal matrix to initialize weights
W = torch.qr(torch.FloatTensor(c, c).normal_())[0]
# Ensure determinant is 1.0 not -1.0
if torch.det(W) < 0:
W[:, 0] = -1 * W[:, 0]
W = W.view(c, c, 1)
self.conv.weight.data = W
self.cache_inverse = cache_inverse
def forward(self, z, inverse=False):
# DO NOT apply n_of_groups, as it doesn't account for padded sequences
W = self.conv.weight.squeeze()
if inverse:
if not hasattr(self, "W_inverse"):
# Inverse computation
W_inverse = W.float().inverse()
if z.is_cuda and z.dtype == torch.float16:
W_inverse = W_inverse.half()
self.W_inverse = W_inverse[..., None]
z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0)
if not self.cache_inverse:
delattr(self, "W_inverse")
return z
else:
# Forward computation
log_det_W = torch.logdet(W).clone()
z = self.conv(z)
return z, log_det_W
class SimpleConvNet(torch.nn.Module):
def __init__(
self,
n_mel_channels,
n_context_dim,
final_out_channels,
n_layers=2,
kernel_size=5,
with_dilation=True,
max_channels=1024,
zero_init=True,
use_partial_padding=True,
):
super(SimpleConvNet, self).__init__()
self.layers = torch.nn.ModuleList()
self.n_layers = n_layers
in_channels = n_mel_channels + n_context_dim
out_channels = -1
self.use_partial_padding = use_partial_padding
for i in range(n_layers):
dilation = 2**i if with_dilation else 1
padding = int((kernel_size * dilation - dilation) / 2)
out_channels = min(max_channels, in_channels * 2)
self.layers.append(
ConvNorm(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=True,
w_init_gain="relu",
use_partial_padding=use_partial_padding,
)
)
in_channels = out_channels
self.last_layer = torch.nn.Conv1d(
out_channels, final_out_channels, kernel_size=1
)
if zero_init:
self.last_layer.weight.data *= 0
self.last_layer.bias.data *= 0
def forward(self, z_w_context, seq_lens: torch.Tensor = None):
# seq_lens: tensor array of sequence sequence lengths
# output should be b x n_mel_channels x z_w_context.shape(2)
mask = None
if seq_lens is not None:
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
for i in range(self.n_layers):
z_w_context = self.layers[i](z_w_context, mask)
z_w_context = torch.relu(z_w_context)
z_w_context = self.last_layer(z_w_context)
return z_w_context
class WN(torch.nn.Module):
"""
Adapted from WN() module in WaveGlow with modififcations to variable names
"""
def __init__(
self,
n_in_channels,
n_context_dim,
n_layers,
n_channels,
kernel_size=5,
affine_activation="softplus",
use_partial_padding=True,
):
super(WN, self).__init__()
assert kernel_size % 2 == 1
assert n_channels % 2 == 0
self.n_layers = n_layers
self.n_channels = n_channels
self.in_layers = torch.nn.ModuleList()
self.res_skip_layers = torch.nn.ModuleList()
start = torch.nn.Conv1d(n_in_channels + n_context_dim, n_channels, 1)
start = torch.nn.utils.parametrizations.weight_norm(start, name="weight")
self.start = start
self.softplus = torch.nn.Softplus()
self.affine_activation = affine_activation
self.use_partial_padding = use_partial_padding
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1)
end.weight.data.zero_()
end.bias.data.zero_()
self.end = end
for i in range(n_layers):
dilation = 2**i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = ConvNorm(
n_channels,
n_channels,
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
use_partial_padding=use_partial_padding,
use_weight_norm=True,
)
# in_layer = nn.Conv1d(n_channels, n_channels, kernel_size,
# dilation=dilation, padding=padding)
# in_layer = nn.utils.weight_norm(in_layer)
self.in_layers.append(in_layer)
res_skip_layer = nn.Conv1d(n_channels, n_channels, 1)
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer)
self.res_skip_layers.append(res_skip_layer)
def forward(
self,
forward_input: Tuple[torch.Tensor, torch.Tensor],
seq_lens: torch.Tensor = None,
):
z, context = forward_input
z = torch.cat((z, context), 1) # append context to z as well
z = self.start(z)
output = torch.zeros_like(z)
mask = None
if seq_lens is not None:
mask = get_mask_from_lengths(seq_lens).unsqueeze(1).float()
non_linearity = torch.relu
if self.affine_activation == "softplus":
non_linearity = self.softplus
for i in range(self.n_layers):
z = non_linearity(self.in_layers[i](z, mask))
res_skip_acts = non_linearity(self.res_skip_layers[i](z))
output = output + res_skip_acts
output = self.end(output) # [B, dim, seq_len]
return output
# Affine Coupling Layers
class SplineTransformationLayerAR(torch.nn.Module):
def __init__(
self,
n_in_channels,
n_context_dim,
n_layers,
affine_model="simple_conv",
kernel_size=1,
scaling_fn="exp",
affine_activation="softplus",
n_channels=1024,
n_bins=8,
left=-6,
right=6,
bottom=-6,
top=6,
use_quadratic=False,
):
super(SplineTransformationLayerAR, self).__init__()
self.n_in_channels = n_in_channels # input dimensions
self.left = left
self.right = right
self.bottom = bottom
self.top = top
self.n_bins = n_bins
self.spline_fn = piecewise_linear_transform
self.inv_spline_fn = piecewise_linear_inverse_transform
self.use_quadratic = use_quadratic
if self.use_quadratic:
self.spline_fn = unbounded_piecewise_quadratic_transform
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
self.n_bins = 2 * self.n_bins + 1
final_out_channels = self.n_in_channels * self.n_bins
# autoregressive flow, kernel size 1 and no dilation
self.param_predictor = SimpleConvNet(
n_context_dim,
0,
final_out_channels,
n_layers,
with_dilation=False,
kernel_size=1,
zero_init=True,
use_partial_padding=False,
)
# output is unnormalized bin weights
def normalize(self, z, inverse):
# normalize to [0, 1]
if inverse:
z = (z - self.bottom) / (self.top - self.bottom)
else:
z = (z - self.left) / (self.right - self.left)
return z
def denormalize(self, z, inverse):
if inverse:
z = z * (self.right - self.left) + self.left
else:
z = z * (self.top - self.bottom) + self.bottom
return z
def forward(self, z, context, inverse=False):
b_s, c_s, t_s = z.size(0), z.size(1), z.size(2)
z = self.normalize(z, inverse)
if z.min() < 0.0 or z.max() > 1.0:
print("spline z scaled beyond [0, 1]", z.min(), z.max())
z_reshaped = z.permute(0, 2, 1).reshape(b_s * t_s, -1)
affine_params = self.param_predictor(context)
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, c_s, -1)
with torch.autocast(device, enabled=False):
if self.use_quadratic:
w = q_tilde[:, :, : self.n_bins // 2]
v = q_tilde[:, :, self.n_bins // 2 :]
z_tformed, log_s = self.spline_fn(
z_reshaped.float(), w.float(), v.float(), inverse=inverse
)
else:
z_tformed, log_s = self.spline_fn(z_reshaped.float(), q_tilde.float())
z = z_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
z = self.denormalize(z, inverse)
if inverse:
return z
log_s = log_s.reshape(b_s, t_s, -1)
log_s = log_s.permute(0, 2, 1)
log_s = log_s + c_s * (
np.log(self.top - self.bottom) - np.log(self.right - self.left)
)
return z, log_s
class SplineTransformationLayer(torch.nn.Module):
def __init__(
self,
n_mel_channels,
n_context_dim,
n_layers,
with_dilation=True,
kernel_size=5,
scaling_fn="exp",
affine_activation="softplus",
n_channels=1024,
n_bins=8,
left=-4,
right=4,
bottom=-4,
top=4,
use_quadratic=False,
):
super(SplineTransformationLayer, self).__init__()
self.n_mel_channels = n_mel_channels # input dimensions
self.half_mel_channels = int(n_mel_channels / 2) # half, because we split
self.left = left
self.right = right
self.bottom = bottom
self.top = top
self.n_bins = n_bins
self.spline_fn = piecewise_linear_transform
self.inv_spline_fn = piecewise_linear_inverse_transform
self.use_quadratic = use_quadratic
if self.use_quadratic:
self.spline_fn = unbounded_piecewise_quadratic_transform
self.inv_spline_fn = unbounded_piecewise_quadratic_transform
self.n_bins = 2 * self.n_bins + 1
final_out_channels = self.half_mel_channels * self.n_bins
self.param_predictor = SimpleConvNet(
self.half_mel_channels,
n_context_dim,
final_out_channels,
n_layers,
with_dilation=with_dilation,
kernel_size=kernel_size,
zero_init=False,
)
# output is unnormalized bin weights
def forward(self, z, context, inverse=False, seq_lens=None):
b_s, _, t_s = z.size(0), z.size(1), z.size(2)
# condition on z_0, transform z_1
n_half = self.half_mel_channels
z_0, z_1 = z[:, :n_half], z[:, n_half:]
# normalize to [0,1]
if inverse:
z_1 = (z_1 - self.bottom) / (self.top - self.bottom)
else:
z_1 = (z_1 - self.left) / (self.right - self.left)
z_w_context = torch.cat((z_0, context), 1)
affine_params = self.param_predictor(z_w_context, seq_lens)
z_1_reshaped = z_1.permute(0, 2, 1).reshape(b_s * t_s, -1)
q_tilde = affine_params.permute(0, 2, 1).reshape(b_s * t_s, n_half, self.n_bins)
with torch.autocast(device, enabled=False):
if self.use_quadratic:
w = q_tilde[:, :, : self.n_bins // 2]
v = q_tilde[:, :, self.n_bins // 2 :]
z_1_tformed, log_s = self.spline_fn(
z_1_reshaped.float(), w.float(), v.float(), inverse=inverse
)
if not inverse:
log_s = torch.sum(log_s, 1)
else:
if inverse:
z_1_tformed, _dc = self.inv_spline_fn(
z_1_reshaped.float(), q_tilde.float(), False
)
else:
z_1_tformed, log_s = self.spline_fn(
z_1_reshaped.float(), q_tilde.float()
)
z_1 = z_1_tformed.reshape(b_s, t_s, -1).permute(0, 2, 1)
# undo [0, 1] normalization
if inverse:
z_1 = z_1 * (self.right - self.left) + self.left
z = torch.cat((z_0, z_1), dim=1)
return z
else: # training
z_1 = z_1 * (self.top - self.bottom) + self.bottom
z = torch.cat((z_0, z_1), dim=1)
log_s = log_s.reshape(b_s, t_s).unsqueeze(1) + n_half * (
np.log(self.top - self.bottom) - np.log(self.right - self.left)
)
return z, log_s
class AffineTransformationLayer(torch.nn.Module):
def __init__(
self,
n_mel_channels,
n_context_dim,
n_layers,
affine_model="simple_conv",
with_dilation=True,
kernel_size=5,
scaling_fn="exp",
affine_activation="softplus",
n_channels=1024,
use_partial_padding=False,
):
super(AffineTransformationLayer, self).__init__()
if affine_model not in ("wavenet", "simple_conv"):
raise Exception("{} affine model not supported".format(affine_model))
if isinstance(scaling_fn, list):
if not all(
[x in ("translate", "exp", "tanh", "sigmoid") for x in scaling_fn]
):
raise Exception("{} scaling fn not supported".format(scaling_fn))
else:
if scaling_fn not in ("translate", "exp", "tanh", "sigmoid"):
raise Exception("{} scaling fn not supported".format(scaling_fn))
self.affine_model = affine_model
self.scaling_fn = scaling_fn
if affine_model == "wavenet":
self.affine_param_predictor = WN(
int(n_mel_channels / 2),
n_context_dim,
n_layers=n_layers,
n_channels=n_channels,
affine_activation=affine_activation,
use_partial_padding=use_partial_padding,
)
elif affine_model == "simple_conv":
self.affine_param_predictor = SimpleConvNet(
int(n_mel_channels / 2),
n_context_dim,
n_mel_channels,
n_layers,
with_dilation=with_dilation,
kernel_size=kernel_size,
use_partial_padding=use_partial_padding,
)
self.n_mel_channels = n_mel_channels
def get_scaling_and_logs(self, scale_unconstrained):
if self.scaling_fn == "translate":
s = torch.exp(scale_unconstrained * 0)
log_s = scale_unconstrained * 0
elif self.scaling_fn == "exp":
s = torch.exp(scale_unconstrained)
log_s = scale_unconstrained # log(exp
elif self.scaling_fn == "tanh":
s = torch.tanh(scale_unconstrained) + 1 + 1e-6
log_s = torch.log(s)
elif self.scaling_fn == "sigmoid":
s = torch.sigmoid(scale_unconstrained + 10) + 1e-6
log_s = torch.log(s)
elif isinstance(self.scaling_fn, list):
s_list, log_s_list = [], []
for i in range(scale_unconstrained.shape[1]):
scaling_i = self.scaling_fn[i]
if scaling_i == "translate":
s_i = torch.exp(scale_unconstrained[:i] * 0)
log_s_i = scale_unconstrained[:, i] * 0
elif scaling_i == "exp":
s_i = torch.exp(scale_unconstrained[:, i])
log_s_i = scale_unconstrained[:, i]
elif scaling_i == "tanh":
s_i = torch.tanh(scale_unconstrained[:, i]) + 1 + 1e-6
log_s_i = torch.log(s_i)
elif scaling_i == "sigmoid":
s_i = torch.sigmoid(scale_unconstrained[:, i]) + 1e-6
log_s_i = torch.log(s_i)
s_list.append(s_i[:, None])
log_s_list.append(log_s_i[:, None])
s = torch.cat(s_list, dim=1)
log_s = torch.cat(log_s_list, dim=1)
return s, log_s
def forward(self, z, context, inverse=False, seq_lens=None):
n_half = int(self.n_mel_channels / 2)
z_0, z_1 = z[:, :n_half], z[:, n_half:]
if self.affine_model == "wavenet":
affine_params = self.affine_param_predictor(
(z_0, context), seq_lens=seq_lens
)
elif self.affine_model == "simple_conv":
z_w_context = torch.cat((z_0, context), 1)
affine_params = self.affine_param_predictor(z_w_context, seq_lens=seq_lens)
scale_unconstrained = affine_params[:, :n_half, :]
b = affine_params[:, n_half:, :]
s, log_s = self.get_scaling_and_logs(scale_unconstrained)
if inverse:
z_1 = (z_1 - b) / s
z = torch.cat((z_0, z_1), dim=1)
return z
else:
z_1 = s * z_1 + b
z = torch.cat((z_0, z_1), dim=1)
return z, log_s
class ConvAttention(torch.nn.Module):
def __init__(
self, n_mel_channels=80, n_text_channels=512, n_att_channels=80, temperature=1.0
):
super(ConvAttention, self).__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_proj = nn.Sequential(
ConvNorm(
n_text_channels,
n_text_channels * 2,
kernel_size=3,
bias=True,
w_init_gain="relu",
),
torch.nn.ReLU(),
ConvNorm(n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
)
self.query_proj = nn.Sequential(
ConvNorm(
n_mel_channels,
n_mel_channels * 2,
kernel_size=3,
bias=True,
w_init_gain="relu",
),
torch.nn.ReLU(),
ConvNorm(n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
def run_padded_sequence(
self, sorted_idx, unsort_idx, lens, padded_data, recurrent_model
):
"""Sorts input data by previded ordering (and un-ordering) and runs the
packed data through the recurrent model
Args:
sorted_idx (torch.tensor): 1D sorting index
unsort_idx (torch.tensor): 1D unsorting index (inverse of sorted_idx)
lens: lengths of input data (sorted in descending order)
padded_data (torch.tensor): input sequences (padded)
recurrent_model (nn.Module): recurrent model to run data through
Returns:
hidden_vectors (torch.tensor): outputs of the RNN, in the original,
unsorted, ordering
"""
# sort the data by decreasing length using provided index
# we assume batch index is in dim=1
padded_data = padded_data[:, sorted_idx]
padded_data = nn.utils.rnn.pack_padded_sequence(padded_data, lens)
hidden_vectors = recurrent_model(padded_data)[0]
hidden_vectors, _ = nn.utils.rnn.pad_packed_sequence(hidden_vectors)
# unsort the results at dim=1 and return
hidden_vectors = hidden_vectors[:, unsort_idx]
return hidden_vectors
def forward(
self, queries, keys, query_lens, mask=None, key_lens=None, attn_prior=None
):
"""Attention mechanism for radtts. Unlike in Flowtron, we have no
restrictions such as causality etc, since we only need this during
training.
Args:
queries (torch.tensor): B x C x T1 tensor (likely mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
query_lens: lengths for sorting the queries in descending order
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
temp = 0.0005
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
queries_enc = self.query_proj(queries)
# Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None]) ** 2
# compute log-likelihood from gaussian
eps = 1e-8
attn = -temp * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None] + eps)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(mask.permute(0, 2, 1).unsqueeze(2), -float("inf"))
attn = self.softmax(attn) # softmax along T2
return attn, attn_logprob
def update_params(config, params):
for param in params:
print(param)
k, v = param.split("=")
try:
v = ast.literal_eval(v)
except Exception as e:
print(e)
k_split = k.split(".")
if len(k_split) > 1:
parent_k = k_split[0]
cur_param = [".".join(k_split[1:]) + "=" + str(v)]
update_params(config[parent_k], cur_param)
elif k in config and len(k_split) == 1:
print(f"overriding {k} with {v}")
config[k] = v
else:
print("{}, {} params not updated".format(k, v))