from typing import cast import gradio as gr import numpy as np import torch from transformers import SpeechT5ForTextToSpeech, SpeechT5Processor, SpeechT5HifiGan from speecht5_openjtalk_tokenizer import SpeechT5OpenjtalkTokenizer import pandas as pd import transformers setattr(transformers, SpeechT5OpenjtalkTokenizer.__name__, SpeechT5OpenjtalkTokenizer) class SpeechT5OpenjtalkProcessor(SpeechT5Processor): tokenizer_class = SpeechT5OpenjtalkTokenizer.__name__ model = SpeechT5ForTextToSpeech.from_pretrained("esnya/japanese_speecht5_tts") assert isinstance(model, SpeechT5ForTextToSpeech) processor = SpeechT5OpenjtalkProcessor.from_pretrained("esnya/japanese_speecht5_tts") assert isinstance(processor, SpeechT5OpenjtalkProcessor) vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") assert isinstance(vocoder, SpeechT5HifiGan) if torch.cuda.is_available(): model = model.cuda() vocoder = vocoder.cuda() def convert_float32_to_int16(wav: np.ndarray) -> np.ndarray: assert wav.dtype == np.float32 return np.clip(wav * 32768.0, -32768.0, 32767.0).astype(np.int16) @torch.inference_mode() def text_to_speech( text: str, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, ): speaker_embeddings = ( torch.rand( (1, model.config.speaker_embedding_dim), dtype=torch.float32, device=model.device, ) * 2 - 1 ) input_ids = processor(text=text, return_tensors="pt") assert input_ids is not None input_ids = input_ids.input_ids.to(model.device) speaker_embeddings = cast(torch.FloatTensor, speaker_embeddings) wav = model.generate_speech( input_ids, speaker_embeddings, threshold=threshold, minlenratio=minlenratio, maxlenratio=maxlenratio, vocoder=vocoder, ) wav = cast(torch.FloatTensor, wav) wav = convert_float32_to_int16(wav.reshape(-1).cpu().float().numpy()) return [ (vocoder.config.sampling_rate, wav), pd.DataFrame( { "dim": range(speaker_embeddings.shape[-1]), "value": speaker_embeddings[0].cpu().float().numpy(), } ), ] demo = gr.Interface( fn=text_to_speech, inputs=[ "text", gr.Slider(0, 0.5, 0.5, label="threshold"), gr.Slider(0, 100, 0, label="minlenratio"), gr.Slider(0, 100, 10, label="maxlenratio"), ], outputs=[ "audio", gr.BarPlot( label="speaker_embedding (random generated)", x="dim", y="value", y_lim=[-1, 1], ), ], ) demo.launch()