Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import argparse | |
| import os | |
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| from models.tts.naturalspeech2.ns2 import NaturalSpeech2 | |
| from encodec import EncodecModel | |
| from encodec.utils import convert_audio | |
| from utils.util import load_config | |
| from text import text_to_sequence | |
| from text.cmudict import valid_symbols | |
| from text.g2p import preprocess_english, read_lexicon | |
| import torchaudio | |
| def build_codec(device): | |
| encodec_model = EncodecModel.encodec_model_24khz() | |
| encodec_model = encodec_model.to(device=device) | |
| encodec_model.set_target_bandwidth(12.0) | |
| return encodec_model | |
| def build_model(cfg, device): | |
| model = NaturalSpeech2(cfg.model) | |
| model.load_state_dict( | |
| torch.load( | |
| "ckpts/ns2/pytorch_model.bin", | |
| map_location="cpu", | |
| ) | |
| ) | |
| model = model.to(device=device) | |
| return model | |
| def ns2_inference( | |
| prmopt_audio_path, | |
| text, | |
| diffusion_steps=100, | |
| ): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| os.environ["WORK_DIR"] = "./" | |
| cfg = load_config("egs/tts/NaturalSpeech2/exp_config.json") | |
| model = build_model(cfg, device) | |
| codec = build_codec(device) | |
| ref_wav_path = prmopt_audio_path | |
| ref_wav, sr = torchaudio.load(ref_wav_path) | |
| ref_wav = convert_audio( | |
| ref_wav, sr, codec.sample_rate, codec.channels | |
| ) | |
| ref_wav = ref_wav.unsqueeze(0).to(device=device) | |
| with torch.no_grad(): | |
| encoded_frames = codec.encode(ref_wav) | |
| ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) | |
| ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device) | |
| symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"] | |
| phone2id = {s: i for i, s in enumerate(symbols)} | |
| id2phone = {i: s for s, i in phone2id.items()} | |
| lexicon = read_lexicon(cfg.preprocess.lexicon_path) | |
| phone_seq = preprocess_english(text, lexicon) | |
| phone_id = np.array( | |
| [ | |
| *map( | |
| phone2id.get, | |
| phone_seq.replace("{", "").replace("}", "").split(), | |
| ) | |
| ] | |
| ) | |
| phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=device) | |
| x0, prior_out = model.inference( | |
| ref_code, phone_id, ref_mask, diffusion_steps | |
| ) | |
| latent_ref = codec.quantizer.vq.decode(ref_code.transpose(0, 1)) | |
| rec_wav = codec.decoder(x0) | |
| os.makedirs("result", exist_ok=True) | |
| sf.write( | |
| "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result"), | |
| rec_wav[0, 0].detach().cpu().numpy(), | |
| samplerate=24000, | |
| ) | |
| result_file = "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result") | |
| return result_file | |
| demo_inputs = [ | |
| gr.Audio( | |
| sources=["upload", "microphone"], | |
| label="Upload a reference speech you want to clone timbre", | |
| type="filepath", | |
| ), | |
| gr.Textbox( | |
| value="Amphion is a toolkit that can speak, make sounds, and sing.", | |
| label="Text you want to generate", | |
| type="text", | |
| ), | |
| gr.Slider( | |
| 10, | |
| 1000, | |
| value=200, | |
| step=1, | |
| label="Diffusion Inference Steps", | |
| info="As the step number increases, the synthesis quality will be better while the inference speed will be lower", | |
| ), | |
| ] | |
| demo_outputs = gr.Audio(label="") | |
| demo = gr.Interface( | |
| fn=ns2_inference, | |
| inputs=demo_inputs, | |
| outputs=demo_outputs, | |
| title="Amphion Zero-Shot TTS NaturalSpeech2" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |