import re
import tempfile
from importlib.resources import files
import gradio as gr
import soundfile as sf
import torch
import torchcodec
from cached_path import cached_path
from omegaconf import OmegaConf
from ipa.ipa import g2p_object, text_to_ipa
try:
import spaces
USING_SPACES = True
except ImportError:
USING_SPACES = False
from f5_tts.infer.utils_infer import (
device,
hop_length,
infer_process,
load_checkpoint,
load_vocoder,
mel_spec_type,
n_fft,
n_mel_channels,
ode_method,
preprocess_ref_audio_text,
remove_silence_for_generated_wav,
save_spectrogram,
target_sample_rate,
win_length,
)
from f5_tts.model import CFM, DiT
from f5_tts.model.utils import get_tokenizer
def gpu_decorator(func):
if USING_SPACES:
return spaces.GPU(func)
else:
return func
vocoder = load_vocoder()
def load_model(
model_cls,
model_cfg,
ckpt_path,
mel_spec_type=mel_spec_type,
vocab_file="",
ode_method=ode_method,
use_ema=True,
device=device,
fp16=False,
):
if vocab_file == "":
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
tokenizer = "custom"
print("\nvocab : ", vocab_file)
print("token : ", tokenizer)
print("model : ", ckpt_path, "\n")
vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer)
model = CFM(
transformer=model_cls(
**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
),
mel_spec_kwargs=dict(
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
n_mel_channels=n_mel_channels,
target_sample_rate=target_sample_rate,
mel_spec_type=mel_spec_type,
),
odeint_kwargs=dict(
method=ode_method,
),
vocab_char_map=vocab_char_map,
).to(device)
dtype = torch.float32 if mel_spec_type == "bigvgan" or not fp16 else None
model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema)
return model
def load_f5tts(ckpt_path, vocab_path, old=False, fp16=False):
ckpt_path = str(cached_path(ckpt_path))
F5TTS_model_cfg = dict(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_mask_padding=not old,
pe_attn_head=1 if old else None,
)
vocab_path = str(cached_path(vocab_path))
return load_model(
DiT,
F5TTS_model_cfg,
ckpt_path,
vocab_file=vocab_path,
use_ema=old,
fp16=fp16,
)
OmegaConf.register_new_resolver("load_f5tts", load_f5tts)
models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml"))
refs_config = OmegaConf.to_object(OmegaConf.load("configs/refs.yaml"))
examples_config = OmegaConf.to_object(OmegaConf.load("configs/examples.yaml"))
DEFAULT_MODEL_ID = list(models_config.keys())[0]
ETHNICITIES = list(set([k.split("_")[0] for k in g2p_object.keys()]))
@gpu_decorator
def infer(
ref_audio_orig,
ref_text,
gen_text,
model,
remove_silence=False,
cross_fade_duration=0.15,
nfe_step=32,
speed=1,
show_info=gr.Info,
):
if not ref_audio_orig:
gr.Warning("Please provide reference audio.")
return gr.update(), gr.update(), ref_text
if not gen_text.strip():
gr.Warning("Please enter text to generate.")
return gr.update(), gr.update(), ref_text
ref_audio, ref_text = preprocess_ref_audio_text(
ref_audio_orig, ref_text, show_info=show_info
)
final_wave, final_sample_rate, combined_spectrogram = infer_process(
ref_audio,
ref_text,
gen_text,
model,
vocoder,
cross_fade_duration=cross_fade_duration,
nfe_step=nfe_step,
speed=speed,
show_info=show_info,
progress=gr.Progress(),
)
# Remove silence
if remove_silence:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, final_wave, final_sample_rate)
remove_silence_for_generated_wav(f.name)
final_wave = torchcodec.decoders.AudioDecoder(f.name).get_all_samples().data
final_wave = final_wave.squeeze().cpu().numpy()
# Save the spectrogram
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram:
spectrogram_path = tmp_spectrogram.name
save_spectrogram(combined_spectrogram, spectrogram_path)
return (final_sample_rate, final_wave), spectrogram_path
def get_title():
with open("DEMO.md", encoding="utf-8") as tong:
return tong.readline().strip("# ")
demo = gr.Blocks(
title=get_title(),
css="""@import url(https://tauhu.tw/tauhu-oo.css);
.textonly textarea {border-width: 0px !important; }
""",
theme=gr.themes.Default(
font=(
"tauhu-oo",
gr.themes.GoogleFont("Source Sans Pro"),
"ui-sans-serif",
"system-ui",
"sans-serif",
)
),
js="""
function addButtonsEvent() {
const buttons = document.querySelectorAll("#head-html-block button");
buttons.forEach(button => {
button.addEventListener("click", () => {
navigator.clipboard.writeText(button.innerText);
});
});
}
""",
)
with demo:
with open("DEMO.md") as tong:
gr.Markdown(tong.read())
gr.HTML(
"特殊符號請複製使用(滑鼠點擊即可複製): ",
padding=False,
elem_id="head-html-block",
)
with gr.Tab("預設配音員"):
with gr.Row():
with gr.Column():
default_speaker_ethnicity = gr.Dropdown(
choices=ETHNICITIES,
label="步驟一:選擇族別",
value="阿美",
filterable=False,
)
def get_refs_by_perfix(prefix: str):
return [r for r in refs_config.keys() if r.startswith(prefix)]
default_speaker_refs = gr.Dropdown(
choices=get_refs_by_perfix(default_speaker_ethnicity.value),
label="步驟二:選擇配音員",
value=get_refs_by_perfix(default_speaker_ethnicity.value)[0],
filterable=False,
)
default_speaker_gen_text_input = gr.Textbox(
label="步驟三:輸入文字(上限 300 字元)",
value="",
)
default_speaker_generate_btn = gr.Button(
"步驟四:開始合成", variant="primary"
)
with gr.Column():
default_speaker_audio_output = gr.Audio(
label="合成結果", show_share_button=False, show_download_button=True
)
with gr.Tab("自己當配音員"):
with gr.Row():
with gr.Column():
custom_speaker_ethnicity = gr.Dropdown(
choices=ETHNICITIES,
label="步驟一:選擇族別與語別",
value="阿美",
filterable=False,
)
custom_speaker_language = gr.Dropdown(
choices=[
k
for k in g2p_object.keys()
if k.startswith(custom_speaker_ethnicity.value)
],
value=[
k
for k in g2p_object.keys()
if k.startswith(custom_speaker_ethnicity.value)
][0],
filterable=False,
show_label=False,
)
custom_speaker_ref_text_input = gr.Textbox(
value=refs_config[
get_refs_by_perfix(custom_speaker_language.value)[0]
]["text"],
interactive=False,
label="步驟二:點選🎙️錄製下方句子,或上傳與句子相符的音檔",
elem_classes="textonly",
)
custom_speaker_audio_input = gr.Audio(
type="filepath",
sources=["microphone", "upload"],
waveform_options=gr.WaveformOptions(
sample_rate=24000,
),
label="錄製或上傳",
)
custom_speaker_gen_text_input = gr.Textbox(
label="步驟三:輸入合成文字(上限 300 字元)",
value="",
)
custom_speaker_generate_btn = gr.Button(
"步驟四:開始合成", variant="primary"
)
with gr.Column():
custom_speaker_audio_output = gr.Audio(
label="合成結果", show_share_button=False, show_download_button=True
)
default_speaker_ethnicity.change(
lambda ethnicity: gr.Dropdown(
choices=get_refs_by_perfix(ethnicity),
value=get_refs_by_perfix(ethnicity)[0],
),
inputs=[default_speaker_ethnicity],
outputs=[default_speaker_refs],
)
@gpu_decorator
def default_speaker_tts(
ref: str,
gen_text_input: str,
):
language = re.sub(r"_[男女]聲[12]?", "", ref)
ref_text_input = refs_config[ref]["text"]
ref_audio_input = refs_config[ref]["wav"]
gen_text_input = gen_text_input.strip()
if len(gen_text_input) == 0:
raise gr.Error("請勿輸入空字串。")
if gen_text_input[-1] not in [".", "?", "!", ",", ";", ":"]:
gen_text_input += "."
ignore_punctuation = False
ipa_with_ng = False
ref_text_input = text_to_ipa(
ref_text_input, language, ignore_punctuation, ipa_with_ng
)
gen_text_input = text_to_ipa(
gen_text_input, language, ignore_punctuation, ipa_with_ng
)
audio_out, spectrogram_path = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
models_config[DEFAULT_MODEL_ID],
)
return audio_out
default_speaker_generate_btn.click(
default_speaker_tts,
inputs=[
default_speaker_refs,
default_speaker_gen_text_input,
],
outputs=[default_speaker_audio_output],
)
custom_speaker_ethnicity.change(
lambda ethnicity: gr.Dropdown(
choices=[k for k in g2p_object.keys() if k.startswith(ethnicity)],
value=[k for k in g2p_object.keys() if k.startswith(ethnicity)][0],
visible=len([k for k in g2p_object.keys() if k.startswith(ethnicity)]) > 1,
),
inputs=[custom_speaker_ethnicity],
outputs=[custom_speaker_language],
)
custom_speaker_language.change(
lambda lang: gr.Textbox(
value=refs_config[get_refs_by_perfix(lang)[0]]["text"],
),
inputs=[custom_speaker_language],
outputs=[custom_speaker_ref_text_input],
)
@gpu_decorator
def custom_speaker_tts(
language: str,
ref_text_input: str,
ref_audio_input: str,
gen_text_input: str,
):
ref_text_input = ref_text_input.strip()
if len(ref_text_input) == 0:
raise gr.Error("請勿輸入空字串。")
gen_text_input = gen_text_input.strip()
if len(gen_text_input) == 0:
raise gr.Error("請勿輸入空字串。")
ignore_punctuation = False
ipa_with_ng = False
if gen_text_input[-1] not in [".", "?", "!", ",", ";", ":"]:
gen_text_input += "."
ref_text_input = text_to_ipa(
ref_text_input, language, ignore_punctuation, ipa_with_ng
)
gen_text_input = text_to_ipa(
gen_text_input, language, ignore_punctuation, ipa_with_ng
)
audio_out, spectrogram_path = infer(
ref_audio_input,
ref_text_input,
gen_text_input,
models_config[DEFAULT_MODEL_ID],
)
return audio_out
custom_speaker_generate_btn.click(
custom_speaker_tts,
inputs=[
custom_speaker_language,
custom_speaker_ref_text_input,
custom_speaker_audio_input,
custom_speaker_gen_text_input,
],
outputs=[custom_speaker_audio_output],
)
demo.launch()