Steveeeeeeen HF Staff commited on
Commit
7e6946d
·
1 Parent(s): 3cf0e6f
Files changed (45) hide show
  1. cosyvoice2/flow/__init__.py +0 -0
  2. cosyvoice2/flow/decoder_dit.py +585 -0
  3. cosyvoice2/flow/flow.py +225 -0
  4. cosyvoice2/flow/flow_matching.py +205 -0
  5. cosyvoice2/transformer/__init__.py +0 -0
  6. cosyvoice2/transformer/attention.py +328 -0
  7. cosyvoice2/transformer/embedding.py +119 -0
  8. cosyvoice2/transformer/encoder_layer.py +163 -0
  9. cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
  10. cosyvoice2/transformer/subsampling.py +79 -0
  11. cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
  12. cosyvoice2/utils/class_utils.py +41 -0
  13. cosyvoice2/utils/common.py +101 -0
  14. cosyvoice2/utils/mask.py +49 -0
  15. flashcosyvoice/__init__.py +0 -0
  16. flashcosyvoice/cli.py +424 -0
  17. flashcosyvoice/config.py +80 -0
  18. flashcosyvoice/cosyvoice2.py +160 -0
  19. flashcosyvoice/cosyvoice3.py +1 -0
  20. flashcosyvoice/engine/__init__.py +0 -0
  21. flashcosyvoice/engine/block_manager.py +114 -0
  22. flashcosyvoice/engine/llm_engine.py +125 -0
  23. flashcosyvoice/engine/model_runner.py +310 -0
  24. flashcosyvoice/engine/scheduler.py +77 -0
  25. flashcosyvoice/engine/sequence.py +90 -0
  26. flashcosyvoice/modules/__init__.py +0 -0
  27. flashcosyvoice/modules/flow.py +198 -0
  28. flashcosyvoice/modules/flow_components/__init__.py +0 -0
  29. flashcosyvoice/modules/flow_components/estimator.py +974 -0
  30. flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
  31. flashcosyvoice/modules/hifigan.py +249 -0
  32. flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
  33. flashcosyvoice/modules/hifigan_components/layers.py +433 -0
  34. flashcosyvoice/modules/qwen2.py +92 -0
  35. flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
  36. flashcosyvoice/modules/qwen2_components/layers.py +616 -0
  37. flashcosyvoice/modules/sampler.py +231 -0
  38. flashcosyvoice/utils/__init__.py +0 -0
  39. flashcosyvoice/utils/audio.py +77 -0
  40. flashcosyvoice/utils/context.py +28 -0
  41. flashcosyvoice/utils/loader.py +116 -0
  42. flashcosyvoice/utils/memory.py +19 -0
  43. stepaudio2.py +204 -0
  44. token2wav.py +79 -0
  45. utils.py +91 -0
cosyvoice2/flow/__init__.py ADDED
File without changes
cosyvoice2/flow/decoder_dit.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import numpy as np
4
+ from typing import Optional
5
+ from einops import pack, rearrange, repeat
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+
11
+ """
12
+ DiT-v5
13
+ - Add convolution in DiTBlock to increase high-freq component
14
+ """
15
+
16
+
17
+ class MLP(torch.nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features:int,
21
+ hidden_features:Optional[int]=None,
22
+ out_features:Optional[int]=None,
23
+ act_layer=nn.GELU,
24
+ norm_layer=None,
25
+ bias=True,
26
+ drop=0.,
27
+ ):
28
+ super().__init__()
29
+ hidden_features = hidden_features or in_features
30
+ out_features = out_features or in_features
31
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
32
+ self.act = act_layer()
33
+ self.drop1 = nn.Dropout(drop)
34
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
35
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
36
+ self.drop2 = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop1(x)
42
+ x = self.norm(x)
43
+ x = self.fc2(x)
44
+ x = self.drop2(x)
45
+ return x
46
+
47
+
48
+ class Attention(torch.nn.Module):
49
+ def __init__(
50
+ self,
51
+ dim: int,
52
+ num_heads: int = 8,
53
+ head_dim: int = 64,
54
+ qkv_bias: bool = False,
55
+ qk_norm: bool = False,
56
+ attn_drop: float = 0.,
57
+ proj_drop: float = 0.,
58
+ norm_layer: nn.Module = nn.LayerNorm,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.num_heads = num_heads
62
+ self.head_dim = head_dim
63
+ self.inner_dim = num_heads * head_dim
64
+ self.scale = head_dim ** -0.5
65
+
66
+ self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
67
+ self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
68
+ self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
69
+
70
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
71
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
72
+
73
+ self.attn_drop = nn.Dropout(attn_drop)
74
+ self.proj_drop = nn.Dropout(proj_drop)
75
+
76
+ self.proj = nn.Linear(self.inner_dim, dim)
77
+
78
+ def to_heads(self, ts:torch.Tensor):
79
+ b, t, c = ts.shape
80
+ # (b, t, nh, c)
81
+ ts = ts.reshape(b, t, self.num_heads, c // self.num_heads)
82
+ ts = ts.transpose(1, 2)
83
+ return ts
84
+
85
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
86
+ """Args:
87
+ x(torch.Tensor): shape (b, t, c)
88
+ attn_mask(torch.Tensor): shape (b, t, t)
89
+ """
90
+ b, t, c = x.shape
91
+
92
+ q = self.to_q(x)
93
+ k = self.to_k(x)
94
+ v = self.to_v(x)
95
+
96
+ q = self.to_heads(q) # (b, nh, t, c)
97
+ k = self.to_heads(k)
98
+ v = self.to_heads(v)
99
+
100
+ q = self.q_norm(q)
101
+ k = self.k_norm(k)
102
+
103
+ attn_mask = attn_mask.unsqueeze(1)
104
+ x = F.scaled_dot_product_attention(
105
+ q, k, v,
106
+ attn_mask=attn_mask,
107
+ dropout_p=self.attn_drop.p if self.training else 0.,
108
+ ) # (b, nh, t, c)
109
+ x = x.transpose(1, 2).reshape(b, t, -1)
110
+ x = self.proj(x)
111
+ x = self.proj_drop(x)
112
+ return x
113
+
114
+ def forward_chunk(self, x: torch.Tensor, att_cache: torch.Tensor=None, attn_mask: torch.Tensor=None):
115
+ """
116
+ Args:
117
+ x: shape (b, dt, c)
118
+ att_cache: shape (b, nh, t, c*2)
119
+ """
120
+ b, t, c = x.shape
121
+
122
+ q = self.to_q(x)
123
+ k = self.to_k(x)
124
+ v = self.to_v(x)
125
+
126
+ q = self.to_heads(q) # (b, nh, t, c)
127
+ k = self.to_heads(k)
128
+ v = self.to_heads(v)
129
+
130
+ q = self.q_norm(q)
131
+ k = self.k_norm(k)
132
+
133
+ # unpack {k,v}_cache
134
+ if att_cache is not None:
135
+ if attn_mask is not None:
136
+ k_cache, v_cache = att_cache.chunk(2, dim=3)
137
+ k = torch.cat([k, k_cache], dim=2)
138
+ v = torch.cat([v, v_cache], dim=2)
139
+
140
+ else:
141
+ k_cache, v_cache = att_cache.chunk(2, dim=3)
142
+ k = torch.cat([k, k_cache], dim=2)
143
+ v = torch.cat([v, v_cache], dim=2)
144
+
145
+ # new {k,v}_cache
146
+ new_att_cache = torch.cat([k, v], dim=3)
147
+ # attn_mask = torch.ones((b, 1, t, t1), dtype=torch.bool, device=x.device)
148
+ if attn_mask is not None:
149
+ attn_mask = attn_mask.unsqueeze(1)
150
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # (b, nh, t, c)
151
+ x = x.transpose(1, 2).reshape(b, t, -1)
152
+ x = self.proj(x)
153
+ x = self.proj_drop(x)
154
+ return x, new_att_cache
155
+
156
+
157
+ def modulate(x, shift, scale):
158
+ return x * (1 + scale) + shift
159
+
160
+
161
+ class TimestepEmbedder(nn.Module):
162
+ """
163
+ Embeds scalar timesteps into vector representations.
164
+ """
165
+ def __init__(self, hidden_size, frequency_embedding_size=256):
166
+ super().__init__()
167
+ self.mlp = nn.Sequential(
168
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
169
+ nn.SiLU(),
170
+ nn.Linear(hidden_size, hidden_size, bias=True),
171
+ )
172
+ self.frequency_embedding_size = frequency_embedding_size
173
+ # from SinusoidalPosEmb
174
+ self.scale = 1000
175
+
176
+ @staticmethod
177
+ def timestep_embedding(t, dim, max_period=10000):
178
+ """
179
+ Create sinusoidal timestep embeddings.
180
+ :param t: a 1-D Tensor of N indices, one per batch element.
181
+ These may be fractional.
182
+ :param dim: the dimension of the output.
183
+ :param max_period: controls the minimum frequency of the embeddings.
184
+ :return: an (N, D) Tensor of positional embeddings.
185
+ """
186
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
187
+ half = dim // 2
188
+ freqs = torch.exp(
189
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
190
+ ).to(t)
191
+ args = t[:, None] * freqs[None]
192
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
193
+ if dim % 2:
194
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
195
+ return embedding
196
+
197
+ def forward(self, t):
198
+ t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size)
199
+ t_emb = self.mlp(t_freq)
200
+ return t_emb
201
+
202
+
203
+ # Convolution related
204
+ class Transpose(torch.nn.Module):
205
+ def __init__(self, dim0: int, dim1: int):
206
+ super().__init__()
207
+ self.dim0 = dim0
208
+ self.dim1 = dim1
209
+
210
+ def forward(self, x: torch.Tensor):
211
+ x = torch.transpose(x, self.dim0, self.dim1)
212
+ return x
213
+
214
+
215
+ class CausalConv1d(torch.nn.Conv1d):
216
+ def __init__(
217
+ self,
218
+ in_channels: int,
219
+ out_channels: int,
220
+ kernel_size: int,
221
+ ) -> None:
222
+ super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
223
+ self.causal_padding = (kernel_size - 1, 0)
224
+
225
+ def forward(self, x: torch.Tensor):
226
+ x = F.pad(x, self.causal_padding)
227
+ x = super(CausalConv1d, self).forward(x)
228
+ return x
229
+
230
+ def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None):
231
+ if cnn_cache is None:
232
+ cnn_cache = x.new_zeros((x.shape[0], self.in_channels, self.causal_padding[0]))
233
+ x = torch.cat([cnn_cache, x], dim=2)
234
+ new_cnn_cache = x[..., -self.causal_padding[0]:]
235
+ x = super(CausalConv1d, self).forward(x)
236
+ return x, new_cnn_cache
237
+
238
+
239
+ class CausalConvBlock(nn.Module):
240
+ def __init__(self,
241
+ in_channels: int,
242
+ out_channels: int,
243
+ kernel_size: int = 3,
244
+ ):
245
+ super().__init__()
246
+ self.in_channels = in_channels
247
+ self.out_channels = out_channels
248
+ self.kernel_size = kernel_size
249
+
250
+ self.block = torch.nn.Sequential(
251
+ # norm
252
+ # conv1
253
+ Transpose(1, 2),
254
+ CausalConv1d(in_channels, out_channels, kernel_size),
255
+ Transpose(1, 2),
256
+ # norm & act
257
+ nn.LayerNorm(out_channels),
258
+ nn.Mish(),
259
+ # conv2
260
+ Transpose(1, 2),
261
+ CausalConv1d(out_channels, out_channels, kernel_size),
262
+ Transpose(1, 2),
263
+ )
264
+
265
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
266
+ """
267
+ Args:
268
+ x: shape (b, t, c)
269
+ mask: shape (b, t, 1)
270
+ """
271
+ if mask is not None: x = x * mask
272
+ x = self.block(x)
273
+ if mask is not None: x = x * mask
274
+ return x
275
+
276
+ def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None):
277
+ """
278
+ Args:
279
+ x: shape (b, dt, c)
280
+ cnn_cache: shape (b, c1+c2, 2)
281
+ """
282
+ if cnn_cache is not None:
283
+ cnn_cache1, cnn_cache2 = cnn_cache.split((self.in_channels, self.out_channels), dim=1)
284
+ else:
285
+ cnn_cache1, cnn_cache2 = None, None
286
+ x = self.block[0](x)
287
+ x, new_cnn_cache1 = self.block[1].forward_chunk(x, cnn_cache1)
288
+ x = self.block[2:6](x)
289
+ x, new_cnn_cache2 = self.block[6].forward_chunk(x, cnn_cache2)
290
+ x = self.block[7](x)
291
+ new_cnn_cache = torch.cat((new_cnn_cache1, new_cnn_cache2), dim=1)
292
+ return x, new_cnn_cache
293
+
294
+
295
+ class DiTBlock(nn.Module):
296
+ """
297
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
298
+ """
299
+ def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs):
300
+ super().__init__()
301
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
302
+ self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs)
303
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
304
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
305
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
306
+ self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
307
+ self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
308
+ self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3)
309
+ self.adaLN_modulation = nn.Sequential(
310
+ nn.SiLU(),
311
+ nn.Linear(hidden_size, 9 * hidden_size, bias=True)
312
+ )
313
+
314
+ def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor):
315
+ """Args
316
+ x: shape (b, t, c)
317
+ c: shape (b, 1, c)
318
+ attn_mask: shape (b, t, t), bool type attention mask
319
+ """
320
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
321
+ = self.adaLN_modulation(c).chunk(9, dim=-1)
322
+ # attention
323
+ x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask)
324
+ # conv
325
+ x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv))
326
+ # mlp
327
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
328
+ return x
329
+
330
+ def forward_chunk(self, x: torch.Tensor, c: torch.Tensor, cnn_cache: torch.Tensor=None, att_cache: torch.Tensor=None, mask: torch.Tensor=None):
331
+ """
332
+ Args:
333
+ x: shape (b, dt, c)
334
+ c: shape (b, 1, c)
335
+ cnn_cache: shape (b, c1+c2, 2)
336
+ att_cache: shape (b, nh, t, c * 2)
337
+ """
338
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
339
+ = self.adaLN_modulation(c).chunk(9, dim=-1)
340
+ # attention
341
+ x_att, new_att_cache = self.attn.forward_chunk(modulate(self.norm1(x), shift_msa, scale_msa), att_cache, mask)
342
+ x = x + gate_msa * x_att
343
+ # conv
344
+ x_conv, new_cnn_cache = self.conv.forward_chunk(modulate(self.norm3(x), shift_conv, scale_conv), cnn_cache)
345
+ x = x + gate_conv * x_conv
346
+ # mlp
347
+ x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
348
+ return x, new_cnn_cache, new_att_cache
349
+
350
+
351
+ class FinalLayer(nn.Module):
352
+ """
353
+ The final layer of DiT.
354
+ """
355
+ def __init__(self, hidden_size, out_channels):
356
+ super().__init__()
357
+ self.adaLN_modulation = nn.Sequential(
358
+ nn.SiLU(),
359
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
360
+ )
361
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
362
+ self.linear = nn.Linear(hidden_size, out_channels, bias=True)
363
+
364
+ def forward(self, x, c):
365
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
366
+ x = modulate(self.norm_final(x), shift, scale)
367
+ x = self.linear(x)
368
+ return x
369
+
370
+
371
+ class DiT(nn.Module):
372
+ """
373
+ Diffusion model with a Transformer backbone.
374
+ """
375
+ def __init__(
376
+ self,
377
+ in_channels: int,
378
+ out_channels: int,
379
+ mlp_ratio: float = 4.0,
380
+ depth: int = 28,
381
+ num_heads: int = 8,
382
+ head_dim: int = 64,
383
+ hidden_size: int = 256,
384
+ ):
385
+ super().__init__()
386
+ self.in_channels = in_channels
387
+ self.out_channels = out_channels
388
+ self.t_embedder = TimestepEmbedder(hidden_size)
389
+
390
+ self.in_proj = nn.Linear(in_channels, hidden_size)
391
+
392
+ self.blocks = nn.ModuleList([
393
+ DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth)
394
+ ])
395
+ self.final_layer = FinalLayer(hidden_size, self.out_channels)
396
+
397
+ self.initialize_weights()
398
+
399
+ self.enable_cuda_graph = False
400
+ self.use_cuda_graph = False
401
+
402
+ self.graph_chunk = {}
403
+ self.inference_buffers_chunk = {}
404
+ self.max_size_chunk = {}
405
+
406
+ self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False)
407
+ self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False)
408
+
409
+ def initialize_weights(self):
410
+ # Initialize transformer layers:
411
+ def _basic_init(module):
412
+ if isinstance(module, nn.Linear):
413
+ torch.nn.init.xavier_uniform_(module.weight)
414
+ if module.bias is not None:
415
+ nn.init.constant_(module.bias, 0)
416
+ self.apply(_basic_init)
417
+
418
+ # Initialize timestep embedding MLP:
419
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
420
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
421
+
422
+ # Zero-out adaLN modulation layers in DiT blocks:
423
+ for block in self.blocks:
424
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
425
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
426
+
427
+ # Zero-out output layers:
428
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
429
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
430
+ nn.init.constant_(self.final_layer.linear.weight, 0)
431
+ nn.init.constant_(self.final_layer.linear.bias, 0)
432
+
433
+ def _init_cuda_graph_chunk(self):
434
+ # get dtype, device from registered buffer
435
+ dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device
436
+ # init cuda graph for streaming forward
437
+ with torch.no_grad():
438
+ for chunk_size in [30, 48, 96]:
439
+ if chunk_size == 30 or chunk_size == 48:
440
+ max_size = 500
441
+ self.max_size_chunk[chunk_size] = max_size
442
+ else:
443
+ max_size = 1000
444
+ self.max_size_chunk[chunk_size] = max_size
445
+ static_x1 = torch.zeros((2, 320, chunk_size), dtype=dtype, device=device)
446
+ static_t1 = torch.zeros((2, 1, 512), dtype=dtype, device=device)
447
+ static_mask1 = torch.ones((2, chunk_size, max_size+chunk_size), dtype=torch.bool, device=device)
448
+ static_att_cache = torch.zeros((16, 2, 8, max_size, 128), dtype=dtype, device=device)
449
+ static_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device)
450
+ static_inputs1 = [
451
+ static_x1,
452
+ static_t1,
453
+ static_mask1,
454
+ static_cnn_cache,
455
+ static_att_cache,
456
+ ]
457
+ static_new_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device)
458
+ static_new_att_cache = torch.zeros((16, 2, 8, max_size+chunk_size, 128), dtype=dtype, device=device)
459
+ self.blocks_forward_chunk(
460
+ static_inputs1[0],
461
+ static_inputs1[1],
462
+ static_inputs1[2],
463
+ static_inputs1[3],
464
+ static_inputs1[4],
465
+ static_new_cnn_cache,
466
+ static_new_att_cache)
467
+ graph_chunk = torch.cuda.CUDAGraph()
468
+ with torch.cuda.graph(graph_chunk):
469
+ static_out1 = self.blocks_forward_chunk(static_x1, static_t1, static_mask1, static_cnn_cache, static_att_cache, static_new_cnn_cache, static_new_att_cache)
470
+ static_outputs1 = [static_out1, static_new_cnn_cache, static_new_att_cache]
471
+ self.inference_buffers_chunk[chunk_size] = {
472
+ 'static_inputs': static_inputs1,
473
+ 'static_outputs': static_outputs1
474
+ }
475
+ self.graph_chunk[chunk_size] = graph_chunk
476
+
477
+ def _init_cuda_graph_all(self):
478
+ self._init_cuda_graph_chunk()
479
+ self.use_cuda_graph = True
480
+ print(f"CUDA Graph initialized successfully for chunk decoder")
481
+
482
+ def forward(self, x, mask, mu, t, spks=None, cond=None):
483
+ """Args:
484
+ x: shape (b, c, t)
485
+ mask: shape (b, 1, t)
486
+ t: shape (b,)
487
+ spks: shape (b, c)
488
+ cond: shape (b, c, t)
489
+ """
490
+ # (sfy) chunk training strategy should not be open-sourced
491
+
492
+ # time
493
+ t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)
494
+ x = pack([x, mu], "b * t")[0]
495
+ if spks is not None:
496
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
497
+ x = pack([x, spks], "b * t")[0]
498
+ if cond is not None:
499
+ x = pack([x, cond], "b * t")[0]
500
+
501
+ return self.blocks_forward(x, t, mask)
502
+
503
+ def blocks_forward(self, x, t, mask):
504
+ x = x.transpose(1, 2)
505
+ attn_mask = mask.bool()
506
+ x = self.in_proj(x)
507
+ for block in self.blocks:
508
+ x = block(x, t, attn_mask)
509
+ x = self.final_layer(x, t)
510
+ x = x.transpose(1, 2)
511
+ return x
512
+
513
+ def forward_chunk(self,
514
+ x: torch.Tensor,
515
+ mu: torch.Tensor,
516
+ t: torch.Tensor,
517
+ spks: torch.Tensor,
518
+ cond: torch.Tensor,
519
+ cnn_cache: torch.Tensor = None,
520
+ att_cache: torch.Tensor = None,
521
+ ):
522
+ """
523
+ Args:
524
+ x: shape (b, dt, c)
525
+ mu: shape (b, dt, c)
526
+ t: shape (b,)
527
+ spks: shape (b, c)
528
+ cond: shape (b, dt, c)
529
+ cnn_cache: shape (depth, b, c1+c2, 2)
530
+ att_cache: shape (depth, b, nh, t, c * 2)
531
+ """
532
+
533
+ # time
534
+ t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)
535
+ x = pack([x, mu], "b * t")[0]
536
+ if spks is not None:
537
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
538
+ x = pack([x, spks], "b * t")[0]
539
+ if cond is not None:
540
+ x = pack([x, cond], "b * t")[0]
541
+
542
+ # create fake cache
543
+ if cnn_cache is None:
544
+ cnn_cache = [None] * len(self.blocks)
545
+ if att_cache is None:
546
+ att_cache = [None] * len(self.blocks)
547
+ if att_cache[0] is not None:
548
+ last_att_len = att_cache.shape[3]
549
+ else:
550
+ last_att_len = 0
551
+ chunk_size = x.shape[2]
552
+ mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device)
553
+ if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]:
554
+ padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device)
555
+ padded_mask[:, :, :mask.shape[-1]] = mask
556
+ padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device)
557
+ padded_att_cache[:, :, :, :last_att_len, :] = att_cache
558
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][0].copy_(x)
559
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][1].copy_(t)
560
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][2].copy_(padded_mask)
561
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][3].copy_(cnn_cache)
562
+ self.inference_buffers_chunk[chunk_size]['static_inputs'][4].copy_(padded_att_cache)
563
+ self.graph_chunk[chunk_size].replay()
564
+ x = self.inference_buffers_chunk[chunk_size]['static_outputs'][0][:, :, :chunk_size]
565
+ new_cnn_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][1]
566
+ new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :]
567
+ else:
568
+ mask = None
569
+ x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer)
570
+ new_cnn_cache = self.cnn_cache_buffer
571
+ new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :]
572
+
573
+ return x, new_cnn_cache, new_att_cache
574
+
575
+ def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None):
576
+ x = x.transpose(1, 2)
577
+ x = self.in_proj(x)
578
+ for b_idx, block in enumerate(self.blocks):
579
+ x, this_new_cnn_cache, this_new_att_cache \
580
+ = block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask)
581
+ cnn_cache_buffer[b_idx] = this_new_cnn_cache
582
+ att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache
583
+ x = self.final_layer(x, t)
584
+ x = x.transpose(1, 2)
585
+ return x
cosyvoice2/flow/flow.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ from cosyvoice2.utils.mask import make_pad_mask
19
+ from cosyvoice2.flow.flow_matching import CausalConditionalCFM
20
+ from cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
21
+
22
+
23
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
24
+ def __init__(self,
25
+ input_size: int = 512,
26
+ output_size: int = 80,
27
+ spk_embed_dim: int = 192,
28
+ output_type: str = "mel",
29
+ vocab_size: int = 5121,
30
+ encoder: UpsampleConformerEncoderV2 = None,
31
+ decoder: CausalConditionalCFM = None,
32
+ input_embedding: torch.nn.Module = None,
33
+ ):
34
+ super().__init__()
35
+ self.input_size = input_size
36
+ self.output_size = output_size
37
+ self.vocab_size = vocab_size
38
+ self.output_type = output_type
39
+ self.pre_lookahead_len = int(encoder.pre_lookahead_layer.pre_lookahead_len)
40
+ self.up_rate = int(encoder.up_layer.stride)
41
+ if input_embedding is None:
42
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
43
+ else:
44
+ self.input_embedding = input_embedding
45
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
46
+ self.encoder = encoder
47
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
48
+ self.decoder = decoder
49
+
50
+ # xvec projection with CUDA Graph optimization
51
+ # 初始化 CUDA Graph 相关变量
52
+ self.enable_cuda_graph = False
53
+ self.static_embedding = None
54
+ self.static_output = None
55
+ self.graph = None
56
+ self.embedding_shape = None
57
+
58
+ def scatter_cuda_graph(self, enable_cuda_graph: bool):
59
+ self.enable_cuda_graph = enable_cuda_graph
60
+ if self.enable_cuda_graph:
61
+ # self.encoder.scatter_cuda_graph(enable_cuda_graph)
62
+ self.decoder.scatter_cuda_graph(enable_cuda_graph)
63
+
64
+ @torch.inference_mode()
65
+ def inference(self,
66
+ token,
67
+ token_len,
68
+ prompt_token,
69
+ prompt_token_len,
70
+ prompt_feat,
71
+ prompt_feat_len,
72
+ embedding,
73
+ n_timesteps: int = 10,
74
+ ):
75
+ assert token.shape[0] == 1
76
+
77
+ # xvec projection
78
+ embedding = F.normalize(embedding, dim=1)
79
+ embedding = self.spk_embed_affine_layer(embedding)
80
+
81
+ # concat text and prompt_text
82
+ token_len = prompt_token_len + token_len
83
+ token = torch.concat([prompt_token, token], dim=1)
84
+
85
+ mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
86
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
87
+
88
+ # token encode
89
+ h, _ = self.encoder.forward(token, token_len)
90
+ h = self.encoder_proj(h)
91
+
92
+ # condition
93
+ mel_len1 = prompt_feat.shape[1]
94
+ mel_len2 = h.shape[1] - prompt_feat.shape[1]
95
+
96
+ conds = torch.zeros_like(h)
97
+ conds[:, :mel_len1] = prompt_feat
98
+ conds = conds.transpose(1, 2).contiguous()
99
+
100
+ mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
101
+
102
+ feat = self.decoder.forward(
103
+ mu=h.transpose(1, 2).contiguous(),
104
+ mask=mask.unsqueeze(1),
105
+ spks=embedding,
106
+ cond=conds,
107
+ n_timesteps=n_timesteps,
108
+ )
109
+
110
+ feat = feat[:, :, mel_len1:]
111
+ assert feat.shape[2] == mel_len2
112
+ return feat
113
+
114
+ @torch.inference_mode()
115
+ def setup_cache(self,
116
+ token: torch.Tensor,
117
+ mel: torch.Tensor,
118
+ spk: torch.Tensor,
119
+ n_timesteps: int = 10,
120
+ ):
121
+ """
122
+ Args:
123
+ token: shape (b, t), with look ahead tokens
124
+ mel: shape (b, t, c), groundtruth mel
125
+ spk: shape (b, 192), speaker embedding
126
+ Returns:
127
+ cache: dict {
128
+ 'conformer': {'cnn_cache': xxx, 'att_cache': xxx},
129
+ 'estimator': {'cnn_cache': xxx, 'att_cache': xxx}
130
+ }
131
+ """
132
+ # check if look ahead token included
133
+ assert (token.shape[1] - self.pre_lookahead_len) * self.up_rate == mel.shape[1], (token.shape, mel.shape)
134
+
135
+ # xvec projection
136
+ spk = F.normalize(spk, dim=1)
137
+ spk = self.spk_embed_affine_layer(spk)
138
+
139
+ token = self.input_embedding(token)
140
+ # NOTE encoder.forward_chunk will strip the look ahead part
141
+ h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
142
+ xs = token,
143
+ last_chunk = False,
144
+ cnn_cache = None,
145
+ att_cache = None,
146
+ )
147
+ h = self.encoder_proj(h)
148
+
149
+ feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
150
+ mu = h.transpose(1, 2).contiguous(),
151
+ spks = spk,
152
+ cond = mel.transpose(1, 2).contiguous(),
153
+ n_timesteps = n_timesteps,
154
+ temperature = 1.0,
155
+ cnn_cache = None,
156
+ att_cache = None,
157
+ )
158
+
159
+ cache = {
160
+ 'conformer_cnn_cache': conformer_cnn_cache,
161
+ 'conformer_att_cache': conformer_att_cache,
162
+ 'estimator_cnn_cache': estimator_cnn_cache,
163
+ 'estimator_att_cache': estimator_att_cache,
164
+ }
165
+ return cache
166
+
167
+ @torch.inference_mode()
168
+ def inference_chunk(self,
169
+ token: torch.Tensor,
170
+ spk: torch.Tensor,
171
+ cache: dict,
172
+ last_chunk: bool = False,
173
+ n_timesteps: int = 10,
174
+ ):
175
+ """
176
+ Args:
177
+ token: shape (b, t), with look ahead tokens
178
+ spk: shape (b, 192), speaker embedding
179
+ cache: dict {
180
+ 'conformer_cnn_cache': xxx,
181
+ ...
182
+ }
183
+ """
184
+ # unpack cache
185
+ conformer_cnn_cache = cache['conformer_cnn_cache']
186
+ conformer_att_cache = cache['conformer_att_cache']
187
+ estimator_cnn_cache = cache['estimator_cnn_cache']
188
+ estimator_att_cache = cache['estimator_att_cache']
189
+
190
+ # xvec projection
191
+ spk = F.normalize(spk, dim=1)
192
+ spk = self.spk_embed_affine_layer(spk)
193
+
194
+ token = self.input_embedding(token)
195
+ # if not the last chunk, h is shorter than xs for a length of lookahead_length * stride (6)
196
+ h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
197
+ xs = token,
198
+ last_chunk = last_chunk,
199
+ cnn_cache = conformer_cnn_cache,
200
+ att_cache = conformer_att_cache,
201
+ )
202
+ h = self.encoder_proj(h)
203
+
204
+ cond = torch.zeros_like(h)
205
+ # forward estimator
206
+ feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
207
+ mu = h.transpose(1, 2).contiguous(),
208
+ spks = spk,
209
+ cond = cond.transpose(1, 2).contiguous(),
210
+ n_timesteps = n_timesteps,
211
+ temperature = 1.0,
212
+ cnn_cache = estimator_cnn_cache,
213
+ att_cache = estimator_att_cache,
214
+ )
215
+
216
+
217
+ new_cache = {
218
+ 'conformer_cnn_cache': conformer_cnn_cache,
219
+ 'conformer_att_cache': conformer_att_cache,
220
+ 'estimator_cnn_cache': estimator_cnn_cache,
221
+ 'estimator_att_cache': estimator_att_cache,
222
+ }
223
+
224
+ return feat, new_cache
225
+
cosyvoice2/flow/flow_matching.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ import onnxruntime
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+ from cosyvoice2.flow.decoder_dit import DiT
20
+ from cosyvoice2.utils.mask import make_pad_mask
21
+
22
+
23
+ """
24
+ Inference wrapper
25
+ """
26
+ class CausalConditionalCFM(torch.nn.Module):
27
+ def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7):
28
+ super().__init__()
29
+ self.estimator = estimator
30
+ self.inference_cfg_rate = inference_cfg_rate
31
+ self.out_channels = estimator.out_channels
32
+ # a maximum of 600s
33
+ self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False)
34
+
35
+ self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False)
36
+ self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False)
37
+
38
+ def scatter_cuda_graph(self, enable_cuda_graph: bool):
39
+ if enable_cuda_graph:
40
+ self.estimator._init_cuda_graph_all()
41
+
42
+ def solve_euler(self, x, t_span, mu, mask, spks, cond):
43
+ """
44
+ Fixed euler solver for ODEs.
45
+ Args:
46
+ x (torch.Tensor): random noise
47
+ t_span (torch.Tensor): n_timesteps interpolated
48
+ shape: (n_timesteps + 1,)
49
+ mu (torch.Tensor): output of encoder
50
+ shape: (batch_size, n_feats, mel_timesteps)
51
+ mask (torch.Tensor): output_mask
52
+ shape: (batch_size, 1, mel_timesteps)
53
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
54
+ shape: (batch_size, spk_emb_dim)
55
+ cond: Not used but kept for future purposes
56
+ """
57
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
58
+ t = t.unsqueeze(dim=0)
59
+ assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0'
60
+
61
+ # constant during denoising
62
+ mask_in = torch.cat([mask, mask], dim=0)
63
+ mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
64
+ spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
65
+ cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
66
+
67
+ for step in range(1, len(t_span)):
68
+
69
+ x_in = torch.cat([x, x], dim=0)
70
+ t_in = torch.cat([t, t], dim=0)
71
+
72
+ dphi_dt = self.estimator.forward(
73
+ x_in,
74
+ mask_in,
75
+ mu_in,
76
+ t_in,
77
+ spks_in,
78
+ cond_in,
79
+ )
80
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
81
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
82
+ x = x + dt * dphi_dt
83
+ t = t + dt
84
+ if step < len(t_span) - 1:
85
+ dt = t_span[step + 1] - t
86
+
87
+ return x
88
+
89
+ @torch.inference_mode()
90
+ def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0):
91
+ z = self.rand_noise[:, :, :mu.size(2)] * temperature
92
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
93
+ # cosine scheduling
94
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
95
+ return self.solve_euler(z, t_span, mu, mask, spks, cond)
96
+
97
+ def solve_euler_chunk(self,
98
+ x:torch.Tensor,
99
+ t_span:torch.Tensor,
100
+ mu:torch.Tensor,
101
+ spks:torch.Tensor,
102
+ cond:torch.Tensor,
103
+ cnn_cache:torch.Tensor=None,
104
+ att_cache:torch.Tensor=None,
105
+ ):
106
+ """
107
+ Fixed euler solver for ODEs.
108
+ Args:
109
+ x (torch.Tensor): random noise
110
+ t_span (torch.Tensor): n_timesteps interpolated
111
+ shape: (n_timesteps + 1,)
112
+ mu (torch.Tensor): output of encoder
113
+ shape: (batch_size, n_feats, mel_timesteps)
114
+ mask (torch.Tensor): output_mask
115
+ shape: (batch_size, 1, mel_timesteps)
116
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
117
+ shape: (batch_size, spk_emb_dim)
118
+ cond: Not used but kept for future purposes
119
+ cnn_cache: shape (n_time, depth, b, c1+c2, 2)
120
+ att_cache: shape (n_time, depth, b, nh, t, c * 2)
121
+ """
122
+ assert self.inference_cfg_rate > 0, 'cfg rate should be > 0'
123
+
124
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
125
+ t = t.unsqueeze(dim=0) # (b,)
126
+
127
+ # setup initial cache
128
+ if cnn_cache is None:
129
+ cnn_cache = [None for _ in range(len(t_span)-1)]
130
+ if att_cache is None:
131
+ att_cache = [None for _ in range(len(t_span)-1)]
132
+ # next chunk's cache at each timestep
133
+
134
+ if att_cache[0] is not None:
135
+ last_att_len = att_cache.shape[4]
136
+ else:
137
+ last_att_len = 0
138
+
139
+ # constant during denoising
140
+ mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
141
+ spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
142
+ cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
143
+ for step in range(1, len(t_span)):
144
+ # torch.cuda.memory._record_memory_history(max_entries=100000)
145
+ # torch.cuda.memory._record_memory_history(max_entries=100000)
146
+ this_att_cache = att_cache[step-1]
147
+ this_cnn_cache = cnn_cache[step-1]
148
+
149
+ dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk(
150
+ x = x.repeat(2, 1, 1),
151
+ mu = mu_in,
152
+ t = t.repeat(2),
153
+ spks = spks_in,
154
+ cond = cond_in,
155
+ cnn_cache = this_cnn_cache,
156
+ att_cache = this_att_cache,
157
+ )
158
+ dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0)
159
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
160
+ x = x + dt * dphi_dt
161
+ t = t + dt
162
+ if step < len(t_span) - 1:
163
+ dt = t_span[step + 1] - t
164
+
165
+ self.cnn_cache_buffer[step-1] = this_new_cnn_cache
166
+ self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache
167
+
168
+ cnn_cache = self.cnn_cache_buffer
169
+ att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :]
170
+ return x, cnn_cache, att_cache
171
+
172
+ @torch.inference_mode()
173
+ def forward_chunk(self,
174
+ mu:torch.Tensor,
175
+ spks:torch.Tensor,
176
+ cond:torch.Tensor,
177
+ n_timesteps:int=10,
178
+ temperature:float=1.0,
179
+ cnn_cache:torch.Tensor=None,
180
+ att_cache:torch.Tensor=None,
181
+ ):
182
+ """
183
+ Args:
184
+ mu(torch.Tensor): shape (b, c, t)
185
+ spks(torch.Tensor): shape (b, 192)
186
+ cond(torch.Tensor): shape (b, c, t)
187
+ cnn_cache: shape (n_time, depth, b, c1+c2, 2)
188
+ att_cache: shape (n_time, depth, b, nh, t, c * 2)
189
+ """
190
+ # get offset from att_cache
191
+ offset = att_cache.shape[4] if att_cache is not None else 0
192
+ z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature
193
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
194
+ # cosine scheduling
195
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
196
+ x, new_cnn_cache, new_att_cache = self.solve_euler_chunk(
197
+ x=z,
198
+ t_span=t_span,
199
+ mu=mu,
200
+ spks=spks,
201
+ cond=cond,
202
+ att_cache=att_cache,
203
+ cnn_cache=cnn_cache,
204
+ )
205
+ return x, new_cnn_cache, new_att_cache
cosyvoice2/transformer/__init__.py ADDED
File without changes
cosyvoice2/transformer/attention.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2022 Xingchen Song ([email protected])
4
+ # 2024 Alibaba Inc (Xiang Lyu)
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ """Multi-Head Attention layer definition."""
18
+
19
+ import math
20
+ from typing import Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ class MultiHeadedAttention(nn.Module):
27
+ """Multi-Head Attention layer.
28
+
29
+ Args:
30
+ n_head (int): The number of heads.
31
+ n_feat (int): The number of features.
32
+ dropout_rate (float): Dropout rate.
33
+
34
+ """
35
+
36
+ def __init__(self,
37
+ n_head: int,
38
+ n_feat: int,
39
+ dropout_rate: float,
40
+ key_bias: bool = True):
41
+ """Construct an MultiHeadedAttention object."""
42
+ super().__init__()
43
+ assert n_feat % n_head == 0
44
+ # We assume d_v always equals d_k
45
+ self.d_k = n_feat // n_head
46
+ self.h = n_head
47
+ self.linear_q = nn.Linear(n_feat, n_feat)
48
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
49
+ self.linear_v = nn.Linear(n_feat, n_feat)
50
+ self.linear_out = nn.Linear(n_feat, n_feat)
51
+ self.dropout = nn.Dropout(p=dropout_rate)
52
+
53
+ def forward_qkv(
54
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
55
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
56
+ """Transform query, key and value.
57
+
58
+ Args:
59
+ query (torch.Tensor): Query tensor (#batch, time1, size).
60
+ key (torch.Tensor): Key tensor (#batch, time2, size).
61
+ value (torch.Tensor): Value tensor (#batch, time2, size).
62
+
63
+ Returns:
64
+ torch.Tensor: Transformed query tensor, size
65
+ (#batch, n_head, time1, d_k).
66
+ torch.Tensor: Transformed key tensor, size
67
+ (#batch, n_head, time2, d_k).
68
+ torch.Tensor: Transformed value tensor, size
69
+ (#batch, n_head, time2, d_k).
70
+
71
+ """
72
+ n_batch = query.size(0)
73
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
74
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
75
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
76
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
77
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
78
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
79
+
80
+ return q, k, v
81
+
82
+ def forward_attention(
83
+ self,
84
+ value: torch.Tensor,
85
+ scores: torch.Tensor,
86
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
87
+ ) -> torch.Tensor:
88
+ """Compute attention context vector.
89
+
90
+ Args:
91
+ value (torch.Tensor): Transformed value, size
92
+ (#batch, n_head, time2, d_k).
93
+ scores (torch.Tensor): Attention score, size
94
+ (#batch, n_head, time1, time2).
95
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
96
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
97
+
98
+ Returns:
99
+ torch.Tensor: Transformed value (#batch, time1, d_model)
100
+ weighted by the attention score (#batch, time1, time2).
101
+
102
+ """
103
+ n_batch = value.size(0)
104
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
105
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
106
+ # 1st chunk to ease the onnx export.]
107
+ # 2. pytorch training
108
+ if mask.size(2) > 0: # time2 > 0
109
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
110
+ # For last chunk, time2 might be larger than scores.size(-1)
111
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
112
+ scores = scores.masked_fill(mask, -float('inf'))
113
+ attn = torch.softmax(scores, dim=-1).masked_fill(
114
+ mask, 0.0) # (batch, head, time1, time2)
115
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
116
+ # 1. onnx(16/-1, -1/-1, 16/0)
117
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
118
+ else:
119
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
120
+
121
+ p_attn = self.dropout(attn)
122
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
123
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
124
+ self.h * self.d_k)
125
+ ) # (batch, time1, d_model)
126
+
127
+ return self.linear_out(x) # (batch, time1, d_model)
128
+
129
+ def forward(
130
+ self,
131
+ query: torch.Tensor,
132
+ key: torch.Tensor,
133
+ value: torch.Tensor,
134
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
135
+ pos_emb: torch.Tensor = torch.empty(0),
136
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
137
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
138
+ """Compute scaled dot product attention.
139
+
140
+ Args:
141
+ query (torch.Tensor): Query tensor (#batch, time1, size).
142
+ key (torch.Tensor): Key tensor (#batch, time2, size).
143
+ value (torch.Tensor): Value tensor (#batch, time2, size).
144
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
145
+ (#batch, time1, time2).
146
+ 1.When applying cross attention between decoder and encoder,
147
+ the batch padding mask for input is in (#batch, 1, T) shape.
148
+ 2.When applying self attention of encoder,
149
+ the mask is in (#batch, T, T) shape.
150
+ 3.When applying self attention of decoder,
151
+ the mask is in (#batch, L, L) shape.
152
+ 4.If the different position in decoder see different block
153
+ of the encoder, such as Mocha, the passed in mask could be
154
+ in (#batch, L, T) shape. But there is no such case in current
155
+ CosyVoice.
156
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
157
+ where `cache_t == chunk_size * num_decoding_left_chunks`
158
+ and `head * d_k == size`
159
+
160
+
161
+ Returns:
162
+ torch.Tensor: Output tensor (#batch, time1, d_model).
163
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
164
+ where `cache_t == chunk_size * num_decoding_left_chunks`
165
+ and `head * d_k == size`
166
+
167
+ """
168
+ q, k, v = self.forward_qkv(query, key, value)
169
+
170
+ # NOTE(xcsong):
171
+ # when export onnx model, for 1st chunk, we feed
172
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
173
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
174
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
175
+ # and we will always do splitting and
176
+ # concatnation(this will simplify onnx export). Note that
177
+ # it's OK to concat & split zero-shaped tensors(see code below).
178
+ # when export jit model, for 1st chunk, we always feed
179
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
180
+ # >>> a = torch.ones((1, 2, 0, 4))
181
+ # >>> b = torch.ones((1, 2, 3, 4))
182
+ # >>> c = torch.cat((a, b), dim=2)
183
+ # >>> torch.equal(b, c) # True
184
+ # >>> d = torch.split(a, 2, dim=-1)
185
+ # >>> torch.equal(d[0], d[1]) # True
186
+ if cache.size(0) > 0:
187
+ key_cache, value_cache = torch.split(cache,
188
+ cache.size(-1) // 2,
189
+ dim=-1)
190
+ k = torch.cat([key_cache, k], dim=2)
191
+ v = torch.cat([value_cache, v], dim=2)
192
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
193
+ # non-trivial to calculate `next_cache_start` here.
194
+ new_cache = torch.cat((k, v), dim=-1)
195
+
196
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
197
+ return self.forward_attention(v, scores, mask), new_cache
198
+
199
+
200
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
201
+ """Multi-Head Attention layer with relative position encoding.
202
+ Paper: https://arxiv.org/abs/1901.02860
203
+ Args:
204
+ n_head (int): The number of heads.
205
+ n_feat (int): The number of features.
206
+ dropout_rate (float): Dropout rate.
207
+ """
208
+
209
+ def __init__(self,
210
+ n_head: int,
211
+ n_feat: int,
212
+ dropout_rate: float,
213
+ key_bias: bool = True):
214
+ """Construct an RelPositionMultiHeadedAttention object."""
215
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
216
+ # linear transformation for positional encoding
217
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
218
+ # these two learnable bias are used in matrix c and matrix d
219
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
220
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
221
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
222
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
223
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
224
+
225
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
226
+ """Compute relative positional encoding.
227
+
228
+ Args:
229
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
230
+ time1 means the length of query vector.
231
+
232
+ Returns:
233
+ torch.Tensor: Output tensor.
234
+
235
+ """
236
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
237
+ device=x.device,
238
+ dtype=x.dtype)
239
+ x_padded = torch.cat([zero_pad, x], dim=-1)
240
+
241
+ x_padded = x_padded.view(x.size()[0],
242
+ x.size()[1],
243
+ x.size(3) + 1, x.size(2))
244
+ x = x_padded[:, :, 1:].view_as(x)[
245
+ :, :, :, : x.size(-1) // 2 + 1
246
+ ] # only keep the positions from 0 to time2
247
+ return x
248
+
249
+ def forward(
250
+ self,
251
+ query: torch.Tensor,
252
+ key: torch.Tensor,
253
+ value: torch.Tensor,
254
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
255
+ pos_emb: torch.Tensor = torch.empty(0),
256
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
257
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
258
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
259
+ Args:
260
+ query (torch.Tensor): Query tensor (#batch, time1, size).
261
+ key (torch.Tensor): Key tensor (#batch, time2, size).
262
+ value (torch.Tensor): Value tensor (#batch, time2, size).
263
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
264
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
265
+ pos_emb (torch.Tensor): Positional embedding tensor
266
+ (#batch, time2, size).
267
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
268
+ where `cache_t == chunk_size * num_decoding_left_chunks`
269
+ and `head * d_k == size`
270
+ Returns:
271
+ torch.Tensor: Output tensor (#batch, time1, d_model).
272
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
273
+ where `cache_t == chunk_size * num_decoding_left_chunks`
274
+ and `head * d_k == size`
275
+ """
276
+ q, k, v = self.forward_qkv(query, key, value)
277
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
278
+
279
+ # NOTE(xcsong):
280
+ # when export onnx model, for 1st chunk, we feed
281
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
282
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
283
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
284
+ # and we will always do splitting and
285
+ # concatnation(this will simplify onnx export). Note that
286
+ # it's OK to concat & split zero-shaped tensors(see code below).
287
+ # when export jit model, for 1st chunk, we always feed
288
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
289
+ # >>> a = torch.ones((1, 2, 0, 4))
290
+ # >>> b = torch.ones((1, 2, 3, 4))
291
+ # >>> c = torch.cat((a, b), dim=2)
292
+ # >>> torch.equal(b, c) # True
293
+ # >>> d = torch.split(a, 2, dim=-1)
294
+ # >>> torch.equal(d[0], d[1]) # True
295
+ if cache is not None and cache.size(0) > 0:
296
+ key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
297
+ k = torch.cat([key_cache, k], dim=2)
298
+ v = torch.cat([value_cache, v], dim=2)
299
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
300
+ # non-trivial to calculate `next_cache_start` here.
301
+ new_cache = torch.cat((k, v), dim=-1)
302
+
303
+ n_batch_pos = pos_emb.size(0)
304
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
305
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
306
+
307
+ # (batch, head, time1, d_k)
308
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
309
+ # (batch, head, time1, d_k)
310
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
311
+
312
+ # compute attention score
313
+ # first compute matrix a and matrix c
314
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
315
+ # (batch, head, time1, time2)
316
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
317
+
318
+ # compute matrix b and matrix d
319
+ # (batch, head, time1, time2)
320
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
321
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
322
+ if matrix_ac.shape != matrix_bd.shape:
323
+ matrix_bd = self.rel_shift(matrix_bd)
324
+
325
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
326
+ self.d_k) # (batch, head, time1, time2)
327
+
328
+ return self.forward_attention(v, scores, mask), new_cache
cosyvoice2/transformer/embedding.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Positonal Encoding Module."""
17
+
18
+ import math
19
+ from typing import Tuple, Union
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ import numpy as np
24
+
25
+
26
+ class EspnetRelPositionalEncoding(torch.nn.Module):
27
+ """Relative positional encoding module (new implementation).
28
+
29
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
30
+
31
+ See : Appendix B in https://arxiv.org/abs/1901.02860
32
+
33
+ Args:
34
+ d_model (int): Embedding dimension.
35
+ dropout_rate (float): Dropout rate.
36
+ max_len (int): Maximum input length.
37
+
38
+ """
39
+
40
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
41
+ """Construct an PositionalEncoding object."""
42
+ super(EspnetRelPositionalEncoding, self).__init__()
43
+ self.d_model = d_model
44
+ self.xscale = math.sqrt(self.d_model)
45
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
46
+ self.pe = None
47
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
48
+
49
+ def extend_pe(self, x: torch.Tensor):
50
+ """Reset the positional encodings."""
51
+ if self.pe is not None:
52
+ # self.pe contains both positive and negative parts
53
+ # the length of self.pe is 2 * input_len - 1
54
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
55
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
56
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
57
+ return
58
+ # Suppose `i` means to the position of query vecotr and `j` means the
59
+ # position of key vector. We use position relative positions when keys
60
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
61
+ pe_positive = torch.zeros(x.size(1), self.d_model)
62
+ pe_negative = torch.zeros(x.size(1), self.d_model)
63
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
64
+ div_term = torch.exp(
65
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
66
+ * -(math.log(10000.0) / self.d_model)
67
+ )
68
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
69
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
70
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
71
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
72
+
73
+ # Reserve the order of positive indices and concat both positive and
74
+ # negative indices. This is used to support the shifting trick
75
+ # as in https://arxiv.org/abs/1901.02860
76
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
77
+ pe_negative = pe_negative[1:].unsqueeze(0)
78
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
79
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
80
+
81
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
82
+ -> Tuple[torch.Tensor, torch.Tensor]:
83
+ """Add positional encoding.
84
+
85
+ Args:
86
+ x (torch.Tensor): Input tensor (batch, time, `*`).
87
+
88
+ Returns:
89
+ torch.Tensor: Encoded tensor (batch, time, `*`).
90
+
91
+ """
92
+ self.extend_pe(x)
93
+ x = x * self.xscale
94
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
95
+ return self.dropout(x), self.dropout(pos_emb)
96
+
97
+ def position_encoding(self,
98
+ offset: Union[int, torch.Tensor],
99
+ size: int) -> torch.Tensor:
100
+ """ For getting encoding in a streaming fashion
101
+
102
+ Attention!!!!!
103
+ we apply dropout only once at the whole utterance level in a none
104
+ streaming way, but will call this function several times with
105
+ increasing input size in a streaming scenario, so the dropout will
106
+ be applied several times.
107
+
108
+ Args:
109
+ offset (int or torch.tensor): start offset
110
+ size (int): required size of position encoding
111
+
112
+ Returns:
113
+ torch.Tensor: Corresponding encoding
114
+ """
115
+ pos_emb = self.pe[
116
+ :,
117
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
118
+ ]
119
+ return pos_emb
cosyvoice2/transformer/encoder_layer.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Encoder self-attention layer definition."""
17
+
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ from torch import nn
22
+
23
+
24
+ class ConformerEncoderLayer(nn.Module):
25
+ """Encoder layer module.
26
+ Args:
27
+ size (int): Input dimension.
28
+ self_attn (torch.nn.Module): Self-attention module instance.
29
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
30
+ instance can be used as the argument.
31
+ feed_forward (torch.nn.Module): Feed-forward module instance.
32
+ `PositionwiseFeedForward` instance can be used as the argument.
33
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
34
+ instance.
35
+ `PositionwiseFeedForward` instance can be used as the argument.
36
+ conv_module (torch.nn.Module): Convolution module instance.
37
+ `ConvlutionModule` instance can be used as the argument.
38
+ dropout_rate (float): Dropout rate.
39
+ normalize_before (bool):
40
+ True: use layer_norm before each sub-block.
41
+ False: use layer_norm after each sub-block.
42
+ enable_cuda_graph (bool): Control whether to enable CUDA Graph.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ size: int,
48
+ self_attn: torch.nn.Module,
49
+ feed_forward: Optional[nn.Module] = None,
50
+ feed_forward_macaron: Optional[nn.Module] = None,
51
+ conv_module: Optional[nn.Module] = None,
52
+ dropout_rate: float = 0.1,
53
+ normalize_before: bool = True,
54
+ ):
55
+ """Construct an EncoderLayer object."""
56
+ super().__init__()
57
+ self.self_attn = self_attn
58
+ self.feed_forward = feed_forward
59
+ self.feed_forward_macaron = feed_forward_macaron
60
+ self.conv_module = conv_module
61
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
62
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
63
+ if feed_forward_macaron is not None:
64
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
65
+ self.ff_scale = 0.5
66
+ else:
67
+ self.ff_scale = 1.0
68
+ if self.conv_module is not None:
69
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
70
+ self.norm_final = nn.LayerNorm(
71
+ size, eps=1e-12) # for the final output of the block
72
+ self.dropout = nn.Dropout(dropout_rate)
73
+ self.size = size
74
+ self.normalize_before = normalize_before
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ mask: torch.Tensor,
80
+ pos_emb: torch.Tensor,
81
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
82
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
83
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
84
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ """Compute encoded features.
86
+
87
+ Args:
88
+ x (torch.Tensor): (#batch, time, size)
89
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
90
+ (0, 0, 0) means fake mask.
91
+ pos_emb (torch.Tensor): positional encoding, must not be None
92
+ for ConformerEncoderLayer.
93
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
94
+ (#batch, 1,time), (0, 0, 0) means fake mask.
95
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
96
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
97
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
98
+ (#batch=1, size, cache_t2)
99
+ Returns:
100
+ torch.Tensor: Output tensor (#batch, time, size).
101
+ torch.Tensor: Mask tensor (#batch, time, time).
102
+ torch.Tensor: att_cache tensor,
103
+ (#batch=1, head, cache_t1 + time, d_k * 2).
104
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
105
+ """
106
+ return self._forward_impl(x, mask, pos_emb, mask_pad, att_cache, cnn_cache)
107
+
108
+ def _forward_impl(
109
+ self,
110
+ x: torch.Tensor,
111
+ mask: torch.Tensor,
112
+ pos_emb: torch.Tensor,
113
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
114
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
115
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
116
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
117
+ """原始的前向传播实现"""
118
+ # whether to use macaron style
119
+ if self.feed_forward_macaron is not None:
120
+ residual = x
121
+ if self.normalize_before:
122
+ x = self.norm_ff_macaron(x)
123
+ x = residual + self.ff_scale * self.dropout(
124
+ self.feed_forward_macaron(x))
125
+ if not self.normalize_before:
126
+ x = self.norm_ff_macaron(x)
127
+
128
+ # multi-headed self-attention module
129
+ residual = x
130
+ if self.normalize_before:
131
+ x = self.norm_mha(x)
132
+ # att_cache: (b, head, cache_t, d_k*2)
133
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
134
+ att_cache)
135
+ x = residual + self.dropout(x_att)
136
+ if not self.normalize_before:
137
+ x = self.norm_mha(x)
138
+
139
+ # convolution module
140
+ # Fake new cnn cache here, and then change it in conv_module
141
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
142
+ if self.conv_module is not None:
143
+ residual = x
144
+ if self.normalize_before:
145
+ x = self.norm_conv(x)
146
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
147
+ x = residual + self.dropout(x)
148
+
149
+ if not self.normalize_before:
150
+ x = self.norm_conv(x)
151
+
152
+ # feed forward module
153
+ residual = x
154
+ if self.normalize_before:
155
+ x = self.norm_ff(x)
156
+
157
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
158
+ if not self.normalize_before:
159
+ x = self.norm_ff(x)
160
+
161
+ if self.conv_module is not None:
162
+ x = self.norm_final(x)
163
+ return x, mask, new_att_cache, new_cnn_cache
cosyvoice2/transformer/positionwise_feed_forward.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Positionwise feed forward layer definition."""
16
+
17
+ import torch
18
+
19
+
20
+ class PositionwiseFeedForward(torch.nn.Module):
21
+ """Positionwise feed forward layer.
22
+
23
+ FeedForward are appied on each position of the sequence.
24
+ The output dim is same with the input dim.
25
+
26
+ Args:
27
+ idim (int): Input dimenstion.
28
+ hidden_units (int): The number of hidden units.
29
+ dropout_rate (float): Dropout rate.
30
+ activation (torch.nn.Module): Activation function
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ idim: int,
36
+ hidden_units: int,
37
+ dropout_rate: float,
38
+ activation: torch.nn.Module = torch.nn.ReLU(),
39
+ ):
40
+ """Construct a PositionwiseFeedForward object."""
41
+ super(PositionwiseFeedForward, self).__init__()
42
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
43
+ self.activation = activation
44
+ self.dropout = torch.nn.Dropout(dropout_rate)
45
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
46
+
47
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
48
+ """Forward function.
49
+
50
+ Args:
51
+ xs: input tensor (B, L, D)
52
+ Returns:
53
+ output tensor, (B, L, D)
54
+ """
55
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56
+
cosyvoice2/transformer/subsampling.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2024 Alibaba Inc (Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Subsampling layer definition."""
17
+
18
+ from typing import Tuple, Union
19
+
20
+ import torch
21
+
22
+
23
+ class BaseSubsampling(torch.nn.Module):
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self.right_context = 0
28
+ self.subsampling_rate = 1
29
+
30
+ def position_encoding(self, offset: Union[int, torch.Tensor],
31
+ size: int) -> torch.Tensor:
32
+ return self.pos_enc.position_encoding(offset, size)
33
+
34
+
35
+ class LinearNoSubsampling(BaseSubsampling):
36
+ """Linear transform the input without subsampling
37
+
38
+ Args:
39
+ idim (int): Input dimension.
40
+ odim (int): Output dimension.
41
+ dropout_rate (float): Dropout rate.
42
+
43
+ """
44
+
45
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
46
+ pos_enc_class: torch.nn.Module):
47
+ """Construct an linear object."""
48
+ super().__init__()
49
+ self.out = torch.nn.Sequential(
50
+ torch.nn.Linear(idim, odim),
51
+ torch.nn.LayerNorm(odim, eps=1e-5),
52
+ torch.nn.Dropout(dropout_rate),
53
+ )
54
+ self.pos_enc = pos_enc_class
55
+ self.right_context = 0
56
+ self.subsampling_rate = 1
57
+
58
+ def forward(
59
+ self,
60
+ x: torch.Tensor,
61
+ x_mask: torch.Tensor,
62
+ offset: Union[int, torch.Tensor] = 0
63
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
64
+ """Input x.
65
+
66
+ Args:
67
+ x (torch.Tensor): Input tensor (#batch, time, idim).
68
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
69
+
70
+ Returns:
71
+ torch.Tensor: linear input tensor (#batch, time', odim),
72
+ where time' = time .
73
+ torch.Tensor: linear input mask (#batch, 1, time'),
74
+ where time' = time .
75
+
76
+ """
77
+ x = self.out(x)
78
+ x, pos_emb = self.pos_enc(x, offset)
79
+ return x, pos_emb, x_mask
cosyvoice2/transformer/upsample_encoder_v2.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
2
+ # 2022 Xingchen Song ([email protected])
3
+ # 2024 Alibaba Inc (Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # Modified from ESPnet(https://github.com/espnet/espnet)
17
+ """Encoder definition."""
18
+ from typing import Tuple, List
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+
24
+ from cosyvoice2.transformer.encoder_layer import ConformerEncoderLayer
25
+ from cosyvoice2.transformer.positionwise_feed_forward import PositionwiseFeedForward
26
+ from cosyvoice2.utils.class_utils import (
27
+ COSYVOICE_EMB_CLASSES,
28
+ COSYVOICE_SUBSAMPLE_CLASSES,
29
+ COSYVOICE_ATTENTION_CLASSES,
30
+ COSYVOICE_ACTIVATION_CLASSES,
31
+ )
32
+ from cosyvoice2.utils.mask import (
33
+ make_pad_mask,
34
+ )
35
+
36
+ import torch._dynamo
37
+ torch._dynamo.config.suppress_errors = True
38
+ torch._dynamo.config.cache_size_limit = 128
39
+
40
+ class Upsample1D(nn.Module):
41
+ """A 1D upsampling layer with an optional convolution.
42
+
43
+ Parameters:
44
+ channels (`int`):
45
+ number of channels in the inputs and outputs.
46
+ use_conv (`bool`, default `False`):
47
+ option to use a convolution.
48
+ use_conv_transpose (`bool`, default `False`):
49
+ option to use a convolution transpose.
50
+ out_channels (`int`, optional):
51
+ number of output channels. Defaults to `channels`.
52
+ """
53
+
54
+ def __init__(self, channels: int, out_channels: int, stride: int = 2, scale_factor: float = None):
55
+ super().__init__()
56
+ self.channels = channels
57
+ self.out_channels = out_channels
58
+ self.stride = stride
59
+ # In this mode, first repeat interpolate, than conv with stride=1
60
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
61
+ self.scale_factor = float(self.stride) if scale_factor is None else float(scale_factor)
62
+
63
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
64
+ outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest")
65
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
66
+ outputs = self.conv(outputs)
67
+ return outputs, input_lengths * self.stride
68
+
69
+ def forward_chunk(self, inputs: torch.Tensor, input_lengths: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0))):
70
+ """
71
+ Args:
72
+ inputs(torch.Tensor): shape (b, c, t)
73
+ input_length(torch.Tensor): shape (b), can be None
74
+ cache(torch.Tensor): shape (b, c, cache_t), where cache_t = stride * 2
75
+ """
76
+ outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest")
77
+
78
+ if cache is None:
79
+ cache = inputs.new_zeros(inputs.shape[0], inputs.shape[1], self.stride * 2)
80
+ outputs = torch.cat([cache, outputs], dim=2)
81
+ new_cache = outputs[..., -self.stride*2:]
82
+ outputs = self.conv(outputs)
83
+
84
+ if input_lengths is not None:
85
+ input_lengths = input_lengths * self.stride
86
+ return outputs, input_lengths, new_cache
87
+
88
+
89
+ class PreLookaheadLayer(nn.Module):
90
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
91
+ super().__init__()
92
+ self.channels = channels
93
+ self.pre_lookahead_len = pre_lookahead_len
94
+ self.conv1 = nn.Conv1d(
95
+ channels, channels,
96
+ kernel_size=pre_lookahead_len + 1,
97
+ stride=1, padding=0,
98
+ )
99
+ self.conv2 = nn.Conv1d(
100
+ channels, channels,
101
+ kernel_size=3, stride=1, padding=0,
102
+ )
103
+
104
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
105
+ """
106
+ inputs: (batch_size, seq_len, channels)
107
+ """
108
+ outputs = inputs.transpose(1, 2).contiguous()
109
+ # look ahead
110
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
111
+ outputs = F.leaky_relu(self.conv1(outputs))
112
+ # outputs
113
+ outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
114
+ outputs = self.conv2(outputs)
115
+ outputs = outputs.transpose(1, 2).contiguous()
116
+
117
+ # residual connection
118
+ outputs = outputs + inputs
119
+ return outputs
120
+
121
+ def forward_chunk(self, inputs: torch.Tensor, cache: torch.Tensor = None):
122
+ """
123
+ Args:
124
+ inputs(torch.Tensor): shape (b, t, c)
125
+ cache(torch.Tensor): shape (b, c, cache_t=2), c = channels
126
+ """
127
+ outputs = inputs.transpose(1, 2).contiguous()
128
+ outputs = F.leaky_relu(self.conv1(outputs))
129
+ # the length of outputs is input length - pre_lookahead_len
130
+ if cache is None:
131
+ cache = outputs.new_zeros(outputs.shape[0], outputs.shape[1], 2)
132
+ # NOTE
133
+ new_cache = outputs[..., -2:]
134
+ outputs = torch.cat([cache, outputs], dim=2)
135
+ outputs = self.conv2(outputs)
136
+ outputs = outputs.transpose(1, 2).contiguous()
137
+ # residual connection
138
+ outputs = outputs + inputs[:, :-self.pre_lookahead_len]
139
+ return outputs, new_cache
140
+
141
+
142
+ """Customize each sample's chunk attention mask
143
+ """
144
+ class UpsampleConformerEncoderV2(torch.nn.Module):
145
+
146
+ def __init__(
147
+ self,
148
+ # input & output
149
+ input_size: int,
150
+ output_size: int = 256,
151
+ input_layer: str = "linear",
152
+ pre_lookahead_len: int = 3,
153
+ # size
154
+ num_blocks: int = 6,
155
+ num_up_blocks: int = 4,
156
+ # upsampling
157
+ up_stride: int = 2,
158
+ up_scale_factor: float = 2,
159
+ # attention
160
+ attention_heads: int = 4,
161
+ pos_enc_layer_type: str = "rel_pos_espnet",
162
+ selfattention_layer_type: str = "rel_selfattn",
163
+ key_bias: bool = True,
164
+ # mlp
165
+ linear_units: int = 2048,
166
+ # dropouts
167
+ dropout_rate: float = 0.1,
168
+ positional_dropout_rate: float = 0.1,
169
+ attention_dropout_rate: float = 0.0,
170
+ # other
171
+ normalize_before: bool = True,
172
+ activation_type: str = "swish",
173
+ **kwargs,
174
+ ):
175
+ super().__init__()
176
+ self._output_size = output_size
177
+ self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
178
+ input_size,
179
+ output_size,
180
+ dropout_rate,
181
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
182
+ output_size,
183
+ positional_dropout_rate
184
+ ),
185
+ )
186
+
187
+ self.normalize_before = normalize_before
188
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
189
+ activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
190
+ # self-attention module definition
191
+ encoder_selfattn_layer_args = (
192
+ attention_heads,
193
+ output_size,
194
+ attention_dropout_rate,
195
+ key_bias,
196
+ )
197
+ # feed-forward module definition
198
+ positionwise_layer_args = (
199
+ output_size,
200
+ linear_units,
201
+ dropout_rate,
202
+ activation,
203
+ )
204
+ self.pre_lookahead_layer = PreLookaheadLayer(
205
+ channels=output_size,
206
+ pre_lookahead_len=pre_lookahead_len
207
+ )
208
+ self.encoders = torch.nn.ModuleList([
209
+ ConformerEncoderLayer(
210
+ output_size,
211
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
212
+ *encoder_selfattn_layer_args
213
+ ),
214
+ PositionwiseFeedForward(*positionwise_layer_args),
215
+ None,
216
+ None,
217
+ dropout_rate,
218
+ normalize_before,
219
+ ) for _ in range(num_blocks)
220
+ ])
221
+ self.up_layer = Upsample1D(
222
+ channels=output_size,
223
+ out_channels=output_size,
224
+ stride=up_stride,
225
+ scale_factor=up_scale_factor
226
+ )
227
+ self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
228
+ input_size,
229
+ output_size,
230
+ dropout_rate,
231
+ COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
232
+ output_size,
233
+ positional_dropout_rate
234
+ ),
235
+ )
236
+ self.up_encoders = torch.nn.ModuleList([
237
+ ConformerEncoderLayer(
238
+ output_size,
239
+ COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
240
+ *encoder_selfattn_layer_args
241
+ ),
242
+ PositionwiseFeedForward(*positionwise_layer_args),
243
+ None,
244
+ None,
245
+ dropout_rate,
246
+ normalize_before,
247
+ ) for _ in range(num_up_blocks)
248
+ ])
249
+
250
+ self.enable_cuda_graph = False
251
+ self.use_cuda_graph = False
252
+ self.graph_encoder = {}
253
+ self.graph_up_encoder = {}
254
+ self.inference_buffers_encoder = {}
255
+ self.inference_buffers_up_encoder = {}
256
+ self.max_static_time = 1500
257
+
258
+ # FIXME(sfy) revert hard-coded bfloat16
259
+ # this method is skipped in CausalMaskedDiffWithXvec.scatter_cuda_graph
260
+ def scatter_cuda_graph(self, enable_cuda_graph: bool):
261
+ self.enable_cuda_graph = enable_cuda_graph
262
+ if self.enable_cuda_graph:
263
+ self._init_cuda_graph()
264
+
265
+ def _init_cuda_graph(self):
266
+ """初始化 CUDA Graph"""
267
+
268
+ for l in range(100, 1500, 10):
269
+ static_x = torch.zeros((1, l, 512),
270
+ dtype=torch.float32, device=torch.device('cuda'))
271
+ static_mask = torch.ones((1, 1, l),
272
+ dtype=torch.bool, device=torch.device('cuda'))
273
+ static_pos_emb = torch.zeros((1, 2*l-1, 512),
274
+ dtype=torch.float32, device=torch.device('cuda'))
275
+
276
+ static_inputs = [
277
+ static_x,
278
+ static_mask,
279
+ static_pos_emb,
280
+ ]
281
+
282
+ self._forward_impl_encoder(
283
+ static_inputs[0],
284
+ static_inputs[1],
285
+ static_inputs[2],
286
+ )
287
+ graph = torch.cuda.CUDAGraph()
288
+ with torch.no_grad():
289
+ with torch.cuda.graph(graph):
290
+ static_out_x = self._forward_impl_encoder(
291
+ static_inputs[0],
292
+ static_inputs[1],
293
+ static_inputs[2]
294
+ )
295
+ self.graph_encoder[l] = graph
296
+ static_outputs = [
297
+ static_out_x,
298
+ ]
299
+ self.inference_buffers_encoder[l] = {
300
+ 'static_inputs': static_inputs,
301
+ 'static_outputs': static_outputs
302
+ }
303
+
304
+ for l in range(100, 1500, 10):
305
+ static_x = torch.zeros((1, l, 512),
306
+ dtype=torch.float32, device=torch.device('cuda'))
307
+ static_mask = torch.ones((1, 1, l),
308
+ dtype=torch.bool, device=torch.device('cuda'))
309
+ static_pos_emb = torch.zeros((1, 2*l-1, 512),
310
+ dtype=torch.float32, device=torch.device('cuda'))
311
+
312
+ static_inputs = [
313
+ static_x,
314
+ static_mask,
315
+ static_pos_emb,
316
+ ]
317
+
318
+ self._forward_impl_up_encoder(
319
+ static_inputs[0],
320
+ static_inputs[1],
321
+ static_inputs[2],
322
+ )
323
+ graph = torch.cuda.CUDAGraph()
324
+ with torch.no_grad():
325
+ with torch.cuda.graph(graph):
326
+ static_out_x = self._forward_impl_up_encoder(
327
+ static_inputs[0],
328
+ static_inputs[1],
329
+ static_inputs[2]
330
+ )
331
+ self.graph_up_encoder[l] = graph
332
+ static_outputs = [
333
+ static_out_x,
334
+ ]
335
+ self.inference_buffers_up_encoder[l] = {
336
+ 'static_inputs': static_inputs,
337
+ 'static_outputs': static_outputs
338
+ }
339
+
340
+ self.use_cuda_graph = True
341
+ print("CUDA Graph initialized successfully for encoder and up_encoder")
342
+
343
+ # @torch.compile(dynamic=True,backend="eager")
344
+ def _forward_impl_encoder(self,
345
+ x: torch.Tensor,
346
+ mask: torch.Tensor,
347
+ pos_emb: torch.Tensor):
348
+ for layer in self.encoders:
349
+ x, _, _, _ = layer(x, mask, pos_emb)
350
+ return x
351
+
352
+ # @torch.compile(dynamic=True,backend="eager")
353
+ def _forward_impl_up_encoder(self,
354
+ x: torch.Tensor,
355
+ mask: torch.Tensor,
356
+ pos_emb: torch.Tensor):
357
+ for layer in self.up_encoders:
358
+ x, _, _, _ = layer(x, mask, pos_emb)
359
+ return x
360
+
361
+ def output_size(self) -> int:
362
+ return self._output_size
363
+
364
+ # @torch.compile(dynamic=True,backend="eager")
365
+ def forward(
366
+ self,
367
+ xs: torch.Tensor,
368
+ xs_lens: torch.Tensor,
369
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
370
+ # (sfy) chunk training strategy should not be open-sourced
371
+ T = xs.size(1)
372
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
373
+ xs, pos_emb, masks = self.embed(xs, masks)
374
+
375
+ # lookahead
376
+ xs = self.pre_lookahead_layer(xs)
377
+ # conformer block
378
+ if self.enable_cuda_graph and xs.shape[1] in self.graph_encoder:
379
+ self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs)
380
+ self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks)
381
+ self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb)
382
+ self.graph_encoder[xs.shape[1]].replay()
383
+ xs = self.inference_buffers_encoder[xs.shape[1]]['static_outputs'][0]
384
+ else:
385
+ xs = self._forward_impl_encoder(xs, masks, pos_emb)
386
+ # upsample
387
+ xs = xs.transpose(1, 2).contiguous()
388
+ xs, xs_lens = self.up_layer(xs, xs_lens)
389
+ xs = xs.transpose(1, 2).contiguous()
390
+
391
+ # 2nd conformer block
392
+ T = xs.size(1)
393
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
394
+ xs, pos_emb, masks = self.up_embed(xs, masks)
395
+ if self.enable_cuda_graph and xs.shape[1] in self.graph_up_encoder:
396
+ self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs)
397
+ self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks)
398
+ self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb)
399
+ self.graph_up_encoder[xs.shape[1]].replay()
400
+ xs = self.inference_buffers_up_encoder[xs.shape[1]]['static_outputs'][0]
401
+ else:
402
+ xs = self._forward_impl_up_encoder(xs, masks, pos_emb)
403
+ # post norm
404
+ if self.normalize_before:
405
+ xs = self.after_norm(xs)
406
+ return xs, masks
407
+
408
+ @torch.compile(dynamic=True,backend="eager")
409
+ def forward_chunk(self,
410
+ xs: torch.Tensor,
411
+ last_chunk: bool = False,
412
+ cnn_cache: torch.Tensor = None,
413
+ att_cache: torch.Tensor = None,
414
+ ):
415
+ """
416
+ Args:
417
+ xs: shape (b, dt, c)
418
+ last_chunk: bool. If last chunk, will pad input with lookaheads
419
+ att_cache: shape (depth1+depth2, b, nh, 2*t1, c).
420
+ cnn_cache: shape (b, c, t1+t2). Where t1=2 (pre_lookahead_layer), t2=4 (up_layer)
421
+ """
422
+ if att_cache is not None:
423
+ assert att_cache.shape[3] % 2 == 0, att_cache.shape
424
+ if cnn_cache is not None:
425
+ assert cnn_cache.shape[2] == 2+self.up_layer.stride*2, cnn_cache.shape
426
+
427
+ # unpack caches
428
+ offset1 = att_cache.shape[3] // 2 if att_cache is not None else 0
429
+ att_cache1 = att_cache[:len(self.encoders), :, :, :offset1] if att_cache is not None else [None] * len(self.encoders)
430
+ att_cache2 = att_cache[len(self.encoders):] if att_cache is not None else [None] * len(self.encoders)
431
+ cnn_cache1 = cnn_cache[:, :, :2] if cnn_cache is not None else None
432
+ cnn_cache2 = cnn_cache[:, :, 2:] if cnn_cache is not None else None
433
+ xs, _, _ = self.embed(xs, None)
434
+ if last_chunk:
435
+ xs = F.pad(xs, (0, 0, 0, self.pre_lookahead_layer.pre_lookahead_len))
436
+
437
+ # this_cnn_cache: shape (b=1, c=512, t=2)
438
+ xs, new_cnn_cache1 = self.pre_lookahead_layer.forward_chunk(xs, cache=cnn_cache1)
439
+
440
+ # remake pos_emb, offset param is ignored by position_encoding
441
+ pos_emb = self.embed.position_encoding(offset=None, size=offset1 + xs.shape[1])
442
+
443
+ # first conformer
444
+ chunk_masks = torch.zeros((0, 0, 0))
445
+ new_att_cache1 = []
446
+
447
+ for idx, layer in enumerate(self.encoders):
448
+ # this_att_cache: shape (b, nh, t, c * 2)
449
+ xs, _, this_new_att_cache1, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache1[idx])
450
+ new_att_cache1.append(this_new_att_cache1)
451
+ new_att_cache1 = torch.stack(new_att_cache1, dim=0)
452
+
453
+ # upsample + conformer encoder, xs: (b, t, c) -> (b, c, t)
454
+ xs = xs.transpose(1, 2).contiguous()
455
+ # this_cnn_cache: shape (b=1, c=512, t=2*2)
456
+ xs, _, new_cnn_cache2 = self.up_layer.forward_chunk(xs, None, cache=cnn_cache2)
457
+ xs = xs.transpose(1, 2).contiguous()
458
+
459
+ # at this time, xs are doubled in length
460
+ xs, _, _ = self.up_embed(xs, None)
461
+
462
+ # remake pos_emb
463
+ pos_emb = self.embed.position_encoding(offset=None, size=offset1 * self.up_layer.stride + xs.shape[1])
464
+
465
+ # second conformer
466
+ chunk_masks = torch.zeros((0, 0, 0),dtype=torch.bfloat16)
467
+ new_att_cache2 = []
468
+
469
+ for idx, layer in enumerate(self.up_encoders):
470
+ xs, _, this_new_att_cache2, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache2[idx])
471
+ new_att_cache2.append(this_new_att_cache2)
472
+ new_att_cache2 = torch.stack(new_att_cache2, dim=0)
473
+
474
+ if self.normalize_before:
475
+ xs = self.after_norm(xs)
476
+
477
+ # pack new cache
478
+ new_att_cache = torch.cat([new_att_cache1.repeat(1, 1, 1, 2, 1), new_att_cache2], dim=0)
479
+ new_cnn_cache = torch.cat([new_cnn_cache1, new_cnn_cache2], dim=2)
480
+
481
+ return xs, new_cnn_cache, new_att_cache
482
+
483
+
cosyvoice2/utils/class_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright [2023-11-28] <[email protected], Xingchen Song>
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import torch
16
+
17
+ from cosyvoice2.transformer.subsampling import LinearNoSubsampling
18
+ from cosyvoice2.transformer.attention import RelPositionMultiHeadedAttention
19
+ from cosyvoice2.transformer.embedding import EspnetRelPositionalEncoding
20
+
21
+
22
+ COSYVOICE_ACTIVATION_CLASSES = {
23
+ "hardtanh": torch.nn.Hardtanh,
24
+ "tanh": torch.nn.Tanh,
25
+ "relu": torch.nn.ReLU,
26
+ "selu": torch.nn.SELU,
27
+ "swish": torch.nn.SiLU,
28
+ "gelu": torch.nn.GELU,
29
+ }
30
+
31
+ COSYVOICE_SUBSAMPLE_CLASSES = {
32
+ "linear": LinearNoSubsampling,
33
+ }
34
+
35
+ COSYVOICE_EMB_CLASSES = {
36
+ "rel_pos_espnet": EspnetRelPositionalEncoding,
37
+ }
38
+
39
+ COSYVOICE_ATTENTION_CLASSES = {
40
+ "rel_selfattn": RelPositionMultiHeadedAttention,
41
+ }
cosyvoice2/utils/common.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # Modified from ESPnet(https://github.com/espnet/espnet)
16
+ """Unility functions for Transformer."""
17
+
18
+ import random
19
+ from typing import List
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ IGNORE_ID = -1
25
+
26
+
27
+ def pad_list(xs: List[torch.Tensor], pad_value: int):
28
+ """Perform padding for the list of tensors.
29
+
30
+ Args:
31
+ xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
32
+ pad_value (float): Value for padding.
33
+
34
+ Returns:
35
+ Tensor: Padded tensor (B, Tmax, `*`).
36
+
37
+ Examples:
38
+ >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
39
+ >>> x
40
+ [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
41
+ >>> pad_list(x, 0)
42
+ tensor([[1., 1., 1., 1.],
43
+ [1., 1., 0., 0.],
44
+ [1., 0., 0., 0.]])
45
+
46
+ """
47
+ max_len = max([len(item) for item in xs])
48
+ batchs = len(xs)
49
+ ndim = xs[0].ndim
50
+ if ndim == 1:
51
+ pad_res = torch.zeros(batchs,
52
+ max_len,
53
+ dtype=xs[0].dtype,
54
+ device=xs[0].device)
55
+ elif ndim == 2:
56
+ pad_res = torch.zeros(batchs,
57
+ max_len,
58
+ xs[0].shape[1],
59
+ dtype=xs[0].dtype,
60
+ device=xs[0].device)
61
+ elif ndim == 3:
62
+ pad_res = torch.zeros(batchs,
63
+ max_len,
64
+ xs[0].shape[1],
65
+ xs[0].shape[2],
66
+ dtype=xs[0].dtype,
67
+ device=xs[0].device)
68
+ else:
69
+ raise ValueError(f"Unsupported ndim: {ndim}")
70
+ pad_res.fill_(pad_value)
71
+ for i in range(batchs):
72
+ pad_res[i, :len(xs[i])] = xs[i]
73
+ return pad_res
74
+
75
+
76
+ def get_padding(kernel_size, dilation=1):
77
+ return int((kernel_size * dilation - dilation) / 2)
78
+
79
+
80
+ def init_weights(m, mean=0.0, std=0.01):
81
+ classname = m.__class__.__name__
82
+ if classname.find("Conv") != -1:
83
+ m.weight.data.normal_(mean, std)
84
+
85
+
86
+ def fade_in_out(fade_in_mel, fade_out_mel, window):
87
+ device = fade_in_mel.device
88
+ fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
89
+ mel_overlap_len = int(window.shape[0] / 2)
90
+ if fade_in_mel.device == torch.device('cpu'):
91
+ fade_in_mel = fade_in_mel.clone()
92
+ fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
93
+ fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
94
+ return fade_in_mel.to(device)
95
+
96
+
97
+ def set_all_random_seed(seed):
98
+ random.seed(seed)
99
+ np.random.seed(seed)
100
+ torch.manual_seed(seed)
101
+ torch.cuda.manual_seed_all(seed)
cosyvoice2/utils/mask.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019 Shigeki Karita
2
+ # 2020 Mobvoi Inc (Binbin Zhang)
3
+ # 2024 Alibaba Inc (authors: Xiang Lyu)
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import math
18
+ import torch
19
+ from typing import List
20
+
21
+
22
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
23
+ """Make mask tensor containing indices of padded part.
24
+
25
+ See description of make_non_pad_mask.
26
+
27
+ Args:
28
+ lengths (torch.Tensor): Batch of lengths (B,).
29
+ Returns:
30
+ torch.Tensor: Mask tensor containing indices of padded part.
31
+
32
+ Examples:
33
+ >>> lengths = [5, 3, 2]
34
+ >>> make_pad_mask(lengths)
35
+ masks = [[0, 0, 0, 0 ,0],
36
+ [0, 0, 0, 1, 1],
37
+ [0, 0, 1, 1, 1]]
38
+ """
39
+ batch_size = lengths.size(0)
40
+ max_len = max_len if max_len > 0 else lengths.max().item()
41
+ seq_range = torch.arange(0,
42
+ max_len,
43
+ dtype=torch.int64,
44
+ device=lengths.device)
45
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
46
+ seq_length_expand = lengths.unsqueeze(-1)
47
+ mask = seq_range_expand >= seq_length_expand
48
+ return mask
49
+
flashcosyvoice/__init__.py ADDED
File without changes
flashcosyvoice/cli.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """ Example Usage: see README.md
15
+ """
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ import random
21
+ import sys
22
+ import time
23
+ from concurrent.futures import ThreadPoolExecutor
24
+ from datetime import datetime
25
+
26
+ import numpy as np
27
+ import onnxruntime
28
+ import s3tokenizer
29
+ import torch
30
+ import torch.distributed as dist
31
+ import torchaudio
32
+ import torchaudio.compliance.kaldi as kaldi
33
+ from torch.utils.data import DataLoader, Dataset, DistributedSampler
34
+ from tqdm import tqdm
35
+
36
+ from flashcosyvoice.config import Config, CosyVoice2LLMConfig, SamplingParams
37
+ from flashcosyvoice.cosyvoice2 import CosyVoice2
38
+ from flashcosyvoice.utils.audio import mel_spectrogram
39
+
40
+
41
+ def set_all_random_seed(seed):
42
+ random.seed(seed)
43
+ np.random.seed(seed)
44
+ torch.manual_seed(seed)
45
+ torch.cuda.manual_seed_all(seed)
46
+
47
+
48
+ def save_file_async(
49
+ wav, prompt_speech_tokens, generated_speech_tokens,
50
+ info, timing_stats
51
+ ):
52
+ """Save audio asynchronously."""
53
+ try:
54
+ os.makedirs(os.path.dirname(info['wav']), exist_ok=True)
55
+ if wav is not None:
56
+ wav = wav.cpu()
57
+ torchaudio.save(info['wav'], wav, 24000)
58
+ duration = wav.shape[-1] / 24000.0
59
+ rtf = ((timing_stats['dataloader_time'] + timing_stats['model_inference_time']) / timing_stats['batch_size']) / duration
60
+ timing_stats['rtf'] = rtf
61
+ else:
62
+ duration = 0.0
63
+ info['timing_stats'] = timing_stats
64
+ info['prompt_speech_tokens'] = prompt_speech_tokens
65
+ info['generated_speech_tokens'] = generated_speech_tokens
66
+ with open(f"{info['wav'].replace('.wav', '.json')}", "w") as f:
67
+ json.dump(info, f, ensure_ascii=False, indent=4)
68
+ return duration
69
+ except Exception as e:
70
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
71
+ tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}")
72
+ return 0.0
73
+
74
+
75
+ class AudioDataset(Dataset):
76
+
77
+ def __init__(self, text_norm, text_tokenizer, data_list, model_config: Config):
78
+ self.datas = []
79
+ self.text_norm = text_norm
80
+ self.model_config = model_config
81
+
82
+ """Example data_list:
83
+ ```
84
+ {"key": "uttid_1", "prompt_text": "你好,我是小明。", "text": "你好,我是小红。", "prompt_wav": "/mnt/data/audio/00000000.wav", "wav": "/mnt/data/audio_synthetic/uttid_1.wav"}
85
+ {"key": "uttid_2", "prompt_text": "你好,我是小红。", "text": "你好,我是小明。", "prompt_wav": "/mnt/data/audio/00000001.wav", "wav": "/mnt/data/audio_synthetic/uttid_2.wav"}
86
+ ```
87
+ Note:
88
+ - `key` is the key of this sample.
89
+ - `prompt_text` is the text used for prompt.
90
+ - `text` is the text used for generating real audio.
91
+ - `prompt_wav` is the audio used for prompt.
92
+ - `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script).
93
+ """
94
+ missing = 0
95
+ with open(data_list, 'r', encoding='utf-8') as f:
96
+ lines = f.readlines()
97
+ total_lines = len(lines)
98
+ if torch.distributed.get_node_local_rank() == 0:
99
+ iterator = tqdm(lines, desc='Loading data')
100
+ else:
101
+ iterator = lines
102
+ for line in iterator:
103
+ data = json.loads(line.strip())
104
+ valid = True
105
+ for k in ['key', 'prompt_text', 'text', 'prompt_wav']:
106
+ if k not in data:
107
+ valid = False
108
+ break
109
+ if data[k] is None:
110
+ valid = False
111
+ break
112
+ if not os.path.exists(data['prompt_wav']):
113
+ valid = False
114
+ if valid:
115
+ self.datas.append(data)
116
+ else:
117
+ missing += 1
118
+ if torch.distributed.get_node_local_rank() == 0:
119
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
120
+ tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.')
121
+
122
+ self.text_tokenizer = text_tokenizer
123
+
124
+ option = onnxruntime.SessionOptions()
125
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
126
+ option.intra_op_num_threads = 1
127
+ self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option,
128
+ providers=["CPUExecutionProvider"])
129
+
130
+ def __len__(self):
131
+ return len(self.datas)
132
+
133
+ def __getitem__(self, idx):
134
+ data = self.datas[idx]
135
+
136
+ try:
137
+ # 1. feature for s3tokenizer
138
+ audio = s3tokenizer.load_audio(data['prompt_wav'], sr=16000) # [T]
139
+ log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
140
+
141
+ # 2. feature for speaker embedding
142
+ spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
143
+ spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
144
+ spk_emb = self.spk_model.run(
145
+ None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
146
+ )[0].flatten().tolist()
147
+
148
+ # 3. feature for flow
149
+ audio, sample_rate = torchaudio.load(data['prompt_wav'], backend='soundfile')
150
+ audio = audio.mean(dim=0, keepdim=True) # [1, T]
151
+ if sample_rate != 24000:
152
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
153
+ mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
154
+ mel_len = mel.shape[0]
155
+
156
+ # 4. feature for llm
157
+ if self.text_norm is not None:
158
+ prompt_texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['prompt_text'].strip()))["sentences"]]
159
+ prompt_text = ''.join(prompt_texts)
160
+ texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['text'].strip()))["sentences"]]
161
+ text = ''.join(texts)
162
+ else:
163
+ prompt_text = data['prompt_text']
164
+ text = data['text']
165
+ prompt_text_ids = self.text_tokenizer.encode(prompt_text)
166
+ prompt_text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in prompt_text_ids]
167
+ text_ids = self.text_tokenizer.encode(text)
168
+ text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in text_ids]
169
+ item = {
170
+ "prompt_text_tokens": prompt_text_ids, "text_tokens": text_ids,
171
+ "spk_emb": spk_emb, "mel": mel, "mel_len": mel_len, "log_mel": log_mel, "info": data,
172
+ "min_tokens": len(text_ids) * self.model_config.min_token_text_ratio,
173
+ "max_tokens": len(text_ids) * self.model_config.max_token_text_ratio,
174
+ }
175
+ except Exception as e:
176
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
177
+ tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}")
178
+ return None
179
+ return item
180
+
181
+
182
+ def collate_fn(batch):
183
+ prompt_mels_for_llm = [item["log_mel"] for item in batch if item is not None]
184
+ prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_mels_for_llm) # [B, num_mels=128, T]
185
+ prompt_text_tokens_for_llm = [item["prompt_text_tokens"] for item in batch if item is not None]
186
+ text_tokens_for_llm = [item["text_tokens"] for item in batch if item is not None]
187
+ prompt_mels_for_flow = [item["mel"] for item in batch if item is not None]
188
+ prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
189
+ prompt_mels_lens_for_flow = [item["mel_len"] for item in batch if item is not None]
190
+ prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
191
+ spk_emb_for_flow = [item["spk_emb"] for item in batch if item is not None]
192
+ spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
193
+ sampling_params = [SamplingParams(min_tokens=item["min_tokens"], max_tokens=item["max_tokens"], use_ras=True) for item in batch if item is not None]
194
+ infos = [item["info"] for item in batch if item is not None]
195
+ return {
196
+ "prompt_mels_for_llm": prompt_mels_for_llm,
197
+ "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
198
+ "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
199
+ "text_tokens_for_llm": text_tokens_for_llm,
200
+ "prompt_mels_for_flow": prompt_mels_for_flow,
201
+ "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow,
202
+ "spk_emb_for_flow": spk_emb_for_flow,
203
+ "sampling_params": sampling_params,
204
+ "infos": infos,
205
+ }
206
+
207
+
208
+ def init_distributed():
209
+ world_size = int(os.environ.get('WORLD_SIZE', 1))
210
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
211
+ rank = int(os.environ.get('RANK', 0))
212
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
213
+ tqdm.write(f'[{timestamp}] - [INFO] - Inference on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}')
214
+ torch.cuda.set_device(local_rank)
215
+ dist.init_process_group("nccl")
216
+ return world_size, local_rank, rank
217
+
218
+
219
+ def get_args():
220
+ parser = argparse.ArgumentParser(description='FlashCosyVoice')
221
+ parser.add_argument('--model_path',
222
+ required=True,
223
+ type=str,
224
+ help='model path')
225
+ parser.add_argument('--data_list',
226
+ required=True,
227
+ type=str,
228
+ help='data list')
229
+ parser.add_argument('--batch_size_dataloader',
230
+ required=True,
231
+ type=int,
232
+ help='batch size (per-device) for dataloading')
233
+ parser.add_argument('--batch_size_flow',
234
+ required=True,
235
+ type=int,
236
+ help='batch size (per-device) for flow-matching')
237
+ parser.add_argument('--num_workers',
238
+ type=int,
239
+ default=4,
240
+ help='workers for dataloader')
241
+ parser.add_argument('--prefetch',
242
+ type=int,
243
+ default=5,
244
+ help='prefetch for dataloader')
245
+ parser.add_argument('--enable_tn',
246
+ action='store_true',
247
+ help='enable text normalization')
248
+ parser.add_argument('--only_llm',
249
+ action='store_true',
250
+ help='only generate speech tokens from llm')
251
+ parser.add_argument('--fp16_flow',
252
+ action='store_true',
253
+ help='enable fp16 flow')
254
+ parser.add_argument('--seed',
255
+ type=int,
256
+ default=1986,
257
+ help='random seed for generation')
258
+ args = parser.parse_args()
259
+ return args
260
+
261
+
262
+ def main():
263
+ args = get_args()
264
+
265
+ if args.enable_tn:
266
+ # Check python version, if == 3.10, use ttsfrd
267
+ if sys.version_info.major == 3 and sys.version_info.minor == 10:
268
+ # Check if ttsfrd is installed
269
+ try:
270
+ import ttsfrd
271
+ from cosyvoice_ttsfrd import get_resource_path
272
+ except ImportError as e:
273
+ raise ImportError("ttsfrd is not installed, please install it first, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for installation guide.") from e
274
+ text_norm = ttsfrd.TtsFrontendEngine()
275
+ text_norm.initialize(get_resource_path())
276
+ text_norm.set_lang_type('pinyinvg')
277
+ else:
278
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
279
+ tqdm.write(f"[{timestamp}] - [WARNING] - Only python 3.10 is supported for ttsfrd, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for more info. Setting enable_tn to False...")
280
+ # TODO: maybe we should use wetext if python version is not 3.10?
281
+ args.enable_tn = False
282
+ text_norm = None
283
+ else:
284
+ text_norm = None
285
+
286
+ assert (torch.cuda.is_available())
287
+ world_size, local_rank, rank = init_distributed()
288
+ config = Config(model=args.model_path, enforce_eager=True, tensor_parallel_size=1,
289
+ max_num_seqs=args.batch_size_dataloader,
290
+ hf_config=CosyVoice2LLMConfig(fp16_flow=args.fp16_flow), rank=local_rank)
291
+ model = CosyVoice2(config)
292
+
293
+ set_all_random_seed(args.seed)
294
+
295
+ dataset = AudioDataset(text_norm, model.llm.tokenizer, args.data_list, config)
296
+ sampler = DistributedSampler(dataset,
297
+ num_replicas=world_size,
298
+ rank=rank)
299
+ dataloader = DataLoader(dataset, batch_size=args.batch_size_dataloader, num_workers=args.num_workers, pin_memory=True,
300
+ sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn)
301
+ total_steps = len(dataset)
302
+
303
+ if local_rank == 0:
304
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
305
+ tqdm.write(f"[{timestamp}] - [INFO] - {args}")
306
+ progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav",
307
+ position=0, leave=True, dynamic_ncols=True)
308
+
309
+ cpu_counts = os.cpu_count()
310
+ executor = ThreadPoolExecutor(max_workers=min(args.batch_size_dataloader, cpu_counts // 8))
311
+ pending_futures = []
312
+ dataloader_iter = iter(dataloader)
313
+ succeed_duration = 0.01 # avoid division by zero
314
+ start_time = time.time()
315
+ estimated_total_wavs = 0
316
+ succeed_wavs = 0
317
+ failed_wavs = 0
318
+ last_print_time = start_time
319
+
320
+ while True:
321
+ try:
322
+ dataloader_start = time.time()
323
+ batch = next(dataloader_iter)
324
+ dataloader_time = time.time() - dataloader_start
325
+
326
+ if len(batch['infos']) == 0:
327
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
328
+ tqdm.write(f"[{timestamp}] - [WARNING] - rank {rank} of {world_size}: No valid batch found, skipping this batch...")
329
+ continue
330
+
331
+ model_start = time.time()
332
+ results_dict, timing_stats = model(**batch, batch_size_flow=args.batch_size_flow,
333
+ only_llm=args.only_llm)
334
+ model_time = time.time() - model_start
335
+
336
+ estimated_total_wavs += len(results_dict['generated_wavs'])
337
+
338
+ timing_stats['dataloader_time'] = dataloader_time
339
+ timing_stats['model_inference_time'] = model_time
340
+
341
+ if args.only_llm:
342
+ results_dict['generated_wavs'] = [None] * len(results_dict['prompt_speech_tokens'])
343
+
344
+ for i in range(len(results_dict['generated_wavs'])):
345
+ future = executor.submit(
346
+ save_file_async, results_dict['generated_wavs'][i],
347
+ results_dict['prompt_speech_tokens'][i],
348
+ results_dict['generated_speech_tokens'][i],
349
+ batch['infos'][i].copy(), timing_stats.copy()
350
+ )
351
+ pending_futures.append(future)
352
+
353
+ completed_futures = []
354
+ for future in pending_futures:
355
+ if future.done():
356
+ try:
357
+ duration = future.result()
358
+ succeed_duration += duration
359
+ succeed_wavs += 1
360
+ except Exception as e:
361
+ failed_wavs += 1
362
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
363
+ tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in async save task: {e}")
364
+ completed_futures.append(future)
365
+
366
+ for future in completed_futures:
367
+ pending_futures.remove(future)
368
+
369
+ if local_rank == 0:
370
+ update_n = world_size * len(batch["prompt_text_tokens_for_llm"])
371
+ if progress_bar.n + update_n > progress_bar.total:
372
+ progress_bar.update(progress_bar.total - progress_bar.n)
373
+ else:
374
+ progress_bar.update(update_n)
375
+
376
+ current_time = time.time()
377
+ if current_time - last_print_time >= 120 and not args.only_llm:
378
+ elapsed_time = current_time - start_time
379
+ avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0
380
+ estimated_total_duration = avg_duration * estimated_total_wavs
381
+ current_rtf = elapsed_time / estimated_total_duration if estimated_total_duration > 0.01 else 0
382
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
383
+ tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Estimated RTF: {current_rtf:.5f}, Elapsed time: {elapsed_time:.2f}s") # noqa
384
+ last_print_time = current_time
385
+ except StopIteration:
386
+ break
387
+ except Exception as e:
388
+ failed_wavs += 1
389
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
390
+ tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in main loop: {e}")
391
+ continue
392
+
393
+ total_time = time.time() - start_time
394
+
395
+ if local_rank == 0:
396
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
397
+ tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...")
398
+
399
+ for future in pending_futures:
400
+ try:
401
+ duration = future.result(timeout=60)
402
+ succeed_duration += duration
403
+ succeed_wavs += 1
404
+ except Exception as e:
405
+ failed_wavs += 1
406
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
407
+ tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in final async save task: {e}")
408
+ executor.shutdown(wait=True)
409
+
410
+ if local_rank == 0:
411
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
412
+ tqdm.write(f"[{timestamp}] - [INFO] - All async save tasks completed.")
413
+ progress_bar.close()
414
+
415
+ if not args.only_llm:
416
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
417
+ tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h), RTF: {total_time / succeed_duration:.5f}") # noqa
418
+
419
+ dist.barrier()
420
+ dist.destroy_process_group()
421
+
422
+
423
+ if __name__ == "__main__":
424
+ main()
flashcosyvoice/config.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass, field
3
+
4
+ import torch
5
+ from transformers import AutoConfig
6
+
7
+
8
+ @dataclass
9
+ class CosyVoice2LLMConfig:
10
+ architectures: list[str] = field(default_factory=lambda: ["Qwen2ForCausalLM"])
11
+ attention_dropout: float = 0.0
12
+ bos_token_id: int = 151643
13
+ eos_token_id: int = 6561 # speech eos
14
+ hidden_act: str = "silu"
15
+ hidden_size: int = 896
16
+ initializer_range: float = 0.02
17
+ intermediate_size: int = 4864
18
+ max_position_embeddings: int = 32768
19
+ max_window_layers: int = 24
20
+ model_type: str = "qwen2"
21
+ num_attention_heads: int = 14
22
+ num_hidden_layers: int = 24
23
+ num_key_value_heads: int = 2
24
+ head_dim: int = 64
25
+ rms_norm_eps: float = 1e-06
26
+ rope_scaling: dict | None = None
27
+ rope_theta: float = 1000000.0
28
+ sliding_window: int = 32768
29
+ tie_word_embeddings: bool = False
30
+ torch_dtype: torch.dtype = torch.bfloat16
31
+ transformers_version: str = "4.52.0.dev0"
32
+ use_cache: bool = True
33
+ use_sliding_window: bool = False
34
+ vocab_size: int = 158500 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
35
+ text_vocab_size: int = 151936
36
+ speech_vocab_size: int = 6562 # actually 6564, we only care about non-streaming inference, so cut off tokens (6562, 6563) that are only used for streaming TTS
37
+ lm_head_bias: bool = True
38
+ qkv_bias: bool = True
39
+ fp16_flow: bool = True
40
+
41
+
42
+ @dataclass
43
+ class SamplingParams:
44
+ temperature: float = 1.0
45
+ min_tokens: int = 2
46
+ max_tokens: int = 64
47
+ ignore_eos: bool = False
48
+ top_k: int = 25
49
+ # RasSampler parameters
50
+ use_ras: bool = False
51
+ win_size: int = 10
52
+ tau_r: float = 0.1
53
+ top_p: float = 0.8
54
+
55
+
56
+ @dataclass
57
+ class Config:
58
+ model: str
59
+ max_num_batched_tokens: int = 1572864
60
+ max_num_seqs: int = 1024
61
+ max_model_len: int = 1536 # 15s prompt + 30s generated audio for 25hz audio tokenizer
62
+ gpu_memory_utilization: float = 0.9
63
+ tensor_parallel_size: int = 1
64
+ enforce_eager: bool = False
65
+ hf_config: CosyVoice2LLMConfig | AutoConfig = field(default_factory=CosyVoice2LLMConfig)
66
+ eos: int = -1
67
+ kvcache_block_size: int = 256
68
+ num_kvcache_blocks: int = -1
69
+ min_token_text_ratio: int = 2
70
+ max_token_text_ratio: int = 20
71
+ rank: int = 0
72
+
73
+ def __post_init__(self):
74
+ assert os.path.isdir(self.model)
75
+ assert self.kvcache_block_size % 256 == 0
76
+ assert 1 <= self.tensor_parallel_size <= 8
77
+
78
+ max_pos = getattr(self.hf_config, "max_position_embeddings", 4096)
79
+ self.max_model_len = min(self.max_model_len, max_pos)
80
+ assert self.max_num_batched_tokens >= self.max_model_len
flashcosyvoice/cosyvoice2.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import time
15
+ from datetime import datetime
16
+
17
+ import s3tokenizer
18
+ import torch
19
+ from tqdm import tqdm
20
+
21
+ from flashcosyvoice.config import Config, SamplingParams
22
+ from flashcosyvoice.engine.llm_engine import LLMEngine
23
+ from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
24
+ from flashcosyvoice.modules.hifigan import HiFTGenerator
25
+
26
+
27
+ class CosyVoice2(torch.nn.Module):
28
+ def __init__(self, config: Config = None):
29
+ super().__init__()
30
+ self.config = Config() if config is None else config
31
+
32
+ self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval()
33
+
34
+ self.llm = LLMEngine(**self.config.__dict__)
35
+
36
+ self.use_tqdm = torch.distributed.get_node_local_rank() == 0
37
+
38
+ self.flow = CausalMaskedDiffWithXvec()
39
+ if self.config.hf_config.fp16_flow:
40
+ timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
41
+ tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16")
42
+ self.flow.half()
43
+ self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True)
44
+ self.flow.cuda().eval()
45
+
46
+ self.hift = HiFTGenerator()
47
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()}
48
+ self.hift.load_state_dict(hift_state_dict, strict=True)
49
+ self.hift.cuda().eval()
50
+
51
+ @torch.inference_mode()
52
+ def forward(
53
+ self, prompt_mels_for_llm: torch.Tensor, prompt_mels_lens_for_llm: torch.Tensor,
54
+ prompt_text_tokens_for_llm: list[list[int]], text_tokens_for_llm: list[list[int]],
55
+ prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor,
56
+ spk_emb_for_flow: torch.Tensor,
57
+ sampling_params: SamplingParams | list[SamplingParams],
58
+ batch_size_flow: int,
59
+ only_llm: bool,
60
+ **kwargs, # for compatibility
61
+ ):
62
+ timing_stats = {}
63
+
64
+ # Audio tokenization
65
+ start_time = time.time()
66
+ prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
67
+ prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda()
68
+ )
69
+ timing_stats['audio_tokenization'] = time.time() - start_time
70
+
71
+ batch_size = prompt_speech_tokens.shape[0]
72
+ assert len(prompt_text_tokens_for_llm) == batch_size
73
+
74
+ # Prepare LLM inputs
75
+ start_time = time.time()
76
+ valid_prompt_speech_tokens = []
77
+ inputs = []
78
+ for i in range(batch_size):
79
+ speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
80
+ valid_prompt_speech_tokens.append(speech_tokens_i)
81
+ inputs.append([self.config.hf_config.speech_vocab_size] + prompt_text_tokens_for_llm[i] + text_tokens_for_llm[i] + [self.config.hf_config.speech_vocab_size + 1] + speech_tokens_i)
82
+ timing_stats['prepare_llm_inputs'] = time.time() - start_time
83
+
84
+ # LLM generation
85
+ start_time = time.time()
86
+ llm_outputs = self.llm.generate(inputs, sampling_params, use_tqdm=self.use_tqdm)
87
+ timing_stats['llm_generation'] = time.time() - start_time
88
+
89
+ results_dict = {
90
+ "prompt_speech_tokens": valid_prompt_speech_tokens,
91
+ "generated_speech_tokens": [o['token_ids'][:-1] for o in llm_outputs],
92
+ }
93
+ if only_llm:
94
+ return results_dict, timing_stats
95
+
96
+ # Prepare Flow inputs
97
+ start_time = time.time()
98
+ flow_inputs = []
99
+ flow_inputs_lens = []
100
+ for i, o in enumerate(llm_outputs):
101
+ generated_speech_tokens = o['token_ids'][:-1] # ignore last eos
102
+ prompt_speech_tokens = valid_prompt_speech_tokens[i]
103
+ flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
104
+ flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
105
+ flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
106
+ flow_inputs_lens = torch.tensor(flow_inputs_lens)
107
+ timing_stats['prepare_flow_inputs'] = time.time() - start_time
108
+
109
+ # Flow generation and HiFi-GAN generation (with batching)
110
+ total_batch_size = flow_inputs.shape[0]
111
+ generated_wavs = []
112
+ flow_total_time = 0.0
113
+ hifigan_total_time = 0.0
114
+
115
+ # Process in batches according to batch_size_flow, batch_size_flow <= total_batch_size
116
+ # NOTE(xcsong): When executing both LLM and Flow on the same GPU,
117
+ # Flow can easily fill up the SM and memory. Therefore, batch processing is required to avoid OOM.
118
+ num_batches = (total_batch_size + batch_size_flow - 1) // batch_size_flow
119
+ batch_iterator = range(0, total_batch_size, batch_size_flow)
120
+ if self.use_tqdm:
121
+ batch_iterator = tqdm(batch_iterator, desc="Generating wavs (Flow+HiFi-GAN)", leave=False, unit="batch",
122
+ total=num_batches, dynamic_ncols=True, position=self.config.rank + 1)
123
+
124
+ for start_idx in batch_iterator:
125
+ end_idx = min(start_idx + batch_size_flow, total_batch_size)
126
+ batch_flow_inputs = flow_inputs[start_idx:end_idx]
127
+ batch_flow_inputs_lens = flow_inputs_lens[start_idx:end_idx]
128
+ batch_prompt_mels = prompt_mels_for_flow[start_idx:end_idx]
129
+ batch_prompt_mels_lens = prompt_mels_lens_for_flow[start_idx:end_idx]
130
+ batch_spk_emb = spk_emb_for_flow[start_idx:end_idx]
131
+
132
+ # Flow generation for this batch
133
+ flow_start_time = time.time()
134
+ with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
135
+ batch_generated_mels, batch_generated_mels_lens = self.flow(
136
+ batch_flow_inputs.cuda(), batch_flow_inputs_lens.cuda(),
137
+ batch_prompt_mels.cuda(), batch_prompt_mels_lens.cuda(), batch_spk_emb.cuda(),
138
+ streaming=False, finalize=True
139
+ )
140
+ flow_total_time += time.time() - flow_start_time
141
+
142
+ # HiFi-GAN generation for this batch
143
+ hifigan_start_time = time.time()
144
+ batch_size_current = end_idx - start_idx
145
+ for i in range(batch_size_current):
146
+ mel = batch_generated_mels[i, :, batch_prompt_mels_lens[i].item():batch_generated_mels_lens[i].item()].unsqueeze(0)
147
+ wav, _ = self.hift(speech_feat=mel)
148
+ generated_wavs.append(wav)
149
+ hifigan_total_time += time.time() - hifigan_start_time
150
+
151
+ timing_stats['flow_generation'] = flow_total_time
152
+ timing_stats['hifigan_generation'] = hifigan_total_time
153
+
154
+ # Calculate total time and batch statistics
155
+ timing_stats['model.forward_total'] = sum(timing_stats.values())
156
+ timing_stats['batch_size'] = len(generated_wavs)
157
+ timing_stats['batch_size_flow'] = batch_size_flow
158
+
159
+ results_dict['generated_wavs'] = generated_wavs
160
+ return results_dict, timing_stats
flashcosyvoice/cosyvoice3.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # TODO(xcsong): Implement CosyVoice3 when it is released
flashcosyvoice/engine/__init__.py ADDED
File without changes
flashcosyvoice/engine/block_manager.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ import numpy as np
4
+ import xxhash
5
+
6
+ from flashcosyvoice.engine.sequence import Sequence
7
+
8
+
9
+ class Block:
10
+
11
+ def __init__(self, block_id):
12
+ self.block_id = block_id
13
+ self.ref_count = 0
14
+ self.hash = -1
15
+ self.token_ids = []
16
+
17
+ def update(self, hash: int, token_ids: list[int]):
18
+ self.hash = hash
19
+ self.token_ids = token_ids
20
+
21
+ def reset(self):
22
+ self.ref_count = 1
23
+ self.hash = -1
24
+ self.token_ids = []
25
+
26
+
27
+ class BlockManager:
28
+
29
+ def __init__(self, num_blocks: int, block_size: int):
30
+ assert num_blocks > 0
31
+ self.block_size = block_size
32
+ self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
33
+ self.hash_to_block_id: dict[int, int] = dict()
34
+ self.free_block_ids: deque[int] = deque(range(num_blocks))
35
+ self.used_block_ids: set[int] = set()
36
+
37
+ @classmethod
38
+ def compute_hash(cls, token_ids: list[int], prefix: int = -1):
39
+ h = xxhash.xxh64()
40
+ if prefix != -1:
41
+ h.update(prefix.to_bytes(8, "little"))
42
+ h.update(np.array(token_ids).tobytes())
43
+ return h.intdigest()
44
+
45
+ def _allocate_block(self, block_id: int) -> Block:
46
+ block = self.blocks[block_id]
47
+ assert block.ref_count == 0
48
+ block.reset()
49
+ self.free_block_ids.remove(block_id)
50
+ self.used_block_ids.add(block_id)
51
+ return self.blocks[block_id]
52
+
53
+ def _deallocate_block(self, block_id: int) -> Block:
54
+ assert self.blocks[block_id].ref_count == 0
55
+ self.used_block_ids.remove(block_id)
56
+ self.free_block_ids.append(block_id)
57
+
58
+ def can_allocate(self, seq: Sequence) -> bool:
59
+ return len(self.free_block_ids) >= seq.num_blocks
60
+
61
+ def allocate(self, seq: Sequence):
62
+ assert not seq.block_table
63
+ h = -1
64
+ cache_miss = False
65
+ for i in range(seq.num_blocks):
66
+ token_ids = seq.block(i)
67
+ h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
68
+ block_id = self.hash_to_block_id.get(h, -1)
69
+ if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
70
+ cache_miss = True
71
+ if cache_miss:
72
+ block_id = self.free_block_ids[0]
73
+ block = self._allocate_block(block_id)
74
+ else:
75
+ seq.num_cached_tokens += self.block_size
76
+ if block_id in self.used_block_ids:
77
+ block = self.blocks[block_id]
78
+ block.ref_count += 1
79
+ else:
80
+ block = self._allocate_block(block_id)
81
+ if h != -1:
82
+ block.update(h, token_ids)
83
+ self.hash_to_block_id[h] = block_id
84
+ seq.block_table.append(block_id)
85
+
86
+ def deallocate(self, seq: Sequence):
87
+ for block_id in reversed(seq.block_table):
88
+ block = self.blocks[block_id]
89
+ block.ref_count -= 1
90
+ if block.ref_count == 0:
91
+ self._deallocate_block(block_id)
92
+ seq.num_cached_tokens = 0
93
+ seq.block_table.clear()
94
+
95
+ def can_append(self, seq: Sequence) -> bool:
96
+ return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
97
+
98
+ def may_append(self, seq: Sequence):
99
+ block_table = seq.block_table
100
+ last_block = self.blocks[block_table[-1]]
101
+ if len(seq) % self.block_size == 1:
102
+ assert last_block.hash != -1
103
+ block_id = self.free_block_ids[0]
104
+ self._allocate_block(block_id)
105
+ block_table.append(block_id)
106
+ elif len(seq) % self.block_size == 0:
107
+ assert last_block.hash == -1
108
+ token_ids = seq.block(seq.num_blocks - 1)
109
+ prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
110
+ h = self.compute_hash(token_ids, prefix)
111
+ last_block.update(h, token_ids)
112
+ self.hash_to_block_id[h] = last_block.block_id
113
+ else:
114
+ assert last_block.hash == -1
flashcosyvoice/engine/llm_engine.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from dataclasses import fields
3
+ from time import perf_counter
4
+
5
+ import torch.multiprocessing as mp
6
+ from tqdm.auto import tqdm
7
+ from transformers import AutoTokenizer
8
+
9
+ from flashcosyvoice.config import Config, SamplingParams
10
+ from flashcosyvoice.engine.model_runner import ModelRunner
11
+ from flashcosyvoice.engine.scheduler import Scheduler
12
+ from flashcosyvoice.engine.sequence import Sequence
13
+
14
+
15
+ class LLMEngine:
16
+
17
+ def __init__(self, model, **kwargs):
18
+ config_fields = {field.name for field in fields(Config)}
19
+ config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
20
+ config = Config(model, **config_kwargs)
21
+ self.ps = []
22
+ self.events = []
23
+ ctx = mp.get_context("spawn")
24
+ assert config.tensor_parallel_size == 1, "NOTE(xcsong): Currently only support tp=1"
25
+ for i in range(1, config.tensor_parallel_size):
26
+ event = ctx.Event()
27
+ process = ctx.Process(target=ModelRunner, args=(config, i, event))
28
+ process.start()
29
+ self.ps.append(process)
30
+ self.events.append(event)
31
+ if hasattr(config.hf_config, "speech_vocab_size"):
32
+ # NOTE: non-chat model, all these special tokens keep randomly initialized.
33
+ special_tokens = {
34
+ 'eos_token': '<|endoftext|>',
35
+ 'pad_token': '<|endoftext|>',
36
+ 'additional_special_tokens': [
37
+ '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
38
+ '[breath]', '<strong>', '</strong>', '[noise]',
39
+ '[laughter]', '[cough]', '[clucking]', '[accent]',
40
+ '[quick_breath]',
41
+ "<laughter>", "</laughter>",
42
+ "[hissing]", "[sigh]", "[vocalized-noise]",
43
+ "[lipsmack]", "[mn]"
44
+ ]
45
+ }
46
+ self.tokenizer = AutoTokenizer.from_pretrained(f"{config.model}/CosyVoice-BlankEN")
47
+ self.tokenizer.add_special_tokens(special_tokens)
48
+ self.skip_special_tokens = True
49
+ else:
50
+ self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
51
+ if hasattr(config.hf_config, "eos_token_id"):
52
+ config.eos = config.hf_config.eos_token_id
53
+ else:
54
+ config.eos = self.tokenizer.eos_token_id
55
+ self.model_runner = ModelRunner(config, config.rank, self.events)
56
+ self.scheduler = Scheduler(config)
57
+ self.config = config
58
+ atexit.register(self.exit)
59
+
60
+ def exit(self):
61
+ self.model_runner.call("exit")
62
+ del self.model_runner
63
+ for p in self.ps:
64
+ p.join()
65
+
66
+ def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
67
+ if isinstance(prompt, str):
68
+ prompt = self.tokenizer.encode(prompt)
69
+ seq = Sequence(prompt, sampling_params)
70
+ self.scheduler.add(seq)
71
+
72
+ def step(self):
73
+ seqs, is_prefill = self.scheduler.schedule()
74
+ token_ids = self.model_runner.call("run", seqs, is_prefill)
75
+ self.scheduler.postprocess(seqs, token_ids)
76
+ outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
77
+ num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
78
+ return outputs, num_tokens
79
+
80
+ def is_finished(self):
81
+ return self.scheduler.is_finished()
82
+
83
+ def generate(
84
+ self,
85
+ prompts: list[str] | list[list[int]],
86
+ sampling_params: SamplingParams | list[SamplingParams],
87
+ use_tqdm: bool = True,
88
+ ) -> list[str]:
89
+ if use_tqdm:
90
+ pbar = tqdm(total=len(prompts), desc="Generating tokens (LLM)", leave=False,
91
+ dynamic_ncols=True, position=self.config.rank + 1)
92
+ if not isinstance(sampling_params, list):
93
+ sampling_params = [sampling_params] * len(prompts)
94
+ for prompt, sp in zip(prompts, sampling_params):
95
+ self.add_request(prompt, sp)
96
+ outputs = {}
97
+ prefill_throughput = decode_throughput = instant_decode_throughput = 0.
98
+ total_decode_tokens = 0
99
+ total_decode_time = 0.
100
+ while not self.is_finished():
101
+ t = perf_counter()
102
+ output, num_tokens = self.step()
103
+ step_time = perf_counter() - t
104
+ if use_tqdm:
105
+ if num_tokens > 0:
106
+ prefill_throughput = num_tokens / step_time
107
+ else:
108
+ instant_decode_throughput = -num_tokens / step_time
109
+ total_decode_tokens += -num_tokens
110
+ total_decode_time += step_time
111
+ decode_throughput = total_decode_tokens / total_decode_time if total_decode_time > 0 else 0
112
+ pbar.set_postfix({
113
+ "Prefill": f"{int(prefill_throughput)}tok/s",
114
+ "AvgDecode": f"{int(decode_throughput)}tok/s",
115
+ "InstDecode": f"{int(instant_decode_throughput)}tok/s",
116
+ })
117
+ for seq_id, token_ids in output:
118
+ outputs[seq_id] = token_ids
119
+ if use_tqdm:
120
+ pbar.update(1)
121
+ outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
122
+ outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
123
+ if use_tqdm:
124
+ pbar.close()
125
+ return outputs
flashcosyvoice/engine/model_runner.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from multiprocessing.shared_memory import SharedMemory
3
+ from multiprocessing.synchronize import Event
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ from flashcosyvoice.config import Config
9
+ from flashcosyvoice.engine.sequence import Sequence
10
+ from flashcosyvoice.modules.qwen2 import Qwen2ForCausalLM
11
+ from flashcosyvoice.modules.sampler import RasSampler, Sampler
12
+ from flashcosyvoice.utils.context import (get_context, reset_context,
13
+ set_context)
14
+ from flashcosyvoice.utils.loader import load_model
15
+
16
+
17
+ class ModelRunner:
18
+
19
+ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
20
+ self.config = config
21
+ hf_config = config.hf_config
22
+ self.block_size = config.kvcache_block_size
23
+ self.enforce_eager = config.enforce_eager
24
+ self.world_size = config.tensor_parallel_size
25
+ self.rank = rank
26
+ self.event = event
27
+
28
+ # TODO(xcsong): support tp > 1
29
+ if self.world_size > 1:
30
+ dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
31
+ torch.cuda.set_device(rank)
32
+ default_dtype = torch.get_default_dtype()
33
+ torch.set_default_dtype(hf_config.torch_dtype)
34
+ torch.set_default_device("cuda")
35
+ self.model = Qwen2ForCausalLM(hf_config)
36
+ load_model(self.model, config.model, hf_config)
37
+ self.sampler = Sampler()
38
+ self.ras_sampler = RasSampler()
39
+ self.warmup_model()
40
+ self.allocate_kv_cache()
41
+ if not self.enforce_eager:
42
+ self.capture_cudagraph()
43
+ torch.set_default_device("cpu")
44
+ torch.set_default_dtype(default_dtype)
45
+
46
+ if self.world_size > 1:
47
+ if rank == 0:
48
+ self.shm = SharedMemory(name="flashcosyvoice", create=True, size=2**20)
49
+ dist.barrier()
50
+ else:
51
+ dist.barrier()
52
+ self.shm = SharedMemory(name="flashcosyvoice")
53
+ self.loop()
54
+
55
+ def exit(self):
56
+ if self.world_size > 1:
57
+ self.shm.close()
58
+ dist.barrier()
59
+ if self.rank == 0:
60
+ self.shm.unlink()
61
+ if not self.enforce_eager:
62
+ del self.graphs, self.graph_pool
63
+ torch.cuda.synchronize()
64
+ if self.world_size > 1:
65
+ dist.destroy_process_group()
66
+
67
+ def loop(self):
68
+ while True:
69
+ method_name, args = self.read_shm()
70
+ self.call(method_name, *args)
71
+ if method_name == "exit":
72
+ break
73
+
74
+ def read_shm(self):
75
+ assert self.world_size > 1 and self.rank
76
+ self.event.wait()
77
+ n = int.from_bytes(self.shm.buf[0:4], "little")
78
+ method_name, *args = pickle.loads(self.shm.buf[4:n + 4])
79
+ self.event.clear()
80
+ return method_name, args
81
+
82
+ def write_shm(self, method_name, *args):
83
+ assert self.world_size > 1 and not self.rank
84
+ data = pickle.dumps([method_name, *args])
85
+ n = len(data)
86
+ self.shm.buf[0:4] = n.to_bytes(4, "little")
87
+ self.shm.buf[4:n + 4] = data
88
+ for event in self.event:
89
+ event.set()
90
+
91
+ def call(self, method_name, *args):
92
+ if self.world_size > 1 and self.rank == 0:
93
+ self.write_shm(method_name, *args)
94
+ method = getattr(self, method_name, None)
95
+ return method(*args)
96
+
97
+ def warmup_model(self):
98
+ torch.cuda.empty_cache()
99
+ torch.cuda.reset_peak_memory_stats()
100
+ max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
101
+ num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
102
+ seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
103
+ self.run(seqs, True)
104
+ torch.cuda.empty_cache()
105
+
106
+ def allocate_kv_cache(self):
107
+ config = self.config
108
+ hf_config = config.hf_config
109
+ free, total = torch.cuda.mem_get_info()
110
+ used = total - free
111
+ peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
112
+ current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
113
+ num_kv_heads = hf_config.num_key_value_heads // self.world_size
114
+ head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
115
+ block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
116
+ config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
117
+ assert config.num_kvcache_blocks > 0, "try to **increase** gpu_memory_utilization"
118
+ self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
119
+ layer_id = 0
120
+ for module in self.model.modules():
121
+ if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
122
+ module.k_cache = self.kv_cache[0, layer_id]
123
+ module.v_cache = self.kv_cache[1, layer_id]
124
+ layer_id += 1
125
+
126
+ def prepare_block_tables(self, seqs: list[Sequence]):
127
+ max_len = max(len(seq.block_table) for seq in seqs)
128
+ block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
129
+ block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
130
+ return block_tables
131
+
132
+ def prepare_prefill(self, seqs: list[Sequence]):
133
+ input_ids = []
134
+ positions = []
135
+ cu_seqlens_q = [0]
136
+ cu_seqlens_k = [0]
137
+ max_seqlen_q = 0
138
+ max_seqlen_k = 0
139
+ slot_mapping = []
140
+ block_tables = None
141
+ for seq in seqs:
142
+ seqlen = len(seq)
143
+ input_ids.extend(seq[seq.num_cached_tokens:])
144
+ positions.extend(list(range(seq.num_cached_tokens, seqlen)))
145
+ seqlen_q = seqlen - seq.num_cached_tokens
146
+ seqlen_k = seqlen
147
+ cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
148
+ cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
149
+ max_seqlen_q = max(seqlen_q, max_seqlen_q)
150
+ max_seqlen_k = max(seqlen_k, max_seqlen_k)
151
+ if not seq.block_table:
152
+ continue
153
+ for i in range(seq.num_cached_blocks, seq.num_blocks):
154
+ start = seq.block_table[i] * self.block_size
155
+ if i != seq.num_blocks - 1:
156
+ end = start + self.block_size
157
+ else:
158
+ end = start + seq.last_block_num_tokens
159
+ slot_mapping.extend(list(range(start, end)))
160
+ if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
161
+ block_tables = self.prepare_block_tables(seqs)
162
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
163
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
164
+ cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
165
+ cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
166
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
167
+ set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
168
+ return input_ids, positions
169
+
170
+ def prepare_decode(self, seqs: list[Sequence]):
171
+ input_ids = []
172
+ positions = []
173
+ slot_mapping = []
174
+ context_lens = []
175
+ for seq in seqs:
176
+ input_ids.append(seq.last_token)
177
+ positions.append(len(seq))
178
+ context_lens.append(len(seq))
179
+ slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
180
+ input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
181
+ positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
182
+ slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
183
+ context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
184
+ block_tables = self.prepare_block_tables(seqs)
185
+ set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
186
+ return input_ids, positions
187
+
188
+ def prepare_sample(self, seqs: list[Sequence]):
189
+ temperatures = []
190
+ top_ks = []
191
+ win_sizes = []
192
+ tau_rs = []
193
+ top_ps = []
194
+ min_tokens_list = []
195
+ use_ras_list = []
196
+
197
+ for seq in seqs:
198
+ temperatures.append(seq.temperature)
199
+ top_ks.append(seq.top_k)
200
+ win_sizes.append(seq.win_size)
201
+ tau_rs.append(seq.tau_r)
202
+ top_ps.append(seq.top_p)
203
+ min_tokens_list.append(seq.min_tokens)
204
+ use_ras_list.append(seq.use_ras)
205
+
206
+ temperatures_tensor = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
207
+ # check all items equal
208
+ assert all(item == temperatures[0] for item in temperatures)
209
+ assert all(item == top_ks[0] for item in top_ks)
210
+ assert all(item == win_sizes[0] for item in win_sizes)
211
+ assert all(item == tau_rs[0] for item in tau_rs)
212
+ assert all(item == top_ps[0] for item in top_ps)
213
+ assert all(item == use_ras_list[0] for item in use_ras_list)
214
+
215
+ return {
216
+ 'temperatures': temperatures_tensor,
217
+ 'top_k': top_ks[0],
218
+ 'win_size': win_sizes[0],
219
+ 'tau_r': tau_rs[0],
220
+ 'top_p': top_ps[0],
221
+ 'eos_token': self.config.eos,
222
+ 'min_tokens': min_tokens_list,
223
+ 'use_ras': use_ras_list[0]
224
+ }
225
+
226
+ @torch.inference_mode()
227
+ def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
228
+ if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
229
+ return self.model.compute_logits(self.model(input_ids, positions))
230
+ else:
231
+ bs = input_ids.size(0)
232
+ context = get_context()
233
+ graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
234
+ graph_vars = self.graph_vars
235
+ for k, v in graph_vars.items():
236
+ if k != "outputs":
237
+ v.zero_()
238
+ graph_vars["input_ids"][:bs] = input_ids
239
+ graph_vars["positions"][:bs] = positions
240
+ graph_vars["slot_mapping"][:bs] = context.slot_mapping
241
+ graph_vars["context_lens"][:bs] = context.context_lens
242
+ graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
243
+ graph.replay()
244
+ return self.model.compute_logits(graph_vars["outputs"][:bs])
245
+
246
+ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
247
+ input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
248
+ if self.rank == 0 or self.world_size == 1:
249
+ sample_params = self.prepare_sample(seqs)
250
+ logits = self.run_model(input_ids, positions, is_prefill)
251
+
252
+ if sample_params['use_ras']:
253
+ # Prepare decoded tokens list for RasSampler
254
+ decoded_tokens_list = [seq.completion_token_ids for seq in seqs]
255
+ # Pass all parameters as lists to RasSampler
256
+ token_ids = self.ras_sampler(
257
+ logits,
258
+ decoded_tokens_list,
259
+ win_size=sample_params['win_size'],
260
+ tau_r=sample_params['tau_r'],
261
+ top_p=sample_params['top_p'],
262
+ top_k=sample_params['top_k'],
263
+ eos_token=sample_params['eos_token'],
264
+ min_tokens=sample_params['min_tokens']
265
+ ).tolist()
266
+ else:
267
+ # Use the default sampler with list form of top_ks
268
+ token_ids = self.sampler(logits, sample_params['temperatures'], sample_params['top_k']).tolist()
269
+ else:
270
+ logits = self.run_model(input_ids, positions, is_prefill)
271
+ token_ids = None
272
+ reset_context()
273
+ return token_ids
274
+
275
+ @torch.inference_mode()
276
+ def capture_cudagraph(self):
277
+ config = self.config
278
+ hf_config = config.hf_config
279
+ max_bs = min(self.config.max_num_seqs, 512)
280
+ max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
281
+ input_ids = torch.zeros(max_bs, dtype=torch.int64)
282
+ positions = torch.zeros(max_bs, dtype=torch.int64)
283
+ slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
284
+ context_lens = torch.zeros(max_bs, dtype=torch.int32)
285
+ block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
286
+ outputs = torch.zeros(max_bs, hf_config.hidden_size)
287
+ self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
288
+ self.graphs = {}
289
+ self.graph_pool = None
290
+
291
+ for bs in reversed(self.graph_bs):
292
+ graph = torch.cuda.CUDAGraph()
293
+ set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
294
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
295
+ with torch.cuda.graph(graph, self.graph_pool):
296
+ outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
297
+ if self.graph_pool is None:
298
+ self.graph_pool = graph.pool()
299
+ self.graphs[bs] = graph
300
+ torch.cuda.synchronize()
301
+ reset_context()
302
+
303
+ self.graph_vars = dict(
304
+ input_ids=input_ids,
305
+ positions=positions,
306
+ slot_mapping=slot_mapping,
307
+ context_lens=context_lens,
308
+ block_tables=block_tables,
309
+ outputs=outputs,
310
+ )
flashcosyvoice/engine/scheduler.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import deque
2
+
3
+ from flashcosyvoice.config import Config
4
+ from flashcosyvoice.engine.block_manager import BlockManager
5
+ from flashcosyvoice.engine.sequence import Sequence, SequenceStatus
6
+
7
+
8
+ class Scheduler:
9
+
10
+ def __init__(self, config: Config):
11
+ self.max_num_seqs = config.max_num_seqs
12
+ self.max_num_batched_tokens = config.max_num_batched_tokens
13
+ self.eos = config.eos
14
+ self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
15
+ self.waiting: deque[Sequence] = deque()
16
+ self.running: deque[Sequence] = deque()
17
+
18
+ def is_finished(self):
19
+ return not self.waiting and not self.running
20
+
21
+ def add(self, seq: Sequence):
22
+ self.waiting.append(seq)
23
+
24
+ def schedule(self) -> tuple[list[Sequence], bool]:
25
+ # prefill
26
+ scheduled_seqs = []
27
+ num_seqs = 0
28
+ num_batched_tokens = 0
29
+ while self.waiting and num_seqs < self.max_num_seqs:
30
+ seq = self.waiting[0]
31
+ if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
32
+ break
33
+ num_seqs += 1
34
+ self.block_manager.allocate(seq)
35
+ num_batched_tokens += len(seq) - seq.num_cached_tokens
36
+ seq.status = SequenceStatus.RUNNING
37
+ self.waiting.popleft()
38
+ self.running.append(seq)
39
+ scheduled_seqs.append(seq)
40
+ if scheduled_seqs:
41
+ return scheduled_seqs, True
42
+
43
+ # decode
44
+ while self.running and num_seqs < self.max_num_seqs:
45
+ seq = self.running.popleft()
46
+ while not self.block_manager.can_append(seq):
47
+ if self.running:
48
+ self.preempt(self.running.pop())
49
+ else:
50
+ self.preempt(seq)
51
+ break
52
+ else:
53
+ num_seqs += 1
54
+ self.block_manager.may_append(seq)
55
+ scheduled_seqs.append(seq)
56
+ assert scheduled_seqs
57
+ self.running.extendleft(reversed(scheduled_seqs))
58
+ return scheduled_seqs, False
59
+
60
+ def preempt(self, seq: Sequence):
61
+ seq.status = SequenceStatus.WAITING
62
+ self.block_manager.deallocate(seq)
63
+ self.waiting.appendleft(seq)
64
+
65
+ def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
66
+ for seq, token_id in zip(seqs, token_ids):
67
+ seq.append_token(token_id)
68
+ # Check if the sequence has reached the maximum number of tokens
69
+ reached_max_tokens = seq.num_completion_tokens == seq.max_tokens
70
+ # Check if the sequence has reached EOS and has generated enough tokens (satisfying min_tokens requirements)
71
+ eos_with_min_tokens = (not seq.ignore_eos and token_id == self.eos and
72
+ seq.num_completion_tokens >= seq.min_tokens)
73
+
74
+ if reached_max_tokens or eos_with_min_tokens:
75
+ seq.status = SequenceStatus.FINISHED
76
+ self.block_manager.deallocate(seq)
77
+ self.running.remove(seq)
flashcosyvoice/engine/sequence.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from enum import Enum, auto
3
+ from itertools import count
4
+
5
+ from flashcosyvoice.config import SamplingParams
6
+
7
+
8
+ class SequenceStatus(Enum):
9
+ WAITING = auto()
10
+ RUNNING = auto()
11
+ FINISHED = auto()
12
+
13
+
14
+ class Sequence:
15
+ block_size = 256
16
+ counter = count()
17
+
18
+ def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
19
+ self.seq_id = next(Sequence.counter)
20
+ self.status = SequenceStatus.WAITING
21
+ self.token_ids = copy(token_ids)
22
+ self.last_token = token_ids[-1]
23
+ self.num_tokens = len(self.token_ids)
24
+ self.num_prompt_tokens = len(token_ids)
25
+ self.num_cached_tokens = 0
26
+ self.block_table = []
27
+ self.temperature = sampling_params.temperature
28
+ self.min_tokens = sampling_params.min_tokens
29
+ self.max_tokens = sampling_params.max_tokens
30
+ self.ignore_eos = sampling_params.ignore_eos
31
+ self.top_k = sampling_params.top_k
32
+ # RasSampler parameters
33
+ self.use_ras = sampling_params.use_ras
34
+ self.win_size = sampling_params.win_size
35
+ self.tau_r = sampling_params.tau_r
36
+ self.top_p = sampling_params.top_p
37
+
38
+ def __len__(self):
39
+ return self.num_tokens
40
+
41
+ def __getitem__(self, key):
42
+ return self.token_ids[key]
43
+
44
+ @property
45
+ def is_finished(self):
46
+ return self.status == SequenceStatus.FINISHED
47
+
48
+ @property
49
+ def num_completion_tokens(self):
50
+ return self.num_tokens - self.num_prompt_tokens
51
+
52
+ @property
53
+ def prompt_token_ids(self):
54
+ return self.token_ids[:self.num_prompt_tokens]
55
+
56
+ @property
57
+ def completion_token_ids(self):
58
+ return self.token_ids[self.num_prompt_tokens:]
59
+
60
+ @property
61
+ def num_cached_blocks(self):
62
+ return self.num_cached_tokens // self.block_size
63
+
64
+ @property
65
+ def num_blocks(self):
66
+ return (self.num_tokens + self.block_size - 1) // self.block_size
67
+
68
+ @property
69
+ def last_block_num_tokens(self):
70
+ return self.num_tokens - (self.num_blocks - 1) * self.block_size
71
+
72
+ def block(self, i):
73
+ assert 0 <= i < self.num_blocks
74
+ return self.token_ids[i*self.block_size: (i+1)*self.block_size]
75
+
76
+ def append_token(self, token_id: int):
77
+ self.token_ids.append(token_id)
78
+ self.last_token = token_id
79
+ self.num_tokens += 1
80
+
81
+ def __getstate__(self):
82
+ return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
83
+ self.token_ids if self.num_completion_tokens == 0 else self.last_token)
84
+
85
+ def __setstate__(self, state):
86
+ self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
87
+ if self.num_completion_tokens == 0:
88
+ self.token_ids = state[-1]
89
+ else:
90
+ self.last_token = state[-1]
flashcosyvoice/modules/__init__.py ADDED
File without changes
flashcosyvoice/modules/flow.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from flashcosyvoice.modules.flow_components.estimator import \
8
+ CausalConditionalDecoder
9
+ from flashcosyvoice.modules.flow_components.upsample_encoder import (
10
+ UpsampleConformerEncoder, make_pad_mask)
11
+
12
+
13
+ # TODO(xcsong): make it configurable
14
+ @dataclass
15
+ class CfmParams:
16
+ sigma_min: float = 1e-6
17
+ solver: str = "euler"
18
+ t_scheduler: str = "cosine"
19
+ training_cfg_rate: float = 0.2
20
+ inference_cfg_rate: float = 0.7
21
+
22
+
23
+ class CausalConditionalCFM(torch.nn.Module):
24
+ def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
25
+ super().__init__()
26
+ self.n_feats = in_channels
27
+ self.n_spks = n_spks
28
+ self.spk_emb_dim = spk_emb_dim
29
+ self.solver = cfm_params.solver
30
+ if hasattr(cfm_params, "sigma_min"):
31
+ self.sigma_min = cfm_params.sigma_min
32
+ else:
33
+ self.sigma_min = 1e-4
34
+ self.t_scheduler = cfm_params.t_scheduler
35
+ self.training_cfg_rate = cfm_params.training_cfg_rate
36
+ self.inference_cfg_rate = cfm_params.inference_cfg_rate
37
+ in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
38
+ # Just change the architecture of the estimator here
39
+ self.estimator = CausalConditionalDecoder() if estimator is None else estimator
40
+
41
+ @torch.inference_mode()
42
+ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
43
+ """Forward diffusion
44
+
45
+ Args:
46
+ mu (torch.Tensor): output of encoder
47
+ shape: (batch_size, n_feats, mel_timesteps)
48
+ mask (torch.Tensor): output_mask
49
+ shape: (batch_size, 1, mel_timesteps)
50
+ n_timesteps (int): number of diffusion steps
51
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
52
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
53
+ shape: (batch_size, spk_emb_dim)
54
+ cond: Not used but kept for future purposes
55
+
56
+ Returns:
57
+ sample: generated mel-spectrogram
58
+ shape: (batch_size, n_feats, mel_timesteps)
59
+ """
60
+ z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
61
+ # fix prompt and overlap part mu and z
62
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
63
+ if self.t_scheduler == 'cosine':
64
+ t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
65
+ return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
66
+
67
+ def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
68
+ """
69
+ Fixed euler solver for ODEs.
70
+ Args:
71
+ x (torch.Tensor): random noise
72
+ t_span (torch.Tensor): n_timesteps interpolated
73
+ shape: (n_timesteps + 1,)
74
+ mu (torch.Tensor): output of encoder
75
+ shape: (batch_size, n_feats, mel_timesteps)
76
+ mask (torch.Tensor): output_mask
77
+ shape: (batch_size, 1, mel_timesteps)
78
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
79
+ shape: (batch_size, spk_emb_dim)
80
+ cond: Not used but kept for future purposes
81
+ """
82
+ batch_size = x.size(0)
83
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
84
+
85
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
86
+ # Or in future might add like a return_all_steps flag
87
+ sol = []
88
+
89
+ # Do not use concat, it may cause memory format changed and trt infer with wrong results!
90
+ # Create tensors with double batch size for CFG (conditional + unconditional)
91
+ x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
92
+ mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
93
+ mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
94
+ t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
95
+ spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
96
+ cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
97
+
98
+ for step in range(1, len(t_span)):
99
+ # Classifier-Free Guidance inference introduced in VoiceBox
100
+ # Copy conditional and unconditional input
101
+ x_in[:batch_size] = x
102
+ x_in[batch_size:] = x
103
+ mask_in[:batch_size] = mask
104
+ mask_in[batch_size:] = mask
105
+ mu_in[:batch_size] = mu
106
+ # Unconditional part remains 0
107
+ t_in.fill_(t)
108
+ spks_in[:batch_size] = spks
109
+ cond_in[:batch_size] = cond
110
+
111
+ dphi_dt = self.estimator(
112
+ x_in, mask_in,
113
+ mu_in, t_in,
114
+ spks_in,
115
+ cond_in,
116
+ streaming
117
+ )
118
+ dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
119
+ dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
120
+ x = x + dt * dphi_dt
121
+ t = t + dt
122
+ sol.append(x)
123
+ if step < len(t_span) - 1:
124
+ dt = t_span[step + 1] - t
125
+
126
+ return sol[-1].float()
127
+
128
+
129
+ class CausalMaskedDiffWithXvec(torch.nn.Module):
130
+ def __init__(
131
+ self,
132
+ input_size: int = 512,
133
+ output_size: int = 80,
134
+ spk_embed_dim: int = 192,
135
+ output_type: str = "mel",
136
+ vocab_size: int = 6561,
137
+ input_frame_rate: int = 25,
138
+ token_mel_ratio: int = 2,
139
+ pre_lookahead_len: int = 3,
140
+ encoder: torch.nn.Module = None,
141
+ decoder: torch.nn.Module = None,
142
+ ):
143
+ super().__init__()
144
+ self.input_size = input_size
145
+ self.output_size = output_size
146
+ self.vocab_size = vocab_size
147
+ self.output_type = output_type
148
+ self.input_frame_rate = input_frame_rate
149
+ self.input_embedding = nn.Embedding(vocab_size, input_size)
150
+ self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
151
+ self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
152
+ self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
153
+ self.decoder = CausalConditionalCFM() if decoder is None else decoder
154
+ self.token_mel_ratio = token_mel_ratio
155
+ self.pre_lookahead_len = pre_lookahead_len
156
+
157
+ @torch.inference_mode()
158
+ def forward(self,
159
+ token,
160
+ token_len,
161
+ prompt_feat,
162
+ prompt_feat_len,
163
+ embedding,
164
+ streaming,
165
+ finalize):
166
+ # xvec projection
167
+ embedding = F.normalize(embedding, dim=1)
168
+ embedding = self.spk_embed_affine_layer(embedding)
169
+
170
+ # concat text and prompt_text
171
+ mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
172
+ token = self.input_embedding(torch.clamp(token, min=0)) * mask
173
+
174
+ # text encode
175
+ if finalize is True:
176
+ h, h_lengths = self.encoder(token, token_len, streaming=streaming)
177
+ else:
178
+ token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
179
+ h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
180
+ h = self.encoder_proj(h)
181
+
182
+ # get conditions
183
+ conds = torch.zeros_like(h, device=token.device)
184
+ for i, j in enumerate(prompt_feat_len):
185
+ conds[i, :j] = prompt_feat[i, :j]
186
+ conds = conds.transpose(1, 2)
187
+
188
+ h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
189
+ mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
190
+ feat, _ = self.decoder(
191
+ mu=h.transpose(1, 2).contiguous(),
192
+ mask=mask.unsqueeze(1),
193
+ spks=embedding,
194
+ cond=conds,
195
+ n_timesteps=10,
196
+ streaming=streaming
197
+ ) # [B, num_mels, T]
198
+ return feat.float(), h_lengths
flashcosyvoice/modules/flow_components/__init__.py ADDED
File without changes
flashcosyvoice/modules/flow_components/estimator.py ADDED
@@ -0,0 +1,974 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Any, Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
8
+ AdaLayerNormZero, ApproximateGELU)
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.models.lora import LoRACompatibleLinear
11
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
12
+ from einops import pack, rearrange, repeat
13
+
14
+ from flashcosyvoice.modules.flow_components.upsample_encoder import \
15
+ add_optional_chunk_mask
16
+
17
+
18
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
19
+ assert mask.dtype == torch.bool
20
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
21
+ mask = mask.to(dtype)
22
+ # attention mask bias
23
+ # NOTE(Mddct): torch.finfo jit issues
24
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
25
+ mask = (1.0 - mask) * -1.0e+10
26
+ return mask
27
+
28
+
29
+ class SnakeBeta(nn.Module):
30
+ """
31
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
32
+ Shape:
33
+ - Input: (B, C, T)
34
+ - Output: (B, C, T), same shape as the input
35
+ Parameters:
36
+ - alpha - trainable parameter that controls frequency
37
+ - beta - trainable parameter that controls magnitude
38
+ References:
39
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
40
+ https://arxiv.org/abs/2006.08195
41
+ Examples:
42
+ >>> a1 = snakebeta(256)
43
+ >>> x = torch.randn(256)
44
+ >>> x = a1(x)
45
+
46
+ Args:
47
+ in_features: shape of the input
48
+ out_features: shape of the output
49
+ alpha: trainable parameter that controls frequency
50
+ alpha_trainable: whether alpha is trainable
51
+ alpha_logscale: whether to use log scale for alpha
52
+ alpha is initialized to 1 by default, higher values = higher-frequency.
53
+ beta is initialized to 1 by default, higher values = higher-magnitude.
54
+ alpha will be trained along with the rest of your model.
55
+ """
56
+
57
+ def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
58
+ super().__init__()
59
+ self.in_features = out_features if isinstance(out_features, list) else [out_features]
60
+ self.proj = LoRACompatibleLinear(in_features, out_features)
61
+
62
+ # initialize alpha
63
+ self.alpha_logscale = alpha_logscale
64
+ if self.alpha_logscale: # log scale alphas initialized to zeros
65
+ self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
66
+ self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
67
+ else: # linear scale alphas initialized to ones
68
+ self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
69
+ self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
70
+
71
+ self.alpha.requires_grad = alpha_trainable
72
+ self.beta.requires_grad = alpha_trainable
73
+
74
+ self.no_div_by_zero = 0.000000001
75
+
76
+ def forward(self, x):
77
+ """
78
+ Forward pass of the function.
79
+ Applies the function to the input elementwise.
80
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
81
+ """
82
+ x = self.proj(x)
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(self.alpha)
85
+ beta = torch.exp(self.beta)
86
+ else:
87
+ alpha = self.alpha
88
+ beta = self.beta
89
+
90
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
91
+
92
+ return x
93
+
94
+
95
+ class FeedForward(nn.Module):
96
+ r"""
97
+ A feed-forward layer.
98
+
99
+ Parameters:
100
+ dim (`int`): The number of channels in the input.
101
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
102
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
103
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
104
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
105
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dim: int,
111
+ dim_out: Optional[int] = None,
112
+ mult: int = 4,
113
+ dropout: float = 0.0,
114
+ activation_fn: str = "geglu",
115
+ final_dropout: bool = False,
116
+ ):
117
+ super().__init__()
118
+ inner_dim = int(dim * mult)
119
+ dim_out = dim_out if dim_out is not None else dim
120
+
121
+ if activation_fn == "gelu":
122
+ act_fn = GELU(dim, inner_dim)
123
+ if activation_fn == "gelu-approximate":
124
+ act_fn = GELU(dim, inner_dim, approximate="tanh")
125
+ elif activation_fn == "geglu":
126
+ act_fn = GEGLU(dim, inner_dim)
127
+ elif activation_fn == "geglu-approximate":
128
+ act_fn = ApproximateGELU(dim, inner_dim)
129
+ elif activation_fn == "snakebeta":
130
+ act_fn = SnakeBeta(dim, inner_dim)
131
+
132
+ self.net = nn.ModuleList([])
133
+ # project in
134
+ self.net.append(act_fn)
135
+ # project dropout
136
+ self.net.append(nn.Dropout(dropout))
137
+ # project out
138
+ self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
139
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
140
+ if final_dropout:
141
+ self.net.append(nn.Dropout(dropout))
142
+
143
+ def forward(self, hidden_states):
144
+ for module in self.net:
145
+ hidden_states = module(hidden_states)
146
+ return hidden_states
147
+
148
+
149
+ @maybe_allow_in_graph
150
+ class BasicTransformerBlock(nn.Module):
151
+ r"""
152
+ A basic Transformer block.
153
+
154
+ Parameters:
155
+ dim (`int`): The number of channels in the input and output.
156
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
157
+ attention_head_dim (`int`): The number of channels in each head.
158
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
159
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
160
+ only_cross_attention (`bool`, *optional*):
161
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
162
+ double_self_attention (`bool`, *optional*):
163
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
164
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
165
+ num_embeds_ada_norm (:
166
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
167
+ attention_bias (:
168
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ dim: int,
174
+ num_attention_heads: int,
175
+ attention_head_dim: int,
176
+ dropout=0.0,
177
+ cross_attention_dim: Optional[int] = None,
178
+ activation_fn: str = "geglu",
179
+ num_embeds_ada_norm: Optional[int] = None,
180
+ attention_bias: bool = False,
181
+ only_cross_attention: bool = False,
182
+ double_self_attention: bool = False,
183
+ upcast_attention: bool = False,
184
+ norm_elementwise_affine: bool = True,
185
+ norm_type: str = "layer_norm",
186
+ final_dropout: bool = False,
187
+ ):
188
+ super().__init__()
189
+ self.only_cross_attention = only_cross_attention
190
+
191
+ self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
192
+ self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
193
+
194
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
195
+ raise ValueError(
196
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
197
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
198
+ )
199
+
200
+ # Define 3 blocks. Each block has its own normalization layer.
201
+ # 1. Self-Attn
202
+ if self.use_ada_layer_norm:
203
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
204
+ elif self.use_ada_layer_norm_zero:
205
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
206
+ else:
207
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
208
+ self.attn1 = Attention(
209
+ query_dim=dim,
210
+ heads=num_attention_heads,
211
+ dim_head=attention_head_dim,
212
+ dropout=dropout,
213
+ bias=attention_bias,
214
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
215
+ upcast_attention=upcast_attention,
216
+ )
217
+
218
+ # 2. Cross-Attn
219
+ if cross_attention_dim is not None or double_self_attention:
220
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
221
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
222
+ # the second cross attention block.
223
+ self.norm2 = (
224
+ AdaLayerNorm(dim, num_embeds_ada_norm)
225
+ if self.use_ada_layer_norm
226
+ else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
227
+ )
228
+ self.attn2 = Attention(
229
+ query_dim=dim,
230
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
231
+ heads=num_attention_heads,
232
+ dim_head=attention_head_dim,
233
+ dropout=dropout,
234
+ bias=attention_bias,
235
+ upcast_attention=upcast_attention,
236
+ # scale_qk=False, # uncomment this to not to use flash attention
237
+ ) # is self-attn if encoder_hidden_states is none
238
+ else:
239
+ self.norm2 = None
240
+ self.attn2 = None
241
+
242
+ # 3. Feed-forward
243
+ self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
244
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
245
+
246
+ # let chunk size default to None
247
+ self._chunk_size = None
248
+ self._chunk_dim = 0
249
+
250
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
251
+ # Sets chunk feed-forward
252
+ self._chunk_size = chunk_size
253
+ self._chunk_dim = dim
254
+
255
+ def forward(
256
+ self,
257
+ hidden_states: torch.FloatTensor,
258
+ attention_mask: Optional[torch.FloatTensor] = None,
259
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
260
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
261
+ timestep: Optional[torch.LongTensor] = None,
262
+ cross_attention_kwargs: Dict[str, Any] = None,
263
+ class_labels: Optional[torch.LongTensor] = None,
264
+ ):
265
+ # Notice that normalization is always applied before the real computation in the following blocks.
266
+ # 1. Self-Attention
267
+ if self.use_ada_layer_norm:
268
+ norm_hidden_states = self.norm1(hidden_states, timestep)
269
+ elif self.use_ada_layer_norm_zero:
270
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
271
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
272
+ )
273
+ else:
274
+ norm_hidden_states = self.norm1(hidden_states)
275
+
276
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
277
+
278
+ attn_output = self.attn1(
279
+ norm_hidden_states,
280
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
281
+ attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
282
+ **cross_attention_kwargs,
283
+ )
284
+ if self.use_ada_layer_norm_zero:
285
+ attn_output = gate_msa.unsqueeze(1) * attn_output
286
+ hidden_states = attn_output + hidden_states
287
+
288
+ # 2. Cross-Attention
289
+ if self.attn2 is not None:
290
+ norm_hidden_states = (
291
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
292
+ )
293
+
294
+ attn_output = self.attn2(
295
+ norm_hidden_states,
296
+ encoder_hidden_states=encoder_hidden_states,
297
+ attention_mask=encoder_attention_mask,
298
+ **cross_attention_kwargs,
299
+ )
300
+ hidden_states = attn_output + hidden_states
301
+
302
+ # 3. Feed-forward
303
+ norm_hidden_states = self.norm3(hidden_states)
304
+
305
+ if self.use_ada_layer_norm_zero:
306
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
307
+
308
+ if self._chunk_size is not None:
309
+ # "feed_forward_chunk_size" can be used to save memory
310
+ if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
311
+ raise ValueError(
312
+ f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
313
+ )
314
+
315
+ num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
316
+ ff_output = torch.cat(
317
+ [self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
318
+ dim=self._chunk_dim,
319
+ )
320
+ else:
321
+ ff_output = self.ff(norm_hidden_states)
322
+
323
+ if self.use_ada_layer_norm_zero:
324
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
325
+
326
+ hidden_states = ff_output + hidden_states
327
+
328
+ return hidden_states
329
+
330
+
331
+ class SinusoidalPosEmb(torch.nn.Module):
332
+ def __init__(self, dim):
333
+ super().__init__()
334
+ self.dim = dim
335
+ assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
336
+
337
+ def forward(self, x, scale=1000):
338
+ if x.ndim < 1:
339
+ x = x.unsqueeze(0)
340
+ device = x.device
341
+ half_dim = self.dim // 2
342
+ emb = math.log(10000) / (half_dim - 1)
343
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
344
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
345
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
346
+ return emb
347
+
348
+
349
+ class Block1D(torch.nn.Module):
350
+ def __init__(self, dim, dim_out, groups=8):
351
+ super().__init__()
352
+ self.block = torch.nn.Sequential(
353
+ torch.nn.Conv1d(dim, dim_out, 3, padding=1),
354
+ torch.nn.GroupNorm(groups, dim_out),
355
+ nn.Mish(),
356
+ )
357
+
358
+ def forward(self, x, mask):
359
+ output = self.block(x * mask)
360
+ return output * mask
361
+
362
+
363
+ class ResnetBlock1D(torch.nn.Module):
364
+ def __init__(self, dim, dim_out, time_emb_dim, groups=8):
365
+ super().__init__()
366
+ self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
367
+
368
+ self.block1 = Block1D(dim, dim_out, groups=groups)
369
+ self.block2 = Block1D(dim_out, dim_out, groups=groups)
370
+
371
+ self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
372
+
373
+ def forward(self, x, mask, time_emb):
374
+ h = self.block1(x, mask)
375
+ h += self.mlp(time_emb).unsqueeze(-1)
376
+ h = self.block2(h, mask)
377
+ output = h + self.res_conv(x * mask)
378
+ return output
379
+
380
+
381
+ class Downsample1D(nn.Module):
382
+ def __init__(self, dim):
383
+ super().__init__()
384
+ self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
385
+
386
+ def forward(self, x):
387
+ return self.conv(x)
388
+
389
+
390
+ class TimestepEmbedding(nn.Module):
391
+ def __init__(
392
+ self,
393
+ in_channels: int,
394
+ time_embed_dim: int,
395
+ act_fn: str = "silu",
396
+ out_dim: int = None,
397
+ post_act_fn: Optional[str] = None,
398
+ cond_proj_dim=None,
399
+ ):
400
+ super().__init__()
401
+ assert act_fn == "silu", "act_fn must be silu"
402
+
403
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim)
404
+
405
+ if cond_proj_dim is not None:
406
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
407
+ else:
408
+ self.cond_proj = None
409
+
410
+ self.act = nn.SiLU()
411
+
412
+ if out_dim is not None:
413
+ time_embed_dim_out = out_dim
414
+ else:
415
+ time_embed_dim_out = time_embed_dim
416
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
417
+
418
+ if post_act_fn is None:
419
+ self.post_act = None
420
+ else:
421
+ self.post_act = nn.SiLU()
422
+
423
+ def forward(self, sample, condition=None):
424
+ if condition is not None:
425
+ sample = sample + self.cond_proj(condition)
426
+ sample = self.linear_1(sample)
427
+
428
+ if self.act is not None:
429
+ sample = self.act(sample)
430
+
431
+ sample = self.linear_2(sample)
432
+
433
+ if self.post_act is not None:
434
+ sample = self.post_act(sample)
435
+ return sample
436
+
437
+
438
+ class Upsample1D(nn.Module):
439
+ """A 1D upsampling layer with an optional convolution.
440
+
441
+ Parameters:
442
+ channels (`int`):
443
+ number of channels in the inputs and outputs.
444
+ use_conv (`bool`, default `False`):
445
+ option to use a convolution.
446
+ use_conv_transpose (`bool`, default `False`):
447
+ option to use a convolution transpose.
448
+ out_channels (`int`, optional):
449
+ number of output channels. Defaults to `channels`.
450
+ """
451
+
452
+ def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
453
+ super().__init__()
454
+ self.channels = channels
455
+ self.out_channels = out_channels or channels
456
+ self.use_conv = use_conv
457
+ self.use_conv_transpose = use_conv_transpose
458
+ self.name = name
459
+
460
+ self.conv = None
461
+ if use_conv_transpose:
462
+ self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
463
+ elif use_conv:
464
+ self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
465
+
466
+ def forward(self, inputs):
467
+ assert inputs.shape[1] == self.channels
468
+ if self.use_conv_transpose:
469
+ return self.conv(inputs)
470
+
471
+ outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
472
+
473
+ if self.use_conv:
474
+ outputs = self.conv(outputs)
475
+
476
+ return outputs
477
+
478
+
479
+ class Transpose(torch.nn.Module):
480
+ def __init__(self, dim0: int, dim1: int):
481
+ super().__init__()
482
+ self.dim0 = dim0
483
+ self.dim1 = dim1
484
+
485
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
486
+ x = torch.transpose(x, self.dim0, self.dim1)
487
+ return x
488
+
489
+
490
+ class CausalConv1d(torch.nn.Conv1d):
491
+ def __init__(
492
+ self,
493
+ in_channels: int,
494
+ out_channels: int,
495
+ kernel_size: int,
496
+ stride: int = 1,
497
+ dilation: int = 1,
498
+ groups: int = 1,
499
+ bias: bool = True,
500
+ padding_mode: str = 'zeros',
501
+ device=None,
502
+ dtype=None
503
+ ) -> None:
504
+ super(CausalConv1d, self).__init__(in_channels, out_channels,
505
+ kernel_size, stride,
506
+ padding=0, dilation=dilation,
507
+ groups=groups, bias=bias,
508
+ padding_mode=padding_mode,
509
+ device=device, dtype=dtype)
510
+ assert stride == 1
511
+ self.causal_padding = kernel_size - 1
512
+
513
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
514
+ x = F.pad(x, (self.causal_padding, 0), value=0.0)
515
+ x = super(CausalConv1d, self).forward(x)
516
+ return x
517
+
518
+
519
+ class CausalBlock1D(Block1D):
520
+ def __init__(self, dim: int, dim_out: int):
521
+ super(CausalBlock1D, self).__init__(dim, dim_out)
522
+ self.block = torch.nn.Sequential(
523
+ CausalConv1d(dim, dim_out, 3),
524
+ Transpose(1, 2),
525
+ nn.LayerNorm(dim_out),
526
+ Transpose(1, 2),
527
+ nn.Mish(),
528
+ )
529
+
530
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
531
+ output = self.block(x * mask)
532
+ return output * mask
533
+
534
+
535
+ class CausalResnetBlock1D(ResnetBlock1D):
536
+ def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
537
+ super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
538
+ self.block1 = CausalBlock1D(dim, dim_out)
539
+ self.block2 = CausalBlock1D(dim_out, dim_out)
540
+
541
+
542
+ class ConditionalDecoder(nn.Module):
543
+ """
544
+ This decoder requires an input with the same shape of the target. So, if your text content
545
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
546
+
547
+ Args:
548
+ in_channels: number of input channels
549
+ out_channels: number of output channels
550
+ channels: tuple of channel dimensions
551
+ dropout: dropout rate
552
+ attention_head_dim: dimension of attention heads
553
+ n_blocks: number of transformer blocks
554
+ num_mid_blocks: number of middle blocks
555
+ num_heads: number of attention heads
556
+ act_fn: activation function name
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ in_channels,
562
+ out_channels,
563
+ channels=(256, 256),
564
+ dropout=0.05,
565
+ attention_head_dim=64,
566
+ n_blocks=1,
567
+ num_mid_blocks=2,
568
+ num_heads=4,
569
+ act_fn="snake",
570
+ ):
571
+ super().__init__()
572
+ channels = tuple(channels)
573
+ self.in_channels = in_channels
574
+ self.out_channels = out_channels
575
+
576
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
577
+ time_embed_dim = channels[0] * 4
578
+ self.time_mlp = TimestepEmbedding(
579
+ in_channels=in_channels,
580
+ time_embed_dim=time_embed_dim,
581
+ act_fn="silu",
582
+ )
583
+ self.down_blocks = nn.ModuleList([])
584
+ self.mid_blocks = nn.ModuleList([])
585
+ self.up_blocks = nn.ModuleList([])
586
+
587
+ output_channel = in_channels
588
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
589
+ input_channel = output_channel
590
+ output_channel = channels[i]
591
+ is_last = i == len(channels) - 1
592
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
593
+ transformer_blocks = nn.ModuleList(
594
+ [
595
+ BasicTransformerBlock(
596
+ dim=output_channel,
597
+ num_attention_heads=num_heads,
598
+ attention_head_dim=attention_head_dim,
599
+ dropout=dropout,
600
+ activation_fn=act_fn,
601
+ )
602
+ for _ in range(n_blocks)
603
+ ]
604
+ )
605
+ downsample = (
606
+ Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
607
+ )
608
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
609
+
610
+ for _ in range(num_mid_blocks):
611
+ input_channel = channels[-1]
612
+ out_channels = channels[-1]
613
+ resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
614
+
615
+ transformer_blocks = nn.ModuleList(
616
+ [
617
+ BasicTransformerBlock(
618
+ dim=output_channel,
619
+ num_attention_heads=num_heads,
620
+ attention_head_dim=attention_head_dim,
621
+ dropout=dropout,
622
+ activation_fn=act_fn,
623
+ )
624
+ for _ in range(n_blocks)
625
+ ]
626
+ )
627
+
628
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
629
+
630
+ channels = channels[::-1] + (channels[0],)
631
+ for i in range(len(channels) - 1):
632
+ input_channel = channels[i] * 2
633
+ output_channel = channels[i + 1]
634
+ is_last = i == len(channels) - 2
635
+ resnet = ResnetBlock1D(
636
+ dim=input_channel,
637
+ dim_out=output_channel,
638
+ time_emb_dim=time_embed_dim,
639
+ )
640
+ transformer_blocks = nn.ModuleList(
641
+ [
642
+ BasicTransformerBlock(
643
+ dim=output_channel,
644
+ num_attention_heads=num_heads,
645
+ attention_head_dim=attention_head_dim,
646
+ dropout=dropout,
647
+ activation_fn=act_fn,
648
+ )
649
+ for _ in range(n_blocks)
650
+ ]
651
+ )
652
+ upsample = (
653
+ Upsample1D(output_channel, use_conv_transpose=True)
654
+ if not is_last
655
+ else nn.Conv1d(output_channel, output_channel, 3, padding=1)
656
+ )
657
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
658
+ self.final_block = Block1D(channels[-1], channels[-1])
659
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
660
+ self.initialize_weights()
661
+
662
+ def initialize_weights(self):
663
+ for m in self.modules():
664
+ if isinstance(m, nn.Conv1d):
665
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
666
+ if m.bias is not None:
667
+ nn.init.constant_(m.bias, 0)
668
+ elif isinstance(m, nn.GroupNorm):
669
+ nn.init.constant_(m.weight, 1)
670
+ nn.init.constant_(m.bias, 0)
671
+ elif isinstance(m, nn.Linear):
672
+ nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
673
+ if m.bias is not None:
674
+ nn.init.constant_(m.bias, 0)
675
+
676
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
677
+ """Forward pass of the UNet1DConditional model.
678
+
679
+ Args:
680
+ x (torch.Tensor): shape (batch_size, in_channels, time)
681
+ mask (_type_): shape (batch_size, 1, time)
682
+ t (_type_): shape (batch_size)
683
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
684
+ cond (_type_, optional): placeholder for future use. Defaults to None.
685
+
686
+ Raises:
687
+ ValueError: _description_
688
+ ValueError: _description_
689
+
690
+ Returns:
691
+ _type_: _description_
692
+ """
693
+
694
+ t = self.time_embeddings(t).to(t.dtype)
695
+ t = self.time_mlp(t)
696
+
697
+ x = pack([x, mu], "b * t")[0]
698
+
699
+ if spks is not None:
700
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
701
+ x = pack([x, spks], "b * t")[0]
702
+ if cond is not None:
703
+ x = pack([x, cond], "b * t")[0]
704
+
705
+ hiddens = []
706
+ masks = [mask]
707
+ for resnet, transformer_blocks, downsample in self.down_blocks:
708
+ mask_down = masks[-1]
709
+ x = resnet(x, mask_down, t)
710
+ x = rearrange(x, "b c t -> b t c").contiguous()
711
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
712
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
713
+ for transformer_block in transformer_blocks:
714
+ x = transformer_block(
715
+ hidden_states=x,
716
+ attention_mask=attn_mask,
717
+ timestep=t,
718
+ )
719
+ x = rearrange(x, "b t c -> b c t").contiguous()
720
+ hiddens.append(x) # Save hidden states for skip connections
721
+ x = downsample(x * mask_down)
722
+ masks.append(mask_down[:, :, ::2])
723
+ masks = masks[:-1]
724
+ mask_mid = masks[-1]
725
+
726
+ for resnet, transformer_blocks in self.mid_blocks:
727
+ x = resnet(x, mask_mid, t)
728
+ x = rearrange(x, "b c t -> b t c").contiguous()
729
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
730
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
731
+ for transformer_block in transformer_blocks:
732
+ x = transformer_block(
733
+ hidden_states=x,
734
+ attention_mask=attn_mask,
735
+ timestep=t,
736
+ )
737
+ x = rearrange(x, "b t c -> b c t").contiguous()
738
+
739
+ for resnet, transformer_blocks, upsample in self.up_blocks:
740
+ mask_up = masks.pop()
741
+ skip = hiddens.pop()
742
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
743
+ x = resnet(x, mask_up, t)
744
+ x = rearrange(x, "b c t -> b t c").contiguous()
745
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
746
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
747
+ for transformer_block in transformer_blocks:
748
+ x = transformer_block(
749
+ hidden_states=x,
750
+ attention_mask=attn_mask,
751
+ timestep=t,
752
+ )
753
+ x = rearrange(x, "b t c -> b c t").contiguous()
754
+ x = upsample(x * mask_up)
755
+ x = self.final_block(x, mask_up)
756
+ output = self.final_proj(x * mask_up)
757
+ return output * mask
758
+
759
+
760
+ class CausalConditionalDecoder(ConditionalDecoder):
761
+ """
762
+ This decoder requires an input with the same shape of the target. So, if your text content
763
+ is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
764
+
765
+ Args:
766
+ in_channels: number of input channels
767
+ out_channels: number of output channels
768
+ channels: list of channel dimensions
769
+ dropout: dropout rate
770
+ attention_head_dim: dimension of attention heads
771
+ n_blocks: number of transformer blocks
772
+ num_mid_blocks: number of middle blocks
773
+ num_heads: number of attention heads
774
+ act_fn: activation function name
775
+ static_chunk_size: size of static chunks
776
+ num_decoding_left_chunks: number of left chunks for decoding
777
+ """
778
+
779
+ def __init__(
780
+ self,
781
+ in_channels=320,
782
+ out_channels=80,
783
+ channels=[256], # noqa
784
+ dropout=0.0,
785
+ attention_head_dim=64,
786
+ n_blocks=4,
787
+ num_mid_blocks=12,
788
+ num_heads=8,
789
+ act_fn="gelu",
790
+ static_chunk_size=50,
791
+ num_decoding_left_chunks=-1,
792
+ ):
793
+ torch.nn.Module.__init__(self)
794
+ channels = tuple(channels)
795
+ self.in_channels = in_channels
796
+ self.out_channels = out_channels
797
+ self.time_embeddings = SinusoidalPosEmb(in_channels)
798
+ time_embed_dim = channels[0] * 4
799
+ self.time_mlp = TimestepEmbedding(
800
+ in_channels=in_channels,
801
+ time_embed_dim=time_embed_dim,
802
+ act_fn="silu",
803
+ )
804
+ self.static_chunk_size = static_chunk_size
805
+ self.num_decoding_left_chunks = num_decoding_left_chunks
806
+ self.down_blocks = nn.ModuleList([])
807
+ self.mid_blocks = nn.ModuleList([])
808
+ self.up_blocks = nn.ModuleList([])
809
+
810
+ output_channel = in_channels
811
+ for i in range(len(channels)): # pylint: disable=consider-using-enumerate
812
+ input_channel = output_channel
813
+ output_channel = channels[i]
814
+ is_last = i == len(channels) - 1
815
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
816
+ transformer_blocks = nn.ModuleList(
817
+ [
818
+ BasicTransformerBlock(
819
+ dim=output_channel,
820
+ num_attention_heads=num_heads,
821
+ attention_head_dim=attention_head_dim,
822
+ dropout=dropout,
823
+ activation_fn=act_fn,
824
+ )
825
+ for _ in range(n_blocks)
826
+ ]
827
+ )
828
+ downsample = (
829
+ Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
830
+ )
831
+ self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
832
+
833
+ for _ in range(num_mid_blocks):
834
+ input_channel = channels[-1]
835
+ out_channels = channels[-1]
836
+ resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
837
+
838
+ transformer_blocks = nn.ModuleList(
839
+ [
840
+ BasicTransformerBlock(
841
+ dim=output_channel,
842
+ num_attention_heads=num_heads,
843
+ attention_head_dim=attention_head_dim,
844
+ dropout=dropout,
845
+ activation_fn=act_fn,
846
+ )
847
+ for _ in range(n_blocks)
848
+ ]
849
+ )
850
+
851
+ self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
852
+
853
+ channels = channels[::-1] + (channels[0],)
854
+ for i in range(len(channels) - 1):
855
+ input_channel = channels[i] * 2
856
+ output_channel = channels[i + 1]
857
+ is_last = i == len(channels) - 2
858
+ resnet = CausalResnetBlock1D(
859
+ dim=input_channel,
860
+ dim_out=output_channel,
861
+ time_emb_dim=time_embed_dim,
862
+ )
863
+ transformer_blocks = nn.ModuleList(
864
+ [
865
+ BasicTransformerBlock(
866
+ dim=output_channel,
867
+ num_attention_heads=num_heads,
868
+ attention_head_dim=attention_head_dim,
869
+ dropout=dropout,
870
+ activation_fn=act_fn,
871
+ )
872
+ for _ in range(n_blocks)
873
+ ]
874
+ )
875
+ upsample = (
876
+ Upsample1D(output_channel, use_conv_transpose=True)
877
+ if not is_last
878
+ else CausalConv1d(output_channel, output_channel, 3)
879
+ )
880
+ self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
881
+ self.final_block = CausalBlock1D(channels[-1], channels[-1])
882
+ self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
883
+ self.initialize_weights()
884
+
885
+ def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
886
+ """Forward pass of the UNet1DConditional model.
887
+
888
+ Args:
889
+ x (torch.Tensor): shape (batch_size, in_channels, time)
890
+ mask (_type_): shape (batch_size, 1, time)
891
+ t (_type_): shape (batch_size)
892
+ spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
893
+ cond (_type_, optional): placeholder for future use. Defaults to None.
894
+
895
+ Raises:
896
+ ValueError: _description_
897
+ ValueError: _description_
898
+
899
+ Returns:
900
+ _type_: _description_
901
+ """
902
+ t = self.time_embeddings(t).to(t.dtype)
903
+ t = self.time_mlp(t)
904
+
905
+ x = pack([x, mu], "b * t")[0]
906
+
907
+ if spks is not None:
908
+ spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
909
+ x = pack([x, spks], "b * t")[0]
910
+ if cond is not None:
911
+ x = pack([x, cond], "b * t")[0]
912
+
913
+ hiddens = []
914
+ masks = [mask]
915
+ for resnet, transformer_blocks, downsample in self.down_blocks:
916
+ mask_down = masks[-1]
917
+ x = resnet(x, mask_down, t)
918
+ x = rearrange(x, "b c t -> b t c").contiguous()
919
+ if streaming is True:
920
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
921
+ else:
922
+ attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
923
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
924
+ for transformer_block in transformer_blocks:
925
+ x = transformer_block(
926
+ hidden_states=x,
927
+ attention_mask=attn_mask,
928
+ timestep=t,
929
+ )
930
+ x = rearrange(x, "b t c -> b c t").contiguous()
931
+ hiddens.append(x) # Save hidden states for skip connections
932
+ x = downsample(x * mask_down)
933
+ masks.append(mask_down[:, :, ::2])
934
+ masks = masks[:-1]
935
+ mask_mid = masks[-1]
936
+
937
+ for resnet, transformer_blocks in self.mid_blocks:
938
+ x = resnet(x, mask_mid, t)
939
+ x = rearrange(x, "b c t -> b t c").contiguous()
940
+ if streaming is True:
941
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
942
+ else:
943
+ attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
944
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
945
+ for transformer_block in transformer_blocks:
946
+ x = transformer_block(
947
+ hidden_states=x,
948
+ attention_mask=attn_mask,
949
+ timestep=t,
950
+ )
951
+ x = rearrange(x, "b t c -> b c t").contiguous()
952
+
953
+ for resnet, transformer_blocks, upsample in self.up_blocks:
954
+ mask_up = masks.pop()
955
+ skip = hiddens.pop()
956
+ x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
957
+ x = resnet(x, mask_up, t)
958
+ x = rearrange(x, "b c t -> b t c").contiguous()
959
+ if streaming is True:
960
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
961
+ else:
962
+ attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
963
+ attn_mask = mask_to_bias(attn_mask, x.dtype)
964
+ for transformer_block in transformer_blocks:
965
+ x = transformer_block(
966
+ hidden_states=x,
967
+ attention_mask=attn_mask,
968
+ timestep=t,
969
+ )
970
+ x = rearrange(x, "b t c -> b c t").contiguous()
971
+ x = upsample(x * mask_up)
972
+ x = self.final_block(x, mask_up)
973
+ output = self.final_proj(x * mask_up)
974
+ return output * mask
flashcosyvoice/modules/flow_components/upsample_encoder.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def subsequent_chunk_mask(
10
+ size: int,
11
+ chunk_size: int,
12
+ num_left_chunks: int = -1,
13
+ device: torch.device = torch.device("cpu"),
14
+ ) -> torch.Tensor:
15
+ """Create mask for subsequent steps (size, size) with chunk size,
16
+ this is for streaming encoder
17
+
18
+ Args:
19
+ size (int): size of mask
20
+ chunk_size (int): size of chunk
21
+ num_left_chunks (int): number of left chunks
22
+ <0: use full chunk
23
+ >=0: use num_left_chunks
24
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
25
+
26
+ Returns:
27
+ torch.Tensor: mask
28
+
29
+ Examples:
30
+ >>> subsequent_chunk_mask(4, 2)
31
+ [[1, 1, 0, 0],
32
+ [1, 1, 0, 0],
33
+ [1, 1, 1, 1],
34
+ [1, 1, 1, 1]]
35
+ """
36
+ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
37
+ pos_idx = torch.arange(size, device=device)
38
+ block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
39
+ ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
40
+ return ret
41
+
42
+
43
+ def add_optional_chunk_mask(xs: torch.Tensor,
44
+ masks: torch.Tensor,
45
+ use_dynamic_chunk: bool,
46
+ use_dynamic_left_chunk: bool,
47
+ decoding_chunk_size: int,
48
+ static_chunk_size: int,
49
+ num_decoding_left_chunks: int,
50
+ enable_full_context: bool = True):
51
+ """ Apply optional mask for encoder.
52
+
53
+ Args:
54
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
55
+ mask (torch.Tensor): mask for xs, (B, 1, L)
56
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
57
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
58
+ training.
59
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
60
+ 0: default for training, use random dynamic chunk.
61
+ <0: for decoding, use full chunk.
62
+ >0: for decoding, use fixed chunk size as set.
63
+ static_chunk_size (int): chunk size for static chunk training/decoding
64
+ if it's greater than 0, if use_dynamic_chunk is true,
65
+ this parameter will be ignored
66
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
67
+ the chunk size is decoding_chunk_size.
68
+ >=0: use num_decoding_left_chunks
69
+ <0: use all left chunks
70
+ enable_full_context (bool):
71
+ True: chunk size is either [1, 25] or full context(max_len)
72
+ False: chunk size ~ U[1, 25]
73
+
74
+ Returns:
75
+ torch.Tensor: chunk mask of the input xs.
76
+ """
77
+ # Whether to use chunk mask or not
78
+ if use_dynamic_chunk:
79
+ max_len = xs.size(1)
80
+ if decoding_chunk_size < 0:
81
+ chunk_size = max_len
82
+ num_left_chunks = -1
83
+ elif decoding_chunk_size > 0:
84
+ chunk_size = decoding_chunk_size
85
+ num_left_chunks = num_decoding_left_chunks
86
+ else:
87
+ # chunk size is either [1, 25] or full context(max_len).
88
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
89
+ # delay, the maximum frame is 100 / 4 = 25.
90
+ chunk_size = torch.randint(1, max_len, (1, )).item()
91
+ num_left_chunks = -1
92
+ if chunk_size > max_len // 2 and enable_full_context:
93
+ chunk_size = max_len
94
+ else:
95
+ chunk_size = chunk_size % 25 + 1
96
+ if use_dynamic_left_chunk:
97
+ max_left_chunks = (max_len - 1) // chunk_size
98
+ num_left_chunks = torch.randint(0, max_left_chunks,
99
+ (1, )).item()
100
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
101
+ num_left_chunks,
102
+ xs.device) # (L, L)
103
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
104
+ chunk_masks = masks & chunk_masks # (B, L, L)
105
+ elif static_chunk_size > 0:
106
+ num_left_chunks = num_decoding_left_chunks
107
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
108
+ num_left_chunks,
109
+ xs.device) # (L, L)
110
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
111
+ chunk_masks = masks & chunk_masks # (B, L, L)
112
+ else:
113
+ chunk_masks = masks
114
+ assert chunk_masks.dtype == torch.bool
115
+ if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
116
+ print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
117
+ chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
118
+ return chunk_masks
119
+
120
+
121
+ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
122
+ """Make mask tensor containing indices of padded part.
123
+
124
+ See description of make_non_pad_mask.
125
+
126
+ Args:
127
+ lengths (torch.Tensor): Batch of lengths (B,).
128
+ Returns:
129
+ torch.Tensor: Mask tensor containing indices of padded part.
130
+
131
+ Examples:
132
+ >>> lengths = [5, 3, 2]
133
+ >>> make_pad_mask(lengths)
134
+ masks = [[0, 0, 0, 0 ,0],
135
+ [0, 0, 0, 1, 1],
136
+ [0, 0, 1, 1, 1]]
137
+ """
138
+ batch_size = lengths.size(0)
139
+ max_len = max_len if max_len > 0 else lengths.max().item()
140
+ seq_range = torch.arange(0,
141
+ max_len,
142
+ dtype=torch.int64,
143
+ device=lengths.device)
144
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
145
+ seq_length_expand = lengths.unsqueeze(-1)
146
+ mask = seq_range_expand >= seq_length_expand
147
+ return mask
148
+
149
+
150
+ class EspnetRelPositionalEncoding(torch.nn.Module):
151
+ """Relative positional encoding module (new implementation).
152
+
153
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
154
+
155
+ See : Appendix B in https://arxiv.org/abs/1901.02860
156
+
157
+ Args:
158
+ d_model (int): Embedding dimension.
159
+ max_len (int): Maximum input length.
160
+
161
+ """
162
+
163
+ def __init__(self, d_model: int, max_len: int = 5000):
164
+ super(EspnetRelPositionalEncoding, self).__init__()
165
+ self.d_model = d_model
166
+ self.xscale = math.sqrt(self.d_model)
167
+ self.pe = None
168
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
169
+
170
+ def extend_pe(self, x: torch.Tensor):
171
+ """Reset the positional encodings."""
172
+ if self.pe is not None:
173
+ # self.pe contains both positive and negative parts
174
+ # the length of self.pe is 2 * input_len - 1
175
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
176
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
177
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
178
+ return
179
+ # Suppose `i` means to the position of query vecotr and `j` means the
180
+ # position of key vector. We use position relative positions when keys
181
+ # are to the left (i>j) and negative relative positions otherwise (i<j).
182
+ pe_positive = torch.zeros(x.size(1), self.d_model)
183
+ pe_negative = torch.zeros(x.size(1), self.d_model)
184
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
185
+ div_term = torch.exp(
186
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
187
+ * -(math.log(10000.0) / self.d_model)
188
+ )
189
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
190
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
191
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
192
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
193
+
194
+ # Reserve the order of positive indices and concat both positive and
195
+ # negative indices. This is used to support the shifting trick
196
+ # as in https://arxiv.org/abs/1901.02860
197
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
198
+ pe_negative = pe_negative[1:].unsqueeze(0)
199
+ pe = torch.cat([pe_positive, pe_negative], dim=1)
200
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
201
+
202
+ def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
203
+ -> Tuple[torch.Tensor, torch.Tensor]:
204
+ """Add positional encoding.
205
+
206
+ Args:
207
+ x (torch.Tensor): Input tensor (batch, time, `*`).
208
+
209
+ Returns:
210
+ torch.Tensor: Encoded tensor (batch, time, `*`).
211
+
212
+ """
213
+ self.extend_pe(x)
214
+ x = x * self.xscale
215
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
216
+ return x, pos_emb
217
+
218
+ def position_encoding(self,
219
+ offset: Union[int, torch.Tensor],
220
+ size: int) -> torch.Tensor:
221
+ """ For getting encoding in a streaming fashion
222
+
223
+ Attention!!!!!
224
+ we apply dropout only once at the whole utterance level in a none
225
+ streaming way, but will call this function several times with
226
+ increasing input size in a streaming scenario, so the dropout will
227
+ be applied several times.
228
+
229
+ Args:
230
+ offset (int or torch.tensor): start offset
231
+ size (int): required size of position encoding
232
+
233
+ Returns:
234
+ torch.Tensor: Corresponding encoding
235
+ """
236
+ # How to subscript a Union type:
237
+ # https://github.com/pytorch/pytorch/issues/69434
238
+ if isinstance(offset, int):
239
+ pos_emb = self.pe[
240
+ :,
241
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
242
+ ]
243
+ elif isinstance(offset, torch.Tensor):
244
+ pos_emb = self.pe[
245
+ :,
246
+ self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
247
+ ]
248
+ return pos_emb
249
+
250
+
251
+ class LinearNoSubsampling(torch.nn.Module):
252
+ """Linear transform the input without subsampling
253
+
254
+ Args:
255
+ idim (int): Input dimension.
256
+ odim (int): Output dimension.
257
+ pos_enc_class (torch.nn.Module): Positional encoding class.
258
+
259
+ """
260
+
261
+ def __init__(self, idim: int, odim: int,
262
+ pos_enc_class: torch.nn.Module):
263
+ super().__init__()
264
+ self.out = torch.nn.Sequential(
265
+ torch.nn.Linear(idim, odim),
266
+ torch.nn.LayerNorm(odim, eps=1e-5),
267
+ )
268
+ self.pos_enc = pos_enc_class
269
+ self.right_context = 0
270
+ self.subsampling_rate = 1
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ x_mask: torch.Tensor,
276
+ offset: Union[int, torch.Tensor] = 0
277
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
278
+ """Input x.
279
+
280
+ Args:
281
+ x (torch.Tensor): Input tensor (#batch, time, idim).
282
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
283
+
284
+ Returns:
285
+ torch.Tensor: linear input tensor (#batch, time', odim),
286
+ where time' = time .
287
+ torch.Tensor: linear input mask (#batch, 1, time'),
288
+ where time' = time .
289
+
290
+ """
291
+ x = self.out(x)
292
+ x, pos_emb = self.pos_enc(x, offset)
293
+ return x, pos_emb, x_mask
294
+
295
+ def position_encoding(self, offset: Union[int, torch.Tensor],
296
+ size: int) -> torch.Tensor:
297
+ return self.pos_enc.position_encoding(offset, size)
298
+
299
+
300
+ class Upsample1D(nn.Module):
301
+ """A 1D upsampling layer with an optional convolution.
302
+
303
+ Parameters:
304
+ channels (`int`):
305
+ number of channels in the inputs and outputs.
306
+ use_conv (`bool`, default `False`):
307
+ option to use a convolution.
308
+ use_conv_transpose (`bool`, default `False`):
309
+ option to use a convolution transpose.
310
+ out_channels (`int`, optional):
311
+ number of output channels. Defaults to `channels`.
312
+ """
313
+
314
+ def __init__(self, channels: int, out_channels: int, stride: int = 2):
315
+ super().__init__()
316
+ self.channels = channels
317
+ self.out_channels = out_channels
318
+ self.stride = stride
319
+ # In this mode, first repeat interpolate, than conv with stride=1
320
+ self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
321
+
322
+ def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
323
+ outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
324
+ outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
325
+ outputs = self.conv(outputs)
326
+ return outputs, input_lengths * self.stride
327
+
328
+
329
+ class PreLookaheadLayer(nn.Module):
330
+ def __init__(self, channels: int, pre_lookahead_len: int = 1):
331
+ super().__init__()
332
+ self.channels = channels
333
+ self.pre_lookahead_len = pre_lookahead_len
334
+ self.conv1 = nn.Conv1d(
335
+ channels, channels,
336
+ kernel_size=pre_lookahead_len + 1,
337
+ stride=1, padding=0,
338
+ )
339
+ self.conv2 = nn.Conv1d(
340
+ channels, channels,
341
+ kernel_size=3, stride=1, padding=0,
342
+ )
343
+
344
+ def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
345
+ """
346
+ inputs: (batch_size, seq_len, channels)
347
+ """
348
+ outputs = inputs.transpose(1, 2).contiguous()
349
+ context = context.transpose(1, 2).contiguous()
350
+ # look ahead
351
+ if context.size(2) == 0:
352
+ outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
353
+ else:
354
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
355
+ assert context.size(2) == self.pre_lookahead_len
356
+ outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
357
+ outputs = F.leaky_relu(self.conv1(outputs))
358
+ # outputs
359
+ outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
360
+ outputs = self.conv2(outputs)
361
+ outputs = outputs.transpose(1, 2).contiguous()
362
+
363
+ # residual connection
364
+ outputs = outputs + inputs
365
+ return outputs
366
+
367
+
368
+ class MultiHeadedAttention(nn.Module):
369
+ """Multi-Head Attention layer.
370
+
371
+ Args:
372
+ n_head (int): The number of heads.
373
+ n_feat (int): The number of features.
374
+ dropout_rate (float): Dropout rate.
375
+ key_bias (bool): Whether to use bias in key linear layer.
376
+
377
+ """
378
+
379
+ def __init__(self,
380
+ n_head: int,
381
+ n_feat: int,
382
+ dropout_rate: float,
383
+ key_bias: bool = True):
384
+ super().__init__()
385
+ assert n_feat % n_head == 0
386
+ # We assume d_v always equals d_k
387
+ self.d_k = n_feat // n_head
388
+ self.h = n_head
389
+ self.linear_q = nn.Linear(n_feat, n_feat)
390
+ self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
391
+ self.linear_v = nn.Linear(n_feat, n_feat)
392
+ self.linear_out = nn.Linear(n_feat, n_feat)
393
+ self.dropout = nn.Dropout(p=dropout_rate)
394
+
395
+ def forward_qkv(
396
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
397
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
398
+ """Transform query, key and value.
399
+
400
+ Args:
401
+ query (torch.Tensor): Query tensor (#batch, time1, size).
402
+ key (torch.Tensor): Key tensor (#batch, time2, size).
403
+ value (torch.Tensor): Value tensor (#batch, time2, size).
404
+
405
+ Returns:
406
+ torch.Tensor: Transformed query tensor, size
407
+ (#batch, n_head, time1, d_k).
408
+ torch.Tensor: Transformed key tensor, size
409
+ (#batch, n_head, time2, d_k).
410
+ torch.Tensor: Transformed value tensor, size
411
+ (#batch, n_head, time2, d_k).
412
+
413
+ """
414
+ n_batch = query.size(0)
415
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
416
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
417
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
418
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
419
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
420
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
421
+
422
+ return q, k, v
423
+
424
+ def forward_attention(
425
+ self,
426
+ value: torch.Tensor,
427
+ scores: torch.Tensor,
428
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
429
+ ) -> torch.Tensor:
430
+ """Compute attention context vector.
431
+
432
+ Args:
433
+ value (torch.Tensor): Transformed value, size
434
+ (#batch, n_head, time2, d_k).
435
+ scores (torch.Tensor): Attention score, size
436
+ (#batch, n_head, time1, time2).
437
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
438
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
439
+
440
+ Returns:
441
+ torch.Tensor: Transformed value (#batch, time1, d_model)
442
+ weighted by the attention score (#batch, time1, time2).
443
+
444
+ """
445
+ n_batch = value.size(0)
446
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be True?
447
+ # 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
448
+ # 1st chunk to ease the onnx export.]
449
+ # 2. pytorch training
450
+ if mask.size(2) > 0: # time2 > 0
451
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
452
+ # For last chunk, time2 might be larger than scores.size(-1)
453
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
454
+ scores = scores.masked_fill(mask, -float('inf'))
455
+ attn = torch.softmax(scores, dim=-1).masked_fill(
456
+ mask, 0.0) # (batch, head, time1, time2)
457
+ # NOTE(xcsong): When will `if mask.size(2) > 0` be False?
458
+ # 1. onnx(16/-1, -1/-1, 16/0)
459
+ # 2. jit (16/-1, -1/-1, 16/0, 16/4)
460
+ else:
461
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
462
+
463
+ p_attn = self.dropout(attn)
464
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
465
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
466
+ self.h * self.d_k)
467
+ ) # (batch, time1, d_model)
468
+
469
+ return self.linear_out(x) # (batch, time1, d_model)
470
+
471
+ def forward(
472
+ self,
473
+ query: torch.Tensor,
474
+ key: torch.Tensor,
475
+ value: torch.Tensor,
476
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
477
+ pos_emb: torch.Tensor = torch.empty(0),
478
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
479
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
480
+ """Compute scaled dot product attention.
481
+
482
+ Args:
483
+ query (torch.Tensor): Query tensor (#batch, time1, size).
484
+ key (torch.Tensor): Key tensor (#batch, time2, size).
485
+ value (torch.Tensor): Value tensor (#batch, time2, size).
486
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
487
+ (#batch, time1, time2).
488
+ 1.When applying cross attention between decoder and encoder,
489
+ the batch padding mask for input is in (#batch, 1, T) shape.
490
+ 2.When applying self attention of encoder,
491
+ the mask is in (#batch, T, T) shape.
492
+ 3.When applying self attention of decoder,
493
+ the mask is in (#batch, L, L) shape.
494
+ 4.If the different position in decoder see different block
495
+ of the encoder, such as Mocha, the passed in mask could be
496
+ in (#batch, L, T) shape. But there is no such case in current
497
+ CosyVoice.
498
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
499
+ where `cache_t == chunk_size * num_decoding_left_chunks`
500
+ and `head * d_k == size`
501
+
502
+
503
+ Returns:
504
+ torch.Tensor: Output tensor (#batch, time1, d_model).
505
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
506
+ where `cache_t == chunk_size * num_decoding_left_chunks`
507
+ and `head * d_k == size`
508
+
509
+ """
510
+ q, k, v = self.forward_qkv(query, key, value)
511
+
512
+ # NOTE(xcsong):
513
+ # when export onnx model, for 1st chunk, we feed
514
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
515
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
516
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
517
+ # and we will always do splitting and
518
+ # concatnation(this will simplify onnx export). Note that
519
+ # it's OK to concat & split zero-shaped tensors(see code below).
520
+ # when export jit model, for 1st chunk, we always feed
521
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
522
+ # >>> a = torch.ones((1, 2, 0, 4))
523
+ # >>> b = torch.ones((1, 2, 3, 4))
524
+ # >>> c = torch.cat((a, b), dim=2)
525
+ # >>> torch.equal(b, c) # True
526
+ # >>> d = torch.split(a, 2, dim=-1)
527
+ # >>> torch.equal(d[0], d[1]) # True
528
+ if cache.size(0) > 0:
529
+ key_cache, value_cache = torch.split(cache,
530
+ cache.size(-1) // 2,
531
+ dim=-1)
532
+ k = torch.cat([key_cache, k], dim=2)
533
+ v = torch.cat([value_cache, v], dim=2)
534
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
535
+ # non-trivial to calculate `next_cache_start` here.
536
+ new_cache = torch.cat((k, v), dim=-1)
537
+
538
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
539
+ return self.forward_attention(v, scores, mask), new_cache
540
+
541
+
542
+ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
543
+ """Multi-Head Attention layer with relative position encoding.
544
+ Paper: https://arxiv.org/abs/1901.02860
545
+ Args:
546
+ n_head (int): The number of heads.
547
+ n_feat (int): The number of features.
548
+ dropout_rate (float): Dropout rate.
549
+ key_bias (bool): Whether to use bias in key linear layer.
550
+ """
551
+
552
+ def __init__(self,
553
+ n_head: int,
554
+ n_feat: int,
555
+ dropout_rate: float,
556
+ key_bias: bool = True):
557
+ super().__init__(n_head, n_feat, dropout_rate, key_bias)
558
+ # linear transformation for positional encoding
559
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
560
+ # these two learnable bias are used in matrix c and matrix d
561
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
562
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
563
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
564
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
565
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
566
+
567
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
568
+ """Compute relative positional encoding.
569
+
570
+ Args:
571
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
572
+ time1 means the length of query vector.
573
+
574
+ Returns:
575
+ torch.Tensor: Output tensor.
576
+
577
+ """
578
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
579
+ device=x.device,
580
+ dtype=x.dtype)
581
+ x_padded = torch.cat([zero_pad, x], dim=-1)
582
+
583
+ x_padded = x_padded.view(x.size()[0],
584
+ x.size()[1],
585
+ x.size(3) + 1, x.size(2))
586
+ x = x_padded[:, :, 1:].view_as(x)[
587
+ :, :, :, : x.size(-1) // 2 + 1
588
+ ] # only keep the positions from 0 to time2
589
+ return x
590
+
591
+ def forward(
592
+ self,
593
+ query: torch.Tensor,
594
+ key: torch.Tensor,
595
+ value: torch.Tensor,
596
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
597
+ pos_emb: torch.Tensor = torch.empty(0),
598
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
599
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
600
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
601
+ Args:
602
+ query (torch.Tensor): Query tensor (#batch, time1, size).
603
+ key (torch.Tensor): Key tensor (#batch, time2, size).
604
+ value (torch.Tensor): Value tensor (#batch, time2, size).
605
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
606
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
607
+ pos_emb (torch.Tensor): Positional embedding tensor
608
+ (#batch, time2, size).
609
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
610
+ where `cache_t == chunk_size * num_decoding_left_chunks`
611
+ and `head * d_k == size`
612
+ Returns:
613
+ torch.Tensor: Output tensor (#batch, time1, d_model).
614
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
615
+ where `cache_t == chunk_size * num_decoding_left_chunks`
616
+ and `head * d_k == size`
617
+ """
618
+ q, k, v = self.forward_qkv(query, key, value)
619
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
620
+
621
+ # NOTE(xcsong):
622
+ # when export onnx model, for 1st chunk, we feed
623
+ # cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
624
+ # or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
625
+ # In all modes, `if cache.size(0) > 0` will alwayse be `True`
626
+ # and we will always do splitting and
627
+ # concatnation(this will simplify onnx export). Note that
628
+ # it's OK to concat & split zero-shaped tensors(see code below).
629
+ # when export jit model, for 1st chunk, we always feed
630
+ # cache(0, 0, 0, 0) since jit supports dynamic if-branch.
631
+ # >>> a = torch.ones((1, 2, 0, 4))
632
+ # >>> b = torch.ones((1, 2, 3, 4))
633
+ # >>> c = torch.cat((a, b), dim=2)
634
+ # >>> torch.equal(b, c) # True
635
+ # >>> d = torch.split(a, 2, dim=-1)
636
+ # >>> torch.equal(d[0], d[1]) # True
637
+ if cache.size(0) > 0:
638
+ key_cache, value_cache = torch.split(cache,
639
+ cache.size(-1) // 2,
640
+ dim=-1)
641
+ k = torch.cat([key_cache, k], dim=2)
642
+ v = torch.cat([value_cache, v], dim=2)
643
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
644
+ # non-trivial to calculate `next_cache_start` here.
645
+ new_cache = torch.cat((k, v), dim=-1)
646
+
647
+ n_batch_pos = pos_emb.size(0)
648
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
649
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
650
+
651
+ # (batch, head, time1, d_k)
652
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
653
+ # (batch, head, time1, d_k)
654
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
655
+
656
+ # compute attention score
657
+ # first compute matrix a and matrix c
658
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
659
+ # (batch, head, time1, time2)
660
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
661
+
662
+ # compute matrix b and matrix d
663
+ # (batch, head, time1, time2)
664
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
665
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
666
+ if matrix_ac.shape != matrix_bd.shape:
667
+ matrix_bd = self.rel_shift(matrix_bd)
668
+
669
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
670
+ self.d_k) # (batch, head, time1, time2)
671
+
672
+ return self.forward_attention(v, scores, mask), new_cache
673
+
674
+
675
+ class PositionwiseFeedForward(torch.nn.Module):
676
+ """Positionwise feed forward layer.
677
+
678
+ FeedForward are appied on each position of the sequence.
679
+ The output dim is same with the input dim.
680
+
681
+ Args:
682
+ idim (int): Input dimenstion.
683
+ hidden_units (int): The number of hidden units.
684
+ dropout_rate (float): Dropout rate.
685
+ activation (torch.nn.Module): Activation function
686
+ """
687
+
688
+ def __init__(
689
+ self,
690
+ idim: int,
691
+ hidden_units: int,
692
+ dropout_rate: float,
693
+ activation: torch.nn.Module = torch.nn.ReLU(),
694
+ ):
695
+ super(PositionwiseFeedForward, self).__init__()
696
+ self.w_1 = torch.nn.Linear(idim, hidden_units)
697
+ self.activation = activation
698
+ self.dropout = torch.nn.Dropout(dropout_rate)
699
+ self.w_2 = torch.nn.Linear(hidden_units, idim)
700
+
701
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
702
+ """Forward function.
703
+
704
+ Args:
705
+ xs: input tensor (B, L, D)
706
+ Returns:
707
+ output tensor, (B, L, D)
708
+ """
709
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
710
+
711
+
712
+ class ConformerEncoderLayer(nn.Module):
713
+ """Encoder layer module.
714
+ Args:
715
+ size (int): Input dimension.
716
+ self_attn (torch.nn.Module): Self-attention module instance.
717
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
718
+ instance can be used as the argument.
719
+ feed_forward (torch.nn.Module): Feed-forward module instance.
720
+ `PositionwiseFeedForward` instance can be used as the argument.
721
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
722
+ instance.
723
+ `PositionwiseFeedForward` instance can be used as the argument.
724
+ conv_module (torch.nn.Module): Convolution module instance.
725
+ `ConvlutionModule` instance can be used as the argument.
726
+ dropout_rate (float): Dropout rate.
727
+ normalize_before (bool):
728
+ True: use layer_norm before each sub-block.
729
+ False: use layer_norm after each sub-block.
730
+ """
731
+
732
+ def __init__(
733
+ self,
734
+ size: int,
735
+ self_attn: torch.nn.Module,
736
+ feed_forward: Optional[nn.Module] = None,
737
+ feed_forward_macaron: Optional[nn.Module] = None,
738
+ conv_module: Optional[nn.Module] = None,
739
+ dropout_rate: float = 0.0,
740
+ normalize_before: bool = True,
741
+ ):
742
+ super().__init__()
743
+ self.self_attn = self_attn
744
+ self.feed_forward = feed_forward
745
+ self.feed_forward_macaron = feed_forward_macaron
746
+ self.conv_module = conv_module
747
+ self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
748
+ self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
749
+ if feed_forward_macaron is not None:
750
+ self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
751
+ self.ff_scale = 0.5
752
+ else:
753
+ self.ff_scale = 1.0
754
+ if self.conv_module is not None:
755
+ self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
756
+ self.norm_final = nn.LayerNorm(
757
+ size, eps=1e-12) # for the final output of the block
758
+ self.dropout = nn.Dropout(dropout_rate)
759
+ self.size = size
760
+ self.normalize_before = normalize_before
761
+
762
+ def forward(
763
+ self,
764
+ x: torch.Tensor,
765
+ mask: torch.Tensor,
766
+ pos_emb: torch.Tensor,
767
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
768
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
769
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
770
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
771
+ """Compute encoded features.
772
+
773
+ Args:
774
+ x (torch.Tensor): (#batch, time, size)
775
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
776
+ (0, 0, 0) means fake mask.
777
+ pos_emb (torch.Tensor): positional encoding, must not be None
778
+ for ConformerEncoderLayer.
779
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
780
+ (#batch, 1,time), (0, 0, 0) means fake mask.
781
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
782
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
783
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
784
+ (#batch=1, size, cache_t2)
785
+ Returns:
786
+ torch.Tensor: Output tensor (#batch, time, size).
787
+ torch.Tensor: Mask tensor (#batch, time, time).
788
+ torch.Tensor: att_cache tensor,
789
+ (#batch=1, head, cache_t1 + time, d_k * 2).
790
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
791
+ """
792
+
793
+ # whether to use macaron style
794
+ if self.feed_forward_macaron is not None:
795
+ residual = x
796
+ if self.normalize_before:
797
+ x = self.norm_ff_macaron(x)
798
+ x = residual + self.ff_scale * self.dropout(
799
+ self.feed_forward_macaron(x))
800
+ if not self.normalize_before:
801
+ x = self.norm_ff_macaron(x)
802
+
803
+ # multi-headed self-attention module
804
+ residual = x
805
+ if self.normalize_before:
806
+ x = self.norm_mha(x)
807
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
808
+ att_cache)
809
+ x = residual + self.dropout(x_att)
810
+ if not self.normalize_before:
811
+ x = self.norm_mha(x)
812
+
813
+ # convolution module
814
+ # Fake new cnn cache here, and then change it in conv_module
815
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
816
+ if self.conv_module is not None:
817
+ residual = x
818
+ if self.normalize_before:
819
+ x = self.norm_conv(x)
820
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
821
+ x = residual + self.dropout(x)
822
+
823
+ if not self.normalize_before:
824
+ x = self.norm_conv(x)
825
+
826
+ # feed forward module
827
+ residual = x
828
+ if self.normalize_before:
829
+ x = self.norm_ff(x)
830
+
831
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
832
+ if not self.normalize_before:
833
+ x = self.norm_ff(x)
834
+
835
+ if self.conv_module is not None:
836
+ x = self.norm_final(x)
837
+
838
+ return x, mask, new_att_cache, new_cnn_cache
839
+
840
+
841
+ class UpsampleConformerEncoder(torch.nn.Module):
842
+ """
843
+ Args:
844
+ input_size (int): input dim
845
+ output_size (int): dimension of attention
846
+ attention_heads (int): the number of heads of multi head attention
847
+ linear_units (int): the hidden units number of position-wise feed
848
+ forward
849
+ num_blocks (int): the number of decoder blocks
850
+ static_chunk_size (int): chunk size for static chunk training and
851
+ decoding
852
+ use_dynamic_chunk (bool): whether use dynamic chunk size for
853
+ training or not, You can only use fixed chunk(chunk_size > 0)
854
+ or dyanmic chunk size(use_dynamic_chunk = True)
855
+ use_dynamic_left_chunk (bool): whether use dynamic left chunk in
856
+ dynamic chunk training
857
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
858
+ """
859
+
860
+ def __init__(
861
+ self,
862
+ input_size: int = 512,
863
+ output_size: int = 512,
864
+ attention_heads: int = 8,
865
+ linear_units: int = 2048,
866
+ num_blocks: int = 6,
867
+ static_chunk_size: int = 25,
868
+ use_dynamic_chunk: bool = False,
869
+ use_dynamic_left_chunk: bool = False,
870
+ key_bias: bool = True,
871
+ ):
872
+ super().__init__()
873
+ self._output_size = output_size
874
+
875
+ self.embed = LinearNoSubsampling(
876
+ input_size, output_size,
877
+ EspnetRelPositionalEncoding(output_size),
878
+ )
879
+
880
+ self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
881
+ self.static_chunk_size = static_chunk_size
882
+ self.use_dynamic_chunk = use_dynamic_chunk
883
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
884
+ activation = torch.nn.SiLU()
885
+ # self-attention module definition
886
+ encoder_selfattn_layer_args = (
887
+ attention_heads,
888
+ output_size,
889
+ 0.0,
890
+ key_bias,
891
+ )
892
+ # feed-forward module definition
893
+ positionwise_layer_args = (
894
+ output_size,
895
+ linear_units,
896
+ 0.0,
897
+ activation,
898
+ )
899
+ # convolution module definition
900
+ self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
901
+ self.encoders = torch.nn.ModuleList([
902
+ ConformerEncoderLayer(
903
+ output_size,
904
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
905
+ PositionwiseFeedForward(*positionwise_layer_args),
906
+ ) for _ in range(num_blocks)
907
+ ])
908
+ self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
909
+ self.up_embed = LinearNoSubsampling(
910
+ input_size, output_size,
911
+ EspnetRelPositionalEncoding(output_size),
912
+ )
913
+ self.up_encoders = torch.nn.ModuleList([
914
+ ConformerEncoderLayer(
915
+ output_size,
916
+ RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
917
+ PositionwiseFeedForward(*positionwise_layer_args),
918
+ ) for _ in range(4)
919
+ ])
920
+
921
+ def output_size(self) -> int:
922
+ return self._output_size
923
+
924
+ def forward(
925
+ self,
926
+ xs: torch.Tensor,
927
+ xs_lens: torch.Tensor,
928
+ context: torch.Tensor = torch.zeros(0, 0, 0),
929
+ decoding_chunk_size: int = 0,
930
+ num_decoding_left_chunks: int = -1,
931
+ streaming: bool = False,
932
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
933
+ """Embed positions in tensor.
934
+
935
+ Args:
936
+ xs: padded input tensor (B, T, D)
937
+ xs_lens: input length (B)
938
+ decoding_chunk_size: decoding chunk size for dynamic chunk
939
+ 0: default for training, use random dynamic chunk.
940
+ <0: for decoding, use full chunk.
941
+ >0: for decoding, use fixed chunk size as set.
942
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
943
+ the chunk size is decoding_chunk_size.
944
+ >=0: use num_decoding_left_chunks
945
+ <0: use all left chunks
946
+ Returns:
947
+ encoder output tensor xs, and subsampled masks
948
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
949
+ masks: torch.Tensor batch padding mask after subsample
950
+ (B, 1, T' ~= T/subsample_rate)
951
+ NOTE(xcsong):
952
+ We pass the `__call__` method of the modules instead of `forward` to the
953
+ checkpointing API because `__call__` attaches all the hooks of the module.
954
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
955
+ """
956
+ T = xs.size(1)
957
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
958
+ xs, pos_emb, masks = self.embed(xs, masks)
959
+ if context.size(1) != 0:
960
+ assert self.training is False, 'you have passed context, make sure that you are running inference mode'
961
+ context_masks = torch.ones(1, 1, context.size(1)).to(masks)
962
+ context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
963
+ mask_pad = masks # (B, 1, T/subsample_rate)
964
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
965
+ # lookahead + conformer encoder
966
+ xs = self.pre_lookahead_layer(xs, context=context)
967
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
968
+
969
+ # upsample + conformer encoder
970
+ xs = xs.transpose(1, 2).contiguous()
971
+ xs, xs_lens = self.up_layer(xs, xs_lens)
972
+ xs = xs.transpose(1, 2).contiguous()
973
+ T = xs.size(1)
974
+ masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
975
+ xs, pos_emb, masks = self.up_embed(xs, masks)
976
+ mask_pad = masks # (B, 1, T/subsample_rate)
977
+ chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
978
+ xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
979
+
980
+ xs = self.after_norm(xs)
981
+ # Here we assume the mask is not changed in encoder layers, so just
982
+ # return the masks before encoder layers, and the masks will be used
983
+ # for cross attention with decoder later
984
+ return xs, masks
985
+
986
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
987
+ pos_emb: torch.Tensor,
988
+ mask_pad: torch.Tensor) -> torch.Tensor:
989
+ for layer in self.encoders:
990
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
991
+ return xs
992
+
993
+ def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
994
+ pos_emb: torch.Tensor,
995
+ mask_pad: torch.Tensor) -> torch.Tensor:
996
+ for layer in self.up_encoders:
997
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
998
+ return xs
flashcosyvoice/modules/hifigan.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """HIFI-GAN"""
16
+
17
+ from typing import Dict, List
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from scipy.signal import get_window
24
+ from torch.nn import Conv1d, ConvTranspose1d
25
+ from torch.nn.utils import remove_weight_norm
26
+
27
+ try:
28
+ from torch.nn.utils.parametrizations import weight_norm
29
+ except ImportError:
30
+ from torch.nn.utils import weight_norm # noqa
31
+
32
+ from flashcosyvoice.modules.hifigan_components.layers import (
33
+ ResBlock, SourceModuleHnNSF, SourceModuleHnNSF2, init_weights)
34
+
35
+
36
+ class ConvRNNF0Predictor(nn.Module):
37
+ def __init__(self,
38
+ num_class: int = 1,
39
+ in_channels: int = 80,
40
+ cond_channels: int = 512
41
+ ):
42
+ super().__init__()
43
+
44
+ self.num_class = num_class
45
+ self.condnet = nn.Sequential(
46
+ weight_norm( # noqa
47
+ nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
48
+ ),
49
+ nn.ELU(),
50
+ weight_norm( # noqa
51
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
52
+ ),
53
+ nn.ELU(),
54
+ weight_norm( # noqa
55
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
56
+ ),
57
+ nn.ELU(),
58
+ weight_norm( # noqa
59
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
60
+ ),
61
+ nn.ELU(),
62
+ weight_norm( # noqa
63
+ nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
64
+ ),
65
+ nn.ELU(),
66
+ )
67
+ self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ x = self.condnet(x)
71
+ x = x.transpose(1, 2)
72
+ return torch.abs(self.classifier(x).squeeze(-1))
73
+
74
+
75
+ class HiFTGenerator(nn.Module):
76
+ """
77
+ HiFTNet Generator: Neural Source Filter + ISTFTNet
78
+ https://arxiv.org/abs/2309.09493
79
+ """
80
+ def __init__(
81
+ self,
82
+ in_channels: int = 80,
83
+ base_channels: int = 512,
84
+ nb_harmonics: int = 8,
85
+ sampling_rate: int = 24000,
86
+ nsf_alpha: float = 0.1,
87
+ nsf_sigma: float = 0.003,
88
+ nsf_voiced_threshold: float = 10,
89
+ upsample_rates: List[int] = [8, 5, 3], # noqa
90
+ upsample_kernel_sizes: List[int] = [16, 11, 7], # noqa
91
+ istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, # noqa
92
+ resblock_kernel_sizes: List[int] = [3, 7, 11], # noqa
93
+ resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
94
+ source_resblock_kernel_sizes: List[int] = [7, 7, 11], # noqa
95
+ source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
96
+ lrelu_slope: float = 0.1,
97
+ audio_limit: float = 0.99,
98
+ f0_predictor: torch.nn.Module = None,
99
+ ):
100
+ super(HiFTGenerator, self).__init__()
101
+
102
+ self.out_channels = 1
103
+ self.nb_harmonics = nb_harmonics
104
+ self.sampling_rate = sampling_rate
105
+ self.istft_params = istft_params
106
+ self.lrelu_slope = lrelu_slope
107
+ self.audio_limit = audio_limit
108
+
109
+ self.num_kernels = len(resblock_kernel_sizes)
110
+ self.num_upsamples = len(upsample_rates)
111
+ # NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
112
+ this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
113
+ self.m_source = this_SourceModuleHnNSF(
114
+ sampling_rate=sampling_rate,
115
+ upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
116
+ harmonic_num=nb_harmonics,
117
+ sine_amp=nsf_alpha,
118
+ add_noise_std=nsf_sigma,
119
+ voiced_threshod=nsf_voiced_threshold)
120
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
121
+
122
+ self.conv_pre = weight_norm( # noqa
123
+ Conv1d(in_channels, base_channels, 7, 1, padding=3)
124
+ )
125
+
126
+ # Up
127
+ self.ups = nn.ModuleList()
128
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
129
+ self.ups.append(
130
+ weight_norm( # noqa
131
+ ConvTranspose1d(
132
+ base_channels // (2**i),
133
+ base_channels // (2**(i + 1)),
134
+ k,
135
+ u,
136
+ padding=(k - u) // 2,
137
+ )
138
+ )
139
+ )
140
+
141
+ # Down
142
+ self.source_downs = nn.ModuleList()
143
+ self.source_resblocks = nn.ModuleList()
144
+ downsample_rates = [1] + upsample_rates[::-1][:-1]
145
+ downsample_cum_rates = np.cumprod(downsample_rates)
146
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
147
+ if u == 1:
148
+ self.source_downs.append(
149
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
150
+ )
151
+ else:
152
+ self.source_downs.append(
153
+ Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
154
+ )
155
+
156
+ self.source_resblocks.append(
157
+ ResBlock(base_channels // (2 ** (i + 1)), k, d)
158
+ )
159
+
160
+ self.resblocks = nn.ModuleList()
161
+ for i in range(len(self.ups)):
162
+ ch = base_channels // (2**(i + 1))
163
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
164
+ self.resblocks.append(ResBlock(ch, k, d))
165
+
166
+ self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) # noqa
167
+ self.ups.apply(init_weights)
168
+ self.conv_post.apply(init_weights)
169
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
170
+ self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
171
+ self.f0_predictor = ConvRNNF0Predictor() if f0_predictor is None else f0_predictor
172
+
173
+ def remove_weight_norm(self):
174
+ print('Removing weight norm...')
175
+ for up in self.ups:
176
+ remove_weight_norm(up)
177
+ for resblock in self.resblocks:
178
+ resblock.remove_weight_norm()
179
+ remove_weight_norm(self.conv_pre)
180
+ remove_weight_norm(self.conv_post)
181
+ self.m_source.remove_weight_norm()
182
+ for source_down in self.source_downs:
183
+ remove_weight_norm(source_down)
184
+ for source_resblock in self.source_resblocks:
185
+ source_resblock.remove_weight_norm()
186
+
187
+ def _stft(self, x):
188
+ spec = torch.stft(
189
+ x,
190
+ self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
191
+ return_complex=True)
192
+ spec = torch.view_as_real(spec) # [B, F, TT, 2]
193
+ return spec[..., 0], spec[..., 1]
194
+
195
+ def _istft(self, magnitude, phase):
196
+ magnitude = torch.clip(magnitude, max=1e2)
197
+ real = magnitude * torch.cos(phase)
198
+ img = magnitude * torch.sin(phase)
199
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
200
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
201
+ return inverse_transform
202
+
203
+ def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
204
+ s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
205
+ s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
206
+
207
+ x = self.conv_pre(x)
208
+ for i in range(self.num_upsamples):
209
+ x = F.leaky_relu(x, self.lrelu_slope)
210
+ x = self.ups[i](x)
211
+
212
+ if i == self.num_upsamples - 1:
213
+ x = self.reflection_pad(x)
214
+
215
+ # fusion
216
+ si = self.source_downs[i](s_stft)
217
+ si = self.source_resblocks[i](si)
218
+ x = x + si
219
+
220
+ xs = None
221
+ for j in range(self.num_kernels):
222
+ if xs is None:
223
+ xs = self.resblocks[i * self.num_kernels + j](x)
224
+ else:
225
+ xs += self.resblocks[i * self.num_kernels + j](x)
226
+ x = xs / self.num_kernels
227
+
228
+ x = F.leaky_relu(x)
229
+ x = self.conv_post(x)
230
+ magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
231
+ phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
232
+
233
+ x = self._istft(magnitude, phase)
234
+ x = torch.clamp(x, -self.audio_limit, self.audio_limit)
235
+ return x
236
+
237
+ @torch.inference_mode()
238
+ def forward(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
239
+ # mel->f0
240
+ f0 = self.f0_predictor(speech_feat)
241
+ # f0->source
242
+ s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
243
+ s, _, _ = self.m_source(s)
244
+ s = s.transpose(1, 2)
245
+ # use cache_source to avoid glitch
246
+ if cache_source.shape[2] != 0:
247
+ s[:, :, :cache_source.shape[2]] = cache_source
248
+ generated_speech = self.decode(x=speech_feat, s=s)
249
+ return generated_speech, s
flashcosyvoice/modules/hifigan_components/__init__.py ADDED
File without changes
flashcosyvoice/modules/hifigan_components/layers.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.distributions.uniform import Uniform
7
+ from torch.nn import Conv1d
8
+ from torch.nn.utils import remove_weight_norm
9
+
10
+ try:
11
+ from torch.nn.utils.parametrizations import weight_norm
12
+ except ImportError:
13
+ from torch.nn.utils import weight_norm # noqa
14
+
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size * dilation - dilation) / 2)
18
+
19
+
20
+ def init_weights(m, mean=0.0, std=0.01):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ m.weight.data.normal_(mean, std)
24
+
25
+
26
+ """hifigan based generator implementation.
27
+
28
+ This code is modified from https://github.com/jik876/hifi-gan
29
+ ,https://github.com/kan-bayashi/ParallelWaveGAN and
30
+ https://github.com/NVIDIA/BigVGAN
31
+
32
+ """
33
+
34
+
35
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
36
+ # LICENSE is in incl_licenses directory.
37
+ class Snake(nn.Module):
38
+ '''
39
+ Implementation of a sine-based periodic activation function
40
+ Shape:
41
+ - Input: (B, C, T)
42
+ - Output: (B, C, T), same shape as the input
43
+ Parameters:
44
+ - alpha - trainable parameter
45
+ References:
46
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
47
+ https://arxiv.org/abs/2006.08195
48
+ Examples:
49
+ >>> a1 = snake(256)
50
+ >>> x = torch.randn(256)
51
+ >>> x = a1(x)
52
+
53
+ Args:
54
+ in_features: shape of the input
55
+ alpha: trainable parameter
56
+ alpha_trainable: whether alpha is trainable
57
+ alpha_logscale: whether to use log scale for alpha
58
+ alpha is initialized to 1 by default, higher values = higher-frequency.
59
+ alpha will be trained along with the rest of your model.
60
+ '''
61
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
62
+ super(Snake, self).__init__()
63
+ self.in_features = in_features
64
+
65
+ # initialize alpha
66
+ self.alpha_logscale = alpha_logscale
67
+ if self.alpha_logscale: # log scale alphas initialized to zeros
68
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
69
+ else: # linear scale alphas initialized to ones
70
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
71
+
72
+ self.alpha.requires_grad = alpha_trainable
73
+
74
+ self.no_div_by_zero = 0.000000001
75
+
76
+ def forward(self, x):
77
+ '''
78
+ Forward pass of the function.
79
+ Applies the function to the input elementwise.
80
+ Snake ∶= x + 1/a * sin^2 (xa)
81
+ '''
82
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
83
+ if self.alpha_logscale:
84
+ alpha = torch.exp(alpha)
85
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
86
+
87
+ return x
88
+
89
+
90
+ class ResBlock(torch.nn.Module):
91
+ """Residual block module in HiFiGAN/BigVGAN."""
92
+ def __init__(
93
+ self,
94
+ channels: int = 512,
95
+ kernel_size: int = 3,
96
+ dilations: List[int] = [1, 3, 5], # noqa
97
+ ):
98
+ super(ResBlock, self).__init__()
99
+ self.convs1 = nn.ModuleList()
100
+ self.convs2 = nn.ModuleList()
101
+
102
+ for dilation in dilations:
103
+ self.convs1.append(
104
+ weight_norm( # noqa
105
+ Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ 1,
110
+ dilation=dilation,
111
+ padding=get_padding(kernel_size, dilation)
112
+ )
113
+ )
114
+ )
115
+ self.convs2.append(
116
+ weight_norm( # noqa
117
+ Conv1d(
118
+ channels,
119
+ channels,
120
+ kernel_size,
121
+ 1,
122
+ dilation=1,
123
+ padding=get_padding(kernel_size, 1)
124
+ )
125
+ )
126
+ )
127
+ self.convs1.apply(init_weights)
128
+ self.convs2.apply(init_weights)
129
+ self.activations1 = nn.ModuleList([
130
+ Snake(channels, alpha_logscale=False)
131
+ for _ in range(len(self.convs1))
132
+ ])
133
+ self.activations2 = nn.ModuleList([
134
+ Snake(channels, alpha_logscale=False)
135
+ for _ in range(len(self.convs2))
136
+ ])
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ for idx in range(len(self.convs1)):
140
+ xt = self.activations1[idx](x)
141
+ xt = self.convs1[idx](xt)
142
+ xt = self.activations2[idx](xt)
143
+ xt = self.convs2[idx](xt)
144
+ x = xt + x
145
+ return x
146
+
147
+ def remove_weight_norm(self):
148
+ for idx in range(len(self.convs1)):
149
+ remove_weight_norm(self.convs1[idx])
150
+ remove_weight_norm(self.convs2[idx])
151
+
152
+
153
+ class SineGen(torch.nn.Module):
154
+ """ Definition of sine generator
155
+ SineGen(samp_rate, harmonic_num = 0,
156
+ sine_amp = 0.1, noise_std = 0.003,
157
+ voiced_threshold = 0,
158
+ flag_for_pulse=False)
159
+ samp_rate: sampling rate in Hz
160
+ harmonic_num: number of harmonic overtones (default 0)
161
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
162
+ noise_std: std of Gaussian noise (default 0.003)
163
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
164
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
165
+ Note: when flag_for_pulse is True, the first time step of a voiced
166
+ segment is always sin(np.pi) or cos(0)
167
+ """
168
+
169
+ def __init__(self, samp_rate, harmonic_num=0,
170
+ sine_amp=0.1, noise_std=0.003,
171
+ voiced_threshold=0):
172
+ super(SineGen, self).__init__()
173
+ self.sine_amp = sine_amp
174
+ self.noise_std = noise_std
175
+ self.harmonic_num = harmonic_num
176
+ self.sampling_rate = samp_rate
177
+ self.voiced_threshold = voiced_threshold
178
+
179
+ def _f02uv(self, f0):
180
+ # generate uv signal
181
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
182
+ return uv
183
+
184
+ @torch.no_grad()
185
+ def forward(self, f0):
186
+ """
187
+ :param f0: [B, 1, sample_len], Hz
188
+ :return: [B, 1, sample_len]
189
+ """
190
+
191
+ F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
192
+ for i in range(self.harmonic_num + 1):
193
+ F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
194
+
195
+ theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
196
+ u_dist = Uniform(low=-np.pi, high=np.pi)
197
+ phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
198
+ phase_vec[:, 0, :] = 0
199
+
200
+ # generate sine waveforms
201
+ sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
202
+
203
+ # generate uv signal
204
+ uv = self._f02uv(f0)
205
+
206
+ # noise: for unvoiced should be similar to sine_amp
207
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
208
+ # . for voiced regions is self.noise_std
209
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
210
+ noise = noise_amp * torch.randn_like(sine_waves)
211
+
212
+ # first: set the unvoiced part to 0 by uv
213
+ # then: additive noise
214
+ sine_waves = sine_waves * uv + noise
215
+ return sine_waves, uv, noise
216
+
217
+
218
+ class SourceModuleHnNSF(torch.nn.Module):
219
+ """ SourceModule for hn-nsf
220
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
221
+ add_noise_std=0.003, voiced_threshod=0)
222
+ sampling_rate: sampling_rate in Hz
223
+ harmonic_num: number of harmonic above F0 (default: 0)
224
+ sine_amp: amplitude of sine source signal (default: 0.1)
225
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
226
+ note that amplitude of noise in unvoiced is decided
227
+ by sine_amp
228
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
229
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
230
+ F0_sampled (batchsize, length, 1)
231
+ Sine_source (batchsize, length, 1)
232
+ noise_source (batchsize, length 1)
233
+ uv (batchsize, length, 1)
234
+ """
235
+
236
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
237
+ add_noise_std=0.003, voiced_threshod=0):
238
+ super(SourceModuleHnNSF, self).__init__()
239
+
240
+ self.sine_amp = sine_amp
241
+ self.noise_std = add_noise_std
242
+
243
+ # to produce sine waveforms
244
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
245
+ sine_amp, add_noise_std, voiced_threshod)
246
+
247
+ # to merge source harmonics into a single excitation
248
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
249
+ self.l_tanh = torch.nn.Tanh()
250
+
251
+ def forward(self, x):
252
+ """
253
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
254
+ F0_sampled (batchsize, length, 1)
255
+ Sine_source (batchsize, length, 1)
256
+ noise_source (batchsize, length 1)
257
+ """
258
+ # source for harmonic branch
259
+ with torch.no_grad():
260
+ sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
261
+ sine_wavs = sine_wavs.transpose(1, 2)
262
+ uv = uv.transpose(1, 2)
263
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
264
+
265
+ # source for noise branch, in the same shape as uv
266
+ noise = torch.randn_like(uv) * self.sine_amp / 3
267
+ return sine_merge, noise, uv
268
+
269
+
270
+ class SineGen2(torch.nn.Module):
271
+ """ Definition of sine generator
272
+ SineGen(samp_rate, harmonic_num = 0,
273
+ sine_amp = 0.1, noise_std = 0.003,
274
+ voiced_threshold = 0,
275
+ flag_for_pulse=False)
276
+ samp_rate: sampling rate in Hz
277
+ harmonic_num: number of harmonic overtones (default 0)
278
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
279
+ noise_std: std of Gaussian noise (default 0.003)
280
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
281
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
282
+ Note: when flag_for_pulse is True, the first time step of a voiced
283
+ segment is always sin(np.pi) or cos(0)
284
+ """
285
+
286
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
287
+ sine_amp=0.1, noise_std=0.003,
288
+ voiced_threshold=0,
289
+ flag_for_pulse=False):
290
+ super(SineGen2, self).__init__()
291
+ self.sine_amp = sine_amp
292
+ self.noise_std = noise_std
293
+ self.harmonic_num = harmonic_num
294
+ self.dim = self.harmonic_num + 1
295
+ self.sampling_rate = samp_rate
296
+ self.voiced_threshold = voiced_threshold
297
+ self.flag_for_pulse = flag_for_pulse
298
+ self.upsample_scale = upsample_scale
299
+
300
+ def _f02uv(self, f0):
301
+ # generate uv signal
302
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
303
+ return uv
304
+
305
+ def _f02sine(self, f0_values):
306
+ """ f0_values: (batchsize, length, dim)
307
+ where dim indicates fundamental tone and overtones
308
+ """
309
+ # convert to F0 in rad. The interger part n can be ignored
310
+ # because 2 * np.pi * n doesn't affect phase
311
+ rad_values = (f0_values / self.sampling_rate) % 1
312
+
313
+ # initial phase noise (no noise for fundamental component)
314
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
315
+ rand_ini[:, 0] = 0
316
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
317
+
318
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
319
+ if not self.flag_for_pulse:
320
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
321
+ scale_factor=1 / self.upsample_scale,
322
+ mode="linear").transpose(1, 2)
323
+
324
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
325
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
326
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
327
+ sines = torch.sin(phase)
328
+ else:
329
+ # If necessary, make sure that the first time step of every
330
+ # voiced segments is sin(pi) or cos(0)
331
+ # This is used for pulse-train generation
332
+
333
+ # identify the last time step in unvoiced segments
334
+ uv = self._f02uv(f0_values)
335
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
336
+ uv_1[:, -1, :] = 1
337
+ u_loc = (uv < 1) * (uv_1 > 0)
338
+
339
+ # get the instantanouse phase
340
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
341
+ # different batch needs to be processed differently
342
+ for idx in range(f0_values.shape[0]):
343
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
344
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
345
+ # stores the accumulation of i.phase within
346
+ # each voiced segments
347
+ tmp_cumsum[idx, :, :] = 0
348
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
349
+
350
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
351
+ # within the previous voiced segment.
352
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
353
+
354
+ # get the sines
355
+ sines = torch.cos(i_phase * 2 * np.pi)
356
+ return sines
357
+
358
+ def forward(self, f0):
359
+ """ sine_tensor, uv = forward(f0)
360
+ input F0: tensor(batchsize=1, length, dim=1)
361
+ f0 for unvoiced steps should be 0
362
+ output sine_tensor: tensor(batchsize=1, length, dim)
363
+ output uv: tensor(batchsize=1, length, 1)
364
+ """
365
+ # fundamental component
366
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
367
+
368
+ # generate sine waveforms
369
+ sine_waves = self._f02sine(fn) * self.sine_amp
370
+
371
+ # generate uv signal
372
+ uv = self._f02uv(f0)
373
+
374
+ # noise: for unvoiced should be similar to sine_amp
375
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
376
+ # . for voiced regions is self.noise_std
377
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
378
+ noise = noise_amp * torch.randn_like(sine_waves)
379
+
380
+ # first: set the unvoiced part to 0 by uv
381
+ # then: additive noise
382
+ sine_waves = sine_waves * uv + noise
383
+ return sine_waves, uv, noise
384
+
385
+
386
+ class SourceModuleHnNSF2(torch.nn.Module):
387
+ """ SourceModule for hn-nsf
388
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
389
+ add_noise_std=0.003, voiced_threshod=0)
390
+ sampling_rate: sampling_rate in Hz
391
+ harmonic_num: number of harmonic above F0 (default: 0)
392
+ sine_amp: amplitude of sine source signal (default: 0.1)
393
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
394
+ note that amplitude of noise in unvoiced is decided
395
+ by sine_amp
396
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
397
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
398
+ F0_sampled (batchsize, length, 1)
399
+ Sine_source (batchsize, length, 1)
400
+ noise_source (batchsize, length 1)
401
+ uv (batchsize, length, 1)
402
+ """
403
+
404
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
405
+ add_noise_std=0.003, voiced_threshod=0):
406
+ super(SourceModuleHnNSF2, self).__init__()
407
+
408
+ self.sine_amp = sine_amp
409
+ self.noise_std = add_noise_std
410
+
411
+ # to produce sine waveforms
412
+ self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
413
+ sine_amp, add_noise_std, voiced_threshod)
414
+
415
+ # to merge source harmonics into a single excitation
416
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
417
+ self.l_tanh = torch.nn.Tanh()
418
+
419
+ def forward(self, x):
420
+ """
421
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
422
+ F0_sampled (batchsize, length, 1)
423
+ Sine_source (batchsize, length, 1)
424
+ noise_source (batchsize, length 1)
425
+ """
426
+ # source for harmonic branch
427
+ with torch.no_grad():
428
+ sine_wavs, uv, _ = self.l_sin_gen(x)
429
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
430
+
431
+ # source for noise branch, in the same shape as uv
432
+ noise = torch.randn_like(uv) * self.sine_amp / 3
433
+ return sine_merge, noise, uv
flashcosyvoice/modules/qwen2.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ from torch import nn
16
+ from transformers import AutoConfig
17
+
18
+ from flashcosyvoice.config import CosyVoice2LLMConfig
19
+ from flashcosyvoice.modules.qwen2_components.layers import (
20
+ ParallelLMHead, Qwen2DecoderLayer, RMSNorm, VocabParallelEmbedding)
21
+
22
+
23
+ class Qwen2Model(nn.Module):
24
+
25
+ def __init__(
26
+ self,
27
+ config: CosyVoice2LLMConfig,
28
+ ):
29
+ super().__init__()
30
+ self.vocab_size = config.vocab_size
31
+ self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
32
+ self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
33
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
34
+
35
+ def forward(
36
+ self,
37
+ input_ids: torch.Tensor,
38
+ positions: torch.Tensor,
39
+ ) -> torch.Tensor:
40
+ hidden_states = self.embed_tokens(input_ids)
41
+ residual = None
42
+ for layer in self.layers:
43
+ hidden_states, residual = layer(
44
+ positions,
45
+ hidden_states,
46
+ residual,
47
+ )
48
+ hidden_states, _ = self.norm(hidden_states, residual)
49
+ return hidden_states
50
+
51
+
52
+ class Qwen2ForCausalLM(nn.Module):
53
+ packed_modules_mapping = {
54
+ "q_proj": ("qkv_proj", "q"),
55
+ "k_proj": ("qkv_proj", "k"),
56
+ "v_proj": ("qkv_proj", "v"),
57
+ "gate_proj": ("gate_up_proj", 0),
58
+ "up_proj": ("gate_up_proj", 1),
59
+ }
60
+
61
+ def __init__(
62
+ self,
63
+ config: CosyVoice2LLMConfig | AutoConfig
64
+ ):
65
+ super().__init__()
66
+ self.model = Qwen2Model(config)
67
+ if hasattr(config, "speech_vocab_size"):
68
+ self.lm_head = ParallelLMHead(config.speech_vocab_size, config.hidden_size, bias=getattr(config, "lm_head_bias", True))
69
+ self.model_type = "speech_llm"
70
+ else:
71
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False)
72
+ self.model_type = "text_llm"
73
+ self.tie_word_embeddings = config.tie_word_embeddings
74
+ if self.tie_word_embeddings:
75
+ if self.model_type == "speech_llm":
76
+ assert config.vocab_size == config.speech_vocab_size, "vocab_size and speech_vocab_size must be the same when tie_word_embeddings is True"
77
+ self.lm_head.weight.data = self.model.embed_tokens.weight.data
78
+
79
+ def forward(
80
+ self,
81
+ input_ids: torch.Tensor,
82
+ positions: torch.Tensor,
83
+ ) -> torch.Tensor:
84
+ hidden_states = self.model(input_ids, positions)
85
+ return hidden_states
86
+
87
+ def compute_logits(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ ) -> torch.Tensor:
91
+ logits = self.lm_head(hidden_states)
92
+ return logits
flashcosyvoice/modules/qwen2_components/__init__.py ADDED
File without changes
flashcosyvoice/modules/qwen2_components/layers.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import triton
8
+ import triton.language as tl
9
+ from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
10
+
11
+ from flashcosyvoice.config import CosyVoice2LLMConfig
12
+ from flashcosyvoice.utils.context import get_context
13
+
14
+
15
+ class SiluAndMul(nn.Module):
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ @torch.compile
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ x, y = x.chunk(2, -1)
23
+ return F.silu(x) * y
24
+
25
+
26
+ class RMSNorm(nn.Module):
27
+
28
+ def __init__(
29
+ self,
30
+ hidden_size: int,
31
+ eps: float = 1e-6,
32
+ ) -> None:
33
+ super().__init__()
34
+ self.hidden_size = hidden_size
35
+ self.eps = eps
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+
38
+ @torch.compile
39
+ def rms_forward(
40
+ self,
41
+ x: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ orig_dtype = x.dtype
44
+ x = x.to(torch.float32)
45
+ var = x.pow(2).mean(dim=-1, keepdim=True)
46
+ x.mul_(torch.rsqrt(var + self.eps))
47
+ x = x.to(orig_dtype).mul_(self.weight)
48
+ return x
49
+
50
+ @torch.compile
51
+ def add_rms_forward(
52
+ self,
53
+ x: torch.Tensor,
54
+ residual: torch.Tensor,
55
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
56
+ orig_dtype = x.dtype
57
+ x = x.to(torch.float32).add_(residual.to(torch.float32))
58
+ residual = x.to(orig_dtype)
59
+ var = x.pow(2).mean(dim=-1, keepdim=True)
60
+ x.mul_(torch.rsqrt(var + self.eps))
61
+ x = x.to(orig_dtype).mul_(self.weight)
62
+ return x, residual
63
+
64
+ def forward(
65
+ self,
66
+ x: torch.Tensor,
67
+ residual: torch.Tensor | None = None,
68
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
69
+ if residual is None:
70
+ return self.rms_forward(x)
71
+ else:
72
+ return self.add_rms_forward(x, residual)
73
+
74
+
75
+ @triton.jit
76
+ def store_kvcache_kernel(
77
+ key_ptr,
78
+ key_stride,
79
+ value_ptr,
80
+ value_stride,
81
+ k_cache_ptr,
82
+ v_cache_ptr,
83
+ slot_mapping_ptr,
84
+ D: tl.constexpr,
85
+ ):
86
+ idx = tl.program_id(0)
87
+ key_offsets = idx * key_stride + tl.arange(0, D)
88
+ value_offsets = idx * value_stride + tl.arange(0, D)
89
+ key = tl.load(key_ptr + key_offsets)
90
+ value = tl.load(value_ptr + value_offsets)
91
+ slot = tl.load(slot_mapping_ptr + idx)
92
+ cache_offsets = slot * D + tl.arange(0, D)
93
+ tl.store(k_cache_ptr + cache_offsets, key)
94
+ tl.store(v_cache_ptr + cache_offsets, value)
95
+
96
+
97
+ def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
98
+ N, num_heads, head_dim = key.shape
99
+ D = num_heads * head_dim
100
+ assert key.stride(-1) == 1 and value.stride(-1) == 1
101
+ assert key.stride(1) == head_dim and value.stride(1) == head_dim
102
+ assert k_cache.stride(1) == D and v_cache.stride(1) == D
103
+ assert slot_mapping.numel() == N
104
+ store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
105
+
106
+
107
+ class Attention(nn.Module):
108
+
109
+ def __init__(
110
+ self,
111
+ num_heads,
112
+ head_dim,
113
+ scale,
114
+ num_kv_heads,
115
+ ):
116
+ super().__init__()
117
+ self.num_heads = num_heads
118
+ self.head_dim = head_dim
119
+ self.scale = scale
120
+ self.num_kv_heads = num_kv_heads
121
+ self.k_cache = self.v_cache = torch.tensor([])
122
+
123
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
124
+ o: torch.Tensor
125
+ q = q.view(-1, self.num_heads, self.head_dim)
126
+ k = k.view(-1, self.num_kv_heads, self.head_dim)
127
+ v = v.view(-1, self.num_kv_heads, self.head_dim)
128
+ context = get_context()
129
+ k_cache, v_cache = self.k_cache, self.v_cache
130
+ if k_cache.numel() and v_cache.numel():
131
+ store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
132
+ if context.is_prefill:
133
+ if context.block_tables is not None: # prefix cache
134
+ k, v = k_cache, v_cache
135
+ o = flash_attn_varlen_func(q, k, v,
136
+ max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
137
+ max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
138
+ softmax_scale=self.scale, causal=True, block_table=context.block_tables)
139
+ else: # decode
140
+ o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
141
+ cache_seqlens=context.context_lens, block_table=context.block_tables,
142
+ softmax_scale=self.scale, causal=True)
143
+ o = o.view(-1, self.num_heads * self.head_dim)
144
+ return o
145
+
146
+
147
+ class VocabParallelEmbedding(nn.Module):
148
+
149
+ def __init__(
150
+ self,
151
+ num_embeddings: int,
152
+ embedding_dim: int,
153
+ ):
154
+ super().__init__()
155
+ # TODO(xcsong): support tp > 1
156
+ self.tp_rank = 0 # dist.get_rank()
157
+ self.tp_size = 1 # dist.get_world_size()
158
+ assert num_embeddings % self.tp_size == 0
159
+ self.num_embeddings = num_embeddings
160
+ self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
161
+ self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
162
+ self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
163
+ self.embedding_dim = embedding_dim
164
+ self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
165
+ self.weight.weight_loader = self.weight_loader
166
+
167
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
168
+ param_data = param.data
169
+ shard_size = param_data.size(0)
170
+ start_idx = self.tp_rank * shard_size
171
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
172
+ assert param_data.size() == loaded_weight.size()
173
+ param_data.copy_(loaded_weight)
174
+
175
+ def forward(self, x: torch.Tensor):
176
+ if self.tp_size > 1:
177
+ mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
178
+ x = mask * (x - self.vocab_start_idx)
179
+ y = F.embedding(x, self.weight)
180
+ if self.tp_size > 1:
181
+ y = mask.unsqueeze(1) * y
182
+ dist.all_reduce(y)
183
+ return y
184
+
185
+
186
+ class ParallelLMHead(VocabParallelEmbedding):
187
+
188
+ def __init__(
189
+ self,
190
+ num_embeddings: int,
191
+ embedding_dim: int,
192
+ bias: bool = False,
193
+ ):
194
+ super().__init__(num_embeddings, embedding_dim)
195
+ if bias:
196
+ self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
197
+ self.bias.weight_loader = self.weight_loader
198
+ else:
199
+ self.register_parameter("bias", None)
200
+
201
+ def forward(self, x: torch.Tensor):
202
+ context = get_context()
203
+ if context.is_prefill:
204
+ last_indices = context.cu_seqlens_q[1:] - 1
205
+ x = x[last_indices].contiguous()
206
+ logits = F.linear(x, self.weight, self.bias)
207
+ if self.tp_size > 1:
208
+ all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
209
+ dist.gather(logits, all_logits, 0)
210
+ logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
211
+ return logits
212
+
213
+
214
+ def divide(numerator, denominator):
215
+ assert numerator % denominator == 0
216
+ return numerator // denominator
217
+
218
+
219
+ class LinearBase(nn.Module):
220
+
221
+ def __init__(
222
+ self,
223
+ input_size: int,
224
+ output_size: int,
225
+ tp_dim: int | None = None,
226
+ ):
227
+ super().__init__()
228
+ self.input_size = input_size
229
+ self.output_size = output_size
230
+ self.tp_dim = tp_dim
231
+ # TODO(xcsong): support tp > 1
232
+ self.tp_rank = 0 # dist.get_rank()
233
+ self.tp_size = 1 # dist.get_world_size()
234
+
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ raise NotImplementedError
237
+
238
+
239
+ class ReplicatedLinear(LinearBase):
240
+
241
+ def __init__(
242
+ self,
243
+ input_size: int,
244
+ output_size: int,
245
+ bias: bool = False,
246
+ ):
247
+ super().__init__(input_size, output_size)
248
+ self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
249
+ self.weight.weight_loader = self.weight_loader
250
+ if bias:
251
+ self.bias = nn.Parameter(torch.empty(self.output_size))
252
+ self.bias.weight_loader = self.weight_loader
253
+ else:
254
+ self.register_parameter("bias", None)
255
+
256
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
257
+ assert param.size() == loaded_weight.size()
258
+ param.data.copy_(loaded_weight)
259
+
260
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
261
+ return F.linear(x, self.weight, self.bias)
262
+
263
+
264
+ class ColumnParallelLinear(LinearBase):
265
+
266
+ def __init__(
267
+ self,
268
+ input_size: int,
269
+ output_size: int,
270
+ bias: bool = False,
271
+ ):
272
+ super().__init__(input_size, output_size, 0)
273
+ self.input_size_per_partition = input_size
274
+ self.output_size_per_partition = divide(output_size, self.tp_size)
275
+ self.output_partition_sizes = [self.output_size_per_partition]
276
+ if hasattr(self, "output_sizes"):
277
+ self.output_partition_sizes = [
278
+ divide(output_size, self.tp_size)
279
+ for output_size in self.output_sizes
280
+ ]
281
+
282
+ self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
283
+ self.weight.weight_loader = self.weight_loader
284
+ if bias:
285
+ self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
286
+ self.bias.weight_loader = self.weight_loader
287
+ else:
288
+ self.register_parameter("bias", None)
289
+
290
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
291
+ param_data = param.data
292
+ shard_size = param_data.size(self.tp_dim)
293
+ start_idx = self.tp_rank * shard_size
294
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
295
+ assert param_data.size() == loaded_weight.size()
296
+ param_data.copy_(loaded_weight)
297
+
298
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
299
+ return F.linear(x, self.weight, self.bias)
300
+
301
+
302
+ class MergedColumnParallelLinear(ColumnParallelLinear):
303
+
304
+ def __init__(
305
+ self,
306
+ input_size: int,
307
+ output_sizes: list[int],
308
+ bias: bool = False,
309
+ ):
310
+ self.output_sizes = output_sizes
311
+ super().__init__(input_size, sum(output_sizes), bias=bias)
312
+
313
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
314
+ param_data = param.data
315
+ shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
316
+ shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
317
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
318
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
319
+ assert param_data.size() == loaded_weight.size()
320
+ param_data.copy_(loaded_weight)
321
+
322
+
323
+ class QKVParallelLinear(ColumnParallelLinear):
324
+
325
+ def __init__(
326
+ self,
327
+ hidden_size: int,
328
+ head_size: int,
329
+ total_num_heads: int,
330
+ total_num_kv_heads: int | None = None,
331
+ bias: bool = False,
332
+ ):
333
+ self.hidden_size = hidden_size
334
+ self.head_size = head_size
335
+ self.total_num_heads = total_num_heads
336
+ if total_num_kv_heads is None:
337
+ total_num_kv_heads = total_num_heads
338
+ self.total_num_kv_heads = total_num_kv_heads
339
+ # TODO(xcsong): support tp > 1
340
+ tp_size = 1 # dist.get_world_size()
341
+ self.num_heads = divide(self.total_num_heads, tp_size)
342
+ self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
343
+ input_size = self.hidden_size
344
+ output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
345
+ self.output_sizes = [
346
+ self.num_heads * self.head_size * tp_size, # q_proj
347
+ self.num_kv_heads * self.head_size * tp_size, # k_proj
348
+ self.num_kv_heads * self.head_size * tp_size, # v_proj
349
+ ]
350
+
351
+ super().__init__(input_size, output_size, bias)
352
+
353
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
354
+ param_data = param.data
355
+ assert loaded_shard_id in ["q", "k", "v"]
356
+ if loaded_shard_id == "q":
357
+ shard_size = self.num_heads * self.head_size
358
+ shard_offset = 0
359
+ elif loaded_shard_id == "k":
360
+ shard_size = self.num_kv_heads * self.head_size
361
+ shard_offset = self.num_heads * self.head_size
362
+ else:
363
+ shard_size = self.num_kv_heads * self.head_size
364
+ shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
365
+ param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
366
+ loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
367
+ assert param_data.size() == loaded_weight.size()
368
+ param_data.copy_(loaded_weight)
369
+
370
+
371
+ class RowParallelLinear(LinearBase):
372
+
373
+ def __init__(
374
+ self,
375
+ input_size: int,
376
+ output_size: int,
377
+ bias: bool = False,
378
+ ):
379
+ super().__init__(input_size, output_size, 1)
380
+ self.input_size_per_partition = divide(input_size, self.tp_size)
381
+ self.output_size_per_partition = output_size
382
+ self.output_partition_sizes = [output_size]
383
+
384
+ self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
385
+ self.weight.weight_loader = self.weight_loader
386
+ if bias:
387
+ self.bias = nn.Parameter(torch.empty(self.output_size))
388
+ self.bias.weight_loader = self.weight_loader
389
+ else:
390
+ self.register_parameter("bias", None)
391
+
392
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
393
+ param_data = param.data
394
+ shard_size = param_data.size(self.tp_dim)
395
+ start_idx = self.tp_rank * shard_size
396
+ loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
397
+ assert param_data.size() == loaded_weight.size()
398
+ param_data.copy_(loaded_weight)
399
+
400
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
401
+ y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
402
+ if self.tp_size > 1:
403
+ dist.all_reduce(y)
404
+ return y
405
+
406
+
407
+ def apply_rotary_emb(
408
+ x: torch.Tensor,
409
+ cos: torch.Tensor,
410
+ sin: torch.Tensor,
411
+ ) -> torch.Tensor:
412
+ cos = cos.unsqueeze(-2)
413
+ sin = sin.unsqueeze(-2)
414
+ x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
415
+ y1 = x1 * cos - x2 * sin
416
+ y2 = x2 * cos + x1 * sin
417
+ return torch.cat((y1, y2), dim=-1).to(x.dtype)
418
+
419
+
420
+ class RotaryEmbedding(nn.Module):
421
+
422
+ def __init__(
423
+ self,
424
+ head_size: int,
425
+ rotary_dim: int,
426
+ max_position_embeddings: int,
427
+ base: float,
428
+ ) -> None:
429
+ super().__init__()
430
+ self.head_size = head_size
431
+ assert rotary_dim == head_size
432
+ inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
433
+ t = torch.arange(max_position_embeddings, dtype=torch.float)
434
+ freqs = torch.einsum("i,j -> ij", t, inv_freq)
435
+ cos = freqs.cos()
436
+ sin = freqs.sin()
437
+ cache = torch.cat((cos, sin), dim=-1)
438
+ self.register_buffer("cos_sin_cache", cache, persistent=False)
439
+
440
+ @torch.compile
441
+ def forward(
442
+ self,
443
+ positions: torch.Tensor,
444
+ query: torch.Tensor,
445
+ key: torch.Tensor,
446
+ ) -> tuple[torch.Tensor, torch.Tensor]:
447
+ positions = positions.flatten()
448
+ num_tokens = positions.shape[0]
449
+ cos_sin = self.cos_sin_cache[positions]
450
+ cos, sin = cos_sin.chunk(2, dim=-1)
451
+ query_shape = query.shape
452
+ query = query.view(num_tokens, -1, self.head_size)
453
+ query = apply_rotary_emb(query, cos, sin).view(query_shape)
454
+ key_shape = key.shape
455
+ key = key.view(num_tokens, -1, self.head_size)
456
+ key = apply_rotary_emb(key, cos, sin).view(key_shape)
457
+ return query, key
458
+
459
+
460
+ @lru_cache(1)
461
+ def get_rope(
462
+ head_size: int,
463
+ rotary_dim: int,
464
+ max_position: int,
465
+ base: float,
466
+ rope_scaling: dict | None = None,
467
+ ):
468
+ assert rope_scaling is None
469
+ rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
470
+ return rotary_emb
471
+
472
+
473
+ class Qwen2Attention(nn.Module):
474
+
475
+ def __init__(
476
+ self,
477
+ hidden_size: int,
478
+ num_heads: int,
479
+ num_kv_heads: int,
480
+ max_position: int = 4096 * 32,
481
+ head_dim: int | None = None,
482
+ rms_norm_eps: float = 1e-06,
483
+ qkv_bias: bool = True,
484
+ rope_theta: float = 1000000.0,
485
+ rope_scaling: tuple | None = None,
486
+ ) -> None:
487
+ super().__init__()
488
+ self.hidden_size = hidden_size
489
+ # TODO(xcsong): support tp > 1
490
+ tp_size = 1 # dist.get_world_size()
491
+ self.total_num_heads = num_heads
492
+ assert self.total_num_heads % tp_size == 0
493
+ self.num_heads = self.total_num_heads // tp_size
494
+ self.total_num_kv_heads = num_kv_heads
495
+ assert self.total_num_kv_heads % tp_size == 0
496
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
497
+ self.head_dim = head_dim or hidden_size // self.total_num_heads
498
+ self.q_size = self.num_heads * self.head_dim
499
+ self.kv_size = self.num_kv_heads * self.head_dim
500
+ self.scaling = self.head_dim**-0.5
501
+ self.rope_theta = rope_theta
502
+
503
+ self.qkv_proj = QKVParallelLinear(
504
+ hidden_size,
505
+ self.head_dim,
506
+ self.total_num_heads,
507
+ self.total_num_kv_heads,
508
+ bias=qkv_bias,
509
+ )
510
+ self.o_proj = RowParallelLinear(
511
+ self.total_num_heads * self.head_dim,
512
+ hidden_size,
513
+ bias=False,
514
+ )
515
+
516
+ self.rotary_emb = get_rope(
517
+ self.head_dim,
518
+ rotary_dim=self.head_dim,
519
+ max_position=max_position,
520
+ base=self.rope_theta,
521
+ rope_scaling=rope_scaling,
522
+ )
523
+ self.attn = Attention(self.num_heads,
524
+ self.head_dim,
525
+ self.scaling,
526
+ num_kv_heads=self.num_kv_heads)
527
+
528
+ def forward(
529
+ self,
530
+ positions: torch.Tensor,
531
+ hidden_states: torch.Tensor,
532
+ ) -> torch.Tensor:
533
+ qkv = self.qkv_proj(hidden_states)
534
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
535
+ q, k = self.rotary_emb(positions, q, k)
536
+ o = self.attn(q, k, v)
537
+ output = self.o_proj(o)
538
+ return output
539
+
540
+
541
+ class Qwen2MLP(nn.Module):
542
+
543
+ def __init__(
544
+ self,
545
+ hidden_size: int,
546
+ intermediate_size: int,
547
+ hidden_act: str,
548
+ ) -> None:
549
+ super().__init__()
550
+ self.gate_up_proj = MergedColumnParallelLinear(
551
+ hidden_size,
552
+ [intermediate_size] * 2,
553
+ bias=False,
554
+ )
555
+ self.down_proj = RowParallelLinear(
556
+ intermediate_size,
557
+ hidden_size,
558
+ bias=False,
559
+ )
560
+ assert hidden_act == "silu"
561
+ self.act_fn = SiluAndMul()
562
+
563
+ def forward(self, x):
564
+ gate_up = self.gate_up_proj(x)
565
+ x = self.act_fn(gate_up)
566
+ x = self.down_proj(x)
567
+ return x
568
+
569
+
570
+ class Qwen2DecoderLayer(nn.Module):
571
+
572
+ def __init__(
573
+ self,
574
+ config: CosyVoice2LLMConfig,
575
+ ) -> None:
576
+ super().__init__()
577
+ self.hidden_size = config.hidden_size
578
+ self.self_attn = Qwen2Attention(
579
+ hidden_size=self.hidden_size,
580
+ num_heads=config.num_attention_heads,
581
+ num_kv_heads=config.num_key_value_heads,
582
+ max_position=config.max_position_embeddings,
583
+ rms_norm_eps=config.rms_norm_eps,
584
+ qkv_bias=getattr(config, "qkv_bias", True),
585
+ head_dim=getattr(config, "head_dim", None),
586
+ rope_theta=getattr(config, "rope_theta", 1000000.0),
587
+ rope_scaling=getattr(config, "rope_scaling", None),
588
+ )
589
+ self.mlp = Qwen2MLP(
590
+ hidden_size=config.hidden_size,
591
+ intermediate_size=config.intermediate_size,
592
+ hidden_act=config.hidden_act,
593
+ )
594
+ self.input_layernorm = RMSNorm(config.hidden_size,
595
+ eps=config.rms_norm_eps)
596
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
597
+ eps=config.rms_norm_eps)
598
+
599
+ def forward(
600
+ self,
601
+ positions: torch.Tensor,
602
+ hidden_states: torch.Tensor,
603
+ residual: torch.Tensor | None,
604
+ ) -> tuple[torch.Tensor, torch.Tensor]:
605
+ if residual is None:
606
+ residual = hidden_states
607
+ hidden_states = self.input_layernorm(hidden_states)
608
+ else:
609
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
610
+ hidden_states = self.self_attn(
611
+ positions=positions,
612
+ hidden_states=hidden_states,
613
+ )
614
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
615
+ hidden_states = self.mlp(hidden_states)
616
+ return hidden_states, residual
flashcosyvoice/modules/sampler.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class Sampler(nn.Module):
6
+ """
7
+ Optimized sampler implementation using vectorized operations instead of loops, significantly improving performance
8
+
9
+ Performance optimizations:
10
+ 1. Using batch processing instead of sequence loops, reducing Python loop overhead
11
+ 2. Using PyTorch's vectorized operations (like torch.sort, torch.gather) for parallel computation
12
+ 3. Using mask operations to apply top-k filtering at once, avoiding per-sequence processing
13
+ """
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, top_k: int = None):
18
+ """
19
+ Perform sampling operation using vectorized method for top-k filtering
20
+
21
+ Args:
22
+ logits: Logits tensor with shape [batch_size, vocab_size]
23
+ temperatures: Temperature parameters with shape [batch_size]
24
+ top_k: Top-k value for filtering (uniform across all sequences)
25
+
26
+ Returns:
27
+ Sampled token IDs
28
+ """
29
+ logits = logits.to(torch.float)
30
+ greedy_tokens = logits.argmax(dim=-1) # Greedy decoding result, used when temperature=0
31
+ logits.div_(temperatures.unsqueeze(dim=1)) # Apply temperature scaling
32
+
33
+ # Apply uniform top-k filtering if top_k is provided
34
+ if top_k is not None and top_k > 0:
35
+ vocab_size = logits.size(-1)
36
+
37
+ # Create a mask to store which positions should be kept
38
+ mask = torch.zeros_like(logits, dtype=torch.bool)
39
+
40
+ # Batch sorting for all sequences at once
41
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
42
+
43
+ # Get threshold for each sequence (the k-th largest value)
44
+ k_value = min(top_k, vocab_size) # Ensure k doesn't exceed vocab size
45
+ thresholds = sorted_logits[:, k_value-1:k_value] # Shape [batch_size, 1]
46
+ thresholds = thresholds.expand(-1, vocab_size) # Expand to match logits shape
47
+
48
+ # Create mask: only keep logits greater than or equal to threshold
49
+ mask = logits >= thresholds
50
+
51
+ # Apply mask: set logits not in top-k to negative infinity
52
+ logits = torch.where(mask, logits, torch.tensor(float('-inf'), device=logits.device))
53
+
54
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float)
55
+ # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
56
+ sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
57
+ return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
58
+
59
+
60
+ class RasSampler(nn.Module):
61
+ """
62
+ Optimized Repetition Aware Sampling implementation
63
+
64
+ Performance optimizations:
65
+ 1. Using vectorized nucleus sampling instead of loop implementation, improving sampling efficiency
66
+ 2. Using tensor operations to calculate repetition rate, reducing Python loop overhead
67
+ 3. Optimizing EOS handling logic, reducing unnecessary resampling
68
+ 4. Using PyTorch's vectorized operations for parallel computation
69
+ 5. Batch processing for all sequences, dramatically improving throughput
70
+ 6. Robust handling for sequences of any length, including empty sequences
71
+ """
72
+ def __init__(self):
73
+ super().__init__()
74
+
75
+ def forward(self, logits: torch.Tensor, decoded_tokens_list: list,
76
+ win_size: int = 10, tau_r: float = 0.1,
77
+ top_p: float = 0.8, top_k: int = 25,
78
+ eos_token: int = 6561, min_tokens: list[int] = None):
79
+ """
80
+ Execute repetition-aware sampling using optimized vectorized operations with batch processing
81
+
82
+ Args:
83
+ logits: Input logits with shape [batch_size, vocab_size]
84
+ decoded_tokens_list: List of decoded tokens, each element is a token list for a batch
85
+ win_size: Window size for repetition detection (uniform across all batch items)
86
+ tau_r: Repetition threshold (uniform across all batch items)
87
+ top_p: Nucleus sampling probability threshold (uniform across all batch items)
88
+ top_k: Nucleus sampling top-k threshold (uniform across all batch items)
89
+ eos_token: End of sequence token ID (uniform across all batch items)
90
+ min_tokens: List of minimum tokens to generate before allowing EOS, one per batch item
91
+ Returns:
92
+ Selected token IDs
93
+ """
94
+ batch_size = logits.size(0)
95
+ device = logits.device
96
+ result = torch.zeros(batch_size, dtype=torch.long, device=device)
97
+
98
+ # Set default values if not provided
99
+ if min_tokens is None:
100
+ min_tokens = [2] * batch_size
101
+
102
+ # Ensure min_tokens list has the correct length
103
+ assert len(min_tokens) == batch_size, f"min_tokens length {len(min_tokens)} != batch_size {batch_size}"
104
+
105
+ # Force continue decode first token
106
+ for i in range(batch_size):
107
+ if i < len(decoded_tokens_list) and len(decoded_tokens_list[i]) == 0:
108
+ logits[i, eos_token] = -float('inf')
109
+
110
+ # 1. First, perform nucleus sampling for all sequences
111
+ probs = torch.softmax(logits, dim=-1)
112
+
113
+ # Use vectorized nucleus sampling for all sequences
114
+ # This can be done in batch since top_p and top_k are uniform
115
+ sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True)
116
+ cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
117
+
118
+ # Create masks for top-p and top-k filtering
119
+ top_p_mask = cumulative_probs <= top_p
120
+
121
+ # Create top-k mask (first top_k positions are True)
122
+ top_k_mask = torch.zeros_like(top_p_mask)
123
+ top_k_mask[:, :top_k] = True
124
+
125
+ # Combine masks
126
+ mask = top_p_mask & top_k_mask
127
+
128
+ # Ensure at least one token is selected per sequence
129
+ first_token_mask = torch.zeros_like(mask)
130
+ first_token_mask[:, 0] = True
131
+ mask = mask | first_token_mask
132
+
133
+ # Sample from the filtered distribution
134
+ sample_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs))
135
+ sample_probs = sample_probs / sample_probs.sum(dim=-1, keepdim=True)
136
+
137
+ # Sample indices from the filtered distribution
138
+ sampled_indices = torch.multinomial(sample_probs, 1).squeeze(-1)
139
+ top_ids = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1)
140
+
141
+ # 2. Check for repetitions and apply random sampling if needed
142
+ # Extract recent tokens for each sequence, handling empty or short sequences
143
+ recent_tokens_list = []
144
+ for i in range(batch_size):
145
+ # Handle index out of range or empty tokens
146
+ if i < len(decoded_tokens_list):
147
+ tokens = decoded_tokens_list[i]
148
+ if len(tokens) > 0:
149
+ start_idx = max(0, len(tokens) - win_size)
150
+ recent_tokens_list.append(tokens[start_idx:])
151
+ else:
152
+ recent_tokens_list.append([]) # Empty list for empty tokens
153
+ else:
154
+ recent_tokens_list.append([]) # Empty list for missing batch items
155
+
156
+ # Check if we have any tokens to process for repetition detection
157
+ if any(len(tokens) > 0 for tokens in recent_tokens_list):
158
+ # Convert to padded tensor for batch processing
159
+ max_recent_len = max(len(tokens) for tokens in recent_tokens_list)
160
+ if max_recent_len > 0: # Only proceed if we have tokens
161
+ recent_tokens_tensor = torch.zeros((batch_size, max_recent_len), dtype=torch.long, device=device) - 1
162
+ for i, tokens in enumerate(recent_tokens_list):
163
+ if len(tokens) > 0:
164
+ recent_tokens_tensor[i, -len(tokens):] = torch.tensor(tokens, device=device)
165
+
166
+ # Create a mask for valid positions and to avoid division by zero
167
+ valid_positions_mask = torch.zeros_like(recent_tokens_tensor, dtype=torch.bool)
168
+ for i, tokens in enumerate(recent_tokens_list):
169
+ if len(tokens) > 0:
170
+ valid_positions_mask[i, -len(tokens):] = True
171
+
172
+ # Check repetition rates
173
+ repetition_counts = torch.zeros(batch_size, device=device)
174
+ for i in range(batch_size):
175
+ if len(recent_tokens_list[i]) > 0:
176
+ repetition_counts[i] = (recent_tokens_tensor[i] == top_ids[i]).sum()
177
+
178
+ # Calculate repetition rates, avoiding division by zero
179
+ recent_lengths = torch.tensor([max(1, len(tokens)) for tokens in recent_tokens_list], device=device)
180
+ repetition_rates = repetition_counts / recent_lengths
181
+
182
+ # Identify sequences needing random sampling
183
+ need_random = repetition_rates >= tau_r
184
+
185
+ # Apply random sampling where needed
186
+ if need_random.any():
187
+ random_indices = torch.multinomial(probs[need_random], 1).squeeze(-1)
188
+ top_ids[need_random] = random_indices
189
+
190
+ # 3. Handle EOS tokens
191
+ # Create mask for sequences that should ignore EOS tokens
192
+ ignore_eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
193
+ for i in range(batch_size):
194
+ if i < len(decoded_tokens_list):
195
+ ignore_eos_mask[i] = len(decoded_tokens_list[i]) < min_tokens[i]
196
+ else:
197
+ ignore_eos_mask[i] = True # Default to ignoring EOS for missing sequences
198
+
199
+ is_eos_mask = top_ids == eos_token
200
+ need_resample = ignore_eos_mask & is_eos_mask
201
+
202
+ # Resample for sequences that need it
203
+ if need_resample.any():
204
+ max_trials = 100
205
+ for attempt in range(max_trials):
206
+ # Break if no more resampling needed
207
+ if not need_resample.any():
208
+ break
209
+
210
+ # Sample new tokens for sequences that need resampling
211
+ new_samples = torch.multinomial(probs[need_resample], 1).squeeze(-1)
212
+
213
+ # Update top_ids with new samples
214
+ top_ids[need_resample] = new_samples
215
+
216
+ # Update which sequences still need resampling
217
+ is_eos_mask = top_ids == eos_token
218
+ need_resample = ignore_eos_mask & is_eos_mask
219
+
220
+ # If still have EOS tokens that should be ignored, force them to be non-EOS
221
+ if need_resample.any():
222
+ # Force to a non-EOS token (e.g., the second most likely token)
223
+ for i in range(batch_size):
224
+ if need_resample[i]:
225
+ # Get second most likely token (or first if only one token)
226
+ second_best_idx = 1 if sorted_indices.size(1) > 1 else 0
227
+ top_ids[i] = sorted_indices[i, second_best_idx]
228
+
229
+ result = top_ids
230
+
231
+ return result
flashcosyvoice/utils/__init__.py ADDED
File without changes
flashcosyvoice/utils/audio.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from librosa.filters import mel as librosa_mel_fn
4
+ from scipy.io.wavfile import read
5
+
6
+ MAX_WAV_VALUE = 32768.0
7
+
8
+
9
+ def load_wav(full_path):
10
+ sampling_rate, data = read(full_path)
11
+ return data, sampling_rate
12
+
13
+
14
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
15
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
16
+
17
+
18
+ def dynamic_range_decompression(x, C=1):
19
+ return np.exp(x) / C
20
+
21
+
22
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
23
+ return torch.log(torch.clamp(x, min=clip_val) * C)
24
+
25
+
26
+ def dynamic_range_decompression_torch(x, C=1):
27
+ return torch.exp(x) / C
28
+
29
+
30
+ def spectral_normalize_torch(magnitudes):
31
+ output = dynamic_range_compression_torch(magnitudes)
32
+ return output
33
+
34
+
35
+ def spectral_de_normalize_torch(magnitudes):
36
+ output = dynamic_range_decompression_torch(magnitudes)
37
+ return output
38
+
39
+
40
+ mel_basis = {}
41
+ hann_window = {}
42
+
43
+
44
+ def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480,
45
+ win_size=1920, fmin=0, fmax=8000, center=False):
46
+ global mel_basis, hann_window # pylint: disable=global-statement
47
+ if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
48
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
49
+ mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
50
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
51
+
52
+ y = torch.nn.functional.pad(
53
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
54
+ )
55
+ y = y.squeeze(1)
56
+
57
+ spec = torch.view_as_real(
58
+ torch.stft(
59
+ y,
60
+ n_fft,
61
+ hop_length=hop_size,
62
+ win_length=win_size,
63
+ window=hann_window[str(y.device)],
64
+ center=center,
65
+ pad_mode="reflect",
66
+ normalized=False,
67
+ onesided=True,
68
+ return_complex=True,
69
+ )
70
+ )
71
+
72
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
73
+
74
+ spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
75
+ spec = spectral_normalize_torch(spec)
76
+
77
+ return spec
flashcosyvoice/utils/context.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+
6
+ @dataclass
7
+ class Context:
8
+ is_prefill: bool = False
9
+ cu_seqlens_q: torch.Tensor | None = None
10
+ cu_seqlens_k: torch.Tensor | None = None
11
+ max_seqlen_q: int = 0
12
+ max_seqlen_k: int = 0
13
+ slot_mapping: torch.Tensor | None = None
14
+ context_lens: torch.Tensor | None = None
15
+ block_tables: torch.Tensor | None = None
16
+
17
+ _CONTEXT = Context()
18
+
19
+ def get_context():
20
+ return _CONTEXT
21
+
22
+ def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
23
+ global _CONTEXT
24
+ _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
25
+
26
+ def reset_context():
27
+ global _CONTEXT
28
+ _CONTEXT = Context()
flashcosyvoice/utils/loader.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+
4
+ import torch
5
+ from safetensors import safe_open
6
+ from torch import nn
7
+
8
+ from flashcosyvoice.config import CosyVoice2LLMConfig
9
+
10
+
11
+ def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
12
+ param.data.copy_(loaded_weight)
13
+
14
+
15
+ def load_text_llm(model: nn.Module, path: str):
16
+ packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
17
+ for file in glob(os.path.join(path, "*.safetensors")):
18
+ with safe_open(file, "pt", "cpu") as f:
19
+ for weight_name in f.keys():
20
+ for k in packed_modules_mapping:
21
+ if k in weight_name:
22
+ v, shard_id = packed_modules_mapping[k]
23
+ param_name = weight_name.replace(k, v)
24
+ param = model.get_parameter(param_name)
25
+ weight_loader = param.weight_loader
26
+ weight_loader(param, f.get_tensor(weight_name), shard_id)
27
+ break
28
+ else:
29
+ param = model.get_parameter(weight_name)
30
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
31
+ weight_loader(param, f.get_tensor(weight_name))
32
+
33
+
34
+ def load_speech_llm(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig):
35
+ packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
36
+
37
+ # NOTE(xcsong): 1. load speech embedding + sos/taskid embedding + lm head
38
+ embedding_weights = {}
39
+ tmp_weights = torch.load(f"{path}/llm.pt", map_location="cpu", weights_only=True)
40
+ missed, missed_names = 0, []
41
+ for k, v in tmp_weights.items():
42
+ if k == "speech_embedding.weight": # torch.Size([6564, 896])
43
+ speech_embedding_size = hf_config.speech_vocab_size # 6562
44
+ # NOTE(xcsong): padding to 6592 for vllm tensor parallel
45
+ if speech_embedding_size != v.shape[0]: # [6564, 896] -> [6562, 896]
46
+ assert speech_embedding_size <= v.shape[0], f"speech_embedding_size should be less than or equal to {v.shape[0]}, but got {speech_embedding_size}"
47
+ v = v[:speech_embedding_size, :]
48
+ embedding_weights["speech_embedding.weight"] = v
49
+ elif k == "llm_embedding.weight": # torch.Size([2, 896]), eos and task_id
50
+ assert v.shape[0] == 2, f"llm_embedding.weight should be of shape [2, 896], but got {v.shape}"
51
+ embedding_weights["llm_embedding.weight"] = v
52
+ elif k == "llm.model.model.embed_tokens.weight": # torch.Size([151936, 896])
53
+ embedding_weights["model.embed_tokens.weight"] = v
54
+ elif k == "llm_decoder.weight": # torch.Size([6564, 896])
55
+ lm_head_size = hf_config.speech_vocab_size # 6562
56
+ if lm_head_size != v.shape[0]: # [6564, 896] -> [6562, 896]
57
+ assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
58
+ v = v[:lm_head_size, :]
59
+ param = model.get_parameter("lm_head.weight")
60
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
61
+ weight_loader(param, v)
62
+ elif k == "llm_decoder.bias": # torch.Size([6564])
63
+ lm_head_size = hf_config.speech_vocab_size # 6562
64
+ if lm_head_size != v.shape[0]: # [6564] -> [6562]
65
+ assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
66
+ v = v[:lm_head_size]
67
+ param = model.get_parameter("lm_head.bias")
68
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
69
+ weight_loader(param, v)
70
+ elif "llm.model." in k:
71
+ weight_name = k.replace("llm.model.", "")
72
+ for kk in packed_modules_mapping:
73
+ if kk in weight_name:
74
+ vv, shard_id = packed_modules_mapping[kk]
75
+ param_name = weight_name.replace(kk, vv)
76
+ try:
77
+ param = model.get_parameter(param_name)
78
+ weight_loader = param.weight_loader
79
+ weight_loader(param, v, shard_id)
80
+ break
81
+ except Exception as e:
82
+ print(e)
83
+ print(f"skip parameter (1): {weight_name}")
84
+ continue
85
+ else:
86
+ try:
87
+ param = model.get_parameter(weight_name)
88
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
89
+ weight_loader(param, v)
90
+ except Exception as e:
91
+ print(e)
92
+ print(f"skip parameter (2): {weight_name}")
93
+ continue
94
+ else:
95
+ missed += 1
96
+ missed_names.append(weight_name)
97
+ continue
98
+ print(f"missed {missed} parameters: {missed_names}")
99
+
100
+ # NOTE(xcsong): 2. merge text embedding, sos/taskid embedding, and speech embedding
101
+ text_embedding_weight = embedding_weights["model.embed_tokens.weight"].cpu() # [151936, 896]
102
+ sos_taskid_embedding_weight = embedding_weights["llm_embedding.weight"].cpu() # [2, 896]
103
+ speech_embedding_weight = embedding_weights["speech_embedding.weight"].cpu() # [6562, 896]
104
+ final_embedding_weight = torch.cat([speech_embedding_weight, sos_taskid_embedding_weight, text_embedding_weight], dim=0) # [158500, 896]
105
+ param = model.get_parameter("model.embed_tokens.weight")
106
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
107
+ weight_loader(param, final_embedding_weight)
108
+
109
+
110
+ def load_model(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig | None = None):
111
+ if model.model_type == "speech_llm":
112
+ load_speech_llm(model, path, hf_config)
113
+ elif model.model_type == "text_llm":
114
+ load_text_llm(model, path)
115
+ else:
116
+ raise ValueError(f"Unsupported model type: {model.model_type}")
flashcosyvoice/utils/memory.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from pynvml import * # noqa
5
+
6
+
7
+ def get_gpu_memory():
8
+ torch.cuda.synchronize()
9
+ nvmlInit()
10
+ visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
11
+ cuda_device_idx = torch.cuda.current_device()
12
+ cuda_device_idx = visible_device[cuda_device_idx]
13
+ handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
14
+ mem_info = nvmlDeviceGetMemoryInfo(handle)
15
+ total_memory = mem_info.total
16
+ used_memory = mem_info.used
17
+ free_memory = mem_info.free
18
+ nvmlShutdown()
19
+ return total_memory, used_memory, free_memory
stepaudio2.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
3
+
4
+ from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels
5
+
6
+
7
+ class StepAudio2Base:
8
+
9
+ def __init__(self, model_path: str):
10
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right")
11
+ self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
12
+ self.eos_token_id = self.llm_tokenizer.eos_token_id
13
+
14
+ def __call__(self, messages: list, **kwargs):
15
+ messages, mels = self.apply_chat_template(messages)
16
+
17
+ # Tokenize prompts
18
+ prompt_ids = []
19
+ for msg in messages:
20
+ if isinstance(msg, str):
21
+ prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"])
22
+ elif isinstance(msg, list):
23
+ prompt_ids.append(torch.tensor([msg], dtype=torch.int32))
24
+ else:
25
+ raise ValueError(f"Unsupported content type: {type(msg)}")
26
+ prompt_ids = torch.cat(prompt_ids, dim=-1).cuda()
27
+ attention_mask = torch.ones_like(prompt_ids)
28
+
29
+ #mels = None if len(mels) == 0 else torch.stack(mels).cuda()
30
+ #mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda')
31
+ if len(mels)==0:
32
+ mels = None
33
+ mel_lengths = None
34
+ else:
35
+ mels, mel_lengths = padding_mels(mels)
36
+ mels = mels.cuda()
37
+ mel_lengths = mel_lengths.cuda()
38
+
39
+ generate_inputs = {
40
+ "input_ids": prompt_ids,
41
+ "wavs": mels,
42
+ "wav_lens": mel_lengths,
43
+ "attention_mask":attention_mask
44
+ }
45
+
46
+ generation_config = dict(max_new_tokens=2048,
47
+ pad_token_id=self.llm_tokenizer.pad_token_id,
48
+ eos_token_id=self.eos_token_id,
49
+ )
50
+ generation_config.update(kwargs)
51
+ generation_config = GenerationConfig(**generation_config)
52
+
53
+ outputs = self.llm.generate(**generate_inputs, generation_config=generation_config)
54
+ output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist()
55
+ output_text_tokens = [i for i in output_token_ids if i < 151688]
56
+ output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695]
57
+ output_text = self.llm_tokenizer.decode(output_text_tokens)
58
+ return output_token_ids, output_text, output_audio_tokens
59
+
60
+ def apply_chat_template(self, messages: list):
61
+ results = []
62
+ mels = []
63
+ for msg in messages:
64
+ content = msg
65
+ if isinstance(content, str):
66
+ text_with_audio = content
67
+ results.append(text_with_audio)
68
+ elif isinstance(content, dict):
69
+ if content["type"] == "text":
70
+ results.append(f"{content['text']}")
71
+ elif content["type"] == "audio":
72
+ audio = load_audio(content['audio'])
73
+ for i in range(0, audio.shape[0], 16000 * 25):
74
+ mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
75
+ mels.append(mel)
76
+ audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
77
+ results.append(f"<audio_start>{audio_tokens}<audio_end>")
78
+ elif content["type"] == "token":
79
+ results.append(content["token"])
80
+ else:
81
+ raise ValueError(f"Unsupported content type: {type(content)}")
82
+ # print(results)
83
+ return results, mels
84
+
85
+
86
+ class StepAudio2(StepAudio2Base):
87
+
88
+ def __init__(self, model_path: str):
89
+ super().__init__(model_path)
90
+ self.llm_tokenizer.eos_token = "<|EOT|>"
91
+ self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
92
+ self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
93
+
94
+ def apply_chat_template(self, messages: list):
95
+ results = []
96
+ mels = []
97
+ for msg in messages:
98
+ role = msg["role"]
99
+ content = msg["content"]
100
+ if role == "user":
101
+ role = "human"
102
+ if isinstance(content, str):
103
+ text_with_audio = f"<|BOT|>{role}\n{content}"
104
+ text_with_audio += '<|EOT|>' if msg.get('eot', True) else ''
105
+ results.append(text_with_audio)
106
+ elif isinstance(content, list):
107
+ results.append(f"<|BOT|>{role}\n")
108
+ for item in content:
109
+ if item["type"] == "text":
110
+ results.append(f"{item['text']}")
111
+ elif item["type"] == "audio":
112
+ audio = load_audio(item['audio'])
113
+ for i in range(0, audio.shape[0], 16000 * 25):
114
+ mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
115
+ mels.append(mel)
116
+ audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
117
+ results.append(f"<audio_start>{audio_tokens}<audio_end>")
118
+ elif item["type"] == "token":
119
+ results.append(item["token"])
120
+ if msg.get('eot', True):
121
+ results.append('<|EOT|>')
122
+ elif content is None:
123
+ results.append(f"<|BOT|>{role}\n")
124
+ else:
125
+ raise ValueError(f"Unsupported content type: {type(content)}")
126
+ # print(results)
127
+ return results, mels
128
+
129
+ if __name__ == '__main__':
130
+ from token2wav import Token2wav
131
+
132
+ model = StepAudio2('/mnt/gpfs/lijingbei/Step-Audio-2-mini')
133
+ token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav')
134
+
135
+ # Text-to-text conversation
136
+ print()
137
+ messages = [
138
+ {"role": "system", "content": "You are a helpful assistant."},
139
+ {"role": "human", "content": "Give me a brief introduction to the Great Wall."},
140
+ {"role": "assistant", "content": None}
141
+ ]
142
+ tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
143
+ print(text)
144
+
145
+ # Text-to-speech conversation
146
+ print()
147
+ messages = [
148
+ {"role": "system", "content": "You are a helpful assistant."},
149
+ {"role": "human", "content": "Give me a brief introduction to the Great Wall."},
150
+ {"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
151
+ ]
152
+ tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
153
+ print(text)
154
+ print(tokens)
155
+ audio = token2wav(audio, prompt_wav='assets/default_male.wav')
156
+ with open('output-male.wav', 'wb') as f:
157
+ f.write(audio)
158
+
159
+ # Speech-to-text conversation
160
+ print()
161
+ messages = [
162
+ {"role": "system", "content": "You are a helpful assistant."},
163
+ {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
164
+ {"role": "assistant", "content": None}
165
+ ]
166
+ tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
167
+ print(text)
168
+
169
+ # Speech-to-speech conversation
170
+ print()
171
+ messages = [
172
+ {"role": "system", "content": "You are a helpful assistant."},
173
+ {"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
174
+ {"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
175
+ ]
176
+ tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
177
+ print(text)
178
+ print(tokens)
179
+ audio = token2wav(audio, prompt_wav='assets/default_female.wav')
180
+ with open('output-female.wav', 'wb') as f:
181
+ f.write(audio)
182
+
183
+ # Multi-turn conversation
184
+ print()
185
+ messages.pop(-1)
186
+ messages += [
187
+ {"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"},
188
+ {"type": "token", "token": tokens}]},
189
+ {"role": "human", "content": "Now write a 4-line poem about it."},
190
+ {"role": "assistant", "content": None}
191
+ ]
192
+ tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
193
+ print(text)
194
+
195
+ # Multi-modal inputs
196
+ print()
197
+ messages = [
198
+ {"role": "system", "content": "You are a helpful assistant."},
199
+ {"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."},
200
+ {"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
201
+ {"role": "assistant", "content": None}
202
+ ]
203
+ tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
204
+ print(text)
token2wav.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import torch
4
+ import torchaudio
5
+ import s3tokenizer
6
+ import onnxruntime
7
+
8
+ import torchaudio.compliance.kaldi as kaldi
9
+ from flashcosyvoice.modules.hifigan import HiFTGenerator
10
+ from flashcosyvoice.utils.audio import mel_spectrogram
11
+ from hyperpyyaml import load_hyperpyyaml
12
+
13
+
14
+ class Token2wav():
15
+
16
+ def __init__(self, model_path, float16=False):
17
+ self.float16 = float16
18
+
19
+ self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval()
20
+
21
+ option = onnxruntime.SessionOptions()
22
+ option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
23
+ option.intra_op_num_threads = 1
24
+ self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
25
+
26
+ with open(f"{model_path}/flow.yaml", "r") as f:
27
+ configs = load_hyperpyyaml(f)
28
+ self.flow = configs['flow']
29
+ if float16:
30
+ self.flow.half()
31
+ self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True)
32
+ self.flow.cuda().eval()
33
+
34
+ self.hift = HiFTGenerator()
35
+ hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()}
36
+ self.hift.load_state_dict(hift_state_dict, strict=True)
37
+ self.hift.cuda().eval()
38
+
39
+ def __call__(self, generated_speech_tokens, prompt_wav):
40
+ audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T]
41
+ mels = s3tokenizer.log_mel_spectrogram(audio)
42
+ mels, mels_lens = s3tokenizer.padding([mels])
43
+ prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda())
44
+
45
+ spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
46
+ spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
47
+ spk_emb = torch.tensor(self.spk_model.run(
48
+ None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
49
+ )[0], device='cuda')
50
+
51
+ audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile')
52
+ audio = audio.mean(dim=0, keepdim=True) # [1, T]
53
+ if sample_rate != 24000:
54
+ audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
55
+ prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
56
+ prompt_mels = prompt_mel.unsqueeze(0).cuda()
57
+ prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda')
58
+
59
+ generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
60
+ generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
61
+
62
+ with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
63
+ mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens,
64
+ prompt_speech_tokens, prompt_speech_tokens_lens,
65
+ prompt_mels, prompt_mels_lens, spk_emb, 10)
66
+
67
+ wav, _ = self.hift(speech_feat=mel)
68
+ output = io.BytesIO()
69
+ torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav')
70
+
71
+ return output.getvalue()
72
+
73
+ if __name__ == '__main__':
74
+ token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav')
75
+
76
+ tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299]
77
+ audio = token2wav(tokens, 'assets/default_male.wav')
78
+ with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f:
79
+ f.write(audio)
utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils.rnn import pad_sequence
5
+ import torchaudio
6
+ from typing import List
7
+
8
+
9
+ def _mel_filters(n_mels: int) -> torch.Tensor:
10
+ """Load the mel filterbank matrix for projecting STFT into a Mel spectrogram."""
11
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
12
+ if n_mels == 128:
13
+ return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128))
14
+ else:
15
+ return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))
16
+
17
+ def load_audio(file_path, target_rate=16000, max_length=None):
18
+ """
19
+ Open an audio file and read as mono waveform, resampling as necessary
20
+ If max_length is provided, truncate the audio to that length
21
+ """
22
+ waveform, sample_rate = torchaudio.load(file_path)
23
+ if sample_rate != target_rate:
24
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
25
+ audio = waveform[0] # get the first channel
26
+
27
+ # Truncate audio if it exceeds max_length
28
+ if max_length is not None and audio.shape[0] > max_length:
29
+ audio = audio[:max_length]
30
+
31
+ return audio
32
+
33
+ def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None):
34
+ """
35
+ Compute the log-Mel spectrogram with specific padding for StepAudio
36
+ """
37
+ if not torch.is_tensor(audio):
38
+ if isinstance(audio, str):
39
+ audio = load_audio(audio)
40
+ audio = torch.from_numpy(audio)
41
+ if device is not None:
42
+ audio = audio.to(device)
43
+ if padding > 0:
44
+ audio = F.pad(audio, (0, padding))
45
+ window = torch.hann_window(400).to(audio.device)
46
+ stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
47
+ magnitudes = stft[..., :-1].abs() ** 2
48
+ filters = _mel_filters(n_mels)
49
+ mel_spec = filters @ magnitudes
50
+
51
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
52
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
53
+ log_spec = (log_spec + 4.0) / 4.0
54
+ return log_spec
55
+
56
+ def compute_token_num(max_feature_len):
57
+ # First, audio goes through encoder:
58
+ # 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged
59
+ # 2. conv2: kernel=3, stride=2, padding=1 -> size/2
60
+ # 3. avg_pooler: kernel=2, stride=2 -> size/2
61
+ max_feature_len = max_feature_len - 2 # remove padding
62
+ encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler
63
+
64
+ # Then through adaptor (parameters from config file):
65
+ padding = 1
66
+ kernel_size = 3 # from config: audio_encoder_config.kernel_size
67
+ stride = 2 # from config: audio_encoder_config.adapter_stride
68
+ adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1
69
+ return adapter_output_dim
70
+
71
+ def padding_mels(data: List[torch.Tensor]):
72
+ """ Padding the data into batch data
73
+
74
+ Parameters
75
+ ----------
76
+ data: List[Tensor], shape of Tensor (128, T)
77
+
78
+ Returns:
79
+ -------
80
+ feats, feats lengths
81
+ """
82
+ sample = data
83
+ assert isinstance(sample, list)
84
+ feats_lengths = torch.tensor([s.size(1)-2 for s in sample],
85
+ dtype=torch.int32)
86
+ feats = [s.t() for s in sample]
87
+ padded_feats = pad_sequence(feats,
88
+ batch_first=True,
89
+ padding_value=0)
90
+
91
+ return padded_feats.transpose(1, 2), feats_lengths