|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Tuple, Union
|
|
|
|
import torch
|
|
from coqpit import Coqpit
|
|
from torch import nn
|
|
from torch.cuda.amp.autocast_mode import autocast
|
|
|
|
from TTS.tts.layers.feed_forward.decoder import Decoder
|
|
from TTS.tts.layers.feed_forward.encoder import Encoder
|
|
from TTS.tts.layers.generic.aligner import AlignmentNetwork
|
|
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
|
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
|
from TTS.tts.models.base_tts import BaseTTS
|
|
from TTS.tts.utils.helpers import average_over_durations, generate_path, maximum_path, sequence_mask
|
|
from TTS.tts.utils.speakers import SpeakerManager
|
|
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
|
from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram
|
|
from TTS.utils.io import load_fsspec
|
|
|
|
|
|
@dataclass
|
|
class ForwardTTSArgs(Coqpit):
|
|
"""ForwardTTS Model arguments.
|
|
|
|
Args:
|
|
|
|
num_chars (int):
|
|
Number of characters in the vocabulary. Defaults to 100.
|
|
|
|
out_channels (int):
|
|
Number of output channels. Defaults to 80.
|
|
|
|
hidden_channels (int):
|
|
Number of base hidden channels of the model. Defaults to 512.
|
|
|
|
use_aligner (bool):
|
|
Whether to use aligner network to learn the text to speech alignment or use pre-computed durations.
|
|
If set False, durations should be computed by `TTS/bin/compute_attention_masks.py` and path to the
|
|
pre-computed durations must be provided to `config.datasets[0].meta_file_attn_mask`. Defaults to True.
|
|
|
|
use_pitch (bool):
|
|
Use pitch predictor to learn the pitch. Defaults to True.
|
|
|
|
use_energy (bool):
|
|
Use energy predictor to learn the energy. Defaults to True.
|
|
|
|
duration_predictor_hidden_channels (int):
|
|
Number of hidden channels in the duration predictor. Defaults to 256.
|
|
|
|
duration_predictor_dropout_p (float):
|
|
Dropout rate for the duration predictor. Defaults to 0.1.
|
|
|
|
duration_predictor_kernel_size (int):
|
|
Kernel size of conv layers in the duration predictor. Defaults to 3.
|
|
|
|
pitch_predictor_hidden_channels (int):
|
|
Number of hidden channels in the pitch predictor. Defaults to 256.
|
|
|
|
pitch_predictor_dropout_p (float):
|
|
Dropout rate for the pitch predictor. Defaults to 0.1.
|
|
|
|
pitch_predictor_kernel_size (int):
|
|
Kernel size of conv layers in the pitch predictor. Defaults to 3.
|
|
|
|
pitch_embedding_kernel_size (int):
|
|
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
|
|
|
|
energy_predictor_hidden_channels (int):
|
|
Number of hidden channels in the energy predictor. Defaults to 256.
|
|
|
|
energy_predictor_dropout_p (float):
|
|
Dropout rate for the energy predictor. Defaults to 0.1.
|
|
|
|
energy_predictor_kernel_size (int):
|
|
Kernel size of conv layers in the energy predictor. Defaults to 3.
|
|
|
|
energy_embedding_kernel_size (int):
|
|
Kernel size of the projection layer in the energy predictor. Defaults to 3.
|
|
|
|
positional_encoding (bool):
|
|
Whether to use positional encoding. Defaults to True.
|
|
|
|
positional_encoding_use_scale (bool):
|
|
Whether to use a learnable scale coeff in the positional encoding. Defaults to True.
|
|
|
|
length_scale (int):
|
|
Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0.
|
|
|
|
encoder_type (str):
|
|
Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`.
|
|
Defaults to `fftransformer` as in the paper.
|
|
|
|
encoder_params (dict):
|
|
Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
|
|
|
decoder_type (str):
|
|
Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`.
|
|
Defaults to `fftransformer` as in the paper.
|
|
|
|
decoder_params (str):
|
|
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
|
|
|
|
detach_duration_predictor (bool):
|
|
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
|
|
does not pass to the earlier layers. Defaults to True.
|
|
|
|
max_duration (int):
|
|
Maximum duration accepted by the model. Defaults to 75.
|
|
|
|
num_speakers (int):
|
|
Number of speakers for the speaker embedding layer. Defaults to 0.
|
|
|
|
speakers_file (str):
|
|
Path to the speaker mapping file for the Speaker Manager. Defaults to None.
|
|
|
|
speaker_embedding_channels (int):
|
|
Number of speaker embedding channels. Defaults to 256.
|
|
|
|
use_d_vector_file (bool):
|
|
Enable/Disable the use of d-vectors for multi-speaker training. Defaults to False.
|
|
|
|
d_vector_dim (int):
|
|
Number of d-vector channels. Defaults to 0.
|
|
|
|
"""
|
|
|
|
num_chars: int = None
|
|
out_channels: int = 80
|
|
hidden_channels: int = 384
|
|
use_aligner: bool = True
|
|
|
|
use_pitch: bool = True
|
|
pitch_predictor_hidden_channels: int = 256
|
|
pitch_predictor_kernel_size: int = 3
|
|
pitch_predictor_dropout_p: float = 0.1
|
|
pitch_embedding_kernel_size: int = 3
|
|
|
|
|
|
use_energy: bool = False
|
|
energy_predictor_hidden_channels: int = 256
|
|
energy_predictor_kernel_size: int = 3
|
|
energy_predictor_dropout_p: float = 0.1
|
|
energy_embedding_kernel_size: int = 3
|
|
|
|
|
|
duration_predictor_hidden_channels: int = 256
|
|
duration_predictor_kernel_size: int = 3
|
|
duration_predictor_dropout_p: float = 0.1
|
|
|
|
positional_encoding: bool = True
|
|
poisitonal_encoding_use_scale: bool = True
|
|
length_scale: int = 1
|
|
encoder_type: str = "fftransformer"
|
|
encoder_params: dict = field(
|
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
|
)
|
|
decoder_type: str = "fftransformer"
|
|
decoder_params: dict = field(
|
|
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
|
|
)
|
|
detach_duration_predictor: bool = False
|
|
max_duration: int = 75
|
|
num_speakers: int = 1
|
|
use_speaker_embedding: bool = False
|
|
speakers_file: str = None
|
|
use_d_vector_file: bool = False
|
|
d_vector_dim: int = None
|
|
d_vector_file: str = None
|
|
|
|
|
|
class ForwardTTS(BaseTTS):
|
|
"""General forward TTS model implementation that uses an encoder-decoder architecture with an optional alignment
|
|
network and a pitch predictor.
|
|
|
|
If the alignment network is used, the model learns the text-to-speech alignment
|
|
from the data instead of using pre-computed durations.
|
|
|
|
If the pitch predictor is used, the model trains a pitch predictor that predicts average pitch value for each
|
|
input character as in the FastPitch model.
|
|
|
|
`ForwardTTS` can be configured to one of these architectures,
|
|
|
|
- FastPitch
|
|
- SpeedySpeech
|
|
- FastSpeech
|
|
- FastSpeech2 (requires average speech energy predictor)
|
|
|
|
Args:
|
|
config (Coqpit): Model coqpit class.
|
|
speaker_manager (SpeakerManager): Speaker manager for multi-speaker training. Only used for multi-speaker models.
|
|
Defaults to None.
|
|
|
|
Examples:
|
|
>>> from TTS.tts.models.fast_pitch import ForwardTTS, ForwardTTSArgs
|
|
>>> config = ForwardTTSArgs()
|
|
>>> model = ForwardTTS(config)
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
self,
|
|
config: Coqpit,
|
|
ap: "AudioProcessor" = None,
|
|
tokenizer: "TTSTokenizer" = None,
|
|
speaker_manager: SpeakerManager = None,
|
|
):
|
|
super().__init__(config, ap, tokenizer, speaker_manager)
|
|
self._set_model_args(config)
|
|
|
|
self.init_multispeaker(config)
|
|
|
|
self.max_duration = self.args.max_duration
|
|
self.use_aligner = self.args.use_aligner
|
|
self.use_pitch = self.args.use_pitch
|
|
self.use_energy = self.args.use_energy
|
|
self.binary_loss_weight = 0.0
|
|
|
|
self.length_scale = (
|
|
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
|
|
)
|
|
|
|
self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
|
|
|
|
self.encoder = Encoder(
|
|
self.args.hidden_channels,
|
|
self.args.hidden_channels,
|
|
self.args.encoder_type,
|
|
self.args.encoder_params,
|
|
self.embedded_speaker_dim,
|
|
)
|
|
|
|
if self.args.positional_encoding:
|
|
self.pos_encoder = PositionalEncoding(self.args.hidden_channels)
|
|
|
|
self.decoder = Decoder(
|
|
self.args.out_channels,
|
|
self.args.hidden_channels,
|
|
self.args.decoder_type,
|
|
self.args.decoder_params,
|
|
)
|
|
|
|
self.duration_predictor = DurationPredictor(
|
|
self.args.hidden_channels,
|
|
self.args.duration_predictor_hidden_channels,
|
|
self.args.duration_predictor_kernel_size,
|
|
self.args.duration_predictor_dropout_p,
|
|
)
|
|
|
|
if self.args.use_pitch:
|
|
self.pitch_predictor = DurationPredictor(
|
|
self.args.hidden_channels,
|
|
self.args.pitch_predictor_hidden_channels,
|
|
self.args.pitch_predictor_kernel_size,
|
|
self.args.pitch_predictor_dropout_p,
|
|
)
|
|
self.pitch_emb = nn.Conv1d(
|
|
1,
|
|
self.args.hidden_channels,
|
|
kernel_size=self.args.pitch_embedding_kernel_size,
|
|
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
|
|
)
|
|
|
|
if self.args.use_energy:
|
|
self.energy_predictor = DurationPredictor(
|
|
self.args.hidden_channels,
|
|
self.args.energy_predictor_hidden_channels,
|
|
self.args.energy_predictor_kernel_size,
|
|
self.args.energy_predictor_dropout_p,
|
|
)
|
|
self.energy_emb = nn.Conv1d(
|
|
1,
|
|
self.args.hidden_channels,
|
|
kernel_size=self.args.energy_embedding_kernel_size,
|
|
padding=int((self.args.energy_embedding_kernel_size - 1) / 2),
|
|
)
|
|
|
|
if self.args.use_aligner:
|
|
self.aligner = AlignmentNetwork(
|
|
in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels
|
|
)
|
|
|
|
def init_multispeaker(self, config: Coqpit):
|
|
"""Init for multi-speaker training.
|
|
|
|
Args:
|
|
config (Coqpit): Model configuration.
|
|
"""
|
|
self.embedded_speaker_dim = 0
|
|
|
|
if self.speaker_manager is None and (config.use_d_vector_file or config.use_speaker_embedding):
|
|
raise ValueError(
|
|
" > SpeakerManager is not provided. You must provide the SpeakerManager before initializing a multi-speaker model."
|
|
)
|
|
|
|
if self.speaker_manager is not None:
|
|
self.num_speakers = self.speaker_manager.num_speakers
|
|
|
|
if config.use_d_vector_file:
|
|
self.embedded_speaker_dim = config.d_vector_dim
|
|
if self.args.d_vector_dim != self.args.hidden_channels:
|
|
|
|
self.proj_g = nn.Linear(in_features=self.args.d_vector_dim, out_features=self.args.hidden_channels)
|
|
|
|
if config.use_speaker_embedding and not config.use_d_vector_file:
|
|
print(" > Init speaker_embedding layer.")
|
|
self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels)
|
|
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
|
|
|
@staticmethod
|
|
def generate_attn(dr, x_mask, y_mask=None):
|
|
"""Generate an attention mask from the durations.
|
|
|
|
Shapes
|
|
- dr: :math:`(B, T_{en})`
|
|
- x_mask: :math:`(B, T_{en})`
|
|
- y_mask: :math:`(B, T_{de})`
|
|
"""
|
|
|
|
if y_mask is None:
|
|
y_lengths = dr.sum(1).long()
|
|
y_lengths[y_lengths < 1] = 1
|
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
|
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
|
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
|
|
return attn
|
|
|
|
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
|
"""Generate attention alignment map from durations and
|
|
expand encoder outputs
|
|
|
|
Shapes:
|
|
- en: :math:`(B, D_{en}, T_{en})`
|
|
- dr: :math:`(B, T_{en})`
|
|
- x_mask: :math:`(B, T_{en})`
|
|
- y_mask: :math:`(B, T_{de})`
|
|
|
|
Examples::
|
|
|
|
encoder output: [a,b,c,d]
|
|
durations: [1, 3, 2, 1]
|
|
|
|
expanded: [a, b, b, b, c, c, d]
|
|
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
|
[0, 0, 0, 0, 1, 1, 0],
|
|
[0, 1, 1, 1, 0, 0, 0],
|
|
[1, 0, 0, 0, 0, 0, 0]]
|
|
"""
|
|
attn = self.generate_attn(dr, x_mask, y_mask)
|
|
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
|
|
return o_en_ex, attn
|
|
|
|
def format_durations(self, o_dr_log, x_mask):
|
|
"""Format predicted durations.
|
|
1. Convert to linear scale from log scale
|
|
2. Apply the length scale for speed adjustment
|
|
3. Apply masking.
|
|
4. Cast 0 durations to 1.
|
|
5. Round the duration values.
|
|
|
|
Args:
|
|
o_dr_log: Log scale durations.
|
|
x_mask: Input text mask.
|
|
|
|
Shapes:
|
|
- o_dr_log: :math:`(B, T_{de})`
|
|
- x_mask: :math:`(B, T_{en})`
|
|
"""
|
|
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
|
o_dr[o_dr < 1] = 1.0
|
|
o_dr = torch.round(o_dr)
|
|
return o_dr
|
|
|
|
def _forward_encoder(
|
|
self, x: torch.LongTensor, x_mask: torch.FloatTensor, g: torch.FloatTensor = None
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
"""Encoding forward pass.
|
|
|
|
1. Embed speaker IDs if multi-speaker mode.
|
|
2. Embed character sequences.
|
|
3. Run the encoder network.
|
|
4. Sum encoder outputs and speaker embeddings
|
|
|
|
Args:
|
|
x (torch.LongTensor): Input sequence IDs.
|
|
x_mask (torch.FloatTensor): Input squence mask.
|
|
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
|
|
|
|
Returns:
|
|
Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
|
|
encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings,
|
|
character embeddings
|
|
|
|
Shapes:
|
|
- x: :math:`(B, T_{en})`
|
|
- x_mask: :math:`(B, 1, T_{en})`
|
|
- g: :math:`(B, C)`
|
|
"""
|
|
if hasattr(self, "emb_g"):
|
|
g = g.type(torch.LongTensor)
|
|
g = self.emb_g(g)
|
|
if g is not None:
|
|
g = g.unsqueeze(-1)
|
|
|
|
x_emb = self.emb(x)
|
|
|
|
|
|
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask, g)
|
|
|
|
|
|
if g is not None:
|
|
if hasattr(self, "proj_g"):
|
|
g = self.proj_g(g.view(g.shape[0], -1)).unsqueeze(-1)
|
|
o_en = o_en + g
|
|
return o_en, x_mask, g, x_emb
|
|
|
|
def _forward_decoder(
|
|
self,
|
|
o_en: torch.FloatTensor,
|
|
dr: torch.IntTensor,
|
|
x_mask: torch.FloatTensor,
|
|
y_lengths: torch.IntTensor,
|
|
g: torch.FloatTensor,
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
"""Decoding forward pass.
|
|
|
|
1. Compute the decoder output mask
|
|
2. Expand encoder output with the durations.
|
|
3. Apply position encoding.
|
|
4. Add speaker embeddings if multi-speaker mode.
|
|
5. Run the decoder.
|
|
|
|
Args:
|
|
o_en (torch.FloatTensor): Encoder output.
|
|
dr (torch.IntTensor): Ground truth durations or alignment network durations.
|
|
x_mask (torch.IntTensor): Input sequence mask.
|
|
y_lengths (torch.IntTensor): Output sequence lengths.
|
|
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
|
|
|
|
Returns:
|
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
|
|
"""
|
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
|
|
|
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
|
|
|
if hasattr(self, "pos_encoder"):
|
|
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
|
|
|
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
|
return o_de.transpose(1, 2), attn.transpose(1, 2)
|
|
|
|
def _forward_pitch_predictor(
|
|
self,
|
|
o_en: torch.FloatTensor,
|
|
x_mask: torch.IntTensor,
|
|
pitch: torch.FloatTensor = None,
|
|
dr: torch.IntTensor = None,
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
"""Pitch predictor forward pass.
|
|
|
|
1. Predict pitch from encoder outputs.
|
|
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
|
3. Embed average pitch values.
|
|
|
|
Args:
|
|
o_en (torch.FloatTensor): Encoder output.
|
|
x_mask (torch.IntTensor): Input sequence mask.
|
|
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
|
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
|
|
|
Returns:
|
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
|
|
|
|
Shapes:
|
|
- o_en: :math:`(B, C, T_{en})`
|
|
- x_mask: :math:`(B, 1, T_{en})`
|
|
- pitch: :math:`(B, 1, T_{de})`
|
|
- dr: :math:`(B, T_{en})`
|
|
"""
|
|
o_pitch = self.pitch_predictor(o_en, x_mask)
|
|
if pitch is not None:
|
|
avg_pitch = average_over_durations(pitch, dr)
|
|
o_pitch_emb = self.pitch_emb(avg_pitch)
|
|
return o_pitch_emb, o_pitch, avg_pitch
|
|
o_pitch_emb = self.pitch_emb(o_pitch)
|
|
return o_pitch_emb, o_pitch
|
|
|
|
def _forward_energy_predictor(
|
|
self,
|
|
o_en: torch.FloatTensor,
|
|
x_mask: torch.IntTensor,
|
|
energy: torch.FloatTensor = None,
|
|
dr: torch.IntTensor = None,
|
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
"""Energy predictor forward pass.
|
|
|
|
1. Predict energy from encoder outputs.
|
|
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
|
|
3. Embed average energy values.
|
|
|
|
Args:
|
|
o_en (torch.FloatTensor): Encoder output.
|
|
x_mask (torch.IntTensor): Input sequence mask.
|
|
energy (torch.FloatTensor, optional): Ground truth energy values. Defaults to None.
|
|
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
|
|
|
|
Returns:
|
|
Tuple[torch.FloatTensor, torch.FloatTensor]: Energy embedding, energy prediction.
|
|
|
|
Shapes:
|
|
- o_en: :math:`(B, C, T_{en})`
|
|
- x_mask: :math:`(B, 1, T_{en})`
|
|
- pitch: :math:`(B, 1, T_{de})`
|
|
- dr: :math:`(B, T_{en})`
|
|
"""
|
|
o_energy = self.energy_predictor(o_en, x_mask)
|
|
if energy is not None:
|
|
avg_energy = average_over_durations(energy, dr)
|
|
o_energy_emb = self.energy_emb(avg_energy)
|
|
return o_energy_emb, o_energy, avg_energy
|
|
o_energy_emb = self.energy_emb(o_energy)
|
|
return o_energy_emb, o_energy
|
|
|
|
def _forward_aligner(
|
|
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
|
|
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
"""Aligner forward pass.
|
|
|
|
1. Compute a mask to apply to the attention map.
|
|
2. Run the alignment network.
|
|
3. Apply MAS to compute the hard alignment map.
|
|
4. Compute the durations from the hard alignment map.
|
|
|
|
Args:
|
|
x (torch.FloatTensor): Input sequence.
|
|
y (torch.FloatTensor): Output sequence.
|
|
x_mask (torch.IntTensor): Input sequence mask.
|
|
y_mask (torch.IntTensor): Output sequence mask.
|
|
|
|
Returns:
|
|
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
|
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
|
|
hard alignment map.
|
|
|
|
Shapes:
|
|
- x: :math:`[B, T_en, C_en]`
|
|
- y: :math:`[B, T_de, C_de]`
|
|
- x_mask: :math:`[B, 1, T_en]`
|
|
- y_mask: :math:`[B, 1, T_de]`
|
|
|
|
- o_alignment_dur: :math:`[B, T_en]`
|
|
- alignment_soft: :math:`[B, T_en, T_de]`
|
|
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
|
|
- alignment_mas: :math:`[B, T_en, T_de]`
|
|
"""
|
|
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
|
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
|
|
alignment_mas = maximum_path(
|
|
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
|
|
)
|
|
o_alignment_dur = torch.sum(alignment_mas, -1).int()
|
|
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
|
|
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
|
|
|
|
def _set_speaker_input(self, aux_input: Dict):
|
|
d_vectors = aux_input.get("d_vectors", None)
|
|
speaker_ids = aux_input.get("speaker_ids", None)
|
|
|
|
if d_vectors is not None and speaker_ids is not None:
|
|
raise ValueError("[!] Cannot use d-vectors and speaker-ids together.")
|
|
|
|
if speaker_ids is not None and not hasattr(self, "emb_g"):
|
|
raise ValueError("[!] Cannot use speaker-ids without enabling speaker embedding.")
|
|
|
|
g = speaker_ids if speaker_ids is not None else d_vectors
|
|
return g
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.LongTensor,
|
|
x_lengths: torch.LongTensor,
|
|
y_lengths: torch.LongTensor,
|
|
y: torch.FloatTensor = None,
|
|
dr: torch.IntTensor = None,
|
|
pitch: torch.FloatTensor = None,
|
|
energy: torch.FloatTensor = None,
|
|
aux_input: Dict = {"d_vectors": None, "speaker_ids": None},
|
|
) -> Dict:
|
|
"""Model's forward pass.
|
|
|
|
Args:
|
|
x (torch.LongTensor): Input character sequences.
|
|
x_lengths (torch.LongTensor): Input sequence lengths.
|
|
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
|
|
y (torch.FloatTensor): Spectrogram frames. Only used when the alignment network is on. Defaults to None.
|
|
dr (torch.IntTensor): Character durations over the spectrogram frames. Only used when the alignment network is off. Defaults to None.
|
|
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Only used when the pitch predictor is on. Defaults to None.
|
|
energy (torch.FloatTensor): energy values for each spectrogram frame. Only used when the energy predictor is on. Defaults to None.
|
|
aux_input (Dict): Auxiliary model inputs for multi-speaker training. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
|
|
|
|
Shapes:
|
|
- x: :math:`[B, T_max]`
|
|
- x_lengths: :math:`[B]`
|
|
- y_lengths: :math:`[B]`
|
|
- y: :math:`[B, T_max2]`
|
|
- dr: :math:`[B, T_max]`
|
|
- g: :math:`[B, C]`
|
|
- pitch: :math:`[B, 1, T]`
|
|
"""
|
|
g = self._set_speaker_input(aux_input)
|
|
|
|
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).float()
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).float()
|
|
|
|
o_en, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
|
|
|
|
if self.args.detach_duration_predictor:
|
|
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
|
else:
|
|
o_dr_log = self.duration_predictor(o_en, x_mask)
|
|
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
|
|
|
|
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
|
|
|
|
o_alignment_dur = None
|
|
alignment_soft = None
|
|
alignment_logprob = None
|
|
alignment_mas = None
|
|
if self.use_aligner:
|
|
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
|
|
x_emb, y, x_mask, y_mask
|
|
)
|
|
alignment_soft = alignment_soft.transpose(1, 2)
|
|
alignment_mas = alignment_mas.transpose(1, 2)
|
|
dr = o_alignment_dur
|
|
|
|
o_pitch = None
|
|
avg_pitch = None
|
|
if self.args.use_pitch:
|
|
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en, x_mask, pitch, dr)
|
|
o_en = o_en + o_pitch_emb
|
|
|
|
o_energy = None
|
|
avg_energy = None
|
|
if self.args.use_energy:
|
|
o_energy_emb, o_energy, avg_energy = self._forward_energy_predictor(o_en, x_mask, energy, dr)
|
|
o_en = o_en + o_energy_emb
|
|
|
|
o_de, attn = self._forward_decoder(
|
|
o_en, dr, x_mask, y_lengths, g=None
|
|
)
|
|
outputs = {
|
|
"model_outputs": o_de,
|
|
"durations_log": o_dr_log.squeeze(1),
|
|
"durations": o_dr.squeeze(1),
|
|
"attn_durations": o_attn,
|
|
"pitch_avg": o_pitch,
|
|
"pitch_avg_gt": avg_pitch,
|
|
"energy_avg": o_energy,
|
|
"energy_avg_gt": avg_energy,
|
|
"alignments": attn,
|
|
"alignment_soft": alignment_soft,
|
|
"alignment_mas": alignment_mas,
|
|
"o_alignment_dur": o_alignment_dur,
|
|
"alignment_logprob": alignment_logprob,
|
|
"x_mask": x_mask,
|
|
"y_mask": y_mask,
|
|
}
|
|
return outputs
|
|
|
|
@torch.no_grad()
|
|
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}):
|
|
"""Model's inference pass.
|
|
|
|
Args:
|
|
x (torch.LongTensor): Input character sequence.
|
|
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
|
|
|
|
Shapes:
|
|
- x: [B, T_max]
|
|
- x_lengths: [B]
|
|
- g: [B, C]
|
|
"""
|
|
g = self._set_speaker_input(aux_input)
|
|
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
|
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
|
|
|
|
o_en, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
|
|
|
|
o_dr_log = self.duration_predictor(o_en.squeeze(), x_mask)
|
|
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
|
y_lengths = o_dr.sum(1)
|
|
|
|
|
|
o_pitch = None
|
|
if self.args.use_pitch:
|
|
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en, x_mask)
|
|
o_en = o_en + o_pitch_emb
|
|
|
|
o_energy = None
|
|
if self.args.use_energy:
|
|
o_energy_emb, o_energy = self._forward_energy_predictor(o_en, x_mask)
|
|
o_en = o_en + o_energy_emb
|
|
|
|
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=None)
|
|
outputs = {
|
|
"model_outputs": o_de,
|
|
"alignments": attn,
|
|
"pitch": o_pitch,
|
|
"energy": o_energy,
|
|
"durations_log": o_dr_log,
|
|
}
|
|
return outputs
|
|
|
|
def train_step(self, batch: dict, criterion: nn.Module):
|
|
text_input = batch["text_input"]
|
|
text_lengths = batch["text_lengths"]
|
|
mel_input = batch["mel_input"]
|
|
mel_lengths = batch["mel_lengths"]
|
|
pitch = batch["pitch"] if self.args.use_pitch else None
|
|
energy = batch["energy"] if self.args.use_energy else None
|
|
d_vectors = batch["d_vectors"]
|
|
speaker_ids = batch["speaker_ids"]
|
|
durations = batch["durations"]
|
|
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
|
|
|
|
|
outputs = self.forward(
|
|
text_input,
|
|
text_lengths,
|
|
mel_lengths,
|
|
y=mel_input,
|
|
dr=durations,
|
|
pitch=pitch,
|
|
energy=energy,
|
|
aux_input=aux_input,
|
|
)
|
|
|
|
if self.use_aligner:
|
|
durations = outputs["o_alignment_dur"]
|
|
|
|
with autocast(enabled=False):
|
|
|
|
loss_dict = criterion(
|
|
decoder_output=outputs["model_outputs"],
|
|
decoder_target=mel_input,
|
|
decoder_output_lens=mel_lengths,
|
|
dur_output=outputs["durations_log"],
|
|
dur_target=durations,
|
|
pitch_output=outputs["pitch_avg"] if self.use_pitch else None,
|
|
pitch_target=outputs["pitch_avg_gt"] if self.use_pitch else None,
|
|
energy_output=outputs["energy_avg"] if self.use_energy else None,
|
|
energy_target=outputs["energy_avg_gt"] if self.use_energy else None,
|
|
input_lens=text_lengths,
|
|
alignment_logprob=outputs["alignment_logprob"] if self.use_aligner else None,
|
|
alignment_soft=outputs["alignment_soft"],
|
|
alignment_hard=outputs["alignment_mas"],
|
|
binary_loss_weight=self.binary_loss_weight,
|
|
)
|
|
|
|
durations_pred = outputs["durations"]
|
|
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
|
|
loss_dict["duration_error"] = duration_error
|
|
|
|
return outputs, loss_dict
|
|
|
|
def _create_logs(self, batch, outputs, ap):
|
|
"""Create common logger outputs."""
|
|
model_outputs = outputs["model_outputs"]
|
|
alignments = outputs["alignments"]
|
|
mel_input = batch["mel_input"]
|
|
|
|
pred_spec = model_outputs[0].data.cpu().numpy()
|
|
gt_spec = mel_input[0].data.cpu().numpy()
|
|
align_img = alignments[0].data.cpu().numpy()
|
|
|
|
figures = {
|
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
|
"alignment": plot_alignment(align_img, output_fig=False),
|
|
}
|
|
|
|
|
|
if self.args.use_pitch:
|
|
pitch_avg = abs(outputs["pitch_avg_gt"][0, 0].data.cpu().numpy())
|
|
pitch_avg_hat = abs(outputs["pitch_avg"][0, 0].data.cpu().numpy())
|
|
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
|
pitch_figures = {
|
|
"pitch_ground_truth": plot_avg_pitch(pitch_avg, chars, output_fig=False),
|
|
"pitch_avg_predicted": plot_avg_pitch(pitch_avg_hat, chars, output_fig=False),
|
|
}
|
|
figures.update(pitch_figures)
|
|
|
|
|
|
if self.args.use_energy:
|
|
energy_avg = abs(outputs["energy_avg_gt"][0, 0].data.cpu().numpy())
|
|
energy_avg_hat = abs(outputs["energy_avg"][0, 0].data.cpu().numpy())
|
|
chars = self.tokenizer.decode(batch["text_input"][0].data.cpu().numpy())
|
|
energy_figures = {
|
|
"energy_ground_truth": plot_avg_energy(energy_avg, chars, output_fig=False),
|
|
"energy_avg_predicted": plot_avg_energy(energy_avg_hat, chars, output_fig=False),
|
|
}
|
|
figures.update(energy_figures)
|
|
|
|
|
|
if "attn_durations" in outputs:
|
|
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
|
|
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
|
|
|
|
|
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
|
return figures, {"audio": train_audio}
|
|
|
|
def train_log(
|
|
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
|
) -> None:
|
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
|
logger.train_figures(steps, figures)
|
|
logger.train_audios(steps, audios, self.ap.sample_rate)
|
|
|
|
def eval_step(self, batch: dict, criterion: nn.Module):
|
|
return self.train_step(batch, criterion)
|
|
|
|
def eval_log(self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int) -> None:
|
|
figures, audios = self._create_logs(batch, outputs, self.ap)
|
|
logger.eval_figures(steps, figures)
|
|
logger.eval_audios(steps, audios, self.ap.sample_rate)
|
|
|
|
def load_checkpoint(
|
|
self, config, checkpoint_path, eval=False, cache=False
|
|
):
|
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache)
|
|
self.load_state_dict(state["model"])
|
|
if eval:
|
|
self.eval()
|
|
assert not self.training
|
|
|
|
def get_criterion(self):
|
|
from TTS.tts.layers.losses import ForwardTTSLoss
|
|
|
|
return ForwardTTSLoss(self.config)
|
|
|
|
def on_train_step_start(self, trainer):
|
|
"""Schedule binary loss weight."""
|
|
self.binary_loss_weight = min(trainer.epochs_done / self.config.binary_loss_warmup_epochs, 1.0) * 1.0
|
|
|
|
@staticmethod
|
|
def init_from_config(config: "ForwardTTSConfig", samples: Union[List[List], List[Dict]] = None):
|
|
"""Initiate model from config
|
|
|
|
Args:
|
|
config (ForwardTTSConfig): Model config.
|
|
samples (Union[List[List], List[Dict]]): Training samples to parse speaker ids for training.
|
|
Defaults to None.
|
|
"""
|
|
from TTS.utils.audio import AudioProcessor
|
|
|
|
ap = AudioProcessor.init_from_config(config)
|
|
tokenizer, new_config = TTSTokenizer.init_from_config(config)
|
|
speaker_manager = SpeakerManager.init_from_config(config, samples)
|
|
return ForwardTTS(new_config, ap, tokenizer, speaker_manager)
|
|
|