import argparse import logging import os import pathlib import time import tempfile import platform if platform.system().lower() == 'windows': temp = pathlib.PosixPath pathlib.PosixPath = pathlib.WindowsPath elif platform.system().lower() == 'linux': temp = pathlib.WindowsPath pathlib.WindowsPath = pathlib.PosixPath os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" import langid langid.set_languages(['en', 'zh', 'ja']) import torch import torchaudio import random import numpy as np from data.tokenizer import ( AudioTokenizer, tokenize_audio, ) from data.collation import get_text_token_collater from models.vallex import VALLE from utils.g2p import PhonemeBpeTokenizer from descriptions import * from macros import * import gradio as gr import whisper import multiprocessing import math import tempfile from typing import Optional, Tuple, Union import matplotlib.pyplot as plt from loguru import logger from PIL import Image from torch import Tensor from torchaudio.backend.common import AudioMetaData from df import config from df.enhance import enhance, init_df, load_audio, save_audio from df.io import resample thread_count = multiprocessing.cpu_count() print("Use",thread_count,"cpu cores for computing") torch.set_num_threads(thread_count) torch.set_num_interop_threads(thread_count) torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) torch._C._set_graph_executor_optimize(False) text_tokenizer = PhonemeBpeTokenizer(tokenizer_path="./utils/g2p/bpe_69.json") text_collater = get_text_token_collater() device = torch.device("cpu") if torch.cuda.is_available(): device = torch.device("cuda", 0) # Denoise model1, df, _ = init_df("./DeepFilterNet2", config_allow_defaults=True) model1 = model1.to(device=device).eval() fig_noisy: plt.Figure fig_enh: plt.Figure ax_noisy: plt.Axes ax_enh: plt.Axes fig_noisy, ax_noisy = plt.subplots(figsize=(15.2, 4)) fig_noisy.set_tight_layout(True) fig_enh, ax_enh = plt.subplots(figsize=(15.2, 4)) fig_enh.set_tight_layout(True) NOISES = { "None": None, } def mix_at_snr(clean, noise, snr, eps=1e-10): """Mix clean and noise signal at a given SNR. Args: clean: 1D Tensor with the clean signal to mix. noise: 1D Tensor of shape. snr: Signal to noise ratio. Returns: clean: 1D Tensor with gain changed according to the snr. noise: 1D Tensor with the combined noise channels. mix: 1D Tensor with added clean and noise signals. """ clean = torch.as_tensor(clean).mean(0, keepdim=True) noise = torch.as_tensor(noise).mean(0, keepdim=True) if noise.shape[1] < clean.shape[1]: noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1])))) max_start = int(noise.shape[1] - clean.shape[1]) start = torch.randint(0, max_start, ()).item() if max_start > 0 else 0 logger.debug(f"start: {start}, {clean.shape}") noise = noise[:, start : start + clean.shape[1]] E_speech = torch.mean(clean.pow(2)) + eps E_noise = torch.mean(noise.pow(2)) K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps) noise = noise / K mixture = clean + noise logger.debug("mixture: {mixture.shape}") assert torch.isfinite(mixture).all() max_m = mixture.abs().max() if max_m > 1: logger.warning(f"Clipping detected during mixing. Reducing gain by {1/max_m}") clean, noise, mixture = clean / max_m, noise / max_m, mixture / max_m return clean, noise, mixture def load_audio_gradio( audio_or_file: Union[None, str, Tuple[int, np.ndarray]], sr: int ) -> Optional[Tuple[Tensor, AudioMetaData]]: if audio_or_file is None: return None if isinstance(audio_or_file, str): if audio_or_file.lower() == "none": return None # First try default format audio, meta = load_audio(audio_or_file, sr) else: meta = AudioMetaData(-1, -1, -1, -1, "") assert isinstance(audio_or_file, (tuple, list)) meta.sample_rate, audio_np = audio_or_file # Gradio documentation says, the shape is [samples, 2], but apparently sometimes its not. audio_np = audio_np.reshape(audio_np.shape[0], -1).T if audio_np.dtype == np.int16: audio_np = (audio_np / (1 << 15)).astype(np.float32) elif audio_np.dtype == np.int32: audio_np = (audio_np / (1 << 31)).astype(np.float32) audio = resample(torch.from_numpy(audio_np), meta.sample_rate, sr) return audio, meta def demo_fn(speech_upl: str, noise_type: str, snr: int, mic_input: str): if mic_input: speech_upl = mic_input sr = config("sr", 48000, int, section="df") logger.info(f"Got parameters speech_upl: {speech_upl}, noise: {noise_type}, snr: {snr}") snr = int(snr) noise_fn = NOISES[noise_type] meta = AudioMetaData(-1, -1, -1, -1, "") max_s = 10 # limit to 10 seconds if speech_upl is not None: sample, meta = load_audio(speech_upl, sr) max_len = max_s * sr if sample.shape[-1] > max_len: start = torch.randint(0, sample.shape[-1] - max_len, ()).item() sample = sample[..., start : start + max_len] else: sample, meta = load_audio("samples/p232_013_clean.wav", sr) sample = sample[..., : max_s * sr] if sample.dim() > 1 and sample.shape[0] > 1: assert ( sample.shape[1] > sample.shape[0] ), f"Expecting channels first, but got {sample.shape}" sample = sample.mean(dim=0, keepdim=True) logger.info(f"Loaded sample with shape {sample.shape}") if noise_fn is not None: noise, _ = load_audio(noise_fn, sr) # type: ignore logger.info(f"Loaded noise with shape {noise.shape}") _, _, sample = mix_at_snr(sample, noise, snr) logger.info("Start denoising audio") enhanced = enhance(model1, df, sample) logger.info("Denoising finished") lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0) lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1) enhanced = enhanced * lim if meta.sample_rate != sr: enhanced = resample(enhanced, sr, meta.sample_rate) sample = resample(sample, sr, meta.sample_rate) sr = meta.sample_rate noisy_wav = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name save_audio(noisy_wav, sample, sr) enhanced_wav = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name save_audio(enhanced_wav, enhanced, sr) logger.info(f"saved audios: {noisy_wav}, {enhanced_wav}") ax_noisy.clear() ax_enh.clear() noisy_im = spec_im(sample, sr=sr, figure=fig_noisy, ax=ax_noisy) enh_im = spec_im(enhanced, sr=sr, figure=fig_enh, ax=ax_enh) # noisy_wav = gr.make_waveform(noisy_fn, bar_count=200) # enh_wav = gr.make_waveform(enhanced_fn, bar_count=200) return noisy_wav, noisy_im, enhanced_wav, enh_im def specshow( spec, ax=None, title=None, xlabel=None, ylabel=None, sr=48000, n_fft=None, hop=None, t=None, f=None, vmin=-100, vmax=0, xlim=None, ylim=None, cmap="inferno", ): """Plots a spectrogram of shape [F, T]""" spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec if ax is not None: set_title = ax.set_title set_xlabel = ax.set_xlabel set_ylabel = ax.set_ylabel set_xlim = ax.set_xlim set_ylim = ax.set_ylim else: ax = plt set_title = plt.title set_xlabel = plt.xlabel set_ylabel = plt.ylabel set_xlim = plt.xlim set_ylim = plt.ylim if n_fft is None: if spec.shape[0] % 2 == 0: n_fft = spec.shape[0] * 2 else: n_fft = (spec.shape[0] - 1) * 2 hop = hop or n_fft // 4 if t is None: t = np.arange(0, spec_np.shape[-1]) * hop / sr if f is None: f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000 im = ax.pcolormesh( t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap ) if title is not None: set_title(title) if xlabel is not None: set_xlabel(xlabel) if ylabel is not None: set_ylabel(ylabel) if xlim is not None: set_xlim(xlim) if ylim is not None: set_ylim(ylim) return im def spec_im( audio: torch.Tensor, figsize=(15, 5), colorbar=False, colorbar_format=None, figure=None, labels=True, **kwargs, ) -> Image: audio = torch.as_tensor(audio) if labels: kwargs.setdefault("xlabel", "Time [s]") kwargs.setdefault("ylabel", "Frequency [Hz]") n_fft = kwargs.setdefault("n_fft", 1024) hop = kwargs.setdefault("hop", 512) w = torch.hann_window(n_fft, device=audio.device) spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False) spec = spec.div_(w.pow(2).sum()) spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10) kwargs.setdefault("vmax", max(0.0, spec.max().item())) if figure is None: figure = plt.figure(figsize=figsize) figure.set_tight_layout(True) if spec.dim() > 2: spec = spec.squeeze(0) im = specshow(spec, **kwargs) if colorbar: ckwargs = {} if "ax" in kwargs: if colorbar_format is None: if kwargs.get("vmin", None) is not None or kwargs.get("vmax", None) is not None: colorbar_format = "%+2.0f dB" ckwargs = {"ax": kwargs["ax"]} plt.colorbar(im, format=colorbar_format, **ckwargs) figure.canvas.draw() return Image.frombytes("RGB", figure.canvas.get_width_height(), figure.canvas.tostring_rgb()) def toggle(choice): if choice == "mic": return gr.update(visible=True, value=None), gr.update(visible=False, value=None) else: return gr.update(visible=False, value=None), gr.update(visible=True, value=None) # VALL-E-X model model = VALLE( N_DIM, NUM_HEAD, NUM_LAYERS, norm_first=True, add_prenet=False, prefix_mode=PREFIX_MODE, share_embedding=True, nar_scale_factor=1.0, prepend_bos=True, num_quantizers=NUM_QUANTIZERS, ) checkpoint = torch.load("./epoch-10.pt", map_location='cpu') missing_keys, unexpected_keys = model.load_state_dict( checkpoint["model"], strict=True ) assert not missing_keys model.eval() # Encodec model audio_tokenizer = AudioTokenizer(device) # ASR whisper_model = whisper.load_model("medium").cpu() # Voice Presets preset_list = os.walk("./presets/").__next__()[2] preset_list = [preset[:-4] for preset in preset_list if preset.endswith(".npz")] def clear_prompts(): try: path = tempfile.gettempdir() for eachfile in os.listdir(path): filename = os.path.join(path, eachfile) if os.path.isfile(filename) and filename.endswith(".npz"): lastmodifytime = os.stat(filename).st_mtime endfiletime = time.time() - 60 if endfiletime > lastmodifytime: os.remove(filename) except: return def transcribe_one(model, audio_path): # load audio and pad/trim it to fit 30 seconds audio = whisper.load_audio(audio_path) audio = whisper.pad_or_trim(audio) # make log-Mel spectrogram and move to the same device as the model mel = whisper.log_mel_spectrogram(audio).to(model.device) # detect the spoken language _, probs = model.detect_language(mel) print(f"Detected language: {max(probs, key=probs.get)}") lang = max(probs, key=probs.get) # decode the audio options = whisper.DecodingOptions(temperature=1.0, best_of=5, fp16=False if device == torch.device("cpu") else True, sample_len=150) result = whisper.decode(model, mel, options) # print the recognized text print(result.text) text_pr = result.text if text_pr.strip(" ")[-1] not in "?!.,。,?!。、": text_pr += "." return lang, text_pr def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content): global model, text_collater, text_tokenizer, audio_tokenizer clear_prompts() audio_prompt = uploaded_audio if uploaded_audio is not None else recorded_audio sr, wav_pr = audio_prompt if len(wav_pr) / sr > 15: return "Rejected, Audio too long (should be less than 15 seconds)", None if not isinstance(wav_pr, torch.FloatTensor): wav_pr = torch.FloatTensor(wav_pr) if wav_pr.abs().max() > 1: wav_pr /= wav_pr.abs().max() if wav_pr.size(-1) == 2: wav_pr = wav_pr[:, 0] if wav_pr.ndim == 1: wav_pr = wav_pr.unsqueeze(0) assert wav_pr.ndim and wav_pr.size(0) == 1 if transcript_content == "": text_pr, lang_pr = make_prompt(name, wav_pr, sr, save=False) else: lang_pr = langid.classify(str(transcript_content))[0] lang_token = lang2token[lang_pr] text_pr = f"{lang_token}{str(transcript_content)}{lang_token}" # tokenize audio encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr)) audio_tokens = encoded_frames[0][0].transpose(2, 1).cpu().numpy() # tokenize text phonemes, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip()) text_tokens, enroll_x_lens = text_collater( [ phonemes ] ) message = f"Detected language: {lang_pr}\n Detected text {text_pr}\n" # save as npz file np.savez(os.path.join(tempfile.gettempdir(), f"{name}.npz"), audio_tokens=audio_tokens, text_tokens=text_tokens, lang_code=lang2code[lang_pr]) return "提取音色成功!", os.path.join(tempfile.gettempdir(), f"{name}.npz") def make_prompt(name, wav, sr, save=True): global whisper_model whisper_model.to(device) if not isinstance(wav, torch.FloatTensor): wav = torch.tensor(wav) if wav.abs().max() > 1: wav /= wav.abs().max() if wav.size(-1) == 2: wav = wav.mean(-1, keepdim=False) if wav.ndim == 1: wav = wav.unsqueeze(0) assert wav.ndim and wav.size(0) == 1 torchaudio.save(f"./prompts/{name}.wav", wav, sr) lang, text = transcribe_one(whisper_model, f"./prompts/{name}.wav") lang_token = lang2token[lang] text = lang_token + text + lang_token with open(f"./prompts/{name}.txt", 'w') as f: f.write(text) if not save: os.remove(f"./prompts/{name}.wav") os.remove(f"./prompts/{name}.txt") whisper_model.cpu() torch.cuda.empty_cache() return text, lang @torch.no_grad() def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content): if len(text) > 150: return "Rejected, Text too long (should be less than 150 characters)", None global model, text_collater, text_tokenizer, audio_tokenizer model.to(device) audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt sr, wav_pr = audio_prompt if len(wav_pr) / sr > 15: return "Rejected, Audio too long (should be less than 15 seconds)", None if not isinstance(wav_pr, torch.FloatTensor): wav_pr = torch.FloatTensor(wav_pr) if wav_pr.abs().max() > 1: wav_pr /= wav_pr.abs().max() if wav_pr.size(-1) == 2: wav_pr = wav_pr[:, 0] if wav_pr.ndim == 1: wav_pr = wav_pr.unsqueeze(0) assert wav_pr.ndim and wav_pr.size(0) == 1 if transcript_content == "": text_pr, lang_pr = make_prompt('dummy', wav_pr, sr, save=False) else: lang_pr = langid.classify(str(transcript_content))[0] lang_token = lang2token[lang_pr] text_pr = f"{lang_token}{str(transcript_content)}{lang_token}" if language == 'auto-detect': lang_token = lang2token[langid.classify(text)[0]] else: lang_token = langdropdown2token[language] lang = token2lang[lang_token] text = lang_token + text + lang_token # onload model model.to(device) # tokenize audio encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr)) audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device) # tokenize text logging.info(f"synthesize text: {text}") phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) text_tokens, text_tokens_lens = text_collater( [ phone_tokens ] ) enroll_x_lens = None if text_pr: text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip()) text_prompts, enroll_x_lens = text_collater( [ text_prompts ] ) text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) text_tokens_lens += enroll_x_lens lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] encoded_frames = model.inference( text_tokens.to(device), text_tokens_lens.to(device), audio_prompts, enroll_x_lens=enroll_x_lens, top_k=-100, temperature=1, prompt_language=lang_pr, text_language=langs if accent == "no-accent" else lang, ) samples = audio_tokenizer.decode( [(encoded_frames.transpose(2, 1), None)] ) # offload model model.to('cpu') torch.cuda.empty_cache() message = f"text prompt: {text_pr}\nsythesized text: {text}" return message, (24000, samples[0][0].cpu().numpy()) @torch.no_grad() def infer_from_prompt(text, language, accent, preset_prompt, prompt_file): if len(text) > 150: return "Rejected, Text too long (should be less than 150 characters)", None clear_prompts() model.to(device) # text to synthesize if language == 'auto-detect': lang_token = lang2token[langid.classify(text)[0]] else: lang_token = langdropdown2token[language] lang = token2lang[lang_token] text = lang_token + text + lang_token # load prompt if prompt_file is not None: prompt_data = np.load(prompt_file.name) else: prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz")) audio_prompts = prompt_data['audio_tokens'] text_prompts = prompt_data['text_tokens'] lang_pr = prompt_data['lang_code'] lang_pr = code2lang[int(lang_pr)] # numpy to tensor audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) text_prompts = torch.tensor(text_prompts).type(torch.int32) enroll_x_lens = text_prompts.shape[-1] logging.info(f"synthesize text: {text}") phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) text_tokens, text_tokens_lens = text_collater( [ phone_tokens ] ) text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) text_tokens_lens += enroll_x_lens # accent control lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] encoded_frames = model.inference( text_tokens.to(device), text_tokens_lens.to(device), audio_prompts, enroll_x_lens=enroll_x_lens, top_k=-100, temperature=1, prompt_language=lang_pr, text_language=langs if accent == "no-accent" else lang, ) samples = audio_tokenizer.decode( [(encoded_frames.transpose(2, 1), None)] ) model.to('cpu') torch.cuda.empty_cache() message = f"sythesized text: {text}" return message, (24000, samples[0][0].cpu().numpy()) from utils.sentence_cutter import split_text_into_sentences @torch.no_grad() def infer_long_text(text, preset_prompt, prompt=None, language='auto', accent='no-accent'): """ For long audio generation, two modes are available. fixed-prompt: This mode will keep using the same prompt the user has provided, and generate audio sentence by sentence. sliding-window: This mode will use the last sentence as the prompt for the next sentence, but has some concern on speaker maintenance. """ if len(text) > 1000: return "Rejected, Text too long (should be less than 1000 characters)", None mode = 'fixed-prompt' global model, audio_tokenizer, text_tokenizer, text_collater model.to(device) if (prompt is None or prompt == "") and preset_prompt == "": mode = 'sliding-window' # If no prompt is given, use sliding-window mode sentences = split_text_into_sentences(text) # detect language if language == "auto-detect": language = langid.classify(text)[0] else: language = token2lang[langdropdown2token[language]] # if initial prompt is given, encode it if prompt is not None and prompt != "": # load prompt prompt_data = np.load(prompt.name) audio_prompts = prompt_data['audio_tokens'] text_prompts = prompt_data['text_tokens'] lang_pr = prompt_data['lang_code'] lang_pr = code2lang[int(lang_pr)] # numpy to tensor audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) text_prompts = torch.tensor(text_prompts).type(torch.int32) elif preset_prompt is not None and preset_prompt != "": prompt_data = np.load(os.path.join("./presets/", f"{preset_prompt}.npz")) audio_prompts = prompt_data['audio_tokens'] text_prompts = prompt_data['text_tokens'] lang_pr = prompt_data['lang_code'] lang_pr = code2lang[int(lang_pr)] # numpy to tensor audio_prompts = torch.tensor(audio_prompts).type(torch.int32).to(device) text_prompts = torch.tensor(text_prompts).type(torch.int32) else: audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device) text_prompts = torch.zeros([1, 0]).type(torch.int32) lang_pr = language if language != 'mix' else 'en' if mode == 'fixed-prompt': complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) for text in sentences: text = text.replace("\n", "").strip(" ") if text == "": continue lang_token = lang2token[language] lang = token2lang[lang_token] text = lang_token + text + lang_token enroll_x_lens = text_prompts.shape[-1] logging.info(f"synthesize text: {text}") phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) text_tokens, text_tokens_lens = text_collater( [ phone_tokens ] ) text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) text_tokens_lens += enroll_x_lens # accent control lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] encoded_frames = model.inference( text_tokens.to(device), text_tokens_lens.to(device), audio_prompts, enroll_x_lens=enroll_x_lens, top_k=-100, temperature=1, prompt_language=lang_pr, text_language=langs if accent == "no-accent" else lang, ) complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) samples = audio_tokenizer.decode( [(complete_tokens, None)] ) model.to('cpu') message = f"Cut into {len(sentences)} sentences" return message, (24000, samples[0][0].cpu().numpy()) elif mode == "sliding-window": complete_tokens = torch.zeros([1, NUM_QUANTIZERS, 0]).type(torch.LongTensor).to(device) original_audio_prompts = audio_prompts original_text_prompts = text_prompts for text in sentences: text = text.replace("\n", "").strip(" ") if text == "": continue lang_token = lang2token[language] lang = token2lang[lang_token] text = lang_token + text + lang_token enroll_x_lens = text_prompts.shape[-1] logging.info(f"synthesize text: {text}") phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip()) text_tokens, text_tokens_lens = text_collater( [ phone_tokens ] ) text_tokens = torch.cat([text_prompts, text_tokens], dim=-1) text_tokens_lens += enroll_x_lens # accent control lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]] encoded_frames = model.inference( text_tokens.to(device), text_tokens_lens.to(device), audio_prompts, enroll_x_lens=enroll_x_lens, top_k=-100, temperature=1, prompt_language=lang_pr, text_language=langs if accent == "no-accent" else lang, ) complete_tokens = torch.cat([complete_tokens, encoded_frames.transpose(2, 1)], dim=-1) if torch.rand(1) < 1.0: audio_prompts = encoded_frames[:, :, -NUM_QUANTIZERS:] text_prompts = text_tokens[:, enroll_x_lens:] else: audio_prompts = original_audio_prompts text_prompts = original_text_prompts samples = audio_tokenizer.decode( [(complete_tokens, None)] ) model.to('cpu') message = f"Cut into {len(sentences)} sentences" return message, (24000, samples[0][0].cpu().numpy()) else: raise ValueError(f"No such mode {mode}") def main(): app = gr.Blocks() with app: gr.HTML("