File size: 6,053 Bytes
320e69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
import copy
import math
import hparams as hp
import utils
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class VarianceAdaptor(nn.Module):
""" Variance Adaptor """
def __init__(self):
super(VarianceAdaptor, self).__init__()
self.duration_predictor = VariancePredictor()
self.length_regulator = LengthRegulator()
self.pitch_predictor = VariancePredictor()
self.energy_predictor = VariancePredictor()
self.energy_embedding_producer = Conv(1, hp.encoder_hidden, kernel_size=9, bias=False, padding=4)
self.pitch_embedding_producer = Conv(1, hp.encoder_hidden, kernel_size=9, bias=False, padding=4)
def forward(self, x, src_mask, mel_mask=None, duration_target=None, pitch_target=None, energy_target=None, max_len=None):
log_duration_prediction = self.duration_predictor(x, src_mask)
pitch_prediction = self.pitch_predictor(x, src_mask)
if pitch_target is not None:
pitch_embedding = self.pitch_embedding_producer(pitch_target.unsqueeze(2))
else:
pitch_embedding = self.pitch_embedding_producer(pitch_prediction.unsqueeze(2))
energy_prediction = self.energy_predictor(x, src_mask)
if energy_target is not None:
energy_embedding = self.energy_embedding_producer(energy_target.unsqueeze(2))
else:
energy_embedding = self.energy_embedding_producer(energy_prediction.unsqueeze(2))
x = x + pitch_embedding + energy_embedding
if duration_target is not None:
x, mel_len = self.length_regulator(x, duration_target, max_len)
else:
duration_rounded = torch.clamp(torch.round(torch.exp(log_duration_prediction)-hp.log_offset), min=0)
x, mel_len = self.length_regulator(x, duration_rounded, max_len)
mel_mask = utils.get_mask_from_lengths(mel_len)
return x, log_duration_prediction, pitch_prediction, energy_prediction, mel_len, mel_mask
class LengthRegulator(nn.Module):
""" Length Regulator """
def __init__(self):
super(LengthRegulator, self).__init__()
def LR(self, x, duration, max_len):
output = list()
mel_len = list()
for batch, expand_target in zip(x, duration):
expanded = self.expand(batch, expand_target)
output.append(expanded)
mel_len.append(expanded.shape[0])
if max_len is not None:
output = utils.pad(output, max_len)
else:
output = utils.pad(output)
return output, torch.LongTensor(mel_len).to(device)
def expand(self, batch, predicted):
out = list()
for i, vec in enumerate(batch):
expand_size = predicted[i].item()
out.append(vec.expand(int(expand_size), -1))
out = torch.cat(out, 0)
return out
def forward(self, x, duration, max_len):
output, mel_len = self.LR(x, duration, max_len)
return output, mel_len
class VariancePredictor(nn.Module):
""" Duration, Pitch and Energy Predictor """
def __init__(self):
super(VariancePredictor, self).__init__()
self.input_size = hp.encoder_hidden
self.filter_size = hp.variance_predictor_filter_size
self.kernel = hp.variance_predictor_kernel_size
self.conv_output_size = hp.variance_predictor_filter_size
self.dropout = hp.variance_predictor_dropout
self.conv_layer = nn.Sequential(OrderedDict([
("conv1d_1", Conv(self.input_size,
self.filter_size,
kernel_size=self.kernel,
padding=(self.kernel-1)//2)),
("relu_1", nn.ReLU()),
("layer_norm_1", nn.LayerNorm(self.filter_size)),
("dropout_1", nn.Dropout(self.dropout)),
("conv1d_2", Conv(self.filter_size,
self.filter_size,
kernel_size=self.kernel,
padding=1)),
("relu_2", nn.ReLU()),
("layer_norm_2", nn.LayerNorm(self.filter_size)),
("dropout_2", nn.Dropout(self.dropout))
]))
self.linear_layer = nn.Linear(self.conv_output_size, 1)
def forward(self, encoder_output, mask):
out = self.conv_layer(encoder_output)
out = self.linear_layer(out)
out = out.squeeze(-1)
if mask is not None:
out = out.masked_fill(mask, 0.)
return out
class Conv(nn.Module):
"""
Convolution Module
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
dilation=1,
bias=True,
w_init='linear'):
"""
:param in_channels: dimension of input
:param out_channels: dimension of output
:param kernel_size: size of kernel
:param stride: size of stride
:param padding: size of padding
:param dilation: dilation rate
:param bias: boolean. if True, bias is included.
:param w_init: str. weight inits with xavier initialization.
"""
super(Conv, self).__init__()
self.conv = nn.Conv1d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
def forward(self, x):
x = x.contiguous().transpose(1, 2)
x = self.conv(x)
x = x.contiguous().transpose(1, 2)
return x
|