Spaces:
Running
Running
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: MIT | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a | |
# copy of this software and associated documentation files (the "Software"), | |
# to deal in the Software without restriction, including without limitation | |
# the rights to use, copy, modify, merge, publish, distribute, sublicense, | |
# and/or sell copies of the Software, and to permit persons to whom the | |
# Software is furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL | |
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | |
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | |
# DEALINGS IN THE SOFTWARE. | |
import torch | |
from torch import nn | |
from common import ConvNorm, Invertible1x1Conv | |
from common import AffineTransformationLayer, SplineTransformationLayer | |
from common import ConvLSTMLinear | |
from transformer import FFTransformer | |
from autoregressive_flow import AR_Step, AR_Back_Step | |
def get_attribute_prediction_model(config): | |
name = config["name"] | |
hparams = config["hparams"] | |
if name == "dap": | |
model = DAP(**hparams) | |
elif name == "bgap": | |
model = BGAP(**hparams) | |
elif name == "agap": | |
model = AGAP(**hparams) | |
else: | |
raise Exception("{} model is not supported".format(name)) | |
return model | |
class AttributeProcessing: | |
def __init__(self, take_log_of_input=False): | |
super(AttributeProcessing).__init__() | |
self.take_log_of_input = take_log_of_input | |
def normalize(self, x): | |
if self.take_log_of_input: | |
x = torch.log(x + 1) | |
return x | |
def denormalize(self, x): | |
if self.take_log_of_input: | |
x = torch.exp(x) - 1 | |
return x | |
class BottleneckLayerLayer(nn.Module): | |
def __init__( | |
self, | |
in_dim, | |
reduction_factor, | |
norm="weightnorm", | |
non_linearity="relu", | |
kernel_size=3, | |
use_partial_padding=False, | |
): | |
super(BottleneckLayerLayer, self).__init__() | |
self.reduction_factor = reduction_factor | |
reduced_dim = int(in_dim / reduction_factor) | |
self.out_dim = reduced_dim | |
if self.reduction_factor > 1: | |
fn = ConvNorm( | |
in_dim, | |
reduced_dim, | |
kernel_size=kernel_size, | |
use_weight_norm=(norm == "weightnorm"), | |
) | |
if norm == "instancenorm": | |
fn = nn.Sequential(fn, nn.InstanceNorm1d(reduced_dim, affine=True)) | |
self.projection_fn = fn | |
self.non_linearity = nn.ReLU() | |
if non_linearity == "leakyrelu": | |
self.non_linearity = nn.LeakyReLU() | |
def forward(self, x): | |
if self.reduction_factor > 1: | |
x = self.projection_fn(x) | |
x = self.non_linearity(x) | |
return x | |
class DAP(nn.Module): | |
def __init__( | |
self, | |
n_speaker_dim, | |
bottleneck_hparams, | |
take_log_of_input, | |
arch_hparams, | |
use_transformer=False, | |
): | |
super(DAP, self).__init__() | |
self.attribute_processing = AttributeProcessing(take_log_of_input) | |
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams) | |
arch_hparams["in_dim"] = self.bottleneck_layer.out_dim + n_speaker_dim | |
if use_transformer: | |
self.feat_pred_fn = FFTransformer(**arch_hparams) | |
else: | |
self.feat_pred_fn = ConvLSTMLinear(**arch_hparams) | |
def forward(self, txt_enc, spk_emb, x, lens): | |
if x is not None: | |
x = self.attribute_processing.normalize(x) | |
txt_enc = self.bottleneck_layer(txt_enc) | |
spk_emb_expanded = spk_emb[..., None].expand(-1, -1, txt_enc.shape[2]) | |
context = torch.cat((txt_enc, spk_emb_expanded), 1) | |
x_hat = self.feat_pred_fn(context, lens) | |
outputs = {"x_hat": x_hat, "x": x} | |
return outputs | |
def infer(self, z, txt_enc, spk_emb, lens=None): | |
x_hat = self.forward(txt_enc, spk_emb, x=None, lens=lens)["x_hat"] | |
x_hat = self.attribute_processing.denormalize(x_hat) | |
return x_hat | |
class BGAP(torch.nn.Module): | |
def __init__( | |
self, | |
n_in_dim, | |
n_speaker_dim, | |
bottleneck_hparams, | |
n_flows, | |
n_group_size, | |
n_layers, | |
with_dilation, | |
kernel_size, | |
scaling_fn, | |
take_log_of_input=False, | |
n_channels=1024, | |
use_quadratic=False, | |
n_bins=8, | |
n_spline_steps=2, | |
): | |
super(BGAP, self).__init__() | |
# assert(n_group_size % 2 == 0) | |
self.n_flows = n_flows | |
self.n_group_size = n_group_size | |
self.transforms = torch.nn.ModuleList() | |
self.convinv = torch.nn.ModuleList() | |
self.n_speaker_dim = n_speaker_dim | |
self.scaling_fn = scaling_fn | |
self.attribute_processing = AttributeProcessing(take_log_of_input) | |
self.n_spline_steps = n_spline_steps | |
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams) | |
n_txt_reduced_dim = self.bottleneck_layer.out_dim | |
context_dim = n_txt_reduced_dim * n_group_size + n_speaker_dim | |
if self.n_group_size > 1: | |
self.unfold_params = { | |
"kernel_size": (n_group_size, 1), | |
"stride": n_group_size, | |
"padding": 0, | |
"dilation": 1, | |
} | |
self.unfold = nn.Unfold(**self.unfold_params) | |
for k in range(n_flows): | |
self.convinv.append(Invertible1x1Conv(n_in_dim * n_group_size)) | |
if k >= n_flows - self.n_spline_steps: | |
left = -3 | |
right = 3 | |
top = 3 | |
bottom = -3 | |
self.transforms.append( | |
SplineTransformationLayer( | |
n_in_dim * n_group_size, | |
context_dim, | |
n_layers, | |
with_dilation=with_dilation, | |
kernel_size=kernel_size, | |
scaling_fn=scaling_fn, | |
n_channels=n_channels, | |
top=top, | |
bottom=bottom, | |
left=left, | |
right=right, | |
use_quadratic=use_quadratic, | |
n_bins=n_bins, | |
) | |
) | |
else: | |
self.transforms.append( | |
AffineTransformationLayer( | |
n_in_dim * n_group_size, | |
context_dim, | |
n_layers, | |
with_dilation=with_dilation, | |
kernel_size=kernel_size, | |
scaling_fn=scaling_fn, | |
affine_model="simple_conv", | |
n_channels=n_channels, | |
) | |
) | |
def fold(self, data): | |
"""Inverse of the self.unfold(data.unsqueeze(-1)) operation used for | |
the grouping or "squeeze" operation on input | |
Args: | |
data: B x C x T tensor of temporal data | |
""" | |
output_size = (data.shape[2] * self.n_group_size, 1) | |
data = nn.functional.fold( | |
data, output_size=output_size, **self.unfold_params | |
).squeeze(-1) | |
return data | |
def preprocess_context(self, txt_emb, speaker_vecs, std_scale=None): | |
if self.n_group_size > 1: | |
txt_emb = self.unfold(txt_emb[..., None]) | |
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2]) | |
context = torch.cat((txt_emb, speaker_vecs), 1) | |
return context | |
def forward(self, txt_enc, spk_emb, x, lens): | |
"""x<tensor>: duration or pitch or energy average""" | |
assert txt_enc.size(2) >= x.size(1) | |
if len(x.shape) == 2: | |
# add channel dimension | |
x = x[:, None] | |
txt_enc = self.bottleneck_layer(txt_enc) | |
# lens including padded values | |
lens_grouped = (lens // self.n_group_size).long() | |
context = self.preprocess_context(txt_enc, spk_emb) | |
x = self.unfold(x[..., None]) | |
log_s_list, log_det_W_list = [], [] | |
for k in range(self.n_flows): | |
x, log_s = self.transforms[k](x, context, seq_lens=lens_grouped) | |
x, log_det_W = self.convinv[k](x) | |
log_det_W_list.append(log_det_W) | |
log_s_list.append(log_s) | |
# prepare outputs | |
outputs = {"z": x, "log_det_W_list": log_det_W_list, "log_s_list": log_s_list} | |
return outputs | |
def infer(self, z, txt_enc, spk_emb, seq_lens): | |
txt_enc = self.bottleneck_layer(txt_enc) | |
context = self.preprocess_context(txt_enc, spk_emb) | |
lens_grouped = (seq_lens // self.n_group_size).long() | |
z = self.unfold(z[..., None]) | |
for k in reversed(range(self.n_flows)): | |
z = self.convinv[k](z, inverse=True) | |
z = self.transforms[k].forward( | |
z, context, inverse=True, seq_lens=lens_grouped | |
) | |
# z mapped to input domain | |
x_hat = self.fold(z) | |
# pad on the way out | |
return x_hat | |
class AGAP(torch.nn.Module): | |
def __init__( | |
self, | |
n_in_dim, | |
n_speaker_dim, | |
n_flows, | |
n_hidden, | |
n_lstm_layers, | |
bottleneck_hparams, | |
scaling_fn="exp", | |
take_log_of_input=False, | |
p_dropout=0.0, | |
setup="", | |
spline_flow_params=None, | |
n_group_size=1, | |
): | |
super(AGAP, self).__init__() | |
self.flows = torch.nn.ModuleList() | |
self.n_group_size = n_group_size | |
self.n_speaker_dim = n_speaker_dim | |
self.attribute_processing = AttributeProcessing(take_log_of_input) | |
self.n_in_dim = n_in_dim | |
self.bottleneck_layer = BottleneckLayerLayer(**bottleneck_hparams) | |
n_txt_reduced_dim = self.bottleneck_layer.out_dim | |
if self.n_group_size > 1: | |
self.unfold_params = { | |
"kernel_size": (n_group_size, 1), | |
"stride": n_group_size, | |
"padding": 0, | |
"dilation": 1, | |
} | |
self.unfold = nn.Unfold(**self.unfold_params) | |
if spline_flow_params is not None: | |
spline_flow_params["n_in_channels"] *= self.n_group_size | |
for i in range(n_flows): | |
if i % 2 == 0: | |
self.flows.append( | |
AR_Step( | |
n_in_dim * n_group_size, | |
n_speaker_dim, | |
n_txt_reduced_dim * n_group_size, | |
n_hidden, | |
n_lstm_layers, | |
scaling_fn, | |
spline_flow_params, | |
) | |
) | |
else: | |
self.flows.append( | |
AR_Back_Step( | |
n_in_dim * n_group_size, | |
n_speaker_dim, | |
n_txt_reduced_dim * n_group_size, | |
n_hidden, | |
n_lstm_layers, | |
scaling_fn, | |
spline_flow_params, | |
) | |
) | |
def fold(self, data): | |
"""Inverse of the self.unfold(data.unsqueeze(-1)) operation used for | |
the grouping or "squeeze" operation on input | |
Args: | |
data: B x C x T tensor of temporal data | |
""" | |
output_size = (data.shape[2] * self.n_group_size, 1) | |
data = nn.functional.fold( | |
data, output_size=output_size, **self.unfold_params | |
).squeeze(-1) | |
return data | |
def preprocess_context(self, txt_emb, speaker_vecs): | |
if self.n_group_size > 1: | |
txt_emb = self.unfold(txt_emb[..., None]) | |
speaker_vecs = speaker_vecs[..., None].expand(-1, -1, txt_emb.shape[2]) | |
context = torch.cat((txt_emb, speaker_vecs), 1) | |
return context | |
def forward(self, txt_emb, spk_emb, x, lens): | |
"""x<tensor>: duration or pitch or energy average""" | |
x = x[:, None] if len(x.shape) == 2 else x # add channel dimension | |
if self.n_group_size > 1: | |
x = self.unfold(x[..., None]) | |
x = x.permute(2, 0, 1) # permute to time, batch, dims | |
x = self.attribute_processing.normalize(x) | |
txt_emb = self.bottleneck_layer(txt_emb) | |
context = self.preprocess_context(txt_emb, spk_emb) | |
context = context.permute(2, 0, 1) # permute to time, batch, dims | |
lens_groupped = (lens / self.n_group_size).long() | |
log_s_list = [] | |
for i, flow in enumerate(self.flows): | |
x, log_s = flow(x, context, lens_groupped) | |
log_s_list.append(log_s) | |
x = x.permute(1, 2, 0) # x mapped to z | |
log_s_list = [log_s_elt.permute(1, 2, 0) for log_s_elt in log_s_list] | |
outputs = {"z": x, "log_s_list": log_s_list, "log_det_W_list": []} | |
return outputs | |
def infer(self, z, txt_emb, spk_emb, seq_lens=None): | |
if self.n_group_size > 1: | |
n_frames = z.shape[2] | |
z = self.unfold(z[..., None]) | |
z = z.permute(2, 0, 1) # permute to time, batch, dims | |
txt_emb = self.bottleneck_layer(txt_emb) | |
context = self.preprocess_context(txt_emb, spk_emb) | |
context = context.permute(2, 0, 1) # permute to time, batch, dims | |
for i, flow in enumerate(reversed(self.flows)): | |
z = flow.infer(z, context) | |
x_hat = z.permute(1, 2, 0) | |
if self.n_group_size > 1: | |
x_hat = self.fold(x_hat) | |
if n_frames > x_hat.shape[2]: | |
m = nn.ReflectionPad1d((0, n_frames - x_hat.shape[2])) | |
x_hat = m(x_hat) | |
x_hat = self.attribute_processing.denormalize(x_hat) | |
return x_hat | |