Ionut-Bostan's picture
Upload 82 files
feec0bf
import os
import json
import copy
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from utils.tools import get_mask_from_lengths, pad
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VarianceAdaptor(nn.Module):
"""Variance Adaptor"""
def __init__(self, preprocess_config, model_config):
super(VarianceAdaptor, self).__init__()
self.duration_predictor = VariancePredictor(model_config)
self.length_regulator = LengthRegulator()
self.pitch_predictor = VariancePredictor(model_config)
self.energy_predictor = VariancePredictor(model_config)
self.pitch_feature_level = preprocess_config["preprocessing"]["pitch"][
"feature"
]
self.energy_feature_level = preprocess_config["preprocessing"]["energy"][
"feature"
]
assert self.pitch_feature_level in ["phoneme_level", "frame_level"]
assert self.energy_feature_level in ["phoneme_level", "frame_level"]
pitch_quantization = model_config["variance_embedding"]["pitch_quantization"]
energy_quantization = model_config["variance_embedding"]["energy_quantization"]
n_bins = model_config["variance_embedding"]["n_bins"]
assert pitch_quantization in ["linear", "log"]
assert energy_quantization in ["linear", "log"]
with open(
os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")
) as f:
stats = json.load(f)
pitch_min, pitch_max = stats["pitch"][:2]
energy_min, energy_max = stats["energy"][:2]
if pitch_quantization == "log":
self.pitch_bins = nn.Parameter(
torch.exp(
torch.linspace(np.log(pitch_min), np.log(pitch_max), n_bins - 1)
),
requires_grad=False,
)
else:
self.pitch_bins = nn.Parameter(
torch.linspace(pitch_min, pitch_max, n_bins - 1),
requires_grad=False,
)
if energy_quantization == "log":
self.energy_bins = nn.Parameter(
torch.exp(
torch.linspace(np.log(energy_min), np.log(energy_max), n_bins - 1)
),
requires_grad=False,
)
else:
self.energy_bins = nn.Parameter(
torch.linspace(energy_min, energy_max, n_bins - 1),
requires_grad=False,
)
self.pitch_embedding = nn.Embedding(
n_bins, model_config["transformer"]["encoder_hidden"]
)
self.energy_embedding = nn.Embedding(
n_bins, model_config["transformer"]["encoder_hidden"]
)
def get_pitch_embedding(self, x, target, mask, control):
prediction = self.pitch_predictor(x, mask)
if target is not None:
embedding = self.pitch_embedding(torch.bucketize(target, self.pitch_bins))
else:
prediction = prediction * control
embedding = self.pitch_embedding(
torch.bucketize(prediction, self.pitch_bins)
)
return prediction, embedding
def get_energy_embedding(self, x, target, mask, control):
prediction = self.energy_predictor(x, mask)
if target is not None:
embedding = self.energy_embedding(torch.bucketize(target, self.energy_bins))
else:
prediction = prediction * control
embedding = self.energy_embedding(
torch.bucketize(prediction, self.energy_bins)
)
return prediction, embedding
def forward(
self,
x,
src_mask,
mel_mask=None,
max_len=None,
pitch_target=None,
energy_target=None,
duration_target=None,
p_control=1.0,
e_control=1.0,
d_control=1.0,
):
log_duration_prediction = self.duration_predictor(x, src_mask)
if self.pitch_feature_level == "phoneme_level":
pitch_prediction, pitch_embedding = self.get_pitch_embedding(
x, pitch_target, src_mask, p_control
)
x = x + pitch_embedding
if self.energy_feature_level == "phoneme_level":
energy_prediction, energy_embedding = self.get_energy_embedding(
x, energy_target, src_mask, p_control
)
x = x + energy_embedding
if duration_target is not None:
x, mel_len = self.length_regulator(x, duration_target, max_len)
duration_rounded = duration_target
else:
duration_rounded = torch.clamp(
(torch.round(torch.exp(log_duration_prediction) - 1) * d_control),
min=0,
)
x, mel_len = self.length_regulator(x, duration_rounded, max_len)
mel_mask = get_mask_from_lengths(mel_len)
if self.pitch_feature_level == "frame_level":
pitch_prediction, pitch_embedding = self.get_pitch_embedding(
x, pitch_target, mel_mask, p_control
)
x = x + pitch_embedding
if self.energy_feature_level == "frame_level":
energy_prediction, energy_embedding = self.get_energy_embedding(
x, energy_target, mel_mask, p_control
)
x = x + energy_embedding
return (
x,
pitch_prediction,
energy_prediction,
log_duration_prediction,
duration_rounded,
mel_len,
mel_mask,
)
class LengthRegulator(nn.Module):
"""Length Regulator"""
def __init__(self):
super(LengthRegulator, self).__init__()
def LR(self, x, duration, max_len):
output = list()
mel_len = list()
for batch, expand_target in zip(x, duration):
expanded = self.expand(batch, expand_target)
output.append(expanded)
mel_len.append(expanded.shape[0])
if max_len is not None:
output = pad(output, max_len)
else:
output = pad(output)
return output, torch.LongTensor(mel_len).to(device)
def expand(self, batch, predicted):
out = list()
for i, vec in enumerate(batch):
expand_size = predicted[i].item()
out.append(vec.expand(max(int(expand_size), 0), -1))
out = torch.cat(out, 0)
return out
def forward(self, x, duration, max_len):
output, mel_len = self.LR(x, duration, max_len)
return output, mel_len
class VariancePredictor(nn.Module):
"""Duration, Pitch and Energy Predictor"""
def __init__(self, model_config):
super(VariancePredictor, self).__init__()
self.input_size = model_config["transformer"]["encoder_hidden"]
self.filter_size = model_config["variance_predictor"]["filter_size"]
self.kernel = model_config["variance_predictor"]["kernel_size"]
self.conv_output_size = model_config["variance_predictor"]["filter_size"]
self.dropout = model_config["variance_predictor"]["dropout"]
self.conv_layer = nn.Sequential(
OrderedDict(
[
(
"conv1d_1",
Conv(
self.input_size,
self.filter_size,
kernel_size=self.kernel,
padding=(self.kernel - 1) // 2,
),
),
("relu_1", nn.ReLU()),
("layer_norm_1", nn.LayerNorm(self.filter_size)),
("dropout_1", nn.Dropout(self.dropout)),
(
"conv1d_2",
Conv(
self.filter_size,
self.filter_size,
kernel_size=self.kernel,
padding=1,
),
),
("relu_2", nn.ReLU()),
("layer_norm_2", nn.LayerNorm(self.filter_size)),
("dropout_2", nn.Dropout(self.dropout)),
]
)
)
self.linear_layer = nn.Linear(self.conv_output_size, 1)
def forward(self, encoder_output, mask):
out = self.conv_layer(encoder_output)
out = self.linear_layer(out)
out = out.squeeze(-1)
if mask is not None:
out = out.masked_fill(mask, 0.0)
return out
class Conv(nn.Module):
"""
Convolution Module
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
w_init="linear",
):
"""
:param in_channels: dimension of input
:param out_channels: dimension of output
:param kernel_size: size of kernel
:param stride: size of stride
:param padding: size of padding
:param dilation: dilation rate
:param bias: boolean. if True, bias is included.
:param w_init: str. weight inits with xavier initialization.
"""
super(Conv, self).__init__()
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
def forward(self, x):
x = x.contiguous().transpose(1, 2)
x = self.conv(x)
x = x.contiguous().transpose(1, 2)
return x