AshwinSankar commited on
Commit
4d17da3
·
verified ·
1 Parent(s): 1e1bace

Create modeling_vits.py

Browse files
Files changed (1) hide show
  1. modeling_vits.py +1508 -0
modeling_vits.py ADDED
@@ -0,0 +1,1508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch VITS model."""
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+
12
+ from transformers.activations import ACT2FN
13
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
14
+ from transformers.integrations.fsdp import is_fsdp_managed_module
15
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
16
+ from transformers.modeling_outputs import (
17
+ BaseModelOutput,
18
+ ModelOutput,
19
+ )
20
+ from transformers.modeling_utils import PreTrainedModel
21
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
22
+ from .configuration_vits import IndicVitsConfig
23
+
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ # General docstring
29
+ _CONFIG_FOR_DOC = "IndicVitsConfig"
30
+
31
+
32
+ @dataclass
33
+ class VitsModelOutput(ModelOutput):
34
+ """
35
+ Describes the outputs for the VITS model, with potential hidden states and attentions.
36
+
37
+ Args:
38
+ waveform (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
39
+ The final audio waveform predicted by the model.
40
+ sequence_lengths (`torch.FloatTensor` of shape `(batch_size,)`):
41
+ The length in samples of each element in the `waveform` batch.
42
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_bins)`):
43
+ The log-mel spectrogram predicted at the output of the flow model. This spectrogram is passed to the Hi-Fi
44
+ GAN decoder model to obtain the final audio waveform.
45
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
46
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
47
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
48
+
49
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
50
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
51
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
52
+ sequence_length)`.
53
+
54
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
55
+ heads.
56
+ """
57
+
58
+ waveform: torch.FloatTensor = None
59
+ sequence_lengths: torch.FloatTensor = None
60
+ spectrogram: Optional[Tuple[torch.FloatTensor]] = None
61
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
62
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
63
+
64
+
65
+ @dataclass
66
+ class VitsTextEncoderOutput(ModelOutput):
67
+ """
68
+ Describes the outputs for the VITS text encoder model, with potential hidden states and attentions.
69
+
70
+ Args:
71
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
72
+ Sequence of hidden-states at the output of the last layer of the model.
73
+ prior_means (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
74
+ The predicted mean values of the prior distribution for the latent text variables.
75
+ prior_log_variances (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
76
+ The predicted log-variance values of the prior distribution for the latent text variables.
77
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
78
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
79
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
80
+
81
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
82
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
83
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
84
+ sequence_length)`.
85
+
86
+ Attention weights after the attention softmax, used to compute the weighted average in the self-attention
87
+ heads.
88
+ """
89
+
90
+ last_hidden_state: torch.FloatTensor = None
91
+ prior_means: torch.FloatTensor = None
92
+ prior_log_variances: torch.FloatTensor = None
93
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
94
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
95
+
96
+
97
+ @torch.jit.script
98
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, num_channels):
99
+ in_act = input_a + input_b
100
+ t_act = torch.tanh(in_act[:, :num_channels, :])
101
+ s_act = torch.sigmoid(in_act[:, num_channels:, :])
102
+ acts = t_act * s_act
103
+ return acts
104
+
105
+
106
+ def _unconstrained_rational_quadratic_spline(
107
+ inputs,
108
+ unnormalized_widths,
109
+ unnormalized_heights,
110
+ unnormalized_derivatives,
111
+ reverse=False,
112
+ tail_bound=5.0,
113
+ min_bin_width=1e-3,
114
+ min_bin_height=1e-3,
115
+ min_derivative=1e-3,
116
+ ):
117
+ """
118
+ This transformation represents a monotonically increasing piecewise rational quadratic function. Outside of the
119
+ `tail_bound`, the transform behaves as an identity function.
120
+
121
+ Args:
122
+ inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
123
+ Second half of the hidden-states input to the Vits convolutional flow module.
124
+ unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
125
+ First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
126
+ layer in the convolutional flow module
127
+ unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
128
+ Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
129
+ layer in the convolutional flow module
130
+ unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
131
+ Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
132
+ layer in the convolutional flow module
133
+ reverse (`bool`, *optional*, defaults to `False`):
134
+ Whether the model is being run in reverse mode.
135
+ tail_bound (`float`, *optional* defaults to 5):
136
+ Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
137
+ transform behaves as an identity function.
138
+ min_bin_width (`float`, *optional*, defaults to 1e-3):
139
+ Minimum bin value across the width dimension for the piecewise rational quadratic function.
140
+ min_bin_height (`float`, *optional*, defaults to 1e-3):
141
+ Minimum bin value across the height dimension for the piecewise rational quadratic function.
142
+ min_derivative (`float`, *optional*, defaults to 1e-3):
143
+ Minimum bin value across the derivatives for the piecewise rational quadratic function.
144
+ Returns:
145
+ outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
146
+ Hidden-states as transformed by the piecewise rational quadratic function with the `tail_bound` limits
147
+ applied.
148
+ log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
149
+ Logarithm of the absolute value of the determinants corresponding to the `outputs` with the `tail_bound`
150
+ limits applied.
151
+ """
152
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
153
+ outside_interval_mask = ~inside_interval_mask
154
+
155
+ outputs = torch.zeros_like(inputs)
156
+ log_abs_det = torch.zeros_like(inputs)
157
+ constant = np.log(np.exp(1 - min_derivative) - 1)
158
+
159
+ unnormalized_derivatives = nn.functional.pad(unnormalized_derivatives, pad=(1, 1))
160
+ unnormalized_derivatives[..., 0] = constant
161
+ unnormalized_derivatives[..., -1] = constant
162
+
163
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
164
+ log_abs_det[outside_interval_mask] = 0.0
165
+
166
+ outputs[inside_interval_mask], log_abs_det[inside_interval_mask] = _rational_quadratic_spline(
167
+ inputs=inputs[inside_interval_mask],
168
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
169
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
170
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
171
+ reverse=reverse,
172
+ tail_bound=tail_bound,
173
+ min_bin_width=min_bin_width,
174
+ min_bin_height=min_bin_height,
175
+ min_derivative=min_derivative,
176
+ )
177
+ return outputs, log_abs_det
178
+
179
+
180
+ def _rational_quadratic_spline(
181
+ inputs,
182
+ unnormalized_widths,
183
+ unnormalized_heights,
184
+ unnormalized_derivatives,
185
+ reverse,
186
+ tail_bound,
187
+ min_bin_width,
188
+ min_bin_height,
189
+ min_derivative,
190
+ ):
191
+ """
192
+ This transformation represents a monotonically increasing piecewise rational quadratic function. Unlike the
193
+ function `_unconstrained_rational_quadratic_spline`, the function behaves the same across the `tail_bound`.
194
+
195
+ Args:
196
+ inputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
197
+ Second half of the hidden-states input to the Vits convolutional flow module.
198
+ unnormalized_widths (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
199
+ First `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
200
+ layer in the convolutional flow module
201
+ unnormalized_heights (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
202
+ Second `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
203
+ layer in the convolutional flow module
204
+ unnormalized_derivatives (`torch.FloatTensor` of shape `(batch_size, channels, seq_len, duration_predictor_flow_bins)`):
205
+ Third `duration_predictor_flow_bins` of the hidden-states from the output of the convolution projection
206
+ layer in the convolutional flow module
207
+ reverse (`bool`):
208
+ Whether the model is being run in reverse mode.
209
+ tail_bound (`float`):
210
+ Upper and lower limit bound for the rational quadratic function. Outside of this `tail_bound`, the
211
+ transform behaves as an identity function.
212
+ min_bin_width (`float`):
213
+ Minimum bin value across the width dimension for the piecewise rational quadratic function.
214
+ min_bin_height (`float`):
215
+ Minimum bin value across the height dimension for the piecewise rational quadratic function.
216
+ min_derivative (`float`):
217
+ Minimum bin value across the derivatives for the piecewise rational quadratic function.
218
+ Returns:
219
+ outputs (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
220
+ Hidden-states as transformed by the piecewise rational quadratic function.
221
+ log_abs_det (`torch.FloatTensor` of shape `(batch_size, channels, seq_len)`:
222
+ Logarithm of the absolute value of the determinants corresponding to the `outputs`.
223
+ """
224
+ upper_bound = tail_bound
225
+ lower_bound = -tail_bound
226
+
227
+ if torch.min(inputs) < lower_bound or torch.max(inputs) > upper_bound:
228
+ raise ValueError("Input to a transform is not within its domain")
229
+
230
+ num_bins = unnormalized_widths.shape[-1]
231
+
232
+ if min_bin_width * num_bins > 1.0:
233
+ raise ValueError(f"Minimal bin width {min_bin_width} too large for the number of bins {num_bins}")
234
+ if min_bin_height * num_bins > 1.0:
235
+ raise ValueError(f"Minimal bin height {min_bin_height} too large for the number of bins {num_bins}")
236
+
237
+ widths = nn.functional.softmax(unnormalized_widths, dim=-1)
238
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
239
+ cumwidths = torch.cumsum(widths, dim=-1)
240
+ cumwidths = nn.functional.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
241
+ cumwidths = (upper_bound - lower_bound) * cumwidths + lower_bound
242
+ cumwidths[..., 0] = lower_bound
243
+ cumwidths[..., -1] = upper_bound
244
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
245
+
246
+ derivatives = min_derivative + nn.functional.softplus(unnormalized_derivatives)
247
+
248
+ heights = nn.functional.softmax(unnormalized_heights, dim=-1)
249
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
250
+ cumheights = torch.cumsum(heights, dim=-1)
251
+ cumheights = nn.functional.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
252
+ cumheights = (upper_bound - lower_bound) * cumheights + lower_bound
253
+ cumheights[..., 0] = lower_bound
254
+ cumheights[..., -1] = upper_bound
255
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
256
+
257
+ bin_locations = cumheights if reverse else cumwidths
258
+ bin_locations[..., -1] += 1e-6
259
+ bin_idx = torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
260
+ bin_idx = bin_idx[..., None]
261
+
262
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
263
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
264
+
265
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
266
+ delta = heights / widths
267
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
268
+
269
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
270
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
271
+
272
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
273
+
274
+ intermediate1 = input_derivatives + input_derivatives_plus_one - 2 * input_delta
275
+ if not reverse:
276
+ theta = (inputs - input_cumwidths) / input_bin_widths
277
+ theta_one_minus_theta = theta * (1 - theta)
278
+
279
+ numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
280
+ denominator = input_delta + intermediate1 * theta_one_minus_theta
281
+ outputs = input_cumheights + numerator / denominator
282
+
283
+ derivative_numerator = input_delta.pow(2) * (
284
+ input_derivatives_plus_one * theta.pow(2)
285
+ + 2 * input_delta * theta_one_minus_theta
286
+ + input_derivatives * (1 - theta).pow(2)
287
+ )
288
+ log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
289
+ return outputs, log_abs_det
290
+ else:
291
+ # find the roots of a quadratic equation
292
+ intermediate2 = inputs - input_cumheights
293
+ intermediate3 = intermediate2 * intermediate1
294
+ a = input_heights * (input_delta - input_derivatives) + intermediate3
295
+ b = input_heights * input_derivatives - intermediate3
296
+ c = -input_delta * intermediate2
297
+
298
+ discriminant = b.pow(2) - 4 * a * c
299
+ if not (discriminant >= 0).all():
300
+ raise RuntimeError(f"invalid discriminant {discriminant}")
301
+
302
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
303
+ outputs = root * input_bin_widths + input_cumwidths
304
+
305
+ theta_one_minus_theta = root * (1 - root)
306
+ denominator = input_delta + intermediate1 * theta_one_minus_theta
307
+ derivative_numerator = input_delta.pow(2) * (
308
+ input_derivatives_plus_one * root.pow(2)
309
+ + 2 * input_delta * theta_one_minus_theta
310
+ + input_derivatives * (1 - root).pow(2)
311
+ )
312
+ log_abs_det = torch.log(derivative_numerator) - 2 * torch.log(denominator)
313
+ return outputs, -log_abs_det
314
+
315
+
316
+ class VitsWaveNet(torch.nn.Module):
317
+ def __init__(self, config: IndicVitsConfig, num_layers: int):
318
+ super().__init__()
319
+ self.hidden_size = config.hidden_size
320
+ self.num_layers = num_layers
321
+
322
+ self.in_layers = torch.nn.ModuleList()
323
+ self.res_skip_layers = torch.nn.ModuleList()
324
+ self.dropout = nn.Dropout(config.wavenet_dropout)
325
+
326
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
327
+ weight_norm = nn.utils.parametrizations.weight_norm
328
+ else:
329
+ weight_norm = nn.utils.weight_norm
330
+
331
+ if config.speaker_embedding_size != 0:
332
+ cond_layer = torch.nn.Conv1d(config.speaker_embedding_size, 2 * config.hidden_size * num_layers, 1)
333
+ self.cond_layer = weight_norm(cond_layer, name="weight")
334
+
335
+ for i in range(num_layers):
336
+ dilation = config.wavenet_dilation_rate**i
337
+ padding = (config.wavenet_kernel_size * dilation - dilation) // 2
338
+ in_layer = torch.nn.Conv1d(
339
+ in_channels=config.hidden_size,
340
+ out_channels=2 * config.hidden_size,
341
+ kernel_size=config.wavenet_kernel_size,
342
+ dilation=dilation,
343
+ padding=padding,
344
+ )
345
+ in_layer = weight_norm(in_layer, name="weight")
346
+ self.in_layers.append(in_layer)
347
+
348
+ # last one is not necessary
349
+ if i < num_layers - 1:
350
+ res_skip_channels = 2 * config.hidden_size
351
+ else:
352
+ res_skip_channels = config.hidden_size
353
+
354
+ res_skip_layer = torch.nn.Conv1d(config.hidden_size, res_skip_channels, 1)
355
+ res_skip_layer = weight_norm(res_skip_layer, name="weight")
356
+ self.res_skip_layers.append(res_skip_layer)
357
+
358
+ def forward(self, inputs, padding_mask, global_conditioning=None):
359
+ outputs = torch.zeros_like(inputs)
360
+ num_channels_tensor = torch.IntTensor([self.hidden_size])
361
+
362
+ if global_conditioning is not None:
363
+ global_conditioning = self.cond_layer(global_conditioning)
364
+
365
+ for i in range(self.num_layers):
366
+ hidden_states = self.in_layers[i](inputs)
367
+
368
+ if global_conditioning is not None:
369
+ cond_offset = i * 2 * self.hidden_size
370
+ global_states = global_conditioning[:, cond_offset : cond_offset + 2 * self.hidden_size, :]
371
+ else:
372
+ global_states = torch.zeros_like(hidden_states)
373
+
374
+ acts = fused_add_tanh_sigmoid_multiply(hidden_states, global_states, num_channels_tensor[0])
375
+ acts = self.dropout(acts)
376
+
377
+ res_skip_acts = self.res_skip_layers[i](acts)
378
+ if i < self.num_layers - 1:
379
+ res_acts = res_skip_acts[:, : self.hidden_size, :]
380
+ inputs = (inputs + res_acts) * padding_mask
381
+ outputs = outputs + res_skip_acts[:, self.hidden_size :, :]
382
+ else:
383
+ outputs = outputs + res_skip_acts
384
+
385
+ return outputs * padding_mask
386
+
387
+ def remove_weight_norm(self):
388
+ if self.speaker_embedding_size != 0:
389
+ torch.nn.utils.parametrize.remove_parametrizations(self.cond_layer, "weight")
390
+ for layer in self.in_layers:
391
+ torch.nn.utils.parametrize.remove_parametrizations(layer, "weight")
392
+ for layer in self.res_skip_layers:
393
+ torch.nn.utils.parametrize.remove_parametrizations(layer, "weight")
394
+
395
+
396
+ class VitsPosteriorEncoder(nn.Module):
397
+ def __init__(self, config: IndicVitsConfig):
398
+ super().__init__()
399
+ self.out_channels = config.flow_size
400
+
401
+ self.conv_pre = nn.Conv1d(config.spectrogram_bins, config.hidden_size, 1)
402
+ self.wavenet = VitsWaveNet(config, num_layers=config.posterior_encoder_num_wavenet_layers)
403
+ self.conv_proj = nn.Conv1d(config.hidden_size, self.out_channels * 2, 1)
404
+
405
+ def forward(self, inputs, padding_mask, global_conditioning=None):
406
+ inputs = self.conv_pre(inputs) * padding_mask
407
+ inputs = self.wavenet(inputs, padding_mask, global_conditioning)
408
+ stats = self.conv_proj(inputs) * padding_mask
409
+ mean, log_stddev = torch.split(stats, self.out_channels, dim=1)
410
+ sampled = (mean + torch.randn_like(mean) * torch.exp(log_stddev)) * padding_mask
411
+ return sampled, mean, log_stddev
412
+
413
+
414
+ # Copied from transformers.models.speecht5.modeling_speecht5.HifiGanResidualBlock
415
+ class HifiGanResidualBlock(nn.Module):
416
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), leaky_relu_slope=0.1):
417
+ super().__init__()
418
+ self.leaky_relu_slope = leaky_relu_slope
419
+
420
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
421
+ weight_norm = nn.utils.parametrizations.weight_norm
422
+ else:
423
+ weight_norm = nn.utils.weight_norm
424
+
425
+ self.convs1 = nn.ModuleList(
426
+ [
427
+ weight_norm(
428
+ nn.Conv1d(
429
+ channels,
430
+ channels,
431
+ kernel_size,
432
+ stride=1,
433
+ dilation=dilation[i],
434
+ padding=self.get_padding(kernel_size, dilation[i]),
435
+ )
436
+ )
437
+ for i in range(len(dilation))
438
+ ]
439
+ )
440
+ self.convs2 = nn.ModuleList(
441
+ [
442
+ weight_norm(
443
+ nn.Conv1d(
444
+ channels,
445
+ channels,
446
+ kernel_size,
447
+ stride=1,
448
+ dilation=1,
449
+ padding=self.get_padding(kernel_size, 1),
450
+ )
451
+ )
452
+ for _ in range(len(dilation))
453
+ ]
454
+ )
455
+
456
+ def get_padding(self, kernel_size, dilation=1):
457
+ return (kernel_size * dilation - dilation) // 2
458
+
459
+ # def apply_weight_norm(self):
460
+ # # Determine the correct weight_norm function to use
461
+ # weight_norm = nn.utils.weight_norm
462
+ # if hasattr(nn.utils.parametrizations, "weight_norm"):
463
+ # weight_norm = nn.utils.parametrizations.weight_norm
464
+
465
+ # # Apply weight_norm to each layer and replace the original layer with the wrapped layer
466
+ # self.convs1 = nn.ModuleList([weight_norm(layer) for layer in self.convs1])
467
+ # self.convs2 = nn.ModuleList([weight_norm(layer) for layer in self.convs2])
468
+
469
+
470
+ def remove_weight_norm(self):
471
+ for layer in self.convs1:
472
+ nn.utils.parametrize.remove_parametrizations(layer, "weight")
473
+ for layer in self.convs2:
474
+ nn.utils.parametrize.remove_parametrizations(layer, "weight")
475
+
476
+ def forward(self, hidden_states):
477
+ for conv1, conv2 in zip(self.convs1, self.convs2):
478
+ residual = hidden_states
479
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
480
+ hidden_states = conv1(hidden_states)
481
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.leaky_relu_slope)
482
+ hidden_states = conv2(hidden_states)
483
+ hidden_states = hidden_states + residual
484
+ return hidden_states
485
+
486
+
487
+ class VitsHifiGan(nn.Module):
488
+ def __init__(self, config: IndicVitsConfig):
489
+ super().__init__()
490
+ self.config = config
491
+ self.num_kernels = len(config.resblock_kernel_sizes)
492
+ self.num_upsamples = len(config.upsample_rates)
493
+ self.conv_pre = nn.Conv1d(
494
+ config.flow_size,
495
+ config.upsample_initial_channel,
496
+ kernel_size=7,
497
+ stride=1,
498
+ padding=3,
499
+ )
500
+
501
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
502
+ weight_norm = nn.utils.parametrizations.weight_norm
503
+ else:
504
+ weight_norm = nn.utils.weight_norm
505
+
506
+ self.upsampler = nn.ModuleList()
507
+ for i, (upsample_rate, kernel_size) in enumerate(zip(config.upsample_rates, config.upsample_kernel_sizes)):
508
+ self.upsampler.append(
509
+ weight_norm(
510
+ nn.ConvTranspose1d(
511
+ config.upsample_initial_channel // (2**i),
512
+ config.upsample_initial_channel // (2 ** (i + 1)),
513
+ kernel_size=kernel_size,
514
+ stride=upsample_rate,
515
+ padding=(kernel_size - upsample_rate) // 2,
516
+ )
517
+ )
518
+ )
519
+
520
+ self.resblocks = nn.ModuleList()
521
+ for i in range(len(self.upsampler)):
522
+ channels = config.upsample_initial_channel // (2 ** (i + 1))
523
+ for kernel_size, dilation in zip(config.resblock_kernel_sizes, config.resblock_dilation_sizes):
524
+ self.resblocks.append(HifiGanResidualBlock(channels, kernel_size, dilation, config.leaky_relu_slope))
525
+
526
+ self.conv_post = nn.Conv1d(channels, 1, kernel_size=7, stride=1, padding=3, bias=False)
527
+
528
+ if config.speaker_embedding_size != 0:
529
+ self.cond = nn.Conv1d(config.speaker_embedding_size, config.upsample_initial_channel, 1)
530
+
531
+ # self.apply_weight_norm()
532
+
533
+ # def apply_weight_norm(self):
534
+ # weight_norm = nn.utils.weight_norm
535
+ # if hasattr(nn.utils.parametrizations, "weight_norm"):
536
+ # weight_norm = nn.utils.parametrizations.weight_norm
537
+
538
+ # self.upsampler = nn.ModuleList([weight_norm(layer) for layer in self.upsampler])
539
+ # for layer in self.resblocks:
540
+ # layer.apply_weight_norm()
541
+
542
+ def remove_weight_norm(self):
543
+ for layer in self.upsampler:
544
+ nn.utils.parametrize.remove_parametrizations(layer, "weight")
545
+ for layer in self.resblocks:
546
+ layer.remove_weight_norm()
547
+
548
+ def forward(
549
+ self, spectrogram: torch.FloatTensor, global_conditioning: Optional[torch.FloatTensor] = None
550
+ ) -> torch.FloatTensor:
551
+ r"""
552
+ Converts a spectrogram into a speech waveform.
553
+
554
+ Args:
555
+ spectrogram (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`):
556
+ Tensor containing the spectrograms.
557
+ global_conditioning (`torch.FloatTensor` of shape `(batch_size, config.speaker_embedding_size, 1)`, *optional*):
558
+ Tensor containing speaker embeddings, for multispeaker models.
559
+
560
+ Returns:
561
+ `torch.FloatTensor`: Tensor of shape shape `(batch_size, 1, num_frames)` containing the speech waveform.
562
+ """
563
+ hidden_states = self.conv_pre(spectrogram)
564
+
565
+ if global_conditioning is not None:
566
+ hidden_states = hidden_states + self.cond(global_conditioning)
567
+
568
+ for i in range(self.num_upsamples):
569
+ hidden_states = nn.functional.leaky_relu(hidden_states, self.config.leaky_relu_slope)
570
+ hidden_states = self.upsampler[i](hidden_states)
571
+
572
+ res_state = self.resblocks[i * self.num_kernels](hidden_states)
573
+ for j in range(1, self.num_kernels):
574
+ res_state += self.resblocks[i * self.num_kernels + j](hidden_states)
575
+ hidden_states = res_state / self.num_kernels
576
+
577
+ hidden_states = nn.functional.leaky_relu(hidden_states)
578
+ hidden_states = self.conv_post(hidden_states)
579
+ waveform = torch.tanh(hidden_states)
580
+ return waveform
581
+
582
+
583
+ class VitsResidualCouplingLayer(nn.Module):
584
+ def __init__(self, config: IndicVitsConfig):
585
+ super().__init__()
586
+ self.half_channels = config.flow_size // 2
587
+
588
+ self.conv_pre = nn.Conv1d(self.half_channels, config.hidden_size, 1)
589
+ self.wavenet = VitsWaveNet(config, num_layers=config.prior_encoder_num_wavenet_layers)
590
+ self.conv_post = nn.Conv1d(config.hidden_size, self.half_channels, 1)
591
+
592
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
593
+ first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
594
+ hidden_states = self.conv_pre(first_half) * padding_mask
595
+ hidden_states = self.wavenet(hidden_states, padding_mask, global_conditioning)
596
+ mean = self.conv_post(hidden_states) * padding_mask
597
+ log_stddev = torch.zeros_like(mean)
598
+
599
+ if not reverse:
600
+ second_half = mean + second_half * torch.exp(log_stddev) * padding_mask
601
+ outputs = torch.cat([first_half, second_half], dim=1)
602
+ log_determinant = torch.sum(log_stddev, [1, 2])
603
+ return outputs, log_determinant
604
+ else:
605
+ second_half = (second_half - mean) * torch.exp(-log_stddev) * padding_mask
606
+ outputs = torch.cat([first_half, second_half], dim=1)
607
+ return outputs, None
608
+
609
+
610
+ class VitsResidualCouplingBlock(nn.Module):
611
+ def __init__(self, config: IndicVitsConfig):
612
+ super().__init__()
613
+ self.flows = nn.ModuleList()
614
+ for _ in range(config.prior_encoder_num_flows):
615
+ self.flows.append(VitsResidualCouplingLayer(config))
616
+
617
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
618
+ if not reverse:
619
+ for flow in self.flows:
620
+ inputs, _ = flow(inputs, padding_mask, global_conditioning)
621
+ inputs = torch.flip(inputs, [1])
622
+ else:
623
+ for flow in reversed(self.flows):
624
+ inputs = torch.flip(inputs, [1])
625
+ inputs, _ = flow(inputs, padding_mask, global_conditioning, reverse=True)
626
+ return inputs
627
+
628
+
629
+ class VitsDilatedDepthSeparableConv(nn.Module):
630
+ def __init__(self, config: IndicVitsConfig, dropout_rate=0.0):
631
+ super().__init__()
632
+ kernel_size = config.duration_predictor_kernel_size
633
+ channels = config.hidden_size
634
+ self.num_layers = config.depth_separable_num_layers
635
+
636
+ self.dropout = nn.Dropout(dropout_rate)
637
+ self.convs_dilated = nn.ModuleList()
638
+ self.convs_pointwise = nn.ModuleList()
639
+ self.norms_1 = nn.ModuleList()
640
+ self.norms_2 = nn.ModuleList()
641
+ for i in range(self.num_layers):
642
+ dilation = kernel_size**i
643
+ padding = (kernel_size * dilation - dilation) // 2
644
+ self.convs_dilated.append(
645
+ nn.Conv1d(
646
+ in_channels=channels,
647
+ out_channels=channels,
648
+ kernel_size=kernel_size,
649
+ groups=channels,
650
+ dilation=dilation,
651
+ padding=padding,
652
+ )
653
+ )
654
+ self.convs_pointwise.append(nn.Conv1d(channels, channels, 1))
655
+ self.norms_1.append(nn.LayerNorm(channels))
656
+ self.norms_2.append(nn.LayerNorm(channels))
657
+
658
+ def forward(self, inputs, padding_mask, global_conditioning=None):
659
+ if global_conditioning is not None:
660
+ inputs = inputs + global_conditioning
661
+
662
+ for i in range(self.num_layers):
663
+ hidden_states = self.convs_dilated[i](inputs * padding_mask)
664
+ hidden_states = self.norms_1[i](hidden_states.transpose(1, -1)).transpose(1, -1)
665
+ hidden_states = nn.functional.gelu(hidden_states)
666
+ hidden_states = self.convs_pointwise[i](hidden_states)
667
+ hidden_states = self.norms_2[i](hidden_states.transpose(1, -1)).transpose(1, -1)
668
+ hidden_states = nn.functional.gelu(hidden_states)
669
+ hidden_states = self.dropout(hidden_states)
670
+ inputs = inputs + hidden_states
671
+
672
+ return inputs * padding_mask
673
+
674
+
675
+ class VitsConvFlow(nn.Module):
676
+ def __init__(self, config: IndicVitsConfig):
677
+ super().__init__()
678
+ self.filter_channels = config.hidden_size
679
+ self.half_channels = config.depth_separable_channels // 2
680
+ self.num_bins = config.duration_predictor_flow_bins
681
+ self.tail_bound = config.duration_predictor_tail_bound
682
+
683
+ self.conv_pre = nn.Conv1d(self.half_channels, self.filter_channels, 1)
684
+ self.conv_dds = VitsDilatedDepthSeparableConv(config)
685
+ self.conv_proj = nn.Conv1d(self.filter_channels, self.half_channels * (self.num_bins * 3 - 1), 1)
686
+
687
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
688
+ first_half, second_half = torch.split(inputs, [self.half_channels] * 2, dim=1)
689
+
690
+ hidden_states = self.conv_pre(first_half)
691
+ hidden_states = self.conv_dds(hidden_states, padding_mask, global_conditioning)
692
+ hidden_states = self.conv_proj(hidden_states) * padding_mask
693
+
694
+ batch_size, channels, length = first_half.shape
695
+ hidden_states = hidden_states.reshape(batch_size, channels, -1, length).permute(0, 1, 3, 2)
696
+
697
+ unnormalized_widths = hidden_states[..., : self.num_bins] / math.sqrt(self.filter_channels)
698
+ unnormalized_heights = hidden_states[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
699
+ unnormalized_derivatives = hidden_states[..., 2 * self.num_bins :]
700
+
701
+ second_half, log_abs_det = _unconstrained_rational_quadratic_spline(
702
+ second_half,
703
+ unnormalized_widths,
704
+ unnormalized_heights,
705
+ unnormalized_derivatives,
706
+ reverse=reverse,
707
+ tail_bound=self.tail_bound,
708
+ )
709
+
710
+ outputs = torch.cat([first_half, second_half], dim=1) * padding_mask
711
+ if not reverse:
712
+ log_determinant = torch.sum(log_abs_det * padding_mask, [1, 2])
713
+ return outputs, log_determinant
714
+ else:
715
+ return outputs, None
716
+
717
+
718
+ class VitsElementwiseAffine(nn.Module):
719
+ def __init__(self, config: IndicVitsConfig):
720
+ super().__init__()
721
+ self.channels = config.depth_separable_channels
722
+ self.translate = nn.Parameter(torch.zeros(self.channels, 1))
723
+ self.log_scale = nn.Parameter(torch.zeros(self.channels, 1))
724
+
725
+ def forward(self, inputs, padding_mask, global_conditioning=None, reverse=False):
726
+ if not reverse:
727
+ outputs = self.translate + torch.exp(self.log_scale) * inputs
728
+ outputs = outputs * padding_mask
729
+ log_determinant = torch.sum(self.log_scale * padding_mask, [1, 2])
730
+ return outputs, log_determinant
731
+ else:
732
+ outputs = (inputs - self.translate) * torch.exp(-self.log_scale) * padding_mask
733
+ return outputs, None
734
+
735
+
736
+ class VitsStochasticDurationPredictor(nn.Module):
737
+ def __init__(self, config):
738
+ super().__init__()
739
+ embed_dim = config.speaker_embedding_size
740
+ filter_channels = config.hidden_size
741
+
742
+ self.conv_pre = nn.Conv1d(filter_channels, filter_channels, 1)
743
+ self.conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
744
+ self.conv_dds = VitsDilatedDepthSeparableConv(
745
+ config,
746
+ dropout_rate=config.duration_predictor_dropout,
747
+ )
748
+
749
+ if embed_dim != 0:
750
+ self.cond = nn.Conv1d(embed_dim, filter_channels, 1)
751
+
752
+ self.flows = nn.ModuleList()
753
+ self.flows.append(VitsElementwiseAffine(config))
754
+ for _ in range(config.duration_predictor_num_flows):
755
+ self.flows.append(VitsConvFlow(config))
756
+
757
+ self.post_conv_pre = nn.Conv1d(1, filter_channels, 1)
758
+ self.post_conv_proj = nn.Conv1d(filter_channels, filter_channels, 1)
759
+ self.post_conv_dds = VitsDilatedDepthSeparableConv(
760
+ config,
761
+ dropout_rate=config.duration_predictor_dropout,
762
+ )
763
+
764
+ self.post_flows = nn.ModuleList()
765
+ self.post_flows.append(VitsElementwiseAffine(config))
766
+ for _ in range(config.duration_predictor_num_flows):
767
+ self.post_flows.append(VitsConvFlow(config))
768
+
769
+ def forward(self, inputs, padding_mask, global_conditioning=None, durations=None, reverse=False, noise_scale=1.0):
770
+ inputs = torch.detach(inputs)
771
+ inputs = self.conv_pre(inputs)
772
+
773
+ if global_conditioning is not None:
774
+ global_conditioning = torch.detach(global_conditioning)
775
+ inputs = inputs + self.cond(global_conditioning)
776
+
777
+ inputs = self.conv_dds(inputs, padding_mask)
778
+ inputs = self.conv_proj(inputs) * padding_mask
779
+
780
+ if not reverse:
781
+ hidden_states = self.post_conv_pre(durations)
782
+ hidden_states = self.post_conv_dds(hidden_states, padding_mask)
783
+ hidden_states = self.post_conv_proj(hidden_states) * padding_mask
784
+
785
+ random_posterior = (
786
+ torch.randn(durations.size(0), 2, durations.size(2)).to(device=inputs.device, dtype=inputs.dtype)
787
+ * padding_mask
788
+ )
789
+ log_determinant_posterior_sum = 0
790
+ latents_posterior = random_posterior
791
+ for flow in self.post_flows:
792
+ latents_posterior, log_determinant = flow(
793
+ latents_posterior, padding_mask, global_conditioning=inputs + hidden_states
794
+ )
795
+ latents_posterior = torch.flip(latents_posterior, [1])
796
+ log_determinant_posterior_sum += log_determinant
797
+
798
+ first_half, second_half = torch.split(latents_posterior, [1, 1], dim=1)
799
+
800
+ log_determinant_posterior_sum += torch.sum(
801
+ (nn.functional.logsigmoid(first_half) + nn.functional.logsigmoid(-first_half)) * padding_mask, [1, 2]
802
+ )
803
+ logq = (
804
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (random_posterior**2)) * padding_mask, [1, 2])
805
+ - log_determinant_posterior_sum
806
+ )
807
+
808
+ first_half = (durations - torch.sigmoid(first_half)) * padding_mask
809
+ first_half = torch.log(torch.clamp_min(first_half, 1e-5)) * padding_mask
810
+ log_determinant_sum = torch.sum(-first_half, [1, 2])
811
+
812
+ latents = torch.cat([first_half, second_half], dim=1)
813
+ for flow in self.flows:
814
+ latents, log_determinant = flow(latents, padding_mask, global_conditioning=inputs)
815
+ latents = torch.flip(latents, [1])
816
+ log_determinant_sum += log_determinant
817
+
818
+ nll = torch.sum(0.5 * (math.log(2 * math.pi) + (latents**2)) * padding_mask, [1, 2]) - log_determinant_sum
819
+ return nll + logq
820
+ else:
821
+ flows = list(reversed(self.flows))
822
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
823
+
824
+ latents = (
825
+ torch.randn(inputs.size(0), 2, inputs.size(2)).to(device=inputs.device, dtype=inputs.dtype)
826
+ * noise_scale
827
+ )
828
+ for flow in flows:
829
+ latents = torch.flip(latents, [1])
830
+ latents, _ = flow(latents, padding_mask, global_conditioning=inputs, reverse=True)
831
+
832
+ log_duration, _ = torch.split(latents, [1, 1], dim=1)
833
+ return log_duration
834
+
835
+
836
+ class VitsDurationPredictor(nn.Module):
837
+ def __init__(self, config):
838
+ super().__init__()
839
+ kernel_size = config.duration_predictor_kernel_size
840
+ filter_channels = config.duration_predictor_filter_channels
841
+
842
+ self.dropout = nn.Dropout(config.duration_predictor_dropout)
843
+ self.conv_1 = nn.Conv1d(config.hidden_size, filter_channels, kernel_size, padding=kernel_size // 2)
844
+ self.norm_1 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
845
+ self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
846
+ self.norm_2 = nn.LayerNorm(filter_channels, eps=config.layer_norm_eps)
847
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
848
+
849
+ if config.speaker_embedding_size != 0:
850
+ self.cond = nn.Conv1d(config.speaker_embedding_size, config.hidden_size, 1)
851
+
852
+ def forward(self, inputs, padding_mask, global_conditioning=None):
853
+ inputs = torch.detach(inputs)
854
+
855
+ if global_conditioning is not None:
856
+ global_conditioning = torch.detach(global_conditioning)
857
+ inputs = inputs + self.cond(global_conditioning)
858
+
859
+ inputs = self.conv_1(inputs * padding_mask)
860
+ inputs = torch.relu(inputs)
861
+ inputs = self.norm_1(inputs.transpose(1, -1)).transpose(1, -1)
862
+ inputs = self.dropout(inputs)
863
+
864
+ inputs = self.conv_2(inputs * padding_mask)
865
+ inputs = torch.relu(inputs)
866
+ inputs = self.norm_2(inputs.transpose(1, -1)).transpose(1, -1)
867
+ inputs = self.dropout(inputs)
868
+
869
+ inputs = self.proj(inputs * padding_mask)
870
+ return inputs * padding_mask
871
+
872
+
873
+ class VitsAttention(nn.Module):
874
+ """Multi-headed attention with relative positional representation."""
875
+
876
+ def __init__(self, config: IndicVitsConfig):
877
+ super().__init__()
878
+ self.embed_dim = config.hidden_size
879
+ self.num_heads = config.num_attention_heads
880
+ self.dropout = config.attention_dropout
881
+ self.window_size = config.window_size
882
+
883
+ self.head_dim = self.embed_dim // self.num_heads
884
+ self.scaling = self.head_dim**-0.5
885
+
886
+ if (self.head_dim * self.num_heads) != self.embed_dim:
887
+ raise ValueError(
888
+ f"hidden_size must be divisible by num_attention_heads (got `hidden_size`: {self.embed_dim}"
889
+ f" and `num_attention_heads`: {self.num_heads})."
890
+ )
891
+
892
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
893
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
894
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
895
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.use_bias)
896
+
897
+ if self.window_size:
898
+ self.emb_rel_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
899
+ self.emb_rel_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.head_dim) * self.scaling)
900
+
901
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
902
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
903
+
904
+ def forward(
905
+ self,
906
+ hidden_states: torch.Tensor,
907
+ key_value_states: Optional[torch.Tensor] = None,
908
+ attention_mask: Optional[torch.Tensor] = None,
909
+ layer_head_mask: Optional[torch.Tensor] = None,
910
+ output_attentions: bool = False,
911
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
912
+ """Input shape: Batch x Time x Channel"""
913
+
914
+ # if key_value_states are provided this layer is used as a cross-attention layer
915
+ # for the decoder
916
+
917
+ bsz, tgt_len, _ = hidden_states.size()
918
+
919
+ # get query proj
920
+ query_states = self.q_proj(hidden_states) * self.scaling
921
+
922
+ # self_attention
923
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
924
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
925
+
926
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
927
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
928
+ key_states = key_states.view(*proj_shape)
929
+ value_states = value_states.view(*proj_shape)
930
+
931
+ src_len = key_states.size(1)
932
+ attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
933
+
934
+ if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
935
+ raise ValueError(
936
+ f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
937
+ f" {attn_weights.size()}"
938
+ )
939
+
940
+ if self.window_size is not None:
941
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, src_len)
942
+ relative_logits = torch.matmul(query_states, key_relative_embeddings.transpose(-2, -1))
943
+ rel_pos_bias = self._relative_position_to_absolute_position(relative_logits)
944
+ attn_weights += rel_pos_bias
945
+
946
+ if attention_mask is not None:
947
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
948
+ raise ValueError(
949
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
950
+ )
951
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
952
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
953
+
954
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
955
+
956
+ if layer_head_mask is not None:
957
+ if layer_head_mask.size() != (self.num_heads,):
958
+ raise ValueError(
959
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
960
+ f" {layer_head_mask.size()}"
961
+ )
962
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
963
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
964
+
965
+ if output_attentions:
966
+ # this operation is a bit awkward, but it's required to
967
+ # make sure that attn_weights keeps its gradient.
968
+ # In order to do so, attn_weights have to be reshaped
969
+ # twice and have to be reused in the following
970
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
971
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
972
+ else:
973
+ attn_weights_reshaped = None
974
+
975
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
976
+
977
+ attn_output = torch.bmm(attn_probs, value_states)
978
+
979
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
980
+ raise ValueError(
981
+ f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
982
+ f" {attn_output.size()}"
983
+ )
984
+
985
+ if self.window_size is not None:
986
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, src_len)
987
+ relative_weights = self._absolute_position_to_relative_position(attn_probs)
988
+ rel_pos_bias = torch.matmul(relative_weights, value_relative_embeddings)
989
+ attn_output += rel_pos_bias
990
+
991
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
992
+ attn_output = attn_output.transpose(1, 2)
993
+
994
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
995
+ # partitioned aross GPUs when using tensor-parallelism.
996
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
997
+
998
+ attn_output = self.out_proj(attn_output)
999
+
1000
+ return attn_output, attn_weights_reshaped
1001
+
1002
+ def _get_relative_embeddings(self, relative_embeddings, length):
1003
+ pad_length = max(length - (self.window_size + 1), 0)
1004
+ if pad_length > 0:
1005
+ relative_embeddings = nn.functional.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
1006
+
1007
+ slice_start_position = max((self.window_size + 1) - length, 0)
1008
+ slice_end_position = slice_start_position + 2 * length - 1
1009
+ return relative_embeddings[:, slice_start_position:slice_end_position]
1010
+
1011
+ def _relative_position_to_absolute_position(self, x):
1012
+ batch_heads, length, _ = x.size()
1013
+
1014
+ # Concat columns of pad to shift from relative to absolute indexing.
1015
+ x = nn.functional.pad(x, [0, 1, 0, 0, 0, 0])
1016
+
1017
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
1018
+ x_flat = x.view([batch_heads, length * 2 * length])
1019
+ x_flat = nn.functional.pad(x_flat, [0, length - 1, 0, 0])
1020
+
1021
+ # Reshape and slice out the padded elements.
1022
+ x_final = x_flat.view([batch_heads, length + 1, 2 * length - 1])
1023
+ x_final = x_final[:, :length, length - 1 :]
1024
+ return x_final
1025
+
1026
+ def _absolute_position_to_relative_position(self, x):
1027
+ batch_heads, length, _ = x.size()
1028
+
1029
+ # Pad along column
1030
+ x = nn.functional.pad(x, [0, length - 1, 0, 0, 0, 0])
1031
+ x_flat = x.view([batch_heads, length * (2 * length - 1)])
1032
+
1033
+ # Add 0's in the beginning that will skew the elements after reshape
1034
+ x_flat = nn.functional.pad(x_flat, [length, 0, 0, 0])
1035
+ x_final = x_flat.view([batch_heads, length, 2 * length])[:, :, 1:]
1036
+ return x_final
1037
+
1038
+
1039
+ class VitsFeedForward(nn.Module):
1040
+ def __init__(self, config):
1041
+ super().__init__()
1042
+ self.conv_1 = nn.Conv1d(config.hidden_size, config.ffn_dim, config.ffn_kernel_size)
1043
+ self.conv_2 = nn.Conv1d(config.ffn_dim, config.hidden_size, config.ffn_kernel_size)
1044
+ self.dropout = nn.Dropout(config.activation_dropout)
1045
+
1046
+ if isinstance(config.hidden_act, str):
1047
+ self.act_fn = ACT2FN[config.hidden_act]
1048
+ else:
1049
+ self.act_fn = config.hidden_act
1050
+
1051
+ if config.ffn_kernel_size > 1:
1052
+ pad_left = (config.ffn_kernel_size - 1) // 2
1053
+ pad_right = config.ffn_kernel_size // 2
1054
+ self.padding = [pad_left, pad_right, 0, 0, 0, 0]
1055
+ else:
1056
+ self.padding = None
1057
+
1058
+ def forward(self, hidden_states, padding_mask):
1059
+ hidden_states = hidden_states.permute(0, 2, 1)
1060
+ padding_mask = padding_mask.permute(0, 2, 1)
1061
+
1062
+ hidden_states = hidden_states * padding_mask
1063
+ if self.padding is not None:
1064
+ hidden_states = nn.functional.pad(hidden_states, self.padding)
1065
+
1066
+ hidden_states = self.conv_1(hidden_states)
1067
+ hidden_states = self.act_fn(hidden_states)
1068
+ hidden_states = self.dropout(hidden_states)
1069
+
1070
+ hidden_states = hidden_states * padding_mask
1071
+ if self.padding is not None:
1072
+ hidden_states = nn.functional.pad(hidden_states, self.padding)
1073
+
1074
+ hidden_states = self.conv_2(hidden_states)
1075
+ hidden_states = hidden_states * padding_mask
1076
+
1077
+ hidden_states = hidden_states.permute(0, 2, 1)
1078
+ return hidden_states
1079
+
1080
+
1081
+ class VitsEncoderLayer(nn.Module):
1082
+ def __init__(self, config: IndicVitsConfig):
1083
+ super().__init__()
1084
+ self.attention = VitsAttention(config)
1085
+ self.dropout = nn.Dropout(config.hidden_dropout)
1086
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1087
+ self.feed_forward = VitsFeedForward(config)
1088
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1089
+
1090
+ def forward(
1091
+ self,
1092
+ hidden_states: torch.Tensor,
1093
+ padding_mask: torch.FloatTensor,
1094
+ attention_mask: Optional[torch.Tensor] = None,
1095
+ output_attentions: bool = False,
1096
+ ):
1097
+ residual = hidden_states
1098
+ hidden_states, attn_weights = self.attention(
1099
+ hidden_states=hidden_states,
1100
+ attention_mask=attention_mask,
1101
+ output_attentions=output_attentions,
1102
+ )
1103
+
1104
+ hidden_states = self.dropout(hidden_states)
1105
+ hidden_states = self.layer_norm(residual + hidden_states)
1106
+
1107
+ residual = hidden_states
1108
+ hidden_states = self.feed_forward(hidden_states, padding_mask)
1109
+ hidden_states = self.dropout(hidden_states)
1110
+ hidden_states = self.final_layer_norm(residual + hidden_states)
1111
+
1112
+ outputs = (hidden_states,)
1113
+
1114
+ if output_attentions:
1115
+ outputs += (attn_weights,)
1116
+
1117
+ return outputs
1118
+
1119
+
1120
+ class VitsEncoder(nn.Module):
1121
+ def __init__(self, config: IndicVitsConfig):
1122
+ super().__init__()
1123
+ self.config = config
1124
+ self.layers = nn.ModuleList([VitsEncoderLayer(config) for _ in range(config.num_hidden_layers)])
1125
+ self.gradient_checkpointing = False
1126
+ self.layerdrop = config.layerdrop
1127
+
1128
+ def forward(
1129
+ self,
1130
+ hidden_states: torch.FloatTensor,
1131
+ padding_mask: torch.FloatTensor,
1132
+ attention_mask: Optional[torch.Tensor] = None,
1133
+ output_attentions: Optional[bool] = None,
1134
+ output_hidden_states: Optional[bool] = None,
1135
+ return_dict: Optional[bool] = None,
1136
+ ) -> Union[Tuple, BaseModelOutput]:
1137
+ all_hidden_states = () if output_hidden_states else None
1138
+ all_self_attentions = () if output_attentions else None
1139
+
1140
+ # expand attention_mask
1141
+ if attention_mask is not None:
1142
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
1143
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
1144
+
1145
+ hidden_states = hidden_states * padding_mask
1146
+
1147
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
1148
+
1149
+ for encoder_layer in self.layers:
1150
+ if output_hidden_states:
1151
+ all_hidden_states = all_hidden_states + (hidden_states,)
1152
+
1153
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
1154
+ dropout_probability = np.random.uniform(0, 1)
1155
+
1156
+ skip_the_layer = self.training and (dropout_probability < self.layerdrop)
1157
+ if not skip_the_layer or synced_gpus:
1158
+ # under fsdp or deepspeed zero3 all gpus must run in sync
1159
+ if self.gradient_checkpointing and self.training:
1160
+ layer_outputs = self._gradient_checkpointing_func(
1161
+ encoder_layer.__call__,
1162
+ hidden_states,
1163
+ padding_mask,
1164
+ attention_mask,
1165
+ output_attentions,
1166
+ )
1167
+ else:
1168
+ layer_outputs = encoder_layer(
1169
+ hidden_states,
1170
+ attention_mask=attention_mask,
1171
+ padding_mask=padding_mask,
1172
+ output_attentions=output_attentions,
1173
+ )
1174
+ hidden_states = layer_outputs[0]
1175
+
1176
+ if skip_the_layer:
1177
+ layer_outputs = (None, None)
1178
+
1179
+ if output_attentions:
1180
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1181
+
1182
+ hidden_states = hidden_states * padding_mask
1183
+
1184
+ if output_hidden_states:
1185
+ all_hidden_states = all_hidden_states + (hidden_states,)
1186
+
1187
+ if not return_dict:
1188
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
1189
+
1190
+ return BaseModelOutput(
1191
+ last_hidden_state=hidden_states,
1192
+ hidden_states=all_hidden_states,
1193
+ attentions=all_self_attentions,
1194
+ )
1195
+
1196
+
1197
+ class VitsTextEncoder(nn.Module):
1198
+ """
1199
+ Transformer encoder that uses relative positional representation instead of absolute positional encoding.
1200
+ """
1201
+
1202
+ def __init__(self, config: IndicVitsConfig):
1203
+ super().__init__()
1204
+ self.config = config
1205
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
1206
+ self.encoder = VitsEncoder(config)
1207
+ self.project = nn.Conv1d(config.hidden_size, config.flow_size * 2, kernel_size=1)
1208
+
1209
+ def get_input_embeddings(self):
1210
+ return self.embed_tokens
1211
+
1212
+ def set_input_embeddings(self, value):
1213
+ self.embed_tokens = value
1214
+
1215
+ def forward(
1216
+ self,
1217
+ input_ids: torch.Tensor,
1218
+ padding_mask: torch.FloatTensor,
1219
+ attention_mask: Optional[torch.Tensor] = None,
1220
+ output_attentions: Optional[bool] = None,
1221
+ output_hidden_states: Optional[bool] = None,
1222
+ return_dict: Optional[bool] = True,
1223
+ ) -> Union[Tuple[torch.Tensor], VitsTextEncoderOutput]:
1224
+ hidden_states = self.embed_tokens(input_ids) * math.sqrt(self.config.hidden_size)
1225
+
1226
+ encoder_outputs = self.encoder(
1227
+ hidden_states=hidden_states,
1228
+ padding_mask=padding_mask,
1229
+ attention_mask=attention_mask,
1230
+ output_attentions=output_attentions,
1231
+ output_hidden_states=output_hidden_states,
1232
+ return_dict=return_dict,
1233
+ )
1234
+
1235
+ last_hidden_state = encoder_outputs[0] if not return_dict else encoder_outputs.last_hidden_state
1236
+
1237
+ stats = self.project(last_hidden_state.transpose(1, 2)).transpose(1, 2) * padding_mask
1238
+ prior_means, prior_log_variances = torch.split(stats, self.config.flow_size, dim=2)
1239
+
1240
+ if not return_dict:
1241
+ outputs = (last_hidden_state, prior_means, prior_log_variances) + encoder_outputs[1:]
1242
+ return outputs
1243
+
1244
+ return VitsTextEncoderOutput(
1245
+ last_hidden_state=last_hidden_state,
1246
+ prior_means=prior_means,
1247
+ prior_log_variances=prior_log_variances,
1248
+ hidden_states=encoder_outputs.hidden_states,
1249
+ attentions=encoder_outputs.attentions,
1250
+ )
1251
+
1252
+
1253
+ class VitsPreTrainedModel(PreTrainedModel):
1254
+ """
1255
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1256
+ models.
1257
+ """
1258
+
1259
+ config_class = IndicVitsConfig
1260
+ base_model_prefix = "vits"
1261
+ main_input_name = "input_ids"
1262
+ supports_gradient_checkpointing = True
1263
+
1264
+ def _init_weights(self, module):
1265
+ """Initialize the weights"""
1266
+ if isinstance(module, nn.Linear):
1267
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1268
+ if module.bias is not None:
1269
+ module.bias.data.zero_()
1270
+ elif isinstance(module, nn.LayerNorm):
1271
+ module.bias.data.zero_()
1272
+ module.weight.data.fill_(1.0)
1273
+ elif isinstance(module, nn.Conv1d):
1274
+ nn.init.kaiming_normal_(module.weight)
1275
+ if module.bias is not None:
1276
+ k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
1277
+ nn.init.uniform_(module.bias, a=-k, b=k)
1278
+ elif isinstance(module, nn.Embedding):
1279
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1280
+ if module.padding_idx is not None:
1281
+ module.weight.data[module.padding_idx].zero_()
1282
+
1283
+
1284
+ VITS_START_DOCSTRING = r"""
1285
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1286
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1287
+ etc.)
1288
+
1289
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1290
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1291
+ and behavior.
1292
+
1293
+ Parameters:
1294
+ config ([`IndicVitsConfig`]):
1295
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1296
+ load the weights associated with the model, only the configuration. Check out the
1297
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1298
+ """
1299
+
1300
+
1301
+ VITS_INPUTS_DOCSTRING = r"""
1302
+ Args:
1303
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1304
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1305
+ it.
1306
+
1307
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1308
+ [`PreTrainedTokenizer.__call__`] for details.
1309
+
1310
+ [What are input IDs?](../glossary#input-ids)
1311
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1312
+ Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0,
1313
+ 1]`:
1314
+
1315
+ - 1 for tokens that are **not masked**,
1316
+ - 0 for tokens that are **masked**.
1317
+
1318
+ [What are attention masks?](../glossary#attention-mask)
1319
+ speaker_id (`int`, *optional*):
1320
+ Which speaker embedding to use. Only used for multispeaker models.
1321
+ output_attentions (`bool`, *optional*):
1322
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1323
+ tensors for more detail.
1324
+ output_hidden_states (`bool`, *optional*):
1325
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1326
+ more detail.
1327
+ return_dict (`bool`, *optional*):
1328
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1329
+ """
1330
+
1331
+
1332
+ @add_start_docstrings(
1333
+ "The complete VITS model, for text-to-speech synthesis.",
1334
+ VITS_START_DOCSTRING,
1335
+ )
1336
+ class IndicVitsModel(VitsPreTrainedModel):
1337
+ def __init__(self, config: IndicVitsConfig):
1338
+ super().__init__(config)
1339
+ self.config = config
1340
+ self.text_encoder = VitsTextEncoder(config)
1341
+ self.flow = VitsResidualCouplingBlock(config)
1342
+ self.decoder = VitsHifiGan(config)
1343
+
1344
+ if config.use_stochastic_duration_prediction:
1345
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
1346
+ else:
1347
+ self.duration_predictor = VitsDurationPredictor(config)
1348
+
1349
+ if config.num_speakers > 1:
1350
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
1351
+
1352
+ if config.num_emotions > 1:
1353
+ self.embed_emotion = nn.Embedding(config.num_emotions, config.emotion_embedding_size)
1354
+
1355
+ # This is used only for training.
1356
+ self.posterior_encoder = VitsPosteriorEncoder(config)
1357
+
1358
+ # These parameters control the synthesised speech properties
1359
+ self.speaking_rate = config.speaking_rate
1360
+ self.noise_scale = config.noise_scale
1361
+ self.noise_scale_duration = config.noise_scale_duration
1362
+
1363
+ # Initialize weights and apply final processing
1364
+ self.post_init()
1365
+
1366
+ def get_encoder(self):
1367
+ return self.text_encoder
1368
+
1369
+ @add_start_docstrings_to_model_forward(VITS_INPUTS_DOCSTRING)
1370
+ @replace_return_docstrings(output_type=VitsModelOutput, config_class=_CONFIG_FOR_DOC)
1371
+ def forward(
1372
+ self,
1373
+ input_ids: Optional[torch.Tensor] = None,
1374
+ attention_mask: Optional[torch.Tensor] = None,
1375
+ speaker_id: Optional[int] = None,
1376
+ emotion_id: Optional[int] = None,
1377
+ output_attentions: Optional[bool] = None,
1378
+ output_hidden_states: Optional[bool] = None,
1379
+ return_dict: Optional[bool] = None,
1380
+ labels: Optional[torch.FloatTensor] = None,
1381
+ ) -> Union[Tuple[Any], VitsModelOutput]:
1382
+ r"""
1383
+ labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
1384
+ Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
1385
+ computation.
1386
+
1387
+ Returns:
1388
+
1389
+ Example:
1390
+
1391
+ ```python
1392
+ >>> from transformers import VitsTokenizer, VitsModel, set_seed
1393
+ >>> import torch
1394
+
1395
+ >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
1396
+ >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
1397
+
1398
+ >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
1399
+
1400
+ >>> set_seed(555) # make deterministic
1401
+
1402
+ >>> with torch.no_grad():
1403
+ ... outputs = model(inputs["input_ids"])
1404
+ >>> outputs.waveform.shape
1405
+ torch.Size([1, 45824])
1406
+ ```
1407
+ """
1408
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1409
+ output_hidden_states = (
1410
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1411
+ )
1412
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1413
+
1414
+ if labels is not None:
1415
+ raise NotImplementedError("Training of VITS is not supported yet.")
1416
+
1417
+ if attention_mask is not None:
1418
+ input_padding_mask = attention_mask.unsqueeze(-1).float()
1419
+ else:
1420
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).float()
1421
+
1422
+ speaker_and_style_embeddings = None
1423
+
1424
+ if self.config.num_speakers > 1 and speaker_id is not None:
1425
+ if not 0 <= speaker_id < self.config.num_speakers:
1426
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}.")
1427
+ if isinstance(speaker_id, int):
1428
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
1429
+ speaker_and_style_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1)
1430
+
1431
+ if self.config.num_emotions > 1 and emotion_id is not None:
1432
+ if not 0 <= emotion_id < self.config.num_emotions:
1433
+ raise ValueError(f"Set `emotion_id` in the range 0-{self.config.num_emotions - 1}.")
1434
+ if isinstance(emotion_id, int):
1435
+ emotion_id = torch.full(size=(1,), fill_value=emotion_id, device=self.device)
1436
+ emotion_embeddings = self.embed_emotion(emotion_id).unsqueeze(-1)
1437
+ if speaker_and_style_embeddings is not None:
1438
+ speaker_and_style_embeddings += emotion_embeddings
1439
+ else:
1440
+ speaker_and_style_embeddings = emotion_embeddings
1441
+
1442
+ text_encoder_output = self.text_encoder(
1443
+ input_ids=input_ids,
1444
+ padding_mask=input_padding_mask,
1445
+ attention_mask=attention_mask,
1446
+ output_attentions=output_attentions,
1447
+ output_hidden_states=output_hidden_states,
1448
+ return_dict=return_dict,
1449
+ )
1450
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
1451
+ hidden_states = hidden_states.transpose(1, 2)
1452
+ input_padding_mask = input_padding_mask.transpose(1, 2)
1453
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
1454
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
1455
+
1456
+ if self.config.use_stochastic_duration_prediction:
1457
+ log_duration = self.duration_predictor(
1458
+ hidden_states,
1459
+ input_padding_mask,
1460
+ speaker_and_style_embeddings,
1461
+ reverse=True,
1462
+ noise_scale=self.noise_scale_duration,
1463
+ )
1464
+ else:
1465
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_and_style_embeddings)
1466
+
1467
+ length_scale = 1.0 / self.speaking_rate
1468
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
1469
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
1470
+
1471
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
1472
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
1473
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
1474
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
1475
+
1476
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
1477
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
1478
+ batch_size, _, output_length, input_length = attn_mask.shape
1479
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
1480
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
1481
+ valid_indices = indices.unsqueeze(0) < cum_duration
1482
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
1483
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
1484
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
1485
+
1486
+ # Expand prior distribution
1487
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
1488
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
1489
+
1490
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
1491
+ latents = self.flow(prior_latents, output_padding_mask, speaker_and_style_embeddings, reverse=True)
1492
+
1493
+ spectrogram = latents * output_padding_mask
1494
+ waveform = self.decoder(spectrogram, speaker_and_style_embeddings)
1495
+ waveform = waveform.squeeze(1)
1496
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
1497
+
1498
+ if not return_dict:
1499
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
1500
+ return outputs
1501
+
1502
+ return VitsModelOutput(
1503
+ waveform=waveform,
1504
+ sequence_lengths=sequence_lengths,
1505
+ spectrogram=spectrogram,
1506
+ hidden_states=text_encoder_output.hidden_states,
1507
+ attentions=text_encoder_output.attentions,
1508
+ )