|
import argparse
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
|
|
import gradio as gr
|
|
import librosa.display
|
|
import numpy as np
|
|
|
|
import os
|
|
import torch
|
|
import torchaudio
|
|
import traceback
|
|
from TTS.demos.xtts_ft_demo.utils.formatter import format_audio_list
|
|
from TTS.demos.xtts_ft_demo.utils.gpt_train import train_gpt
|
|
|
|
from TTS.tts.configs.xtts_config import XttsConfig
|
|
from TTS.tts.models.xtts import Xtts
|
|
|
|
|
|
def clear_gpu_cache():
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
XTTS_MODEL = None
|
|
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
|
|
global XTTS_MODEL
|
|
clear_gpu_cache()
|
|
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
|
|
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
|
|
config = XttsConfig()
|
|
config.load_json(xtts_config)
|
|
XTTS_MODEL = Xtts.init_from_config(config)
|
|
print("Loading XTTS model! ")
|
|
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False)
|
|
if torch.cuda.is_available():
|
|
XTTS_MODEL.cuda()
|
|
|
|
print("Model Loaded!")
|
|
return "Model Loaded!"
|
|
|
|
def run_tts(lang, tts_text, speaker_audio_file):
|
|
if XTTS_MODEL is None or not speaker_audio_file:
|
|
return "You need to run the previous step to load the model !!", None, None
|
|
|
|
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
|
|
out = XTTS_MODEL.inference(
|
|
text=tts_text,
|
|
language=lang,
|
|
gpt_cond_latent=gpt_cond_latent,
|
|
speaker_embedding=speaker_embedding,
|
|
temperature=XTTS_MODEL.config.temperature,
|
|
length_penalty=XTTS_MODEL.config.length_penalty,
|
|
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
|
|
top_k=XTTS_MODEL.config.top_k,
|
|
top_p=XTTS_MODEL.config.top_p,
|
|
)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
|
|
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
|
|
out_path = fp.name
|
|
torchaudio.save(out_path, out["wav"], 24000)
|
|
|
|
return "Speech generated !", out_path, speaker_audio_file
|
|
|
|
|
|
|
|
|
|
|
|
class Logger:
|
|
def __init__(self, filename="log.out"):
|
|
self.log_file = filename
|
|
self.terminal = sys.stdout
|
|
self.log = open(self.log_file, "w")
|
|
|
|
def write(self, message):
|
|
self.terminal.write(message)
|
|
self.log.write(message)
|
|
|
|
def flush(self):
|
|
self.terminal.flush()
|
|
self.log.flush()
|
|
|
|
def isatty(self):
|
|
return False
|
|
|
|
|
|
sys.stdout = Logger()
|
|
sys.stderr = sys.stdout
|
|
|
|
|
|
|
|
import logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
handlers=[
|
|
logging.StreamHandler(sys.stdout)
|
|
]
|
|
)
|
|
|
|
def read_logs():
|
|
sys.stdout.flush()
|
|
with open(sys.stdout.log_file, "r") as f:
|
|
return f.read()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="""XTTS fine-tuning demo\n\n"""
|
|
"""
|
|
Example runs:
|
|
python3 TTS/demos/xtts_ft_demo/xtts_demo.py --port
|
|
""",
|
|
formatter_class=argparse.RawTextHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
help="Port to run the gradio demo. Default: 5003",
|
|
default=5003,
|
|
)
|
|
parser.add_argument(
|
|
"--out_path",
|
|
type=str,
|
|
help="Output path (where data and checkpoints will be saved) Default: /tmp/xtts_ft/",
|
|
default="/tmp/xtts_ft/",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--num_epochs",
|
|
type=int,
|
|
help="Number of epochs to train. Default: 10",
|
|
default=10,
|
|
)
|
|
parser.add_argument(
|
|
"--batch_size",
|
|
type=int,
|
|
help="Batch size. Default: 4",
|
|
default=4,
|
|
)
|
|
parser.add_argument(
|
|
"--grad_acumm",
|
|
type=int,
|
|
help="Grad accumulation steps. Default: 1",
|
|
default=1,
|
|
)
|
|
parser.add_argument(
|
|
"--max_audio_length",
|
|
type=int,
|
|
help="Max permitted audio size in seconds. Default: 11",
|
|
default=11,
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
with gr.Blocks() as demo:
|
|
with gr.Tab("1 - Data processing"):
|
|
out_path = gr.Textbox(
|
|
label="Output path (where data and checkpoints will be saved):",
|
|
value=args.out_path,
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
upload_file = gr.File(
|
|
file_count="multiple",
|
|
label="Select here the audio files that you want to use for XTTS trainining (Supported formats: wav, mp3, and flac)",
|
|
)
|
|
lang = gr.Dropdown(
|
|
label="Dataset Language",
|
|
value="en",
|
|
choices=[
|
|
"en",
|
|
"es",
|
|
"fr",
|
|
"de",
|
|
"it",
|
|
"pt",
|
|
"pl",
|
|
"tr",
|
|
"ru",
|
|
"nl",
|
|
"cs",
|
|
"ar",
|
|
"zh",
|
|
"hu",
|
|
"ko",
|
|
"ja"
|
|
],
|
|
)
|
|
progress_data = gr.Label(
|
|
label="Progress:"
|
|
)
|
|
logs = gr.Textbox(
|
|
label="Logs:",
|
|
interactive=False,
|
|
)
|
|
demo.load(read_logs, None, logs, every=1)
|
|
|
|
prompt_compute_btn = gr.Button(value="Step 1 - Create dataset")
|
|
|
|
def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)):
|
|
clear_gpu_cache()
|
|
out_path = os.path.join(out_path, "dataset")
|
|
os.makedirs(out_path, exist_ok=True)
|
|
if audio_path is None:
|
|
return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", ""
|
|
else:
|
|
try:
|
|
train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress)
|
|
except:
|
|
traceback.print_exc()
|
|
error = traceback.format_exc()
|
|
return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", ""
|
|
|
|
clear_gpu_cache()
|
|
|
|
|
|
if audio_total_size < 120:
|
|
message = "The sum of the duration of the audios that you provided should be at least 2 minutes!"
|
|
print(message)
|
|
return message, "", ""
|
|
|
|
print("Dataset Processed!")
|
|
return "Dataset Processed!", train_meta, eval_meta
|
|
|
|
with gr.Tab("2 - Fine-tuning XTTS Encoder"):
|
|
train_csv = gr.Textbox(
|
|
label="Train CSV:",
|
|
)
|
|
eval_csv = gr.Textbox(
|
|
label="Eval CSV:",
|
|
)
|
|
num_epochs = gr.Slider(
|
|
label="Number of epochs:",
|
|
minimum=1,
|
|
maximum=100,
|
|
step=1,
|
|
value=args.num_epochs,
|
|
)
|
|
batch_size = gr.Slider(
|
|
label="Batch size:",
|
|
minimum=2,
|
|
maximum=512,
|
|
step=1,
|
|
value=args.batch_size,
|
|
)
|
|
grad_acumm = gr.Slider(
|
|
label="Grad accumulation steps:",
|
|
minimum=2,
|
|
maximum=128,
|
|
step=1,
|
|
value=args.grad_acumm,
|
|
)
|
|
max_audio_length = gr.Slider(
|
|
label="Max permitted audio size in seconds:",
|
|
minimum=2,
|
|
maximum=20,
|
|
step=1,
|
|
value=args.max_audio_length,
|
|
)
|
|
progress_train = gr.Label(
|
|
label="Progress:"
|
|
)
|
|
logs_tts_train = gr.Textbox(
|
|
label="Logs:",
|
|
interactive=False,
|
|
)
|
|
demo.load(read_logs, None, logs_tts_train, every=1)
|
|
train_btn = gr.Button(value="Step 2 - Run the training")
|
|
|
|
def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length):
|
|
clear_gpu_cache()
|
|
if not train_csv or not eval_csv:
|
|
return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", ""
|
|
try:
|
|
|
|
max_audio_length = int(max_audio_length * 22050)
|
|
config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length)
|
|
except:
|
|
traceback.print_exc()
|
|
error = traceback.format_exc()
|
|
return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", ""
|
|
|
|
|
|
os.system(f"cp {config_path} {exp_path}")
|
|
os.system(f"cp {vocab_file} {exp_path}")
|
|
|
|
ft_xtts_checkpoint = os.path.join(exp_path, "best_model.pth")
|
|
print("Model training done!")
|
|
clear_gpu_cache()
|
|
return "Model training done!", config_path, vocab_file, ft_xtts_checkpoint, speaker_wav
|
|
|
|
with gr.Tab("3 - Inference"):
|
|
with gr.Row():
|
|
with gr.Column() as col1:
|
|
xtts_checkpoint = gr.Textbox(
|
|
label="XTTS checkpoint path:",
|
|
value="",
|
|
)
|
|
xtts_config = gr.Textbox(
|
|
label="XTTS config path:",
|
|
value="",
|
|
)
|
|
|
|
xtts_vocab = gr.Textbox(
|
|
label="XTTS vocab path:",
|
|
value="",
|
|
)
|
|
progress_load = gr.Label(
|
|
label="Progress:"
|
|
)
|
|
load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model")
|
|
|
|
with gr.Column() as col2:
|
|
speaker_reference_audio = gr.Textbox(
|
|
label="Speaker reference audio:",
|
|
value="",
|
|
)
|
|
tts_language = gr.Dropdown(
|
|
label="Language",
|
|
value="en",
|
|
choices=[
|
|
"en",
|
|
"es",
|
|
"fr",
|
|
"de",
|
|
"it",
|
|
"pt",
|
|
"pl",
|
|
"tr",
|
|
"ru",
|
|
"nl",
|
|
"cs",
|
|
"ar",
|
|
"zh",
|
|
"hu",
|
|
"ko",
|
|
"ja",
|
|
]
|
|
)
|
|
tts_text = gr.Textbox(
|
|
label="Input Text.",
|
|
value="This model sounds really good and above all, it's reasonably fast.",
|
|
)
|
|
tts_btn = gr.Button(value="Step 4 - Inference")
|
|
|
|
with gr.Column() as col3:
|
|
progress_gen = gr.Label(
|
|
label="Progress:"
|
|
)
|
|
tts_output_audio = gr.Audio(label="Generated Audio.")
|
|
reference_audio = gr.Audio(label="Reference audio used.")
|
|
|
|
prompt_compute_btn.click(
|
|
fn=preprocess_dataset,
|
|
inputs=[
|
|
upload_file,
|
|
lang,
|
|
out_path,
|
|
],
|
|
outputs=[
|
|
progress_data,
|
|
train_csv,
|
|
eval_csv,
|
|
],
|
|
)
|
|
|
|
|
|
train_btn.click(
|
|
fn=train_model,
|
|
inputs=[
|
|
lang,
|
|
train_csv,
|
|
eval_csv,
|
|
num_epochs,
|
|
batch_size,
|
|
grad_acumm,
|
|
out_path,
|
|
max_audio_length,
|
|
],
|
|
outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio],
|
|
)
|
|
|
|
load_btn.click(
|
|
fn=load_model,
|
|
inputs=[
|
|
xtts_checkpoint,
|
|
xtts_config,
|
|
xtts_vocab
|
|
],
|
|
outputs=[progress_load],
|
|
)
|
|
|
|
tts_btn.click(
|
|
fn=run_tts,
|
|
inputs=[
|
|
tts_language,
|
|
tts_text,
|
|
speaker_reference_audio,
|
|
],
|
|
outputs=[progress_gen, tts_output_audio, reference_audio],
|
|
)
|
|
|
|
demo.launch(
|
|
share=True,
|
|
debug=False,
|
|
server_port=args.port,
|
|
server_name="0.0.0.0"
|
|
)
|
|
|