import gradio as gr from hyper_parameters import tacotron_params as hparams from training import load_model from audio_processing import griffin_lim from nn_layers import TacotronSTFT from text import text_to_sequence from hifigan.env import AttrDict from examples_taco2 import * from hifigan.models import Generator import torch import numpy as np import json import os from matplotlib import pyplot as plt # Adjust vertical spacing between subplots plt.subplots_adjust(hspace=0.15) # You can adjust the value as needed # Adjust the white space (margins) around the plot plt.tight_layout(pad=0.5) # You can adjust the pad value as needed torch.manual_seed(1234) MAX_WAV_VALUE = 32768.0 DESCRIPTION = """ This is a Tacotron2 model based on the NVIDIA's model plus three unsupervised Global Style Tokens (GST). The whole architecture has been trained from scratch with the LJSpeech dataset. In order to control the relevance of each style token, we configured the attention module as a single-head. Keep in mind that, for a better synthetic output, the sum of the three style weights should be around 1. A combination that sums less than 1 may work, but higher the generated speech may show more distortion and miss-pronunciations. """ def load_checkpoint(filepath, device): assert os.path.isfile(filepath) print("Loading '{}'".format(filepath)) checkpoint_dict = torch.load(filepath, map_location=device) print("Complete.") return checkpoint_dict def plot_spec_align_sep(mel, align): plt.figure(figsize=(4, 3)) fig_mel = plt.figure() ax_mel = fig_mel.add_subplot(111) fig_mel.tight_layout() ax_mel.imshow(mel) # fig_mel.set_title('Mel-Scale Spectrogram', fontsize=12) fig_align = plt.figure() ax_align = fig_align.add_subplot(111) # fig_align fig_align.tight_layout() ax_align.imshow(align) # fig_align.set_title('Alignment', fontsize=12) return fig_mel, fig_align # load trained tacotron2 + GST model: model = load_model(hparams) checkpoint_path = "models/checkpoint_78000.model" model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")['state_dict']) # model.to('cuda') _ = model.eval() # load pre-trained HiFi-GAN model for mel2audio: hifigan_checkpoint_path = "models/generator_v1" config_file = os.path.join(os.path.split(hifigan_checkpoint_path)[0], 'config.json') with open(config_file) as f: data = f.read() json_config = json.loads(data) h = AttrDict(json_config) device = torch.device("cpu") generator = Generator(h).to(device) state_dict_g = load_checkpoint(hifigan_checkpoint_path, device) generator.load_state_dict(state_dict_g['generator']) generator.eval() generator.remove_weight_norm() def synthesize(text, gst_1, gst_2, gst_3, voc): sequence = np.array(text_to_sequence(text, ['english_cleaners']))[None, :] sequence = torch.from_numpy(sequence).to(device='cpu', dtype=torch.int64) # gst_head_scores = np.array([0.5, 0.15, 0.35]) gst_head_scores = np.array([gst_1, gst_2, gst_3]) gst_scores = torch.from_numpy(gst_head_scores).float() with torch.no_grad(): mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence, gst_scores) if voc == 0: # mel2wav inference: with torch.no_grad(): y_g_hat = generator(mel_outputs_postnet) audio = y_g_hat.squeeze() audio = audio * MAX_WAV_VALUE audio_numpy = audio.cpu().numpy().astype('int16') # audio = vocoder_model.inference(mel_outputs_postnet) # audio_numpy = audio.data.cpu().detach().numpy() else: # Griffin Lim vocoder synthesis: griffin_iters = 60 taco_stft = TacotronSTFT(hparams['filter_length'], hparams['hop_length'], hparams['win_length'], sampling_rate=hparams['sampling_rate']) mel_decompress = taco_stft.spectral_de_normalize(mel_outputs_postnet) mel_decompress = mel_decompress.transpose(1, 2).data.cpu() spec_from_mel_scaling = 60 spec_from_mel = torch.mm(mel_decompress[0], taco_stft.mel_basis) spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0) spec_from_mel = spec_from_mel * spec_from_mel_scaling audio = griffin_lim(torch.autograd.Variable(spec_from_mel[:, :, :-1]), taco_stft.stft_fn, griffin_iters) audio = audio.squeeze() audio_numpy = audio.cpu().numpy() # prepare plot for the output: mel_outputs_postnet = torch.flip(mel_outputs_postnet.squeeze(), [0]) mel_outputs_postnet = mel_outputs_postnet.detach().numpy() alignments = alignments.squeeze().T.detach().numpy() # fig_mel = plot_spec_align(mel_outputs_postnet, alignments) # fig_mel, fig_align = plot_spec_align_sep(mel_outputs_postnet, alignments) # normalize numpy arrays between [-1, 1] min_val = np.min(mel_outputs_postnet) max_val = np.max(mel_outputs_postnet) scaled_mel = (mel_outputs_postnet - min_val) / (max_val - min_val) normalized_mel = 2 * scaled_mel - 1 min_val = np.min(alignments) max_val = np.max(alignments) scaled_align = (alignments - min_val) / (max_val - min_val) normalized_align = 2 * scaled_align - 1 aw = gr.make_waveform((22050, audio_numpy), bg_image='background_images/wallpaper_test_1_crop_3.jpg', bars_color=('#f3df4b', '#63edb7'), bar_count=100, bar_width=0.7, animate=True) return aw, normalized_mel, normalized_align # (22050, audio_numpy), fig_mel, fig_align # Custom Demo Interface: # theme='ysharma/steampunk', # css=".gradio-container {background: url('file=background_images/wallpaper_test_mod_2.jpg')}" with gr.Blocks() as demo: gr.Markdown("