Spaces:
Sleeping
Sleeping
| import re | |
| import tempfile | |
| from importlib.resources import files | |
| import gradio as gr | |
| import soundfile as sf | |
| import torch | |
| import torchcodec | |
| from cached_path import cached_path | |
| from omegaconf import OmegaConf | |
| from ipa.ipa import g2p_object, text_to_ipa | |
| try: | |
| import spaces | |
| USING_SPACES = True | |
| except ImportError: | |
| USING_SPACES = False | |
| from f5_tts.infer.utils_infer import ( | |
| device, | |
| hop_length, | |
| infer_process, | |
| load_checkpoint, | |
| load_vocoder, | |
| mel_spec_type, | |
| n_fft, | |
| n_mel_channels, | |
| ode_method, | |
| preprocess_ref_audio_text, | |
| remove_silence_for_generated_wav, | |
| save_spectrogram, | |
| target_sample_rate, | |
| win_length, | |
| ) | |
| from f5_tts.model import CFM, DiT | |
| from f5_tts.model.utils import get_tokenizer | |
| def gpu_decorator(func): | |
| if USING_SPACES: | |
| return spaces.GPU(func) | |
| else: | |
| return func | |
| vocoder = load_vocoder() | |
| def load_model( | |
| model_cls, | |
| model_cfg, | |
| ckpt_path, | |
| mel_spec_type=mel_spec_type, | |
| vocab_file="", | |
| ode_method=ode_method, | |
| use_ema=True, | |
| device=device, | |
| fp16=False, | |
| ): | |
| if vocab_file == "": | |
| vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt")) | |
| tokenizer = "custom" | |
| print("\nvocab : ", vocab_file) | |
| print("token : ", tokenizer) | |
| print("model : ", ckpt_path, "\n") | |
| vocab_char_map, vocab_size = get_tokenizer(vocab_file, tokenizer) | |
| model = CFM( | |
| transformer=model_cls( | |
| **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels | |
| ), | |
| mel_spec_kwargs=dict( | |
| n_fft=n_fft, | |
| hop_length=hop_length, | |
| win_length=win_length, | |
| n_mel_channels=n_mel_channels, | |
| target_sample_rate=target_sample_rate, | |
| mel_spec_type=mel_spec_type, | |
| ), | |
| odeint_kwargs=dict( | |
| method=ode_method, | |
| ), | |
| vocab_char_map=vocab_char_map, | |
| ).to(device) | |
| dtype = torch.float32 if mel_spec_type == "bigvgan" or not fp16 else None | |
| model = load_checkpoint(model, ckpt_path, device, dtype=dtype, use_ema=use_ema) | |
| return model | |
| def load_f5tts(ckpt_path, vocab_path, old=False, fp16=False): | |
| ckpt_path = str(cached_path(ckpt_path)) | |
| F5TTS_model_cfg = dict( | |
| dim=1024, | |
| depth=22, | |
| heads=16, | |
| ff_mult=2, | |
| text_dim=512, | |
| conv_layers=4, | |
| text_mask_padding=not old, | |
| pe_attn_head=1 if old else None, | |
| ) | |
| vocab_path = str(cached_path(vocab_path)) | |
| return load_model( | |
| DiT, | |
| F5TTS_model_cfg, | |
| ckpt_path, | |
| vocab_file=vocab_path, | |
| use_ema=old, | |
| fp16=fp16, | |
| ) | |
| OmegaConf.register_new_resolver("load_f5tts", load_f5tts) | |
| models_config = OmegaConf.to_object(OmegaConf.load("configs/models.yaml")) | |
| refs_config = OmegaConf.to_object(OmegaConf.load("configs/refs.yaml")) | |
| examples_config = OmegaConf.to_object(OmegaConf.load("configs/examples.yaml")) | |
| DEFAULT_MODEL_ID = list(models_config.keys())[0] | |
| ETHNICITIES = list(set([k.split("_")[0] for k in g2p_object.keys()])) | |
| def infer( | |
| ref_audio_orig, | |
| ref_text, | |
| gen_text, | |
| model, | |
| remove_silence=False, | |
| cross_fade_duration=0.15, | |
| nfe_step=32, | |
| speed=1, | |
| show_info=gr.Info, | |
| ): | |
| if not ref_audio_orig: | |
| gr.Warning("Please provide reference audio.") | |
| return gr.update(), gr.update(), ref_text | |
| if not gen_text.strip(): | |
| gr.Warning("Please enter text to generate.") | |
| return gr.update(), gr.update(), ref_text | |
| ref_audio, ref_text = preprocess_ref_audio_text( | |
| ref_audio_orig, ref_text, show_info=show_info | |
| ) | |
| final_wave, final_sample_rate, combined_spectrogram = infer_process( | |
| ref_audio, | |
| ref_text, | |
| gen_text, | |
| model, | |
| vocoder, | |
| cross_fade_duration=cross_fade_duration, | |
| nfe_step=nfe_step, | |
| speed=speed, | |
| show_info=show_info, | |
| progress=gr.Progress(), | |
| ) | |
| # Remove silence | |
| if remove_silence: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| sf.write(f.name, final_wave, final_sample_rate) | |
| remove_silence_for_generated_wav(f.name) | |
| final_wave = torchcodec.decoders.AudioDecoder(f.name).get_all_samples().data | |
| final_wave = final_wave.squeeze().cpu().numpy() | |
| # Save the spectrogram | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_spectrogram: | |
| spectrogram_path = tmp_spectrogram.name | |
| save_spectrogram(combined_spectrogram, spectrogram_path) | |
| return (final_sample_rate, final_wave), spectrogram_path | |
| def get_title(): | |
| with open("DEMO.md", encoding="utf-8") as tong: | |
| return tong.readline().strip("# ") | |
| demo = gr.Blocks( | |
| title=get_title(), | |
| css="""@import url(https://tauhu.tw/tauhu-oo.css); | |
| .textonly textarea {border-width: 0px !important; } | |
| """, | |
| theme=gr.themes.Default( | |
| font=( | |
| "tauhu-oo", | |
| gr.themes.GoogleFont("Source Sans Pro"), | |
| "ui-sans-serif", | |
| "system-ui", | |
| "sans-serif", | |
| ) | |
| ), | |
| js=""" | |
| function addButtonsEvent() { | |
| const buttons = document.querySelectorAll("#head-html-block button"); | |
| buttons.forEach(button => { | |
| button.addEventListener("click", () => { | |
| navigator.clipboard.writeText(button.innerText); | |
| }); | |
| }); | |
| } | |
| """, | |
| ) | |
| with demo: | |
| with open("DEMO.md") as tong: | |
| gr.Markdown(tong.read()) | |
| gr.HTML( | |
| "特殊符號請複製使用(滑鼠點擊即可複製):<button>é</button> <button>ṟ</button> <button>ɨ</button> <button>ʉ</button>", | |
| padding=False, | |
| elem_id="head-html-block", | |
| ) | |
| with gr.Tab("預設配音員"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| default_speaker_ethnicity = gr.Dropdown( | |
| choices=ETHNICITIES, | |
| label="步驟一:選擇族別", | |
| value="阿美", | |
| filterable=False, | |
| ) | |
| def get_refs_by_perfix(prefix: str): | |
| return [r for r in refs_config.keys() if r.startswith(prefix)] | |
| default_speaker_refs = gr.Dropdown( | |
| choices=get_refs_by_perfix(default_speaker_ethnicity.value), | |
| label="步驟二:選擇配音員", | |
| value=get_refs_by_perfix(default_speaker_ethnicity.value)[0], | |
| filterable=False, | |
| ) | |
| default_speaker_gen_text_input = gr.Textbox( | |
| label="步驟三:輸入文字(上限 300 字元)", | |
| value="", | |
| ) | |
| default_speaker_generate_btn = gr.Button( | |
| "步驟四:開始合成", variant="primary" | |
| ) | |
| with gr.Column(): | |
| default_speaker_audio_output = gr.Audio( | |
| label="合成結果", show_share_button=False, show_download_button=True | |
| ) | |
| with gr.Tab("自己當配音員"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| custom_speaker_ethnicity = gr.Dropdown( | |
| choices=ETHNICITIES, | |
| label="步驟一:選擇族別與語別", | |
| value="阿美", | |
| filterable=False, | |
| ) | |
| custom_speaker_language = gr.Dropdown( | |
| choices=[ | |
| k | |
| for k in g2p_object.keys() | |
| if k.startswith(custom_speaker_ethnicity.value) | |
| ], | |
| value=[ | |
| k | |
| for k in g2p_object.keys() | |
| if k.startswith(custom_speaker_ethnicity.value) | |
| ][0], | |
| filterable=False, | |
| show_label=False, | |
| ) | |
| custom_speaker_ref_text_input = gr.Textbox( | |
| value=refs_config[ | |
| get_refs_by_perfix(custom_speaker_language.value)[0] | |
| ]["text"], | |
| interactive=False, | |
| label="步驟二:點選🎙️錄製下方句子,或上傳與句子相符的音檔", | |
| elem_classes="textonly", | |
| ) | |
| custom_speaker_audio_input = gr.Audio( | |
| type="filepath", | |
| sources=["microphone", "upload"], | |
| waveform_options=gr.WaveformOptions( | |
| sample_rate=24000, | |
| ), | |
| label="錄製或上傳", | |
| ) | |
| custom_speaker_gen_text_input = gr.Textbox( | |
| label="步驟三:輸入合成文字(上限 300 字元)", | |
| value="", | |
| ) | |
| custom_speaker_generate_btn = gr.Button( | |
| "步驟四:開始合成", variant="primary" | |
| ) | |
| with gr.Column(): | |
| custom_speaker_audio_output = gr.Audio( | |
| label="合成結果", show_share_button=False, show_download_button=True | |
| ) | |
| default_speaker_ethnicity.change( | |
| lambda ethnicity: gr.Dropdown( | |
| choices=get_refs_by_perfix(ethnicity), | |
| value=get_refs_by_perfix(ethnicity)[0], | |
| ), | |
| inputs=[default_speaker_ethnicity], | |
| outputs=[default_speaker_refs], | |
| ) | |
| def default_speaker_tts( | |
| ref: str, | |
| gen_text_input: str, | |
| ): | |
| language = re.sub(r"_[男女]聲[12]?", "", ref) | |
| ref_text_input = refs_config[ref]["text"] | |
| ref_audio_input = refs_config[ref]["wav"] | |
| gen_text_input = gen_text_input.strip() | |
| if len(gen_text_input) == 0: | |
| raise gr.Error("請勿輸入空字串。") | |
| if gen_text_input[-1] not in [".", "?", "!", ",", ";", ":"]: | |
| gen_text_input += "." | |
| ignore_punctuation = False | |
| ipa_with_ng = False | |
| ref_text_input = text_to_ipa( | |
| ref_text_input, language, ignore_punctuation, ipa_with_ng | |
| ) | |
| gen_text_input = text_to_ipa( | |
| gen_text_input, language, ignore_punctuation, ipa_with_ng | |
| ) | |
| audio_out, spectrogram_path = infer( | |
| ref_audio_input, | |
| ref_text_input, | |
| gen_text_input, | |
| models_config[DEFAULT_MODEL_ID], | |
| ) | |
| return audio_out | |
| default_speaker_generate_btn.click( | |
| default_speaker_tts, | |
| inputs=[ | |
| default_speaker_refs, | |
| default_speaker_gen_text_input, | |
| ], | |
| outputs=[default_speaker_audio_output], | |
| ) | |
| custom_speaker_ethnicity.change( | |
| lambda ethnicity: gr.Dropdown( | |
| choices=[k for k in g2p_object.keys() if k.startswith(ethnicity)], | |
| value=[k for k in g2p_object.keys() if k.startswith(ethnicity)][0], | |
| visible=len([k for k in g2p_object.keys() if k.startswith(ethnicity)]) > 1, | |
| ), | |
| inputs=[custom_speaker_ethnicity], | |
| outputs=[custom_speaker_language], | |
| ) | |
| custom_speaker_language.change( | |
| lambda lang: gr.Textbox( | |
| value=refs_config[get_refs_by_perfix(lang)[0]]["text"], | |
| ), | |
| inputs=[custom_speaker_language], | |
| outputs=[custom_speaker_ref_text_input], | |
| ) | |
| def custom_speaker_tts( | |
| language: str, | |
| ref_text_input: str, | |
| ref_audio_input: str, | |
| gen_text_input: str, | |
| ): | |
| ref_text_input = ref_text_input.strip() | |
| if len(ref_text_input) == 0: | |
| raise gr.Error("請勿輸入空字串。") | |
| gen_text_input = gen_text_input.strip() | |
| if len(gen_text_input) == 0: | |
| raise gr.Error("請勿輸入空字串。") | |
| ignore_punctuation = False | |
| ipa_with_ng = False | |
| if gen_text_input[-1] not in [".", "?", "!", ",", ";", ":"]: | |
| gen_text_input += "." | |
| ref_text_input = text_to_ipa( | |
| ref_text_input, language, ignore_punctuation, ipa_with_ng | |
| ) | |
| gen_text_input = text_to_ipa( | |
| gen_text_input, language, ignore_punctuation, ipa_with_ng | |
| ) | |
| audio_out, spectrogram_path = infer( | |
| ref_audio_input, | |
| ref_text_input, | |
| gen_text_input, | |
| models_config[DEFAULT_MODEL_ID], | |
| ) | |
| return audio_out | |
| custom_speaker_generate_btn.click( | |
| custom_speaker_tts, | |
| inputs=[ | |
| custom_speaker_language, | |
| custom_speaker_ref_text_input, | |
| custom_speaker_audio_input, | |
| custom_speaker_gen_text_input, | |
| ], | |
| outputs=[custom_speaker_audio_output], | |
| ) | |
| demo.launch() | |