|
import itertools |
|
import os |
|
import warnings |
|
|
|
import matplotlib.pyplot as plt |
|
import pyloudnorm |
|
import sounddevice |
|
import soundfile |
|
import torch |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
from speechbrain.pretrained import EncoderClassifier |
|
from torchaudio.transforms import Resample |
|
|
|
from Architectures.ToucanTTS.InferenceToucanTTS import ToucanTTS |
|
from Architectures.Vocoder.HiFiGAN_Generator import HiFiGAN |
|
from Preprocessing.AudioPreprocessor import AudioPreprocessor |
|
from Preprocessing.TextFrontend import ArticulatoryCombinedTextFrontend |
|
from Preprocessing.TextFrontend import get_language_id |
|
from Utility.storage_config import MODELS_DIR |
|
from Utility.utils import cumsum_durations |
|
from Utility.utils import float2pcm |
|
|
|
|
|
class ToucanTTSInterface(torch.nn.Module): |
|
|
|
def __init__(self, |
|
device="cpu", |
|
tts_model_path=os.path.join(MODELS_DIR, f"ToucanTTS_Meta", "best.pt"), |
|
vocoder_model_path=os.path.join(MODELS_DIR, f"Vocoder", "best.pt"), |
|
language="eng", |
|
enhance=None |
|
): |
|
super().__init__() |
|
self.device = device |
|
if not tts_model_path.endswith(".pt"): |
|
|
|
tts_model_path = os.path.join(MODELS_DIR, f"ToucanTTS_{tts_model_path}", "best.pt") |
|
|
|
|
|
|
|
|
|
self.text2phone = ArticulatoryCombinedTextFrontend(language=language, add_silence_to_end=True) |
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(tts_model_path, map_location='cpu') |
|
self.phone2mel = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]) |
|
with torch.no_grad(): |
|
self.phone2mel.store_inverse_all() |
|
self.phone2mel = self.phone2mel.to(torch.device(device)) |
|
|
|
|
|
|
|
|
|
self.speaker_embedding_func_ecapa = EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", |
|
run_opts={"device": str(device)}, |
|
savedir=os.path.join(MODELS_DIR, "Embedding", "speechbrain_speaker_embedding_ecapa")) |
|
|
|
|
|
|
|
|
|
vocoder_checkpoint = torch.load(vocoder_model_path, map_location="cpu") |
|
self.vocoder = HiFiGAN() |
|
self.vocoder.load_state_dict(vocoder_checkpoint) |
|
self.vocoder = self.vocoder.to(device).eval() |
|
self.vocoder.remove_weight_norm() |
|
self.meter = pyloudnorm.Meter(24000) |
|
|
|
|
|
|
|
|
|
self.default_utterance_embedding = checkpoint["default_emb"].to(self.device) |
|
self.ap = AudioPreprocessor(input_sr=100, output_sr=16000, device=device) |
|
self.phone2mel.eval() |
|
self.vocoder.eval() |
|
self.lang_id = get_language_id(language) |
|
self.to(torch.device(device)) |
|
self.eval() |
|
|
|
def set_utterance_embedding(self, path_to_reference_audio="", embedding=None): |
|
if embedding is not None: |
|
self.default_utterance_embedding = embedding.squeeze().to(self.device) |
|
return |
|
if type(path_to_reference_audio) != list: |
|
path_to_reference_audio = [path_to_reference_audio] |
|
|
|
if len(path_to_reference_audio) > 0: |
|
for path in path_to_reference_audio: |
|
assert os.path.exists(path) |
|
speaker_embs = list() |
|
for path in path_to_reference_audio: |
|
wave, sr = soundfile.read(path) |
|
wave = Resample(orig_freq=sr, new_freq=16000).to(self.device)(torch.tensor(wave, device=self.device, dtype=torch.float32)) |
|
speaker_embedding = self.speaker_embedding_func_ecapa.encode_batch(wavs=wave.to(self.device).unsqueeze(0)).squeeze() |
|
speaker_embs.append(speaker_embedding) |
|
self.default_utterance_embedding = sum(speaker_embs) / len(speaker_embs) |
|
|
|
def set_language(self, lang_id): |
|
""" |
|
The id parameter actually refers to the shorthand. This has become ambiguous with the introduction of the actual language IDs |
|
""" |
|
self.set_phonemizer_language(lang_id=lang_id) |
|
self.set_accent_language(lang_id=lang_id) |
|
|
|
def set_phonemizer_language(self, lang_id): |
|
self.text2phone.change_lang(language=lang_id, add_silence_to_end=True) |
|
|
|
def set_accent_language(self, lang_id): |
|
if lang_id in ['ajp', 'ajt', 'lak', 'lno', 'nul', 'pii', 'plj', 'slq', 'smd', 'snb', 'tpw', 'wya', 'zua', 'en-us', 'en-sc', 'fr-be', 'fr-sw', 'pt-br', 'spa-lat', 'vi-ctr', 'vi-so']: |
|
if lang_id == 'vi-so' or lang_id == 'vi-ctr': |
|
lang_id = 'vie' |
|
elif lang_id == 'spa-lat': |
|
lang_id = 'spa' |
|
elif lang_id == 'pt-br': |
|
lang_id = 'por' |
|
elif lang_id == 'fr-sw' or lang_id == 'fr-be': |
|
lang_id = 'fra' |
|
elif lang_id == 'en-sc' or lang_id == 'en-us': |
|
lang_id = 'eng' |
|
else: |
|
|
|
lang_id = 'eng' |
|
|
|
self.lang_id = get_language_id(lang_id).to(self.device) |
|
|
|
def forward(self, |
|
text, |
|
view=False, |
|
duration_scaling_factor=1.0, |
|
pitch_variance_scale=1.0, |
|
energy_variance_scale=1.0, |
|
pause_duration_scaling_factor=1.0, |
|
durations=None, |
|
pitch=None, |
|
energy=None, |
|
input_is_phones=False, |
|
return_plot_as_filepath=False, |
|
loudness_in_db=-24.0, |
|
glow_sampling_temperature=0.2): |
|
""" |
|
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. |
|
1.0 means no scaling happens, higher values increase durations for the whole |
|
utterance, lower values decrease durations for the whole utterance. |
|
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. |
|
1.0 means no scaling happens, higher values increase variance of the pitch curve, |
|
lower values decrease variance of the pitch curve. |
|
energy_variance_scale: reasonable values are 0.6 < scale < 1.4. |
|
1.0 means no scaling happens, higher values increase variance of the energy curve, |
|
lower values decrease variance of the energy curve. |
|
""" |
|
with torch.inference_mode(): |
|
phones = self.text2phone.string_to_tensor(text, input_phonemes=input_is_phones).to(torch.device(self.device)) |
|
mel, durations, pitch, energy = self.phone2mel(phones, |
|
return_duration_pitch_energy=True, |
|
utterance_embedding=self.default_utterance_embedding.to(self.device), |
|
durations=durations, |
|
pitch=pitch, |
|
energy=energy, |
|
lang_id=self.lang_id.to(self.device), |
|
duration_scaling_factor=duration_scaling_factor, |
|
pitch_variance_scale=pitch_variance_scale, |
|
energy_variance_scale=energy_variance_scale, |
|
pause_duration_scaling_factor=pause_duration_scaling_factor, |
|
glow_sampling_temperature=glow_sampling_temperature) |
|
|
|
wave, _, _ = self.vocoder(mel.unsqueeze(0)) |
|
wave = wave.squeeze().cpu() |
|
wave = wave.numpy() |
|
sr = 24000 |
|
try: |
|
loudness = self.meter.integrated_loudness(wave) |
|
wave = pyloudnorm.normalize.loudness(wave, loudness, loudness_in_db) |
|
except ValueError: |
|
|
|
pass |
|
|
|
if view or return_plot_as_filepath: |
|
try: |
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 5)) |
|
|
|
ax.imshow(mel.cpu().numpy(), origin="lower", cmap='GnBu') |
|
ax.yaxis.set_visible(False) |
|
duration_splits, label_positions = cumsum_durations(durations.cpu().numpy()) |
|
ax.xaxis.grid(True, which='minor') |
|
ax.set_xticks(label_positions, minor=False) |
|
if input_is_phones: |
|
phones = text.replace(" ", "|") |
|
else: |
|
phones = self.text2phone.get_phone_string(text, for_plot_labels=True) |
|
try: |
|
ax.set_xticklabels(phones) |
|
except IndexError: |
|
pass |
|
word_boundaries = list() |
|
for label_index, phone in enumerate(phones): |
|
if phone == "|": |
|
word_boundaries.append(label_positions[label_index]) |
|
|
|
try: |
|
prev_word_boundary = 0 |
|
word_label_positions = list() |
|
for word_boundary in word_boundaries: |
|
word_label_positions.append((word_boundary + prev_word_boundary) / 2) |
|
prev_word_boundary = word_boundary |
|
word_label_positions.append((duration_splits[-1] + prev_word_boundary) / 2) |
|
|
|
secondary_ax = ax.secondary_xaxis('bottom') |
|
secondary_ax.tick_params(axis="x", direction="out", pad=24) |
|
secondary_ax.set_xticks(word_label_positions, minor=False) |
|
secondary_ax.set_xticklabels(text.split()) |
|
secondary_ax.tick_params(axis='x', colors='orange') |
|
secondary_ax.xaxis.label.set_color('orange') |
|
except ValueError: |
|
ax.set_title(text) |
|
except IndexError: |
|
ax.set_title(text) |
|
except RuntimeError: |
|
ax.set_title(text) |
|
|
|
ax.vlines(x=duration_splits, colors="green", linestyles="solid", ymin=0, ymax=120, linewidth=0.5) |
|
ax.vlines(x=word_boundaries, colors="orange", linestyles="solid", ymin=0, ymax=120, linewidth=1.0) |
|
plt.subplots_adjust(left=0.02, bottom=0.2, right=0.98, top=.9, wspace=0.0, hspace=0.0) |
|
ax.set_aspect("auto") |
|
except: |
|
pass |
|
|
|
if return_plot_as_filepath: |
|
try: |
|
plt.savefig("tmp.png") |
|
plt.close() |
|
except: |
|
pass |
|
return wave, sr, "tmp.png" |
|
|
|
return wave, sr |
|
|
|
def read_to_file(self, |
|
text_list, |
|
file_location, |
|
duration_scaling_factor=1.0, |
|
pitch_variance_scale=1.0, |
|
energy_variance_scale=1.0, |
|
pause_duration_scaling_factor=1.0, |
|
silent=False, |
|
dur_list=None, |
|
pitch_list=None, |
|
energy_list=None, |
|
glow_sampling_temperature=0.2): |
|
""" |
|
Args: |
|
silent: Whether to be verbose about the process |
|
text_list: A list of strings to be read |
|
file_location: The path and name of the file it should be saved to |
|
energy_list: list of energy tensors to be used for the texts |
|
pitch_list: list of pitch tensors to be used for the texts |
|
dur_list: list of duration tensors to be used for the texts |
|
duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. |
|
1.0 means no scaling happens, higher values increase durations for the whole |
|
utterance, lower values decrease durations for the whole utterance. |
|
pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. |
|
1.0 means no scaling happens, higher values increase variance of the pitch curve, |
|
lower values decrease variance of the pitch curve. |
|
energy_variance_scale: reasonable values are 0.6 < scale < 1.4. |
|
1.0 means no scaling happens, higher values increase variance of the energy curve, |
|
lower values decrease variance of the energy curve. |
|
""" |
|
if not dur_list: |
|
dur_list = [] |
|
if not pitch_list: |
|
pitch_list = [] |
|
if not energy_list: |
|
energy_list = [] |
|
silence = torch.zeros([14300]) |
|
wav = silence.clone() |
|
for (text, durations, pitch, energy) in itertools.zip_longest(text_list, dur_list, pitch_list, energy_list): |
|
if text.strip() != "": |
|
if not silent: |
|
print("Now synthesizing: {}".format(text)) |
|
spoken_sentence, sr = self(text, |
|
durations=durations.to(self.device) if durations is not None else None, |
|
pitch=pitch.to(self.device) if pitch is not None else None, |
|
energy=energy.to(self.device) if energy is not None else None, |
|
duration_scaling_factor=duration_scaling_factor, |
|
pitch_variance_scale=pitch_variance_scale, |
|
energy_variance_scale=energy_variance_scale, |
|
pause_duration_scaling_factor=pause_duration_scaling_factor, |
|
glow_sampling_temperature=glow_sampling_temperature) |
|
spoken_sentence = torch.tensor(spoken_sentence).cpu() |
|
wav = torch.cat((wav, spoken_sentence, silence), 0) |
|
soundfile.write(file=file_location, data=float2pcm(wav), samplerate=sr, subtype="PCM_16") |
|
|
|
def read_aloud(self, |
|
text, |
|
view=False, |
|
duration_scaling_factor=1.0, |
|
pitch_variance_scale=1.0, |
|
energy_variance_scale=1.0, |
|
blocking=False, |
|
glow_sampling_temperature=0.2): |
|
if text.strip() == "": |
|
return |
|
wav, sr = self(text, |
|
view, |
|
duration_scaling_factor=duration_scaling_factor, |
|
pitch_variance_scale=pitch_variance_scale, |
|
energy_variance_scale=energy_variance_scale, |
|
glow_sampling_temperature=glow_sampling_temperature) |
|
silence = torch.zeros([sr // 2]) |
|
wav = torch.cat((silence, torch.tensor(wav), silence), 0).numpy() |
|
sounddevice.play(float2pcm(wav), samplerate=sr) |
|
if view: |
|
plt.show() |
|
if blocking: |
|
sounddevice.wait() |
|
|