File size: 10,281 Bytes
7959401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
061393d
 
b06cdaa
061393d
9ce56e9
7959401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import IPython.display as ipd
import gradio as gr
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
from Utils.PLBERT.util import load_plbert
from text_utils import TextCleaner
from utils import *
from models import *
from nltk.tokenize import word_tokenize
import librosa
import torchaudio
import torch.nn.functional as F
from torch import nn
from munch import Munch
import yaml
import time
import numpy as np
import random
import torch
import nltk
nltk.download('punkt_tab')
from phonemizer import phonemize
from phonemizer.backend import EspeakBackend


global_phonemizer = EspeakBackend(language='en')

torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

random.seed(0)

np.random.seed(0)

# load packages

textcleaner = TextCleaner()

# set up a transformation from a sound wave (an amplitude at each sampling step) to a mel spectrogram (80 dimensions).
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)

mean, std = -4, 4

# Creates a binary mask of 1s for values in the tensor and zero for padding to the length of the longest vector.


def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(
        0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

# Converts a waveform to a normalized log-Mel spectrogram tensor.


def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

# Loads, trims, resamples an audio file, and computes its style and predictor encodings.


def compute_style(path):
    wave, sr = librosa.load(path, sr=24000)
    audio, index = librosa.effects.trim(wave, top_db=30)
    if sr != 24000:
        audio = librosa.resample(audio, sr, 24000)
    mel_tensor = preprocess(audio).to(device)

    with torch.no_grad():
        ref_s = model.style_encoder(mel_tensor.unsqueeze(1))  # gets
        ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))

    return torch.cat([ref_s, ref_p], dim=1)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device != 'cuda':
    print("Using cpu as cuda is not available!")
else:
    print("Using cuda")

# load phonemizer (converts text into phonemes)
global_phonemizer = phonemizer.backend.EspeakBackend(
    language='en-us', preserve_punctuation=True,  with_stress=True)


# model_folder_path="Models/LibriTTS-lora-ft/merged" # for inferencing the merged lora
# config = yaml.safe_load(open(model_folder_path + '/config.yml'))

# for inferencing the full fine-tuned model
model_folder_path = "Models/LibriTTS-fft"
# Rohan, why is the file here config_ft whereas for lora above it is config.yml . Are we loading what we think we are?
config = yaml.safe_load(open(model_folder_path + '/config_ft.yml'))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

model_params = recursive_munch(config['model_params'])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

files = [f for f in os.listdir(model_folder_path) if f.endswith('.pth')]
sorted_files = sorted(files, key=lambda x: int(x.split('_')[-1].split('.')[0]))

print(sorted_files)

# I'm grabbing the last fine instead
params_whole = torch.load(model_folder_path + '/' +
                          sorted_files[-1], map_location='cpu')

if 'net' in params_whole.keys():
    print('yes')
    params = params_whole['net']
else:
    params = params_whole
    print('no')


for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)
#             except:
#                 _load(params[key], model[key])


# Loading the diffusion sampler
sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(
        sigma_min=0.0001, sigma_max=3.0, rho=9.0),  # empirical parameters
    clamp=False
)


def inference(text, ref_s, alpha=0.2, beta=0.2, diffusion_steps=10, embedding_scale=1):
    """
    Generate speech from text using a diffusion-based approach with reference style blending.

    Parameters:
    - text: The input text to convert to speech.
    - ref_s: The reference style and predictor encoder features from an audio snippet.
    - alpha: Blending factor for the reference style (lower alpha means more like the reference).
    - beta: Blending factor for the predictor features (lower beta means more like the reference).
    - diffusion_steps: Number of steps in the diffusion process (more steps improve quality).
    - embedding_scale: Scaling factor for the BERT embeddings.
    """

    # Clean up and tokenize the input text
    text = text.strip()
    ps = global_phonemizer.phonemize([text])
    ps = word_tokenize(ps[0])
    ps = ' '.join(ps)
    tokens = textcleaner(ps)
    tokens.insert(0, 0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

    with torch.no_grad():
        # Get the length of the input tokens
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        # Create a mask for the input text to handle variable lengths
        text_mask = length_to_mask(input_lengths).to(device)

        # Encode the text using the text encoder
        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        # Use BERT to get the prosodic text encoding (to be used for style prediction).
        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        # Further reduce the dimensions of the BERT embeddings to be suitable for the predictor
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2)

        # Generate an output style + predictor vector
        s_pred = sampler(
            noise=torch.randn((1, 256)).unsqueeze(1).to(device),
            embedding=bert_dur,  # BERT output embeddings
            embedding_scale=embedding_scale,
            features=ref_s,  # Style and predictor features from reference audio
            num_steps=diffusion_steps
        ).squeeze(1)

        # Split the generated features into style and predictor components
        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        # Blend the generated style features with the reference style
        ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
        # Blend the generated predictor features with the reference predictor
        s = beta * s + (1 - beta) * ref_s[:, 128:]

        # Use the predictor to encode the text with the generated features
        d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)

        # Pass through the LSTM to get duration predictions
        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)

        # Process the duration predictions
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        # Create a target alignment for the predicted durations
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # Encode the prosody using the target alignment
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        if model_params.decoder.type == "hifigan":
            # Adjust for HiFi-GAN decoder input format
            asr_new = torch.zeros_like(en)
            asr_new[:, :, 0] = en[:, :, 0]
            asr_new[:, :, 1:] = en[:, :, 0:-1]
            en = asr_new

        # Predict F0 and N features (fundamental frequency and noise)
        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        # Create the alignment for the text encoder output
        asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
        if model_params.decoder.type == "hifigan":
            # Adjust for HiFi-GAN decoder input format
            asr_new = torch.zeros_like(asr)
            asr_new[:, :, 0] = asr[:, :, 0]
            asr_new[:, :, 1:] = asr[:, :, 0:-1]
            asr = asr_new

        # Decode the final audio output using the decoder
        out = model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))

    # Return the generated audio, excluding a small pulse at the end
    # weird pulse at the end of the model, need to be fixed later
    return out.squeeze().cpu().numpy()[..., :-50]


import numpy as np
import gradio as gr

def tts_model(text):
    # Assuming a reference path is used for style (you can adjust this path as needed)
    ref_s = compute_style("Trelis_Data/wavs/med5_0.wav")

    # Run inference to generate the output wav
    wav = inference(text, ref_s, alpha=0.3, beta=0.3,
                    diffusion_steps=10, embedding_scale=1)

    # Convert 1D wav array to 2D to match Gradio's expectations (mono audio)
    wav = np.expand_dims(wav, axis=1)

    # Return the audio as a tuple with sample rate
    return 24000, wav  # Assuming a 24000 Hz sample rate for the output audio


# Create a Gradio interface
interface = gr.Interface(
    fn=tts_model,
    inputs=gr.Textbox(label="Input Text"),  # Input text for speech generation
    outputs=gr.Audio(label="Generated Audio", type="numpy"),  # Generated TTS audio
    live=False
)

# Launch the Gradio interface
interface.launch(share=True)