|
import os
|
|
from dataclasses import dataclass
|
|
|
|
import librosa
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torchaudio
|
|
from coqpit import Coqpit
|
|
|
|
from TTS.tts.layers.xtts.gpt import GPT
|
|
from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder
|
|
from TTS.tts.layers.xtts.stream_generator import init_stream_support
|
|
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence
|
|
from TTS.tts.layers.xtts.xtts_manager import SpeakerManager, LanguageManager
|
|
from TTS.tts.models.base_tts import BaseTTS
|
|
from TTS.utils.io import load_fsspec
|
|
|
|
init_stream_support()
|
|
|
|
|
|
def wav_to_mel_cloning(
|
|
wav,
|
|
mel_norms_file="../experiments/clips_mel_norms.pth",
|
|
mel_norms=None,
|
|
device=torch.device("cpu"),
|
|
n_fft=4096,
|
|
hop_length=1024,
|
|
win_length=4096,
|
|
power=2,
|
|
normalized=False,
|
|
sample_rate=22050,
|
|
f_min=0,
|
|
f_max=8000,
|
|
n_mels=80,
|
|
):
|
|
"""
|
|
Convert waveform to mel-spectrogram with hard-coded parameters for cloning.
|
|
|
|
Args:
|
|
wav (torch.Tensor): Input waveform tensor.
|
|
mel_norms_file (str): Path to mel-spectrogram normalization file.
|
|
mel_norms (torch.Tensor): Mel-spectrogram normalization tensor.
|
|
device (torch.device): Device to use for computation.
|
|
|
|
Returns:
|
|
torch.Tensor: Mel-spectrogram tensor.
|
|
"""
|
|
mel_stft = torchaudio.transforms.MelSpectrogram(
|
|
n_fft=n_fft,
|
|
hop_length=hop_length,
|
|
win_length=win_length,
|
|
power=power,
|
|
normalized=normalized,
|
|
sample_rate=sample_rate,
|
|
f_min=f_min,
|
|
f_max=f_max,
|
|
n_mels=n_mels,
|
|
norm="slaney",
|
|
).to(device)
|
|
wav = wav.to(device)
|
|
mel = mel_stft(wav)
|
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
|
if mel_norms is None:
|
|
mel_norms = torch.load(mel_norms_file, map_location=device)
|
|
mel = mel / mel_norms.unsqueeze(0).unsqueeze(-1)
|
|
return mel
|
|
|
|
|
|
def load_audio(audiopath, sampling_rate):
|
|
|
|
|
|
|
|
audio, lsr = torchaudio.load(audiopath)
|
|
|
|
|
|
if audio.size(0) != 1:
|
|
audio = torch.mean(audio, dim=0, keepdim=True)
|
|
|
|
if lsr != sampling_rate:
|
|
audio = torchaudio.functional.resample(audio, lsr, sampling_rate)
|
|
|
|
|
|
|
|
if torch.any(audio > 10) or not torch.any(audio < 0):
|
|
print(f"Error with {audiopath}. Max={audio.max()} min={audio.min()}")
|
|
|
|
audio.clip_(-1, 1)
|
|
return audio
|
|
|
|
|
|
def pad_or_truncate(t, length):
|
|
"""
|
|
Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it.
|
|
|
|
Args:
|
|
t (torch.Tensor): The input tensor to be padded or truncated.
|
|
length (int): The desired length of the tensor.
|
|
|
|
Returns:
|
|
torch.Tensor: The padded or truncated tensor.
|
|
"""
|
|
tp = t[..., :length]
|
|
if t.shape[-1] == length:
|
|
tp = t
|
|
elif t.shape[-1] < length:
|
|
tp = F.pad(t, (0, length - t.shape[-1]))
|
|
return tp
|
|
|
|
|
|
@dataclass
|
|
class XttsAudioConfig(Coqpit):
|
|
"""
|
|
Configuration class for audio-related parameters in the XTTS model.
|
|
|
|
Args:
|
|
sample_rate (int): The sample rate in which the GPT operates.
|
|
output_sample_rate (int): The sample rate of the output audio waveform.
|
|
"""
|
|
|
|
sample_rate: int = 22050
|
|
output_sample_rate: int = 24000
|
|
|
|
|
|
@dataclass
|
|
class XttsArgs(Coqpit):
|
|
"""A dataclass to represent XTTS model arguments that define the model structure.
|
|
|
|
Args:
|
|
gpt_batch_size (int): The size of the auto-regressive batch.
|
|
enable_redaction (bool, optional): Whether to enable redaction. Defaults to True.
|
|
kv_cache (bool, optional): Whether to use the kv_cache. Defaults to True.
|
|
gpt_checkpoint (str, optional): The checkpoint for the autoregressive model. Defaults to None.
|
|
clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None.
|
|
decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None.
|
|
num_chars (int, optional): The maximum number of characters to generate. Defaults to 255.
|
|
|
|
For GPT model:
|
|
gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604.
|
|
gpt_max_text_tokens (int, optional): The maximum text tokens for the autoregressive model. Defaults to 402.
|
|
gpt_max_prompt_tokens (int, optional): The maximum prompt tokens or the autoregressive model. Defaults to 70.
|
|
gpt_layers (int, optional): The number of layers for the autoregressive model. Defaults to 30.
|
|
gpt_n_model_channels (int, optional): The model dimension for the autoregressive model. Defaults to 1024.
|
|
gpt_n_heads (int, optional): The number of heads for the autoregressive model. Defaults to 16.
|
|
gpt_number_text_tokens (int, optional): The number of text tokens for the autoregressive model. Defaults to 255.
|
|
gpt_start_text_token (int, optional): The start text token for the autoregressive model. Defaults to 255.
|
|
gpt_checkpointing (bool, optional): Whether to use checkpointing for the autoregressive model. Defaults to False.
|
|
gpt_train_solo_embeddings (bool, optional): Whether to train embeddings for the autoregressive model. Defaults to False.
|
|
gpt_code_stride_len (int, optional): The hop_size of dvae and consequently of the gpt output. Defaults to 1024.
|
|
gpt_use_masking_gt_prompt_approach (bool, optional): If True, it will use ground truth as prompt and it will mask the loss to avoid repetition. Defaults to True.
|
|
gpt_use_perceiver_resampler (bool, optional): If True, it will use perceiver resampler from flamingo paper - https://arxiv.org/abs/2204.14198. Defaults to False.
|
|
"""
|
|
|
|
gpt_batch_size: int = 1
|
|
enable_redaction: bool = False
|
|
kv_cache: bool = True
|
|
gpt_checkpoint: str = None
|
|
clvp_checkpoint: str = None
|
|
decoder_checkpoint: str = None
|
|
num_chars: int = 255
|
|
|
|
|
|
tokenizer_file: str = ""
|
|
gpt_max_audio_tokens: int = 605
|
|
gpt_max_text_tokens: int = 402
|
|
gpt_max_prompt_tokens: int = 70
|
|
gpt_layers: int = 30
|
|
gpt_n_model_channels: int = 1024
|
|
gpt_n_heads: int = 16
|
|
gpt_number_text_tokens: int = None
|
|
gpt_start_text_token: int = None
|
|
gpt_stop_text_token: int = None
|
|
gpt_num_audio_tokens: int = 8194
|
|
gpt_start_audio_token: int = 8192
|
|
gpt_stop_audio_token: int = 8193
|
|
gpt_code_stride_len: int = 1024
|
|
gpt_use_masking_gt_prompt_approach: bool = True
|
|
gpt_use_perceiver_resampler: bool = False
|
|
|
|
|
|
input_sample_rate: int = 22050
|
|
output_sample_rate: int = 24000
|
|
output_hop_length: int = 256
|
|
decoder_input_dim: int = 1024
|
|
d_vector_dim: int = 512
|
|
cond_d_vector_in_each_upsampling_layer: bool = True
|
|
|
|
|
|
duration_const: int = 102400
|
|
|
|
|
|
class Xtts(BaseTTS):
|
|
"""ⓍTTS model implementation.
|
|
|
|
❗ Currently it only supports inference.
|
|
|
|
Examples:
|
|
>>> from TTS.tts.configs.xtts_config import XttsConfig
|
|
>>> from TTS.tts.models.xtts import Xtts
|
|
>>> config = XttsConfig()
|
|
>>> model = Xtts.inif_from_config(config)
|
|
>>> model.load_checkpoint(config, checkpoint_dir="paths/to/models_dir/", eval=True)
|
|
"""
|
|
|
|
def __init__(self, config: Coqpit):
|
|
super().__init__(config, ap=None, tokenizer=None)
|
|
self.mel_stats_path = None
|
|
self.config = config
|
|
self.gpt_checkpoint = self.args.gpt_checkpoint
|
|
self.decoder_checkpoint = self.args.decoder_checkpoint
|
|
self.models_dir = config.model_dir
|
|
self.gpt_batch_size = self.args.gpt_batch_size
|
|
|
|
self.tokenizer = VoiceBpeTokenizer()
|
|
self.gpt = None
|
|
self.init_models()
|
|
self.register_buffer("mel_stats", torch.ones(80))
|
|
|
|
def init_models(self):
|
|
"""Initialize the models. We do it here since we need to load the tokenizer first."""
|
|
if self.tokenizer.tokenizer is not None:
|
|
self.args.gpt_number_text_tokens = self.tokenizer.get_number_tokens()
|
|
self.args.gpt_start_text_token = self.tokenizer.tokenizer.token_to_id("[START]")
|
|
self.args.gpt_stop_text_token = self.tokenizer.tokenizer.token_to_id("[STOP]")
|
|
|
|
if self.args.gpt_number_text_tokens:
|
|
self.gpt = GPT(
|
|
layers=self.args.gpt_layers,
|
|
model_dim=self.args.gpt_n_model_channels,
|
|
start_text_token=self.args.gpt_start_text_token,
|
|
stop_text_token=self.args.gpt_stop_text_token,
|
|
heads=self.args.gpt_n_heads,
|
|
max_text_tokens=self.args.gpt_max_text_tokens,
|
|
max_mel_tokens=self.args.gpt_max_audio_tokens,
|
|
max_prompt_tokens=self.args.gpt_max_prompt_tokens,
|
|
number_text_tokens=self.args.gpt_number_text_tokens,
|
|
num_audio_tokens=self.args.gpt_num_audio_tokens,
|
|
start_audio_token=self.args.gpt_start_audio_token,
|
|
stop_audio_token=self.args.gpt_stop_audio_token,
|
|
use_perceiver_resampler=self.args.gpt_use_perceiver_resampler,
|
|
code_stride_len=self.args.gpt_code_stride_len,
|
|
)
|
|
|
|
self.hifigan_decoder = HifiDecoder(
|
|
input_sample_rate=self.args.input_sample_rate,
|
|
output_sample_rate=self.args.output_sample_rate,
|
|
output_hop_length=self.args.output_hop_length,
|
|
ar_mel_length_compression=self.args.gpt_code_stride_len,
|
|
decoder_input_dim=self.args.decoder_input_dim,
|
|
d_vector_dim=self.args.d_vector_dim,
|
|
cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer,
|
|
)
|
|
|
|
@property
|
|
def device(self):
|
|
return next(self.parameters()).device
|
|
|
|
@torch.inference_mode()
|
|
def get_gpt_cond_latents(self, audio, sr, length: int = 30, chunk_length: int = 6):
|
|
"""Compute the conditioning latents for the GPT model from the given audio.
|
|
|
|
Args:
|
|
audio (tensor): audio tensor.
|
|
sr (int): Sample rate of the audio.
|
|
length (int): Length of the audio in seconds. If < 0, use the whole audio. Defaults to 30.
|
|
chunk_length (int): Length of the audio chunks in seconds. When `length == chunk_length`, the whole audio
|
|
is being used without chunking. It must be < `length`. Defaults to 6.
|
|
"""
|
|
if sr != 22050:
|
|
audio = torchaudio.functional.resample(audio, sr, 22050)
|
|
if length > 0:
|
|
audio = audio[:, : 22050 * length]
|
|
if self.args.gpt_use_perceiver_resampler:
|
|
style_embs = []
|
|
for i in range(0, audio.shape[1], 22050 * chunk_length):
|
|
audio_chunk = audio[:, i : i + 22050 * chunk_length]
|
|
|
|
|
|
if audio_chunk.size(-1) < 22050 * 0.33:
|
|
continue
|
|
|
|
mel_chunk = wav_to_mel_cloning(
|
|
audio_chunk,
|
|
mel_norms=self.mel_stats.cpu(),
|
|
n_fft=2048,
|
|
hop_length=256,
|
|
win_length=1024,
|
|
power=2,
|
|
normalized=False,
|
|
sample_rate=22050,
|
|
f_min=0,
|
|
f_max=8000,
|
|
n_mels=80,
|
|
)
|
|
style_emb = self.gpt.get_style_emb(mel_chunk.to(self.device), None)
|
|
style_embs.append(style_emb)
|
|
|
|
|
|
cond_latent = torch.stack(style_embs).mean(dim=0)
|
|
else:
|
|
mel = wav_to_mel_cloning(
|
|
audio,
|
|
mel_norms=self.mel_stats.cpu(),
|
|
n_fft=4096,
|
|
hop_length=1024,
|
|
win_length=4096,
|
|
power=2,
|
|
normalized=False,
|
|
sample_rate=22050,
|
|
f_min=0,
|
|
f_max=8000,
|
|
n_mels=80,
|
|
)
|
|
cond_latent = self.gpt.get_style_emb(mel.to(self.device))
|
|
return cond_latent.transpose(1, 2)
|
|
|
|
@torch.inference_mode()
|
|
def get_speaker_embedding(self, audio, sr):
|
|
audio_16k = torchaudio.functional.resample(audio, sr, 16000)
|
|
return (
|
|
self.hifigan_decoder.speaker_encoder.forward(audio_16k.to(self.device), l2_norm=True)
|
|
.unsqueeze(-1)
|
|
.to(self.device)
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def get_conditioning_latents(
|
|
self,
|
|
audio_path,
|
|
max_ref_length=30,
|
|
gpt_cond_len=6,
|
|
gpt_cond_chunk_len=6,
|
|
librosa_trim_db=None,
|
|
sound_norm_refs=False,
|
|
load_sr=22050,
|
|
):
|
|
"""Get the conditioning latents for the GPT model from the given audio.
|
|
|
|
Args:
|
|
audio_path (str or List[str]): Path to reference audio file(s).
|
|
max_ref_length (int): Maximum length of each reference audio in seconds. Defaults to 30.
|
|
gpt_cond_len (int): Length of the audio used for gpt latents. Defaults to 6.
|
|
gpt_cond_chunk_len (int): Chunk length used for gpt latents. It must be <= gpt_conf_len. Defaults to 6.
|
|
librosa_trim_db (int, optional): Trim the audio using this value. If None, not trimming. Defaults to None.
|
|
sound_norm_refs (bool, optional): Whether to normalize the audio. Defaults to False.
|
|
load_sr (int, optional): Sample rate to load the audio. Defaults to 24000.
|
|
"""
|
|
|
|
if not isinstance(audio_path, list):
|
|
audio_paths = [audio_path]
|
|
else:
|
|
audio_paths = audio_path
|
|
|
|
speaker_embeddings = []
|
|
audios = []
|
|
speaker_embedding = None
|
|
for file_path in audio_paths:
|
|
audio = load_audio(file_path, load_sr)
|
|
audio = audio[:, : load_sr * max_ref_length].to(self.device)
|
|
if sound_norm_refs:
|
|
audio = (audio / torch.abs(audio).max()) * 0.75
|
|
if librosa_trim_db is not None:
|
|
audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0]
|
|
|
|
|
|
speaker_embedding = self.get_speaker_embedding(audio, load_sr)
|
|
speaker_embeddings.append(speaker_embedding)
|
|
|
|
audios.append(audio)
|
|
|
|
|
|
full_audio = torch.cat(audios, dim=-1)
|
|
gpt_cond_latents = self.get_gpt_cond_latents(
|
|
full_audio, load_sr, length=gpt_cond_len, chunk_length=gpt_cond_chunk_len
|
|
)
|
|
|
|
if speaker_embeddings:
|
|
speaker_embedding = torch.stack(speaker_embeddings)
|
|
speaker_embedding = speaker_embedding.mean(dim=0)
|
|
|
|
return gpt_cond_latents, speaker_embedding
|
|
|
|
def synthesize(self, text, config, speaker_wav, language, speaker_id=None, **kwargs):
|
|
"""Synthesize speech with the given input text.
|
|
|
|
Args:
|
|
text (str): Input text.
|
|
config (XttsConfig): Config with inference parameters.
|
|
speaker_wav (list): List of paths to the speaker audio files to be used for cloning.
|
|
language (str): Language ID of the speaker.
|
|
**kwargs: Inference settings. See `inference()`.
|
|
|
|
Returns:
|
|
A dictionary of the output values with `wav` as output waveform, `deterministic_seed` as seed used at inference,
|
|
`text_input` as text token IDs after tokenizer, `voice_samples` as samples used for cloning, `conditioning_latents`
|
|
as latents used at inference.
|
|
|
|
"""
|
|
assert (
|
|
"zh-cn" if language == "zh" else language in self.config.languages
|
|
), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}"
|
|
|
|
settings = {
|
|
"temperature": config.temperature,
|
|
"length_penalty": config.length_penalty,
|
|
"repetition_penalty": config.repetition_penalty,
|
|
"top_k": config.top_k,
|
|
"top_p": config.top_p,
|
|
}
|
|
settings.update(kwargs)
|
|
if speaker_id is not None:
|
|
gpt_cond_latent, speaker_embedding = self.speaker_manager.speakers[speaker_id].values()
|
|
return self.inference(text, language, gpt_cond_latent, speaker_embedding, **settings)
|
|
settings.update({
|
|
"gpt_cond_len": config.gpt_cond_len,
|
|
"gpt_cond_chunk_len": config.gpt_cond_chunk_len,
|
|
"max_ref_len": config.max_ref_len,
|
|
"sound_norm_refs": config.sound_norm_refs,
|
|
})
|
|
return self.full_inference(text, speaker_wav, language, **settings)
|
|
|
|
@torch.inference_mode()
|
|
def full_inference(
|
|
self,
|
|
text,
|
|
ref_audio_path,
|
|
language,
|
|
|
|
temperature=0.75,
|
|
length_penalty=1.0,
|
|
repetition_penalty=10.0,
|
|
top_k=50,
|
|
top_p=0.85,
|
|
do_sample=True,
|
|
|
|
gpt_cond_len=30,
|
|
gpt_cond_chunk_len=6,
|
|
max_ref_len=10,
|
|
sound_norm_refs=False,
|
|
**hf_generate_kwargs,
|
|
):
|
|
"""
|
|
This function produces an audio clip of the given text being spoken with the given reference voice.
|
|
|
|
Args:
|
|
text: (str) Text to be spoken.
|
|
|
|
ref_audio_path: (str) Path to a reference audio file to be used for cloning. This audio file should be >3
|
|
seconds long.
|
|
|
|
language: (str) Language of the voice to be generated.
|
|
|
|
temperature: (float) The softmax temperature of the autoregressive model. Defaults to 0.65.
|
|
|
|
length_penalty: (float) A length penalty applied to the autoregressive decoder. Higher settings causes the
|
|
model to produce more terse outputs. Defaults to 1.0.
|
|
|
|
repetition_penalty: (float) A penalty that prevents the autoregressive decoder from repeating itself during
|
|
decoding. Can be used to reduce the incidence of long silences or "uhhhhhhs", etc. Defaults to 2.0.
|
|
|
|
top_k: (int) K value used in top-k sampling. [0,inf]. Lower values mean the decoder produces more "likely"
|
|
(aka boring) outputs. Defaults to 50.
|
|
|
|
top_p: (float) P value used in nucleus sampling. (0,1]. Lower values mean the decoder produces more "likely"
|
|
(aka boring) outputs. Defaults to 0.8.
|
|
|
|
gpt_cond_len: (int) Length of the audio used for cloning. If audio is shorter, then audio length is used
|
|
else the first `gpt_cond_len` secs is used. Defaults to 30 seconds.
|
|
|
|
gpt_cond_chunk_len: (int) Chunk length used for cloning. It must be <= `gpt_cond_len`.
|
|
If gpt_cond_len == gpt_cond_chunk_len, no chunking. Defaults to 6 seconds.
|
|
|
|
hf_generate_kwargs: (**kwargs) The huggingface Transformers generate API is used for the autoregressive
|
|
transformer. Extra keyword args fed to this function get forwarded directly to that API. Documentation
|
|
here: https://huggingface.co/docs/transformers/internal/generation_utils
|
|
|
|
Returns:
|
|
Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length.
|
|
Sample rate is 24kHz.
|
|
"""
|
|
(gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents(
|
|
audio_path=ref_audio_path,
|
|
gpt_cond_len=gpt_cond_len,
|
|
gpt_cond_chunk_len=gpt_cond_chunk_len,
|
|
max_ref_length=max_ref_len,
|
|
sound_norm_refs=sound_norm_refs,
|
|
)
|
|
|
|
return self.inference(
|
|
text,
|
|
language,
|
|
gpt_cond_latent,
|
|
speaker_embedding,
|
|
temperature=temperature,
|
|
length_penalty=length_penalty,
|
|
repetition_penalty=repetition_penalty,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
do_sample=do_sample,
|
|
**hf_generate_kwargs,
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def inference(
|
|
self,
|
|
text,
|
|
language,
|
|
gpt_cond_latent,
|
|
speaker_embedding,
|
|
|
|
temperature=0.75,
|
|
length_penalty=1.0,
|
|
repetition_penalty=10.0,
|
|
top_k=50,
|
|
top_p=0.85,
|
|
do_sample=True,
|
|
num_beams=1,
|
|
speed=1.0,
|
|
enable_text_splitting=False,
|
|
**hf_generate_kwargs,
|
|
):
|
|
language = language.split("-")[0]
|
|
length_scale = 1.0 / max(speed, 0.05)
|
|
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
|
speaker_embedding = speaker_embedding.to(self.device)
|
|
if enable_text_splitting:
|
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
|
else:
|
|
text = [text]
|
|
|
|
wavs = []
|
|
gpt_latents_list = []
|
|
for sent in text:
|
|
sent = sent.strip().lower()
|
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
|
|
|
assert (
|
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
|
|
|
with torch.no_grad():
|
|
gpt_codes = self.gpt.generate(
|
|
cond_latents=gpt_cond_latent,
|
|
text_inputs=text_tokens,
|
|
input_tokens=None,
|
|
do_sample=do_sample,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
temperature=temperature,
|
|
num_return_sequences=self.gpt_batch_size,
|
|
num_beams=num_beams,
|
|
length_penalty=length_penalty,
|
|
repetition_penalty=repetition_penalty,
|
|
output_attentions=False,
|
|
**hf_generate_kwargs,
|
|
)
|
|
expected_output_len = torch.tensor(
|
|
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
|
)
|
|
|
|
text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
|
|
gpt_latents = self.gpt(
|
|
text_tokens,
|
|
text_len,
|
|
gpt_codes,
|
|
expected_output_len,
|
|
cond_latents=gpt_cond_latent,
|
|
return_attentions=False,
|
|
return_latent=True,
|
|
)
|
|
|
|
if length_scale != 1.0:
|
|
gpt_latents = F.interpolate(
|
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
|
).transpose(1, 2)
|
|
|
|
gpt_latents_list.append(gpt_latents.cpu())
|
|
wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze())
|
|
|
|
return {
|
|
"wav": torch.cat(wavs, dim=0).numpy(),
|
|
"gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(),
|
|
"speaker_embedding": speaker_embedding,
|
|
}
|
|
|
|
def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
|
|
"""Handle chunk formatting in streaming mode"""
|
|
wav_chunk = wav_gen[:-overlap_len]
|
|
if wav_gen_prev is not None:
|
|
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
|
|
if wav_overlap is not None:
|
|
|
|
if overlap_len > len(wav_chunk):
|
|
|
|
if wav_gen_prev is not None:
|
|
wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :]
|
|
else:
|
|
|
|
wav_chunk = wav_gen[-overlap_len:]
|
|
return wav_chunk, wav_gen, None
|
|
else:
|
|
crossfade_wav = wav_chunk[:overlap_len]
|
|
crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
|
|
wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
|
|
wav_chunk[:overlap_len] += crossfade_wav
|
|
|
|
wav_overlap = wav_gen[-overlap_len:]
|
|
wav_gen_prev = wav_gen
|
|
return wav_chunk, wav_gen_prev, wav_overlap
|
|
|
|
@torch.inference_mode()
|
|
def inference_stream(
|
|
self,
|
|
text,
|
|
language,
|
|
gpt_cond_latent,
|
|
speaker_embedding,
|
|
|
|
stream_chunk_size=20,
|
|
overlap_wav_len=1024,
|
|
|
|
temperature=0.75,
|
|
length_penalty=1.0,
|
|
repetition_penalty=10.0,
|
|
top_k=50,
|
|
top_p=0.85,
|
|
do_sample=True,
|
|
speed=1.0,
|
|
enable_text_splitting=False,
|
|
**hf_generate_kwargs,
|
|
):
|
|
language = language.split("-")[0]
|
|
length_scale = 1.0 / max(speed, 0.05)
|
|
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
|
speaker_embedding = speaker_embedding.to(self.device)
|
|
if enable_text_splitting:
|
|
text = split_sentence(text, language, self.tokenizer.char_limits[language])
|
|
else:
|
|
text = [text]
|
|
|
|
for sent in text:
|
|
sent = sent.strip().lower()
|
|
text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device)
|
|
|
|
assert (
|
|
text_tokens.shape[-1] < self.args.gpt_max_text_tokens
|
|
), " ❗ XTTS can only generate text with a maximum of 400 tokens."
|
|
|
|
fake_inputs = self.gpt.compute_embeddings(
|
|
gpt_cond_latent.to(self.device),
|
|
text_tokens,
|
|
)
|
|
gpt_generator = self.gpt.get_generator(
|
|
fake_inputs=fake_inputs,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
do_sample=do_sample,
|
|
num_beams=1,
|
|
num_return_sequences=1,
|
|
length_penalty=float(length_penalty),
|
|
repetition_penalty=float(repetition_penalty),
|
|
output_attentions=False,
|
|
output_hidden_states=True,
|
|
**hf_generate_kwargs,
|
|
)
|
|
|
|
last_tokens = []
|
|
all_latents = []
|
|
wav_gen_prev = None
|
|
wav_overlap = None
|
|
is_end = False
|
|
|
|
while not is_end:
|
|
try:
|
|
x, latent = next(gpt_generator)
|
|
last_tokens += [x]
|
|
all_latents += [latent]
|
|
except StopIteration:
|
|
is_end = True
|
|
|
|
if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
|
|
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
|
|
if length_scale != 1.0:
|
|
gpt_latents = F.interpolate(
|
|
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
|
|
).transpose(1, 2)
|
|
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
|
|
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
|
|
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
|
|
)
|
|
last_tokens = []
|
|
yield wav_chunk
|
|
|
|
def forward(self):
|
|
raise NotImplementedError(
|
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
|
)
|
|
|
|
def eval_step(self):
|
|
raise NotImplementedError(
|
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
|
)
|
|
|
|
@staticmethod
|
|
def init_from_config(config: "XttsConfig", **kwargs):
|
|
return Xtts(config)
|
|
|
|
def eval(self):
|
|
"""Sets the model to evaluation mode. Overrides the default eval() method to also set the GPT model to eval mode."""
|
|
self.gpt.init_gpt_for_inference()
|
|
super().eval()
|
|
|
|
def get_compatible_checkpoint_state_dict(self, model_path):
|
|
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
|
|
|
|
ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"]
|
|
for key in list(checkpoint.keys()):
|
|
|
|
if key.startswith("xtts."):
|
|
new_key = key.replace("xtts.", "")
|
|
checkpoint[new_key] = checkpoint[key]
|
|
del checkpoint[key]
|
|
key = new_key
|
|
|
|
|
|
if key.split(".")[0] in ignore_keys:
|
|
del checkpoint[key]
|
|
|
|
return checkpoint
|
|
|
|
def load_checkpoint(
|
|
self,
|
|
config,
|
|
checkpoint_dir=None,
|
|
checkpoint_path=None,
|
|
vocab_path=None,
|
|
eval=True,
|
|
strict=True,
|
|
use_deepspeed=False,
|
|
speaker_file_path=None,
|
|
):
|
|
"""
|
|
Loads a checkpoint from disk and initializes the model's state and tokenizer.
|
|
|
|
Args:
|
|
config (dict): The configuration dictionary for the model.
|
|
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
|
|
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
|
|
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
|
|
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
|
|
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
|
|
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
|
|
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")
|
|
|
|
if speaker_file_path is None and checkpoint_dir is not None:
|
|
speaker_file_path = os.path.join(checkpoint_dir, "speakers_xtts.pth")
|
|
|
|
self.language_manager = LanguageManager(config)
|
|
self.speaker_manager = None
|
|
if speaker_file_path is not None and os.path.exists(speaker_file_path):
|
|
self.speaker_manager = SpeakerManager(speaker_file_path)
|
|
|
|
if os.path.exists(vocab_path):
|
|
self.tokenizer = VoiceBpeTokenizer(vocab_file=vocab_path)
|
|
|
|
self.init_models()
|
|
|
|
checkpoint = self.get_compatible_checkpoint_state_dict(model_path)
|
|
|
|
|
|
try:
|
|
self.load_state_dict(checkpoint, strict=strict)
|
|
except:
|
|
if eval:
|
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
|
|
self.load_state_dict(checkpoint, strict=strict)
|
|
|
|
if eval:
|
|
self.hifigan_decoder.eval()
|
|
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed)
|
|
self.gpt.eval()
|
|
|
|
def train_step(self):
|
|
raise NotImplementedError(
|
|
"XTTS has a dedicated trainer, please check the XTTS docs: https://tts.readthedocs.io/en/dev/models/xtts.html#training"
|
|
)
|
|
|