Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7e6946d
1
Parent(s):
3cf0e6f
add model
Browse files- cosyvoice2/flow/__init__.py +0 -0
- cosyvoice2/flow/decoder_dit.py +585 -0
- cosyvoice2/flow/flow.py +225 -0
- cosyvoice2/flow/flow_matching.py +205 -0
- cosyvoice2/transformer/__init__.py +0 -0
- cosyvoice2/transformer/attention.py +328 -0
- cosyvoice2/transformer/embedding.py +119 -0
- cosyvoice2/transformer/encoder_layer.py +163 -0
- cosyvoice2/transformer/positionwise_feed_forward.py +56 -0
- cosyvoice2/transformer/subsampling.py +79 -0
- cosyvoice2/transformer/upsample_encoder_v2.py +483 -0
- cosyvoice2/utils/class_utils.py +41 -0
- cosyvoice2/utils/common.py +101 -0
- cosyvoice2/utils/mask.py +49 -0
- flashcosyvoice/__init__.py +0 -0
- flashcosyvoice/cli.py +424 -0
- flashcosyvoice/config.py +80 -0
- flashcosyvoice/cosyvoice2.py +160 -0
- flashcosyvoice/cosyvoice3.py +1 -0
- flashcosyvoice/engine/__init__.py +0 -0
- flashcosyvoice/engine/block_manager.py +114 -0
- flashcosyvoice/engine/llm_engine.py +125 -0
- flashcosyvoice/engine/model_runner.py +310 -0
- flashcosyvoice/engine/scheduler.py +77 -0
- flashcosyvoice/engine/sequence.py +90 -0
- flashcosyvoice/modules/__init__.py +0 -0
- flashcosyvoice/modules/flow.py +198 -0
- flashcosyvoice/modules/flow_components/__init__.py +0 -0
- flashcosyvoice/modules/flow_components/estimator.py +974 -0
- flashcosyvoice/modules/flow_components/upsample_encoder.py +998 -0
- flashcosyvoice/modules/hifigan.py +249 -0
- flashcosyvoice/modules/hifigan_components/__init__.py +0 -0
- flashcosyvoice/modules/hifigan_components/layers.py +433 -0
- flashcosyvoice/modules/qwen2.py +92 -0
- flashcosyvoice/modules/qwen2_components/__init__.py +0 -0
- flashcosyvoice/modules/qwen2_components/layers.py +616 -0
- flashcosyvoice/modules/sampler.py +231 -0
- flashcosyvoice/utils/__init__.py +0 -0
- flashcosyvoice/utils/audio.py +77 -0
- flashcosyvoice/utils/context.py +28 -0
- flashcosyvoice/utils/loader.py +116 -0
- flashcosyvoice/utils/memory.py +19 -0
- stepaudio2.py +204 -0
- token2wav.py +79 -0
- utils.py +91 -0
cosyvoice2/flow/__init__.py
ADDED
File without changes
|
cosyvoice2/flow/decoder_dit.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from typing import Optional
|
5 |
+
from einops import pack, rearrange, repeat
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
DiT-v5
|
13 |
+
- Add convolution in DiTBlock to increase high-freq component
|
14 |
+
"""
|
15 |
+
|
16 |
+
|
17 |
+
class MLP(torch.nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features:int,
|
21 |
+
hidden_features:Optional[int]=None,
|
22 |
+
out_features:Optional[int]=None,
|
23 |
+
act_layer=nn.GELU,
|
24 |
+
norm_layer=None,
|
25 |
+
bias=True,
|
26 |
+
drop=0.,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
out_features = out_features or in_features
|
31 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
32 |
+
self.act = act_layer()
|
33 |
+
self.drop1 = nn.Dropout(drop)
|
34 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
35 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
36 |
+
self.drop2 = nn.Dropout(drop)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
x = self.fc1(x)
|
40 |
+
x = self.act(x)
|
41 |
+
x = self.drop1(x)
|
42 |
+
x = self.norm(x)
|
43 |
+
x = self.fc2(x)
|
44 |
+
x = self.drop2(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class Attention(torch.nn.Module):
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
dim: int,
|
52 |
+
num_heads: int = 8,
|
53 |
+
head_dim: int = 64,
|
54 |
+
qkv_bias: bool = False,
|
55 |
+
qk_norm: bool = False,
|
56 |
+
attn_drop: float = 0.,
|
57 |
+
proj_drop: float = 0.,
|
58 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
59 |
+
) -> None:
|
60 |
+
super().__init__()
|
61 |
+
self.num_heads = num_heads
|
62 |
+
self.head_dim = head_dim
|
63 |
+
self.inner_dim = num_heads * head_dim
|
64 |
+
self.scale = head_dim ** -0.5
|
65 |
+
|
66 |
+
self.to_q = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
|
67 |
+
self.to_k = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
|
68 |
+
self.to_v = nn.Linear(dim, self.inner_dim, bias=qkv_bias)
|
69 |
+
|
70 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
71 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
72 |
+
|
73 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
74 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
75 |
+
|
76 |
+
self.proj = nn.Linear(self.inner_dim, dim)
|
77 |
+
|
78 |
+
def to_heads(self, ts:torch.Tensor):
|
79 |
+
b, t, c = ts.shape
|
80 |
+
# (b, t, nh, c)
|
81 |
+
ts = ts.reshape(b, t, self.num_heads, c // self.num_heads)
|
82 |
+
ts = ts.transpose(1, 2)
|
83 |
+
return ts
|
84 |
+
|
85 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor:
|
86 |
+
"""Args:
|
87 |
+
x(torch.Tensor): shape (b, t, c)
|
88 |
+
attn_mask(torch.Tensor): shape (b, t, t)
|
89 |
+
"""
|
90 |
+
b, t, c = x.shape
|
91 |
+
|
92 |
+
q = self.to_q(x)
|
93 |
+
k = self.to_k(x)
|
94 |
+
v = self.to_v(x)
|
95 |
+
|
96 |
+
q = self.to_heads(q) # (b, nh, t, c)
|
97 |
+
k = self.to_heads(k)
|
98 |
+
v = self.to_heads(v)
|
99 |
+
|
100 |
+
q = self.q_norm(q)
|
101 |
+
k = self.k_norm(k)
|
102 |
+
|
103 |
+
attn_mask = attn_mask.unsqueeze(1)
|
104 |
+
x = F.scaled_dot_product_attention(
|
105 |
+
q, k, v,
|
106 |
+
attn_mask=attn_mask,
|
107 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
108 |
+
) # (b, nh, t, c)
|
109 |
+
x = x.transpose(1, 2).reshape(b, t, -1)
|
110 |
+
x = self.proj(x)
|
111 |
+
x = self.proj_drop(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
def forward_chunk(self, x: torch.Tensor, att_cache: torch.Tensor=None, attn_mask: torch.Tensor=None):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
x: shape (b, dt, c)
|
118 |
+
att_cache: shape (b, nh, t, c*2)
|
119 |
+
"""
|
120 |
+
b, t, c = x.shape
|
121 |
+
|
122 |
+
q = self.to_q(x)
|
123 |
+
k = self.to_k(x)
|
124 |
+
v = self.to_v(x)
|
125 |
+
|
126 |
+
q = self.to_heads(q) # (b, nh, t, c)
|
127 |
+
k = self.to_heads(k)
|
128 |
+
v = self.to_heads(v)
|
129 |
+
|
130 |
+
q = self.q_norm(q)
|
131 |
+
k = self.k_norm(k)
|
132 |
+
|
133 |
+
# unpack {k,v}_cache
|
134 |
+
if att_cache is not None:
|
135 |
+
if attn_mask is not None:
|
136 |
+
k_cache, v_cache = att_cache.chunk(2, dim=3)
|
137 |
+
k = torch.cat([k, k_cache], dim=2)
|
138 |
+
v = torch.cat([v, v_cache], dim=2)
|
139 |
+
|
140 |
+
else:
|
141 |
+
k_cache, v_cache = att_cache.chunk(2, dim=3)
|
142 |
+
k = torch.cat([k, k_cache], dim=2)
|
143 |
+
v = torch.cat([v, v_cache], dim=2)
|
144 |
+
|
145 |
+
# new {k,v}_cache
|
146 |
+
new_att_cache = torch.cat([k, v], dim=3)
|
147 |
+
# attn_mask = torch.ones((b, 1, t, t1), dtype=torch.bool, device=x.device)
|
148 |
+
if attn_mask is not None:
|
149 |
+
attn_mask = attn_mask.unsqueeze(1)
|
150 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) # (b, nh, t, c)
|
151 |
+
x = x.transpose(1, 2).reshape(b, t, -1)
|
152 |
+
x = self.proj(x)
|
153 |
+
x = self.proj_drop(x)
|
154 |
+
return x, new_att_cache
|
155 |
+
|
156 |
+
|
157 |
+
def modulate(x, shift, scale):
|
158 |
+
return x * (1 + scale) + shift
|
159 |
+
|
160 |
+
|
161 |
+
class TimestepEmbedder(nn.Module):
|
162 |
+
"""
|
163 |
+
Embeds scalar timesteps into vector representations.
|
164 |
+
"""
|
165 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
166 |
+
super().__init__()
|
167 |
+
self.mlp = nn.Sequential(
|
168 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
169 |
+
nn.SiLU(),
|
170 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
171 |
+
)
|
172 |
+
self.frequency_embedding_size = frequency_embedding_size
|
173 |
+
# from SinusoidalPosEmb
|
174 |
+
self.scale = 1000
|
175 |
+
|
176 |
+
@staticmethod
|
177 |
+
def timestep_embedding(t, dim, max_period=10000):
|
178 |
+
"""
|
179 |
+
Create sinusoidal timestep embeddings.
|
180 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
181 |
+
These may be fractional.
|
182 |
+
:param dim: the dimension of the output.
|
183 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
184 |
+
:return: an (N, D) Tensor of positional embeddings.
|
185 |
+
"""
|
186 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
187 |
+
half = dim // 2
|
188 |
+
freqs = torch.exp(
|
189 |
+
-math.log(max_period) * torch.arange(start=0, end=half) / half
|
190 |
+
).to(t)
|
191 |
+
args = t[:, None] * freqs[None]
|
192 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
193 |
+
if dim % 2:
|
194 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
195 |
+
return embedding
|
196 |
+
|
197 |
+
def forward(self, t):
|
198 |
+
t_freq = self.timestep_embedding(t * self.scale, self.frequency_embedding_size)
|
199 |
+
t_emb = self.mlp(t_freq)
|
200 |
+
return t_emb
|
201 |
+
|
202 |
+
|
203 |
+
# Convolution related
|
204 |
+
class Transpose(torch.nn.Module):
|
205 |
+
def __init__(self, dim0: int, dim1: int):
|
206 |
+
super().__init__()
|
207 |
+
self.dim0 = dim0
|
208 |
+
self.dim1 = dim1
|
209 |
+
|
210 |
+
def forward(self, x: torch.Tensor):
|
211 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
212 |
+
return x
|
213 |
+
|
214 |
+
|
215 |
+
class CausalConv1d(torch.nn.Conv1d):
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
in_channels: int,
|
219 |
+
out_channels: int,
|
220 |
+
kernel_size: int,
|
221 |
+
) -> None:
|
222 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size)
|
223 |
+
self.causal_padding = (kernel_size - 1, 0)
|
224 |
+
|
225 |
+
def forward(self, x: torch.Tensor):
|
226 |
+
x = F.pad(x, self.causal_padding)
|
227 |
+
x = super(CausalConv1d, self).forward(x)
|
228 |
+
return x
|
229 |
+
|
230 |
+
def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None):
|
231 |
+
if cnn_cache is None:
|
232 |
+
cnn_cache = x.new_zeros((x.shape[0], self.in_channels, self.causal_padding[0]))
|
233 |
+
x = torch.cat([cnn_cache, x], dim=2)
|
234 |
+
new_cnn_cache = x[..., -self.causal_padding[0]:]
|
235 |
+
x = super(CausalConv1d, self).forward(x)
|
236 |
+
return x, new_cnn_cache
|
237 |
+
|
238 |
+
|
239 |
+
class CausalConvBlock(nn.Module):
|
240 |
+
def __init__(self,
|
241 |
+
in_channels: int,
|
242 |
+
out_channels: int,
|
243 |
+
kernel_size: int = 3,
|
244 |
+
):
|
245 |
+
super().__init__()
|
246 |
+
self.in_channels = in_channels
|
247 |
+
self.out_channels = out_channels
|
248 |
+
self.kernel_size = kernel_size
|
249 |
+
|
250 |
+
self.block = torch.nn.Sequential(
|
251 |
+
# norm
|
252 |
+
# conv1
|
253 |
+
Transpose(1, 2),
|
254 |
+
CausalConv1d(in_channels, out_channels, kernel_size),
|
255 |
+
Transpose(1, 2),
|
256 |
+
# norm & act
|
257 |
+
nn.LayerNorm(out_channels),
|
258 |
+
nn.Mish(),
|
259 |
+
# conv2
|
260 |
+
Transpose(1, 2),
|
261 |
+
CausalConv1d(out_channels, out_channels, kernel_size),
|
262 |
+
Transpose(1, 2),
|
263 |
+
)
|
264 |
+
|
265 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
|
266 |
+
"""
|
267 |
+
Args:
|
268 |
+
x: shape (b, t, c)
|
269 |
+
mask: shape (b, t, 1)
|
270 |
+
"""
|
271 |
+
if mask is not None: x = x * mask
|
272 |
+
x = self.block(x)
|
273 |
+
if mask is not None: x = x * mask
|
274 |
+
return x
|
275 |
+
|
276 |
+
def forward_chunk(self, x: torch.Tensor, cnn_cache: torch.Tensor=None):
|
277 |
+
"""
|
278 |
+
Args:
|
279 |
+
x: shape (b, dt, c)
|
280 |
+
cnn_cache: shape (b, c1+c2, 2)
|
281 |
+
"""
|
282 |
+
if cnn_cache is not None:
|
283 |
+
cnn_cache1, cnn_cache2 = cnn_cache.split((self.in_channels, self.out_channels), dim=1)
|
284 |
+
else:
|
285 |
+
cnn_cache1, cnn_cache2 = None, None
|
286 |
+
x = self.block[0](x)
|
287 |
+
x, new_cnn_cache1 = self.block[1].forward_chunk(x, cnn_cache1)
|
288 |
+
x = self.block[2:6](x)
|
289 |
+
x, new_cnn_cache2 = self.block[6].forward_chunk(x, cnn_cache2)
|
290 |
+
x = self.block[7](x)
|
291 |
+
new_cnn_cache = torch.cat((new_cnn_cache1, new_cnn_cache2), dim=1)
|
292 |
+
return x, new_cnn_cache
|
293 |
+
|
294 |
+
|
295 |
+
class DiTBlock(nn.Module):
|
296 |
+
"""
|
297 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
298 |
+
"""
|
299 |
+
def __init__(self, hidden_size, num_heads, head_dim, mlp_ratio=4.0, **block_kwargs):
|
300 |
+
super().__init__()
|
301 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
302 |
+
self.attn = Attention(hidden_size, num_heads=num_heads, head_dim=head_dim, qkv_bias=True, qk_norm=True, **block_kwargs)
|
303 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
304 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
305 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
306 |
+
self.mlp = MLP(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
307 |
+
self.norm3 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
308 |
+
self.conv = CausalConvBlock(in_channels=hidden_size, out_channels=hidden_size, kernel_size=3)
|
309 |
+
self.adaLN_modulation = nn.Sequential(
|
310 |
+
nn.SiLU(),
|
311 |
+
nn.Linear(hidden_size, 9 * hidden_size, bias=True)
|
312 |
+
)
|
313 |
+
|
314 |
+
def forward(self, x:torch.Tensor, c:torch.Tensor, attn_mask:torch.Tensor):
|
315 |
+
"""Args
|
316 |
+
x: shape (b, t, c)
|
317 |
+
c: shape (b, 1, c)
|
318 |
+
attn_mask: shape (b, t, t), bool type attention mask
|
319 |
+
"""
|
320 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
|
321 |
+
= self.adaLN_modulation(c).chunk(9, dim=-1)
|
322 |
+
# attention
|
323 |
+
x = x + gate_msa * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), attn_mask)
|
324 |
+
# conv
|
325 |
+
x = x + gate_conv * self.conv(modulate(self.norm3(x), shift_conv, scale_conv))
|
326 |
+
# mlp
|
327 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
328 |
+
return x
|
329 |
+
|
330 |
+
def forward_chunk(self, x: torch.Tensor, c: torch.Tensor, cnn_cache: torch.Tensor=None, att_cache: torch.Tensor=None, mask: torch.Tensor=None):
|
331 |
+
"""
|
332 |
+
Args:
|
333 |
+
x: shape (b, dt, c)
|
334 |
+
c: shape (b, 1, c)
|
335 |
+
cnn_cache: shape (b, c1+c2, 2)
|
336 |
+
att_cache: shape (b, nh, t, c * 2)
|
337 |
+
"""
|
338 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_conv, scale_conv, gate_conv \
|
339 |
+
= self.adaLN_modulation(c).chunk(9, dim=-1)
|
340 |
+
# attention
|
341 |
+
x_att, new_att_cache = self.attn.forward_chunk(modulate(self.norm1(x), shift_msa, scale_msa), att_cache, mask)
|
342 |
+
x = x + gate_msa * x_att
|
343 |
+
# conv
|
344 |
+
x_conv, new_cnn_cache = self.conv.forward_chunk(modulate(self.norm3(x), shift_conv, scale_conv), cnn_cache)
|
345 |
+
x = x + gate_conv * x_conv
|
346 |
+
# mlp
|
347 |
+
x = x + gate_mlp * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
|
348 |
+
return x, new_cnn_cache, new_att_cache
|
349 |
+
|
350 |
+
|
351 |
+
class FinalLayer(nn.Module):
|
352 |
+
"""
|
353 |
+
The final layer of DiT.
|
354 |
+
"""
|
355 |
+
def __init__(self, hidden_size, out_channels):
|
356 |
+
super().__init__()
|
357 |
+
self.adaLN_modulation = nn.Sequential(
|
358 |
+
nn.SiLU(),
|
359 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
360 |
+
)
|
361 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
362 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
363 |
+
|
364 |
+
def forward(self, x, c):
|
365 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
366 |
+
x = modulate(self.norm_final(x), shift, scale)
|
367 |
+
x = self.linear(x)
|
368 |
+
return x
|
369 |
+
|
370 |
+
|
371 |
+
class DiT(nn.Module):
|
372 |
+
"""
|
373 |
+
Diffusion model with a Transformer backbone.
|
374 |
+
"""
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
in_channels: int,
|
378 |
+
out_channels: int,
|
379 |
+
mlp_ratio: float = 4.0,
|
380 |
+
depth: int = 28,
|
381 |
+
num_heads: int = 8,
|
382 |
+
head_dim: int = 64,
|
383 |
+
hidden_size: int = 256,
|
384 |
+
):
|
385 |
+
super().__init__()
|
386 |
+
self.in_channels = in_channels
|
387 |
+
self.out_channels = out_channels
|
388 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
389 |
+
|
390 |
+
self.in_proj = nn.Linear(in_channels, hidden_size)
|
391 |
+
|
392 |
+
self.blocks = nn.ModuleList([
|
393 |
+
DiTBlock(hidden_size, num_heads, head_dim, mlp_ratio=mlp_ratio) for _ in range(depth)
|
394 |
+
])
|
395 |
+
self.final_layer = FinalLayer(hidden_size, self.out_channels)
|
396 |
+
|
397 |
+
self.initialize_weights()
|
398 |
+
|
399 |
+
self.enable_cuda_graph = False
|
400 |
+
self.use_cuda_graph = False
|
401 |
+
|
402 |
+
self.graph_chunk = {}
|
403 |
+
self.inference_buffers_chunk = {}
|
404 |
+
self.max_size_chunk = {}
|
405 |
+
|
406 |
+
self.register_buffer('att_cache_buffer', torch.zeros((16, 2, 8, 1000, 128)), persistent=False)
|
407 |
+
self.register_buffer('cnn_cache_buffer', torch.zeros((16, 2, 1024, 2)), persistent=False)
|
408 |
+
|
409 |
+
def initialize_weights(self):
|
410 |
+
# Initialize transformer layers:
|
411 |
+
def _basic_init(module):
|
412 |
+
if isinstance(module, nn.Linear):
|
413 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
414 |
+
if module.bias is not None:
|
415 |
+
nn.init.constant_(module.bias, 0)
|
416 |
+
self.apply(_basic_init)
|
417 |
+
|
418 |
+
# Initialize timestep embedding MLP:
|
419 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
420 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
421 |
+
|
422 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
423 |
+
for block in self.blocks:
|
424 |
+
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
|
425 |
+
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
|
426 |
+
|
427 |
+
# Zero-out output layers:
|
428 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
429 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
430 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
431 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
432 |
+
|
433 |
+
def _init_cuda_graph_chunk(self):
|
434 |
+
# get dtype, device from registered buffer
|
435 |
+
dtype, device = self.cnn_cache_buffer.dtype, self.cnn_cache_buffer.device
|
436 |
+
# init cuda graph for streaming forward
|
437 |
+
with torch.no_grad():
|
438 |
+
for chunk_size in [30, 48, 96]:
|
439 |
+
if chunk_size == 30 or chunk_size == 48:
|
440 |
+
max_size = 500
|
441 |
+
self.max_size_chunk[chunk_size] = max_size
|
442 |
+
else:
|
443 |
+
max_size = 1000
|
444 |
+
self.max_size_chunk[chunk_size] = max_size
|
445 |
+
static_x1 = torch.zeros((2, 320, chunk_size), dtype=dtype, device=device)
|
446 |
+
static_t1 = torch.zeros((2, 1, 512), dtype=dtype, device=device)
|
447 |
+
static_mask1 = torch.ones((2, chunk_size, max_size+chunk_size), dtype=torch.bool, device=device)
|
448 |
+
static_att_cache = torch.zeros((16, 2, 8, max_size, 128), dtype=dtype, device=device)
|
449 |
+
static_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device)
|
450 |
+
static_inputs1 = [
|
451 |
+
static_x1,
|
452 |
+
static_t1,
|
453 |
+
static_mask1,
|
454 |
+
static_cnn_cache,
|
455 |
+
static_att_cache,
|
456 |
+
]
|
457 |
+
static_new_cnn_cache = torch.zeros((16, 2, 1024, 2), dtype=dtype, device=device)
|
458 |
+
static_new_att_cache = torch.zeros((16, 2, 8, max_size+chunk_size, 128), dtype=dtype, device=device)
|
459 |
+
self.blocks_forward_chunk(
|
460 |
+
static_inputs1[0],
|
461 |
+
static_inputs1[1],
|
462 |
+
static_inputs1[2],
|
463 |
+
static_inputs1[3],
|
464 |
+
static_inputs1[4],
|
465 |
+
static_new_cnn_cache,
|
466 |
+
static_new_att_cache)
|
467 |
+
graph_chunk = torch.cuda.CUDAGraph()
|
468 |
+
with torch.cuda.graph(graph_chunk):
|
469 |
+
static_out1 = self.blocks_forward_chunk(static_x1, static_t1, static_mask1, static_cnn_cache, static_att_cache, static_new_cnn_cache, static_new_att_cache)
|
470 |
+
static_outputs1 = [static_out1, static_new_cnn_cache, static_new_att_cache]
|
471 |
+
self.inference_buffers_chunk[chunk_size] = {
|
472 |
+
'static_inputs': static_inputs1,
|
473 |
+
'static_outputs': static_outputs1
|
474 |
+
}
|
475 |
+
self.graph_chunk[chunk_size] = graph_chunk
|
476 |
+
|
477 |
+
def _init_cuda_graph_all(self):
|
478 |
+
self._init_cuda_graph_chunk()
|
479 |
+
self.use_cuda_graph = True
|
480 |
+
print(f"CUDA Graph initialized successfully for chunk decoder")
|
481 |
+
|
482 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None):
|
483 |
+
"""Args:
|
484 |
+
x: shape (b, c, t)
|
485 |
+
mask: shape (b, 1, t)
|
486 |
+
t: shape (b,)
|
487 |
+
spks: shape (b, c)
|
488 |
+
cond: shape (b, c, t)
|
489 |
+
"""
|
490 |
+
# (sfy) chunk training strategy should not be open-sourced
|
491 |
+
|
492 |
+
# time
|
493 |
+
t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)
|
494 |
+
x = pack([x, mu], "b * t")[0]
|
495 |
+
if spks is not None:
|
496 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
497 |
+
x = pack([x, spks], "b * t")[0]
|
498 |
+
if cond is not None:
|
499 |
+
x = pack([x, cond], "b * t")[0]
|
500 |
+
|
501 |
+
return self.blocks_forward(x, t, mask)
|
502 |
+
|
503 |
+
def blocks_forward(self, x, t, mask):
|
504 |
+
x = x.transpose(1, 2)
|
505 |
+
attn_mask = mask.bool()
|
506 |
+
x = self.in_proj(x)
|
507 |
+
for block in self.blocks:
|
508 |
+
x = block(x, t, attn_mask)
|
509 |
+
x = self.final_layer(x, t)
|
510 |
+
x = x.transpose(1, 2)
|
511 |
+
return x
|
512 |
+
|
513 |
+
def forward_chunk(self,
|
514 |
+
x: torch.Tensor,
|
515 |
+
mu: torch.Tensor,
|
516 |
+
t: torch.Tensor,
|
517 |
+
spks: torch.Tensor,
|
518 |
+
cond: torch.Tensor,
|
519 |
+
cnn_cache: torch.Tensor = None,
|
520 |
+
att_cache: torch.Tensor = None,
|
521 |
+
):
|
522 |
+
"""
|
523 |
+
Args:
|
524 |
+
x: shape (b, dt, c)
|
525 |
+
mu: shape (b, dt, c)
|
526 |
+
t: shape (b,)
|
527 |
+
spks: shape (b, c)
|
528 |
+
cond: shape (b, dt, c)
|
529 |
+
cnn_cache: shape (depth, b, c1+c2, 2)
|
530 |
+
att_cache: shape (depth, b, nh, t, c * 2)
|
531 |
+
"""
|
532 |
+
|
533 |
+
# time
|
534 |
+
t = self.t_embedder(t).unsqueeze(1) # (b, 1, c)
|
535 |
+
x = pack([x, mu], "b * t")[0]
|
536 |
+
if spks is not None:
|
537 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
538 |
+
x = pack([x, spks], "b * t")[0]
|
539 |
+
if cond is not None:
|
540 |
+
x = pack([x, cond], "b * t")[0]
|
541 |
+
|
542 |
+
# create fake cache
|
543 |
+
if cnn_cache is None:
|
544 |
+
cnn_cache = [None] * len(self.blocks)
|
545 |
+
if att_cache is None:
|
546 |
+
att_cache = [None] * len(self.blocks)
|
547 |
+
if att_cache[0] is not None:
|
548 |
+
last_att_len = att_cache.shape[3]
|
549 |
+
else:
|
550 |
+
last_att_len = 0
|
551 |
+
chunk_size = x.shape[2]
|
552 |
+
mask = torch.ones(x.shape[0], chunk_size, last_att_len+chunk_size, dtype=torch.bool, device=x.device)
|
553 |
+
if self.use_cuda_graph and att_cache[0] is not None and chunk_size in self.graph_chunk and last_att_len <= self.max_size_chunk[chunk_size]:
|
554 |
+
padded_mask = torch.zeros((2, chunk_size, self.max_size_chunk[chunk_size]+chunk_size), dtype=mask.dtype, device=mask.device)
|
555 |
+
padded_mask[:, :, :mask.shape[-1]] = mask
|
556 |
+
padded_att_cache = torch.zeros((16, 2, 8, self.max_size_chunk[chunk_size], 128), dtype=att_cache.dtype, device=att_cache.device)
|
557 |
+
padded_att_cache[:, :, :, :last_att_len, :] = att_cache
|
558 |
+
self.inference_buffers_chunk[chunk_size]['static_inputs'][0].copy_(x)
|
559 |
+
self.inference_buffers_chunk[chunk_size]['static_inputs'][1].copy_(t)
|
560 |
+
self.inference_buffers_chunk[chunk_size]['static_inputs'][2].copy_(padded_mask)
|
561 |
+
self.inference_buffers_chunk[chunk_size]['static_inputs'][3].copy_(cnn_cache)
|
562 |
+
self.inference_buffers_chunk[chunk_size]['static_inputs'][4].copy_(padded_att_cache)
|
563 |
+
self.graph_chunk[chunk_size].replay()
|
564 |
+
x = self.inference_buffers_chunk[chunk_size]['static_outputs'][0][:, :, :chunk_size]
|
565 |
+
new_cnn_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][1]
|
566 |
+
new_att_cache = self.inference_buffers_chunk[chunk_size]['static_outputs'][2][:, :, :, :chunk_size+last_att_len, :]
|
567 |
+
else:
|
568 |
+
mask = None
|
569 |
+
x = self.blocks_forward_chunk(x, t, mask, cnn_cache, att_cache, self.cnn_cache_buffer, self.att_cache_buffer)
|
570 |
+
new_cnn_cache = self.cnn_cache_buffer
|
571 |
+
new_att_cache = self.att_cache_buffer[:, :, :, :last_att_len+chunk_size, :]
|
572 |
+
|
573 |
+
return x, new_cnn_cache, new_att_cache
|
574 |
+
|
575 |
+
def blocks_forward_chunk(self, x, t, mask, cnn_cache=None, att_cache=None, cnn_cache_buffer=None, att_cache_buffer=None):
|
576 |
+
x = x.transpose(1, 2)
|
577 |
+
x = self.in_proj(x)
|
578 |
+
for b_idx, block in enumerate(self.blocks):
|
579 |
+
x, this_new_cnn_cache, this_new_att_cache \
|
580 |
+
= block.forward_chunk(x, t, cnn_cache[b_idx], att_cache[b_idx], mask)
|
581 |
+
cnn_cache_buffer[b_idx] = this_new_cnn_cache
|
582 |
+
att_cache_buffer[b_idx][:, :, :this_new_att_cache.shape[2], :] = this_new_att_cache
|
583 |
+
x = self.final_layer(x, t)
|
584 |
+
x = x.transpose(1, 2)
|
585 |
+
return x
|
cosyvoice2/flow/flow.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
|
18 |
+
from cosyvoice2.utils.mask import make_pad_mask
|
19 |
+
from cosyvoice2.flow.flow_matching import CausalConditionalCFM
|
20 |
+
from cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
|
21 |
+
|
22 |
+
|
23 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
24 |
+
def __init__(self,
|
25 |
+
input_size: int = 512,
|
26 |
+
output_size: int = 80,
|
27 |
+
spk_embed_dim: int = 192,
|
28 |
+
output_type: str = "mel",
|
29 |
+
vocab_size: int = 5121,
|
30 |
+
encoder: UpsampleConformerEncoderV2 = None,
|
31 |
+
decoder: CausalConditionalCFM = None,
|
32 |
+
input_embedding: torch.nn.Module = None,
|
33 |
+
):
|
34 |
+
super().__init__()
|
35 |
+
self.input_size = input_size
|
36 |
+
self.output_size = output_size
|
37 |
+
self.vocab_size = vocab_size
|
38 |
+
self.output_type = output_type
|
39 |
+
self.pre_lookahead_len = int(encoder.pre_lookahead_layer.pre_lookahead_len)
|
40 |
+
self.up_rate = int(encoder.up_layer.stride)
|
41 |
+
if input_embedding is None:
|
42 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
43 |
+
else:
|
44 |
+
self.input_embedding = input_embedding
|
45 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
46 |
+
self.encoder = encoder
|
47 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
48 |
+
self.decoder = decoder
|
49 |
+
|
50 |
+
# xvec projection with CUDA Graph optimization
|
51 |
+
# 初始化 CUDA Graph 相关变量
|
52 |
+
self.enable_cuda_graph = False
|
53 |
+
self.static_embedding = None
|
54 |
+
self.static_output = None
|
55 |
+
self.graph = None
|
56 |
+
self.embedding_shape = None
|
57 |
+
|
58 |
+
def scatter_cuda_graph(self, enable_cuda_graph: bool):
|
59 |
+
self.enable_cuda_graph = enable_cuda_graph
|
60 |
+
if self.enable_cuda_graph:
|
61 |
+
# self.encoder.scatter_cuda_graph(enable_cuda_graph)
|
62 |
+
self.decoder.scatter_cuda_graph(enable_cuda_graph)
|
63 |
+
|
64 |
+
@torch.inference_mode()
|
65 |
+
def inference(self,
|
66 |
+
token,
|
67 |
+
token_len,
|
68 |
+
prompt_token,
|
69 |
+
prompt_token_len,
|
70 |
+
prompt_feat,
|
71 |
+
prompt_feat_len,
|
72 |
+
embedding,
|
73 |
+
n_timesteps: int = 10,
|
74 |
+
):
|
75 |
+
assert token.shape[0] == 1
|
76 |
+
|
77 |
+
# xvec projection
|
78 |
+
embedding = F.normalize(embedding, dim=1)
|
79 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
80 |
+
|
81 |
+
# concat text and prompt_text
|
82 |
+
token_len = prompt_token_len + token_len
|
83 |
+
token = torch.concat([prompt_token, token], dim=1)
|
84 |
+
|
85 |
+
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
86 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
87 |
+
|
88 |
+
# token encode
|
89 |
+
h, _ = self.encoder.forward(token, token_len)
|
90 |
+
h = self.encoder_proj(h)
|
91 |
+
|
92 |
+
# condition
|
93 |
+
mel_len1 = prompt_feat.shape[1]
|
94 |
+
mel_len2 = h.shape[1] - prompt_feat.shape[1]
|
95 |
+
|
96 |
+
conds = torch.zeros_like(h)
|
97 |
+
conds[:, :mel_len1] = prompt_feat
|
98 |
+
conds = conds.transpose(1, 2).contiguous()
|
99 |
+
|
100 |
+
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
101 |
+
|
102 |
+
feat = self.decoder.forward(
|
103 |
+
mu=h.transpose(1, 2).contiguous(),
|
104 |
+
mask=mask.unsqueeze(1),
|
105 |
+
spks=embedding,
|
106 |
+
cond=conds,
|
107 |
+
n_timesteps=n_timesteps,
|
108 |
+
)
|
109 |
+
|
110 |
+
feat = feat[:, :, mel_len1:]
|
111 |
+
assert feat.shape[2] == mel_len2
|
112 |
+
return feat
|
113 |
+
|
114 |
+
@torch.inference_mode()
|
115 |
+
def setup_cache(self,
|
116 |
+
token: torch.Tensor,
|
117 |
+
mel: torch.Tensor,
|
118 |
+
spk: torch.Tensor,
|
119 |
+
n_timesteps: int = 10,
|
120 |
+
):
|
121 |
+
"""
|
122 |
+
Args:
|
123 |
+
token: shape (b, t), with look ahead tokens
|
124 |
+
mel: shape (b, t, c), groundtruth mel
|
125 |
+
spk: shape (b, 192), speaker embedding
|
126 |
+
Returns:
|
127 |
+
cache: dict {
|
128 |
+
'conformer': {'cnn_cache': xxx, 'att_cache': xxx},
|
129 |
+
'estimator': {'cnn_cache': xxx, 'att_cache': xxx}
|
130 |
+
}
|
131 |
+
"""
|
132 |
+
# check if look ahead token included
|
133 |
+
assert (token.shape[1] - self.pre_lookahead_len) * self.up_rate == mel.shape[1], (token.shape, mel.shape)
|
134 |
+
|
135 |
+
# xvec projection
|
136 |
+
spk = F.normalize(spk, dim=1)
|
137 |
+
spk = self.spk_embed_affine_layer(spk)
|
138 |
+
|
139 |
+
token = self.input_embedding(token)
|
140 |
+
# NOTE encoder.forward_chunk will strip the look ahead part
|
141 |
+
h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
|
142 |
+
xs = token,
|
143 |
+
last_chunk = False,
|
144 |
+
cnn_cache = None,
|
145 |
+
att_cache = None,
|
146 |
+
)
|
147 |
+
h = self.encoder_proj(h)
|
148 |
+
|
149 |
+
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
|
150 |
+
mu = h.transpose(1, 2).contiguous(),
|
151 |
+
spks = spk,
|
152 |
+
cond = mel.transpose(1, 2).contiguous(),
|
153 |
+
n_timesteps = n_timesteps,
|
154 |
+
temperature = 1.0,
|
155 |
+
cnn_cache = None,
|
156 |
+
att_cache = None,
|
157 |
+
)
|
158 |
+
|
159 |
+
cache = {
|
160 |
+
'conformer_cnn_cache': conformer_cnn_cache,
|
161 |
+
'conformer_att_cache': conformer_att_cache,
|
162 |
+
'estimator_cnn_cache': estimator_cnn_cache,
|
163 |
+
'estimator_att_cache': estimator_att_cache,
|
164 |
+
}
|
165 |
+
return cache
|
166 |
+
|
167 |
+
@torch.inference_mode()
|
168 |
+
def inference_chunk(self,
|
169 |
+
token: torch.Tensor,
|
170 |
+
spk: torch.Tensor,
|
171 |
+
cache: dict,
|
172 |
+
last_chunk: bool = False,
|
173 |
+
n_timesteps: int = 10,
|
174 |
+
):
|
175 |
+
"""
|
176 |
+
Args:
|
177 |
+
token: shape (b, t), with look ahead tokens
|
178 |
+
spk: shape (b, 192), speaker embedding
|
179 |
+
cache: dict {
|
180 |
+
'conformer_cnn_cache': xxx,
|
181 |
+
...
|
182 |
+
}
|
183 |
+
"""
|
184 |
+
# unpack cache
|
185 |
+
conformer_cnn_cache = cache['conformer_cnn_cache']
|
186 |
+
conformer_att_cache = cache['conformer_att_cache']
|
187 |
+
estimator_cnn_cache = cache['estimator_cnn_cache']
|
188 |
+
estimator_att_cache = cache['estimator_att_cache']
|
189 |
+
|
190 |
+
# xvec projection
|
191 |
+
spk = F.normalize(spk, dim=1)
|
192 |
+
spk = self.spk_embed_affine_layer(spk)
|
193 |
+
|
194 |
+
token = self.input_embedding(token)
|
195 |
+
# if not the last chunk, h is shorter than xs for a length of lookahead_length * stride (6)
|
196 |
+
h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
|
197 |
+
xs = token,
|
198 |
+
last_chunk = last_chunk,
|
199 |
+
cnn_cache = conformer_cnn_cache,
|
200 |
+
att_cache = conformer_att_cache,
|
201 |
+
)
|
202 |
+
h = self.encoder_proj(h)
|
203 |
+
|
204 |
+
cond = torch.zeros_like(h)
|
205 |
+
# forward estimator
|
206 |
+
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
|
207 |
+
mu = h.transpose(1, 2).contiguous(),
|
208 |
+
spks = spk,
|
209 |
+
cond = cond.transpose(1, 2).contiguous(),
|
210 |
+
n_timesteps = n_timesteps,
|
211 |
+
temperature = 1.0,
|
212 |
+
cnn_cache = estimator_cnn_cache,
|
213 |
+
att_cache = estimator_att_cache,
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
new_cache = {
|
218 |
+
'conformer_cnn_cache': conformer_cnn_cache,
|
219 |
+
'conformer_att_cache': conformer_att_cache,
|
220 |
+
'estimator_cnn_cache': estimator_cnn_cache,
|
221 |
+
'estimator_att_cache': estimator_att_cache,
|
222 |
+
}
|
223 |
+
|
224 |
+
return feat, new_cache
|
225 |
+
|
cosyvoice2/flow/flow_matching.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import List
|
15 |
+
import onnxruntime
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
from cosyvoice2.flow.decoder_dit import DiT
|
20 |
+
from cosyvoice2.utils.mask import make_pad_mask
|
21 |
+
|
22 |
+
|
23 |
+
"""
|
24 |
+
Inference wrapper
|
25 |
+
"""
|
26 |
+
class CausalConditionalCFM(torch.nn.Module):
|
27 |
+
def __init__(self, estimator: DiT, inference_cfg_rate:float=0.7):
|
28 |
+
super().__init__()
|
29 |
+
self.estimator = estimator
|
30 |
+
self.inference_cfg_rate = inference_cfg_rate
|
31 |
+
self.out_channels = estimator.out_channels
|
32 |
+
# a maximum of 600s
|
33 |
+
self.register_buffer('rand_noise', torch.randn([1, self.out_channels, 50 * 600]), persistent=False)
|
34 |
+
|
35 |
+
self.register_buffer('cnn_cache_buffer', torch.zeros(16, 16, 2, 1024, 2), persistent=False)
|
36 |
+
self.register_buffer('att_cache_buffer', torch.zeros(16, 16, 2, 8, 1000, 128), persistent=False)
|
37 |
+
|
38 |
+
def scatter_cuda_graph(self, enable_cuda_graph: bool):
|
39 |
+
if enable_cuda_graph:
|
40 |
+
self.estimator._init_cuda_graph_all()
|
41 |
+
|
42 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond):
|
43 |
+
"""
|
44 |
+
Fixed euler solver for ODEs.
|
45 |
+
Args:
|
46 |
+
x (torch.Tensor): random noise
|
47 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
48 |
+
shape: (n_timesteps + 1,)
|
49 |
+
mu (torch.Tensor): output of encoder
|
50 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
51 |
+
mask (torch.Tensor): output_mask
|
52 |
+
shape: (batch_size, 1, mel_timesteps)
|
53 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
54 |
+
shape: (batch_size, spk_emb_dim)
|
55 |
+
cond: Not used but kept for future purposes
|
56 |
+
"""
|
57 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
58 |
+
t = t.unsqueeze(dim=0)
|
59 |
+
assert self.inference_cfg_rate > 0, 'inference_cfg_rate better > 0'
|
60 |
+
|
61 |
+
# constant during denoising
|
62 |
+
mask_in = torch.cat([mask, mask], dim=0)
|
63 |
+
mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
64 |
+
spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
|
65 |
+
cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
66 |
+
|
67 |
+
for step in range(1, len(t_span)):
|
68 |
+
|
69 |
+
x_in = torch.cat([x, x], dim=0)
|
70 |
+
t_in = torch.cat([t, t], dim=0)
|
71 |
+
|
72 |
+
dphi_dt = self.estimator.forward(
|
73 |
+
x_in,
|
74 |
+
mask_in,
|
75 |
+
mu_in,
|
76 |
+
t_in,
|
77 |
+
spks_in,
|
78 |
+
cond_in,
|
79 |
+
)
|
80 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
81 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
82 |
+
x = x + dt * dphi_dt
|
83 |
+
t = t + dt
|
84 |
+
if step < len(t_span) - 1:
|
85 |
+
dt = t_span[step + 1] - t
|
86 |
+
|
87 |
+
return x
|
88 |
+
|
89 |
+
@torch.inference_mode()
|
90 |
+
def forward(self, mu, mask, spks, cond, n_timesteps=10, temperature=1.0):
|
91 |
+
z = self.rand_noise[:, :, :mu.size(2)] * temperature
|
92 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
93 |
+
# cosine scheduling
|
94 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
95 |
+
return self.solve_euler(z, t_span, mu, mask, spks, cond)
|
96 |
+
|
97 |
+
def solve_euler_chunk(self,
|
98 |
+
x:torch.Tensor,
|
99 |
+
t_span:torch.Tensor,
|
100 |
+
mu:torch.Tensor,
|
101 |
+
spks:torch.Tensor,
|
102 |
+
cond:torch.Tensor,
|
103 |
+
cnn_cache:torch.Tensor=None,
|
104 |
+
att_cache:torch.Tensor=None,
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Fixed euler solver for ODEs.
|
108 |
+
Args:
|
109 |
+
x (torch.Tensor): random noise
|
110 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
111 |
+
shape: (n_timesteps + 1,)
|
112 |
+
mu (torch.Tensor): output of encoder
|
113 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
114 |
+
mask (torch.Tensor): output_mask
|
115 |
+
shape: (batch_size, 1, mel_timesteps)
|
116 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
117 |
+
shape: (batch_size, spk_emb_dim)
|
118 |
+
cond: Not used but kept for future purposes
|
119 |
+
cnn_cache: shape (n_time, depth, b, c1+c2, 2)
|
120 |
+
att_cache: shape (n_time, depth, b, nh, t, c * 2)
|
121 |
+
"""
|
122 |
+
assert self.inference_cfg_rate > 0, 'cfg rate should be > 0'
|
123 |
+
|
124 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
125 |
+
t = t.unsqueeze(dim=0) # (b,)
|
126 |
+
|
127 |
+
# setup initial cache
|
128 |
+
if cnn_cache is None:
|
129 |
+
cnn_cache = [None for _ in range(len(t_span)-1)]
|
130 |
+
if att_cache is None:
|
131 |
+
att_cache = [None for _ in range(len(t_span)-1)]
|
132 |
+
# next chunk's cache at each timestep
|
133 |
+
|
134 |
+
if att_cache[0] is not None:
|
135 |
+
last_att_len = att_cache.shape[4]
|
136 |
+
else:
|
137 |
+
last_att_len = 0
|
138 |
+
|
139 |
+
# constant during denoising
|
140 |
+
mu_in = torch.cat([mu, torch.zeros_like(mu)], dim=0)
|
141 |
+
spks_in = torch.cat([spks, torch.zeros_like(spks)], dim=0)
|
142 |
+
cond_in = torch.cat([cond, torch.zeros_like(cond)], dim=0)
|
143 |
+
for step in range(1, len(t_span)):
|
144 |
+
# torch.cuda.memory._record_memory_history(max_entries=100000)
|
145 |
+
# torch.cuda.memory._record_memory_history(max_entries=100000)
|
146 |
+
this_att_cache = att_cache[step-1]
|
147 |
+
this_cnn_cache = cnn_cache[step-1]
|
148 |
+
|
149 |
+
dphi_dt, this_new_cnn_cache, this_new_att_cache = self.estimator.forward_chunk(
|
150 |
+
x = x.repeat(2, 1, 1),
|
151 |
+
mu = mu_in,
|
152 |
+
t = t.repeat(2),
|
153 |
+
spks = spks_in,
|
154 |
+
cond = cond_in,
|
155 |
+
cnn_cache = this_cnn_cache,
|
156 |
+
att_cache = this_att_cache,
|
157 |
+
)
|
158 |
+
dphi_dt, cfg_dphi_dt = dphi_dt.chunk(2, dim=0)
|
159 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
160 |
+
x = x + dt * dphi_dt
|
161 |
+
t = t + dt
|
162 |
+
if step < len(t_span) - 1:
|
163 |
+
dt = t_span[step + 1] - t
|
164 |
+
|
165 |
+
self.cnn_cache_buffer[step-1] = this_new_cnn_cache
|
166 |
+
self.att_cache_buffer[step-1][:, :, :, :x.shape[2]+last_att_len, :] = this_new_att_cache
|
167 |
+
|
168 |
+
cnn_cache = self.cnn_cache_buffer
|
169 |
+
att_cache = self.att_cache_buffer[:, :, :, :, :x.shape[2]+last_att_len, :]
|
170 |
+
return x, cnn_cache, att_cache
|
171 |
+
|
172 |
+
@torch.inference_mode()
|
173 |
+
def forward_chunk(self,
|
174 |
+
mu:torch.Tensor,
|
175 |
+
spks:torch.Tensor,
|
176 |
+
cond:torch.Tensor,
|
177 |
+
n_timesteps:int=10,
|
178 |
+
temperature:float=1.0,
|
179 |
+
cnn_cache:torch.Tensor=None,
|
180 |
+
att_cache:torch.Tensor=None,
|
181 |
+
):
|
182 |
+
"""
|
183 |
+
Args:
|
184 |
+
mu(torch.Tensor): shape (b, c, t)
|
185 |
+
spks(torch.Tensor): shape (b, 192)
|
186 |
+
cond(torch.Tensor): shape (b, c, t)
|
187 |
+
cnn_cache: shape (n_time, depth, b, c1+c2, 2)
|
188 |
+
att_cache: shape (n_time, depth, b, nh, t, c * 2)
|
189 |
+
"""
|
190 |
+
# get offset from att_cache
|
191 |
+
offset = att_cache.shape[4] if att_cache is not None else 0
|
192 |
+
z = self.rand_noise[:, :, offset:offset+mu.size(2)] * temperature
|
193 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
194 |
+
# cosine scheduling
|
195 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
196 |
+
x, new_cnn_cache, new_att_cache = self.solve_euler_chunk(
|
197 |
+
x=z,
|
198 |
+
t_span=t_span,
|
199 |
+
mu=mu,
|
200 |
+
spks=spks,
|
201 |
+
cond=cond,
|
202 |
+
att_cache=att_cache,
|
203 |
+
cnn_cache=cnn_cache,
|
204 |
+
)
|
205 |
+
return x, new_cnn_cache, new_att_cache
|
cosyvoice2/transformer/__init__.py
ADDED
File without changes
|
cosyvoice2/transformer/attention.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2022 Xingchen Song ([email protected])
|
4 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
"""Multi-Head Attention layer definition."""
|
18 |
+
|
19 |
+
import math
|
20 |
+
from typing import Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
class MultiHeadedAttention(nn.Module):
|
27 |
+
"""Multi-Head Attention layer.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
n_head (int): The number of heads.
|
31 |
+
n_feat (int): The number of features.
|
32 |
+
dropout_rate (float): Dropout rate.
|
33 |
+
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
n_head: int,
|
38 |
+
n_feat: int,
|
39 |
+
dropout_rate: float,
|
40 |
+
key_bias: bool = True):
|
41 |
+
"""Construct an MultiHeadedAttention object."""
|
42 |
+
super().__init__()
|
43 |
+
assert n_feat % n_head == 0
|
44 |
+
# We assume d_v always equals d_k
|
45 |
+
self.d_k = n_feat // n_head
|
46 |
+
self.h = n_head
|
47 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
48 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
49 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
50 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
51 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
52 |
+
|
53 |
+
def forward_qkv(
|
54 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
55 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
56 |
+
"""Transform query, key and value.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
60 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
61 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
torch.Tensor: Transformed query tensor, size
|
65 |
+
(#batch, n_head, time1, d_k).
|
66 |
+
torch.Tensor: Transformed key tensor, size
|
67 |
+
(#batch, n_head, time2, d_k).
|
68 |
+
torch.Tensor: Transformed value tensor, size
|
69 |
+
(#batch, n_head, time2, d_k).
|
70 |
+
|
71 |
+
"""
|
72 |
+
n_batch = query.size(0)
|
73 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
74 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
75 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
76 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
77 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
78 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
79 |
+
|
80 |
+
return q, k, v
|
81 |
+
|
82 |
+
def forward_attention(
|
83 |
+
self,
|
84 |
+
value: torch.Tensor,
|
85 |
+
scores: torch.Tensor,
|
86 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
87 |
+
) -> torch.Tensor:
|
88 |
+
"""Compute attention context vector.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
value (torch.Tensor): Transformed value, size
|
92 |
+
(#batch, n_head, time2, d_k).
|
93 |
+
scores (torch.Tensor): Attention score, size
|
94 |
+
(#batch, n_head, time1, time2).
|
95 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
96 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
100 |
+
weighted by the attention score (#batch, time1, time2).
|
101 |
+
|
102 |
+
"""
|
103 |
+
n_batch = value.size(0)
|
104 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
105 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
106 |
+
# 1st chunk to ease the onnx export.]
|
107 |
+
# 2. pytorch training
|
108 |
+
if mask.size(2) > 0: # time2 > 0
|
109 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
110 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
111 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
112 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
113 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
114 |
+
mask, 0.0) # (batch, head, time1, time2)
|
115 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
116 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
117 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
118 |
+
else:
|
119 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
120 |
+
|
121 |
+
p_attn = self.dropout(attn)
|
122 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
123 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
124 |
+
self.h * self.d_k)
|
125 |
+
) # (batch, time1, d_model)
|
126 |
+
|
127 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
128 |
+
|
129 |
+
def forward(
|
130 |
+
self,
|
131 |
+
query: torch.Tensor,
|
132 |
+
key: torch.Tensor,
|
133 |
+
value: torch.Tensor,
|
134 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
135 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
136 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
137 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
138 |
+
"""Compute scaled dot product attention.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
142 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
143 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
144 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
145 |
+
(#batch, time1, time2).
|
146 |
+
1.When applying cross attention between decoder and encoder,
|
147 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
148 |
+
2.When applying self attention of encoder,
|
149 |
+
the mask is in (#batch, T, T) shape.
|
150 |
+
3.When applying self attention of decoder,
|
151 |
+
the mask is in (#batch, L, L) shape.
|
152 |
+
4.If the different position in decoder see different block
|
153 |
+
of the encoder, such as Mocha, the passed in mask could be
|
154 |
+
in (#batch, L, T) shape. But there is no such case in current
|
155 |
+
CosyVoice.
|
156 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
157 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
158 |
+
and `head * d_k == size`
|
159 |
+
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
163 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
164 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
165 |
+
and `head * d_k == size`
|
166 |
+
|
167 |
+
"""
|
168 |
+
q, k, v = self.forward_qkv(query, key, value)
|
169 |
+
|
170 |
+
# NOTE(xcsong):
|
171 |
+
# when export onnx model, for 1st chunk, we feed
|
172 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
173 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
174 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
175 |
+
# and we will always do splitting and
|
176 |
+
# concatnation(this will simplify onnx export). Note that
|
177 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
178 |
+
# when export jit model, for 1st chunk, we always feed
|
179 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
180 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
181 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
182 |
+
# >>> c = torch.cat((a, b), dim=2)
|
183 |
+
# >>> torch.equal(b, c) # True
|
184 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
185 |
+
# >>> torch.equal(d[0], d[1]) # True
|
186 |
+
if cache.size(0) > 0:
|
187 |
+
key_cache, value_cache = torch.split(cache,
|
188 |
+
cache.size(-1) // 2,
|
189 |
+
dim=-1)
|
190 |
+
k = torch.cat([key_cache, k], dim=2)
|
191 |
+
v = torch.cat([value_cache, v], dim=2)
|
192 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
193 |
+
# non-trivial to calculate `next_cache_start` here.
|
194 |
+
new_cache = torch.cat((k, v), dim=-1)
|
195 |
+
|
196 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
197 |
+
return self.forward_attention(v, scores, mask), new_cache
|
198 |
+
|
199 |
+
|
200 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
201 |
+
"""Multi-Head Attention layer with relative position encoding.
|
202 |
+
Paper: https://arxiv.org/abs/1901.02860
|
203 |
+
Args:
|
204 |
+
n_head (int): The number of heads.
|
205 |
+
n_feat (int): The number of features.
|
206 |
+
dropout_rate (float): Dropout rate.
|
207 |
+
"""
|
208 |
+
|
209 |
+
def __init__(self,
|
210 |
+
n_head: int,
|
211 |
+
n_feat: int,
|
212 |
+
dropout_rate: float,
|
213 |
+
key_bias: bool = True):
|
214 |
+
"""Construct an RelPositionMultiHeadedAttention object."""
|
215 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
216 |
+
# linear transformation for positional encoding
|
217 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
218 |
+
# these two learnable bias are used in matrix c and matrix d
|
219 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
220 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
221 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
222 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
223 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
224 |
+
|
225 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
226 |
+
"""Compute relative positional encoding.
|
227 |
+
|
228 |
+
Args:
|
229 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
230 |
+
time1 means the length of query vector.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
torch.Tensor: Output tensor.
|
234 |
+
|
235 |
+
"""
|
236 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
237 |
+
device=x.device,
|
238 |
+
dtype=x.dtype)
|
239 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
240 |
+
|
241 |
+
x_padded = x_padded.view(x.size()[0],
|
242 |
+
x.size()[1],
|
243 |
+
x.size(3) + 1, x.size(2))
|
244 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
245 |
+
:, :, :, : x.size(-1) // 2 + 1
|
246 |
+
] # only keep the positions from 0 to time2
|
247 |
+
return x
|
248 |
+
|
249 |
+
def forward(
|
250 |
+
self,
|
251 |
+
query: torch.Tensor,
|
252 |
+
key: torch.Tensor,
|
253 |
+
value: torch.Tensor,
|
254 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
255 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
256 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
257 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
258 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
259 |
+
Args:
|
260 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
261 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
262 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
263 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
264 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
265 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
266 |
+
(#batch, time2, size).
|
267 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
268 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
269 |
+
and `head * d_k == size`
|
270 |
+
Returns:
|
271 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
272 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
273 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
274 |
+
and `head * d_k == size`
|
275 |
+
"""
|
276 |
+
q, k, v = self.forward_qkv(query, key, value)
|
277 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
278 |
+
|
279 |
+
# NOTE(xcsong):
|
280 |
+
# when export onnx model, for 1st chunk, we feed
|
281 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
282 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
283 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
284 |
+
# and we will always do splitting and
|
285 |
+
# concatnation(this will simplify onnx export). Note that
|
286 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
287 |
+
# when export jit model, for 1st chunk, we always feed
|
288 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
289 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
290 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
291 |
+
# >>> c = torch.cat((a, b), dim=2)
|
292 |
+
# >>> torch.equal(b, c) # True
|
293 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
294 |
+
# >>> torch.equal(d[0], d[1]) # True
|
295 |
+
if cache is not None and cache.size(0) > 0:
|
296 |
+
key_cache, value_cache = torch.split(cache, cache.size(-1) // 2, dim=-1)
|
297 |
+
k = torch.cat([key_cache, k], dim=2)
|
298 |
+
v = torch.cat([value_cache, v], dim=2)
|
299 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
300 |
+
# non-trivial to calculate `next_cache_start` here.
|
301 |
+
new_cache = torch.cat((k, v), dim=-1)
|
302 |
+
|
303 |
+
n_batch_pos = pos_emb.size(0)
|
304 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
305 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
306 |
+
|
307 |
+
# (batch, head, time1, d_k)
|
308 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
309 |
+
# (batch, head, time1, d_k)
|
310 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
311 |
+
|
312 |
+
# compute attention score
|
313 |
+
# first compute matrix a and matrix c
|
314 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
315 |
+
# (batch, head, time1, time2)
|
316 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
317 |
+
|
318 |
+
# compute matrix b and matrix d
|
319 |
+
# (batch, head, time1, time2)
|
320 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
321 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
322 |
+
if matrix_ac.shape != matrix_bd.shape:
|
323 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
324 |
+
|
325 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
326 |
+
self.d_k) # (batch, head, time1, time2)
|
327 |
+
|
328 |
+
return self.forward_attention(v, scores, mask), new_cache
|
cosyvoice2/transformer/embedding.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Positonal Encoding Module."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
from typing import Tuple, Union
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn.functional as F
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
|
26 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
27 |
+
"""Relative positional encoding module (new implementation).
|
28 |
+
|
29 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
30 |
+
|
31 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
32 |
+
|
33 |
+
Args:
|
34 |
+
d_model (int): Embedding dimension.
|
35 |
+
dropout_rate (float): Dropout rate.
|
36 |
+
max_len (int): Maximum input length.
|
37 |
+
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
41 |
+
"""Construct an PositionalEncoding object."""
|
42 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
43 |
+
self.d_model = d_model
|
44 |
+
self.xscale = math.sqrt(self.d_model)
|
45 |
+
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
46 |
+
self.pe = None
|
47 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
48 |
+
|
49 |
+
def extend_pe(self, x: torch.Tensor):
|
50 |
+
"""Reset the positional encodings."""
|
51 |
+
if self.pe is not None:
|
52 |
+
# self.pe contains both positive and negative parts
|
53 |
+
# the length of self.pe is 2 * input_len - 1
|
54 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
55 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
56 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
57 |
+
return
|
58 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
59 |
+
# position of key vector. We use position relative positions when keys
|
60 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
61 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
62 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
63 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
64 |
+
div_term = torch.exp(
|
65 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
66 |
+
* -(math.log(10000.0) / self.d_model)
|
67 |
+
)
|
68 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
69 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
70 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
71 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
72 |
+
|
73 |
+
# Reserve the order of positive indices and concat both positive and
|
74 |
+
# negative indices. This is used to support the shifting trick
|
75 |
+
# as in https://arxiv.org/abs/1901.02860
|
76 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
77 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
78 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
79 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
82 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
83 |
+
"""Add positional encoding.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
90 |
+
|
91 |
+
"""
|
92 |
+
self.extend_pe(x)
|
93 |
+
x = x * self.xscale
|
94 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
95 |
+
return self.dropout(x), self.dropout(pos_emb)
|
96 |
+
|
97 |
+
def position_encoding(self,
|
98 |
+
offset: Union[int, torch.Tensor],
|
99 |
+
size: int) -> torch.Tensor:
|
100 |
+
""" For getting encoding in a streaming fashion
|
101 |
+
|
102 |
+
Attention!!!!!
|
103 |
+
we apply dropout only once at the whole utterance level in a none
|
104 |
+
streaming way, but will call this function several times with
|
105 |
+
increasing input size in a streaming scenario, so the dropout will
|
106 |
+
be applied several times.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
offset (int or torch.tensor): start offset
|
110 |
+
size (int): required size of position encoding
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
torch.Tensor: Corresponding encoding
|
114 |
+
"""
|
115 |
+
pos_emb = self.pe[
|
116 |
+
:,
|
117 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
118 |
+
]
|
119 |
+
return pos_emb
|
cosyvoice2/transformer/encoder_layer.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Encoder self-attention layer definition."""
|
17 |
+
|
18 |
+
from typing import Optional, Tuple
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
|
23 |
+
|
24 |
+
class ConformerEncoderLayer(nn.Module):
|
25 |
+
"""Encoder layer module.
|
26 |
+
Args:
|
27 |
+
size (int): Input dimension.
|
28 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
29 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
30 |
+
instance can be used as the argument.
|
31 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
32 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
33 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
34 |
+
instance.
|
35 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
36 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
37 |
+
`ConvlutionModule` instance can be used as the argument.
|
38 |
+
dropout_rate (float): Dropout rate.
|
39 |
+
normalize_before (bool):
|
40 |
+
True: use layer_norm before each sub-block.
|
41 |
+
False: use layer_norm after each sub-block.
|
42 |
+
enable_cuda_graph (bool): Control whether to enable CUDA Graph.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
size: int,
|
48 |
+
self_attn: torch.nn.Module,
|
49 |
+
feed_forward: Optional[nn.Module] = None,
|
50 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
51 |
+
conv_module: Optional[nn.Module] = None,
|
52 |
+
dropout_rate: float = 0.1,
|
53 |
+
normalize_before: bool = True,
|
54 |
+
):
|
55 |
+
"""Construct an EncoderLayer object."""
|
56 |
+
super().__init__()
|
57 |
+
self.self_attn = self_attn
|
58 |
+
self.feed_forward = feed_forward
|
59 |
+
self.feed_forward_macaron = feed_forward_macaron
|
60 |
+
self.conv_module = conv_module
|
61 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
62 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
63 |
+
if feed_forward_macaron is not None:
|
64 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
65 |
+
self.ff_scale = 0.5
|
66 |
+
else:
|
67 |
+
self.ff_scale = 1.0
|
68 |
+
if self.conv_module is not None:
|
69 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
70 |
+
self.norm_final = nn.LayerNorm(
|
71 |
+
size, eps=1e-12) # for the final output of the block
|
72 |
+
self.dropout = nn.Dropout(dropout_rate)
|
73 |
+
self.size = size
|
74 |
+
self.normalize_before = normalize_before
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
x: torch.Tensor,
|
79 |
+
mask: torch.Tensor,
|
80 |
+
pos_emb: torch.Tensor,
|
81 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
82 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
83 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
84 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
85 |
+
"""Compute encoded features.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
x (torch.Tensor): (#batch, time, size)
|
89 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
90 |
+
(0, 0, 0) means fake mask.
|
91 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
92 |
+
for ConformerEncoderLayer.
|
93 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
94 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
95 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
96 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
97 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
98 |
+
(#batch=1, size, cache_t2)
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
101 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
102 |
+
torch.Tensor: att_cache tensor,
|
103 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
104 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
105 |
+
"""
|
106 |
+
return self._forward_impl(x, mask, pos_emb, mask_pad, att_cache, cnn_cache)
|
107 |
+
|
108 |
+
def _forward_impl(
|
109 |
+
self,
|
110 |
+
x: torch.Tensor,
|
111 |
+
mask: torch.Tensor,
|
112 |
+
pos_emb: torch.Tensor,
|
113 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
114 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
115 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
116 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
117 |
+
"""原始的前向传播实现"""
|
118 |
+
# whether to use macaron style
|
119 |
+
if self.feed_forward_macaron is not None:
|
120 |
+
residual = x
|
121 |
+
if self.normalize_before:
|
122 |
+
x = self.norm_ff_macaron(x)
|
123 |
+
x = residual + self.ff_scale * self.dropout(
|
124 |
+
self.feed_forward_macaron(x))
|
125 |
+
if not self.normalize_before:
|
126 |
+
x = self.norm_ff_macaron(x)
|
127 |
+
|
128 |
+
# multi-headed self-attention module
|
129 |
+
residual = x
|
130 |
+
if self.normalize_before:
|
131 |
+
x = self.norm_mha(x)
|
132 |
+
# att_cache: (b, head, cache_t, d_k*2)
|
133 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
134 |
+
att_cache)
|
135 |
+
x = residual + self.dropout(x_att)
|
136 |
+
if not self.normalize_before:
|
137 |
+
x = self.norm_mha(x)
|
138 |
+
|
139 |
+
# convolution module
|
140 |
+
# Fake new cnn cache here, and then change it in conv_module
|
141 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
142 |
+
if self.conv_module is not None:
|
143 |
+
residual = x
|
144 |
+
if self.normalize_before:
|
145 |
+
x = self.norm_conv(x)
|
146 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
147 |
+
x = residual + self.dropout(x)
|
148 |
+
|
149 |
+
if not self.normalize_before:
|
150 |
+
x = self.norm_conv(x)
|
151 |
+
|
152 |
+
# feed forward module
|
153 |
+
residual = x
|
154 |
+
if self.normalize_before:
|
155 |
+
x = self.norm_ff(x)
|
156 |
+
|
157 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
158 |
+
if not self.normalize_before:
|
159 |
+
x = self.norm_ff(x)
|
160 |
+
|
161 |
+
if self.conv_module is not None:
|
162 |
+
x = self.norm_final(x)
|
163 |
+
return x, mask, new_att_cache, new_cnn_cache
|
cosyvoice2/transformer/positionwise_feed_forward.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Positionwise feed forward layer definition."""
|
16 |
+
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
21 |
+
"""Positionwise feed forward layer.
|
22 |
+
|
23 |
+
FeedForward are appied on each position of the sequence.
|
24 |
+
The output dim is same with the input dim.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
idim (int): Input dimenstion.
|
28 |
+
hidden_units (int): The number of hidden units.
|
29 |
+
dropout_rate (float): Dropout rate.
|
30 |
+
activation (torch.nn.Module): Activation function
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
idim: int,
|
36 |
+
hidden_units: int,
|
37 |
+
dropout_rate: float,
|
38 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
39 |
+
):
|
40 |
+
"""Construct a PositionwiseFeedForward object."""
|
41 |
+
super(PositionwiseFeedForward, self).__init__()
|
42 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
43 |
+
self.activation = activation
|
44 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
45 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
46 |
+
|
47 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""Forward function.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
xs: input tensor (B, L, D)
|
52 |
+
Returns:
|
53 |
+
output tensor, (B, L, D)
|
54 |
+
"""
|
55 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
56 |
+
|
cosyvoice2/transformer/subsampling.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Subsampling layer definition."""
|
17 |
+
|
18 |
+
from typing import Tuple, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
class BaseSubsampling(torch.nn.Module):
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
super().__init__()
|
27 |
+
self.right_context = 0
|
28 |
+
self.subsampling_rate = 1
|
29 |
+
|
30 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
31 |
+
size: int) -> torch.Tensor:
|
32 |
+
return self.pos_enc.position_encoding(offset, size)
|
33 |
+
|
34 |
+
|
35 |
+
class LinearNoSubsampling(BaseSubsampling):
|
36 |
+
"""Linear transform the input without subsampling
|
37 |
+
|
38 |
+
Args:
|
39 |
+
idim (int): Input dimension.
|
40 |
+
odim (int): Output dimension.
|
41 |
+
dropout_rate (float): Dropout rate.
|
42 |
+
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
46 |
+
pos_enc_class: torch.nn.Module):
|
47 |
+
"""Construct an linear object."""
|
48 |
+
super().__init__()
|
49 |
+
self.out = torch.nn.Sequential(
|
50 |
+
torch.nn.Linear(idim, odim),
|
51 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
52 |
+
torch.nn.Dropout(dropout_rate),
|
53 |
+
)
|
54 |
+
self.pos_enc = pos_enc_class
|
55 |
+
self.right_context = 0
|
56 |
+
self.subsampling_rate = 1
|
57 |
+
|
58 |
+
def forward(
|
59 |
+
self,
|
60 |
+
x: torch.Tensor,
|
61 |
+
x_mask: torch.Tensor,
|
62 |
+
offset: Union[int, torch.Tensor] = 0
|
63 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
64 |
+
"""Input x.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
68 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
72 |
+
where time' = time .
|
73 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
74 |
+
where time' = time .
|
75 |
+
|
76 |
+
"""
|
77 |
+
x = self.out(x)
|
78 |
+
x, pos_emb = self.pos_enc(x, offset)
|
79 |
+
return x, pos_emb, x_mask
|
cosyvoice2/transformer/upsample_encoder_v2.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
2 |
+
# 2022 Xingchen Song ([email protected])
|
3 |
+
# 2024 Alibaba Inc (Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
17 |
+
"""Encoder definition."""
|
18 |
+
from typing import Tuple, List
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
from torch.nn import functional as F
|
23 |
+
|
24 |
+
from cosyvoice2.transformer.encoder_layer import ConformerEncoderLayer
|
25 |
+
from cosyvoice2.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
26 |
+
from cosyvoice2.utils.class_utils import (
|
27 |
+
COSYVOICE_EMB_CLASSES,
|
28 |
+
COSYVOICE_SUBSAMPLE_CLASSES,
|
29 |
+
COSYVOICE_ATTENTION_CLASSES,
|
30 |
+
COSYVOICE_ACTIVATION_CLASSES,
|
31 |
+
)
|
32 |
+
from cosyvoice2.utils.mask import (
|
33 |
+
make_pad_mask,
|
34 |
+
)
|
35 |
+
|
36 |
+
import torch._dynamo
|
37 |
+
torch._dynamo.config.suppress_errors = True
|
38 |
+
torch._dynamo.config.cache_size_limit = 128
|
39 |
+
|
40 |
+
class Upsample1D(nn.Module):
|
41 |
+
"""A 1D upsampling layer with an optional convolution.
|
42 |
+
|
43 |
+
Parameters:
|
44 |
+
channels (`int`):
|
45 |
+
number of channels in the inputs and outputs.
|
46 |
+
use_conv (`bool`, default `False`):
|
47 |
+
option to use a convolution.
|
48 |
+
use_conv_transpose (`bool`, default `False`):
|
49 |
+
option to use a convolution transpose.
|
50 |
+
out_channels (`int`, optional):
|
51 |
+
number of output channels. Defaults to `channels`.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(self, channels: int, out_channels: int, stride: int = 2, scale_factor: float = None):
|
55 |
+
super().__init__()
|
56 |
+
self.channels = channels
|
57 |
+
self.out_channels = out_channels
|
58 |
+
self.stride = stride
|
59 |
+
# In this mode, first repeat interpolate, than conv with stride=1
|
60 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
61 |
+
self.scale_factor = float(self.stride) if scale_factor is None else float(scale_factor)
|
62 |
+
|
63 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
|
64 |
+
outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest")
|
65 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
66 |
+
outputs = self.conv(outputs)
|
67 |
+
return outputs, input_lengths * self.stride
|
68 |
+
|
69 |
+
def forward_chunk(self, inputs: torch.Tensor, input_lengths: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0))):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
inputs(torch.Tensor): shape (b, c, t)
|
73 |
+
input_length(torch.Tensor): shape (b), can be None
|
74 |
+
cache(torch.Tensor): shape (b, c, cache_t), where cache_t = stride * 2
|
75 |
+
"""
|
76 |
+
outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest")
|
77 |
+
|
78 |
+
if cache is None:
|
79 |
+
cache = inputs.new_zeros(inputs.shape[0], inputs.shape[1], self.stride * 2)
|
80 |
+
outputs = torch.cat([cache, outputs], dim=2)
|
81 |
+
new_cache = outputs[..., -self.stride*2:]
|
82 |
+
outputs = self.conv(outputs)
|
83 |
+
|
84 |
+
if input_lengths is not None:
|
85 |
+
input_lengths = input_lengths * self.stride
|
86 |
+
return outputs, input_lengths, new_cache
|
87 |
+
|
88 |
+
|
89 |
+
class PreLookaheadLayer(nn.Module):
|
90 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
91 |
+
super().__init__()
|
92 |
+
self.channels = channels
|
93 |
+
self.pre_lookahead_len = pre_lookahead_len
|
94 |
+
self.conv1 = nn.Conv1d(
|
95 |
+
channels, channels,
|
96 |
+
kernel_size=pre_lookahead_len + 1,
|
97 |
+
stride=1, padding=0,
|
98 |
+
)
|
99 |
+
self.conv2 = nn.Conv1d(
|
100 |
+
channels, channels,
|
101 |
+
kernel_size=3, stride=1, padding=0,
|
102 |
+
)
|
103 |
+
|
104 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
105 |
+
"""
|
106 |
+
inputs: (batch_size, seq_len, channels)
|
107 |
+
"""
|
108 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
109 |
+
# look ahead
|
110 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
111 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
112 |
+
# outputs
|
113 |
+
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
|
114 |
+
outputs = self.conv2(outputs)
|
115 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
116 |
+
|
117 |
+
# residual connection
|
118 |
+
outputs = outputs + inputs
|
119 |
+
return outputs
|
120 |
+
|
121 |
+
def forward_chunk(self, inputs: torch.Tensor, cache: torch.Tensor = None):
|
122 |
+
"""
|
123 |
+
Args:
|
124 |
+
inputs(torch.Tensor): shape (b, t, c)
|
125 |
+
cache(torch.Tensor): shape (b, c, cache_t=2), c = channels
|
126 |
+
"""
|
127 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
128 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
129 |
+
# the length of outputs is input length - pre_lookahead_len
|
130 |
+
if cache is None:
|
131 |
+
cache = outputs.new_zeros(outputs.shape[0], outputs.shape[1], 2)
|
132 |
+
# NOTE
|
133 |
+
new_cache = outputs[..., -2:]
|
134 |
+
outputs = torch.cat([cache, outputs], dim=2)
|
135 |
+
outputs = self.conv2(outputs)
|
136 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
137 |
+
# residual connection
|
138 |
+
outputs = outputs + inputs[:, :-self.pre_lookahead_len]
|
139 |
+
return outputs, new_cache
|
140 |
+
|
141 |
+
|
142 |
+
"""Customize each sample's chunk attention mask
|
143 |
+
"""
|
144 |
+
class UpsampleConformerEncoderV2(torch.nn.Module):
|
145 |
+
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
# input & output
|
149 |
+
input_size: int,
|
150 |
+
output_size: int = 256,
|
151 |
+
input_layer: str = "linear",
|
152 |
+
pre_lookahead_len: int = 3,
|
153 |
+
# size
|
154 |
+
num_blocks: int = 6,
|
155 |
+
num_up_blocks: int = 4,
|
156 |
+
# upsampling
|
157 |
+
up_stride: int = 2,
|
158 |
+
up_scale_factor: float = 2,
|
159 |
+
# attention
|
160 |
+
attention_heads: int = 4,
|
161 |
+
pos_enc_layer_type: str = "rel_pos_espnet",
|
162 |
+
selfattention_layer_type: str = "rel_selfattn",
|
163 |
+
key_bias: bool = True,
|
164 |
+
# mlp
|
165 |
+
linear_units: int = 2048,
|
166 |
+
# dropouts
|
167 |
+
dropout_rate: float = 0.1,
|
168 |
+
positional_dropout_rate: float = 0.1,
|
169 |
+
attention_dropout_rate: float = 0.0,
|
170 |
+
# other
|
171 |
+
normalize_before: bool = True,
|
172 |
+
activation_type: str = "swish",
|
173 |
+
**kwargs,
|
174 |
+
):
|
175 |
+
super().__init__()
|
176 |
+
self._output_size = output_size
|
177 |
+
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
178 |
+
input_size,
|
179 |
+
output_size,
|
180 |
+
dropout_rate,
|
181 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
|
182 |
+
output_size,
|
183 |
+
positional_dropout_rate
|
184 |
+
),
|
185 |
+
)
|
186 |
+
|
187 |
+
self.normalize_before = normalize_before
|
188 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
189 |
+
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
190 |
+
# self-attention module definition
|
191 |
+
encoder_selfattn_layer_args = (
|
192 |
+
attention_heads,
|
193 |
+
output_size,
|
194 |
+
attention_dropout_rate,
|
195 |
+
key_bias,
|
196 |
+
)
|
197 |
+
# feed-forward module definition
|
198 |
+
positionwise_layer_args = (
|
199 |
+
output_size,
|
200 |
+
linear_units,
|
201 |
+
dropout_rate,
|
202 |
+
activation,
|
203 |
+
)
|
204 |
+
self.pre_lookahead_layer = PreLookaheadLayer(
|
205 |
+
channels=output_size,
|
206 |
+
pre_lookahead_len=pre_lookahead_len
|
207 |
+
)
|
208 |
+
self.encoders = torch.nn.ModuleList([
|
209 |
+
ConformerEncoderLayer(
|
210 |
+
output_size,
|
211 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
212 |
+
*encoder_selfattn_layer_args
|
213 |
+
),
|
214 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
215 |
+
None,
|
216 |
+
None,
|
217 |
+
dropout_rate,
|
218 |
+
normalize_before,
|
219 |
+
) for _ in range(num_blocks)
|
220 |
+
])
|
221 |
+
self.up_layer = Upsample1D(
|
222 |
+
channels=output_size,
|
223 |
+
out_channels=output_size,
|
224 |
+
stride=up_stride,
|
225 |
+
scale_factor=up_scale_factor
|
226 |
+
)
|
227 |
+
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
228 |
+
input_size,
|
229 |
+
output_size,
|
230 |
+
dropout_rate,
|
231 |
+
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
|
232 |
+
output_size,
|
233 |
+
positional_dropout_rate
|
234 |
+
),
|
235 |
+
)
|
236 |
+
self.up_encoders = torch.nn.ModuleList([
|
237 |
+
ConformerEncoderLayer(
|
238 |
+
output_size,
|
239 |
+
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
240 |
+
*encoder_selfattn_layer_args
|
241 |
+
),
|
242 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
243 |
+
None,
|
244 |
+
None,
|
245 |
+
dropout_rate,
|
246 |
+
normalize_before,
|
247 |
+
) for _ in range(num_up_blocks)
|
248 |
+
])
|
249 |
+
|
250 |
+
self.enable_cuda_graph = False
|
251 |
+
self.use_cuda_graph = False
|
252 |
+
self.graph_encoder = {}
|
253 |
+
self.graph_up_encoder = {}
|
254 |
+
self.inference_buffers_encoder = {}
|
255 |
+
self.inference_buffers_up_encoder = {}
|
256 |
+
self.max_static_time = 1500
|
257 |
+
|
258 |
+
# FIXME(sfy) revert hard-coded bfloat16
|
259 |
+
# this method is skipped in CausalMaskedDiffWithXvec.scatter_cuda_graph
|
260 |
+
def scatter_cuda_graph(self, enable_cuda_graph: bool):
|
261 |
+
self.enable_cuda_graph = enable_cuda_graph
|
262 |
+
if self.enable_cuda_graph:
|
263 |
+
self._init_cuda_graph()
|
264 |
+
|
265 |
+
def _init_cuda_graph(self):
|
266 |
+
"""初始化 CUDA Graph"""
|
267 |
+
|
268 |
+
for l in range(100, 1500, 10):
|
269 |
+
static_x = torch.zeros((1, l, 512),
|
270 |
+
dtype=torch.float32, device=torch.device('cuda'))
|
271 |
+
static_mask = torch.ones((1, 1, l),
|
272 |
+
dtype=torch.bool, device=torch.device('cuda'))
|
273 |
+
static_pos_emb = torch.zeros((1, 2*l-1, 512),
|
274 |
+
dtype=torch.float32, device=torch.device('cuda'))
|
275 |
+
|
276 |
+
static_inputs = [
|
277 |
+
static_x,
|
278 |
+
static_mask,
|
279 |
+
static_pos_emb,
|
280 |
+
]
|
281 |
+
|
282 |
+
self._forward_impl_encoder(
|
283 |
+
static_inputs[0],
|
284 |
+
static_inputs[1],
|
285 |
+
static_inputs[2],
|
286 |
+
)
|
287 |
+
graph = torch.cuda.CUDAGraph()
|
288 |
+
with torch.no_grad():
|
289 |
+
with torch.cuda.graph(graph):
|
290 |
+
static_out_x = self._forward_impl_encoder(
|
291 |
+
static_inputs[0],
|
292 |
+
static_inputs[1],
|
293 |
+
static_inputs[2]
|
294 |
+
)
|
295 |
+
self.graph_encoder[l] = graph
|
296 |
+
static_outputs = [
|
297 |
+
static_out_x,
|
298 |
+
]
|
299 |
+
self.inference_buffers_encoder[l] = {
|
300 |
+
'static_inputs': static_inputs,
|
301 |
+
'static_outputs': static_outputs
|
302 |
+
}
|
303 |
+
|
304 |
+
for l in range(100, 1500, 10):
|
305 |
+
static_x = torch.zeros((1, l, 512),
|
306 |
+
dtype=torch.float32, device=torch.device('cuda'))
|
307 |
+
static_mask = torch.ones((1, 1, l),
|
308 |
+
dtype=torch.bool, device=torch.device('cuda'))
|
309 |
+
static_pos_emb = torch.zeros((1, 2*l-1, 512),
|
310 |
+
dtype=torch.float32, device=torch.device('cuda'))
|
311 |
+
|
312 |
+
static_inputs = [
|
313 |
+
static_x,
|
314 |
+
static_mask,
|
315 |
+
static_pos_emb,
|
316 |
+
]
|
317 |
+
|
318 |
+
self._forward_impl_up_encoder(
|
319 |
+
static_inputs[0],
|
320 |
+
static_inputs[1],
|
321 |
+
static_inputs[2],
|
322 |
+
)
|
323 |
+
graph = torch.cuda.CUDAGraph()
|
324 |
+
with torch.no_grad():
|
325 |
+
with torch.cuda.graph(graph):
|
326 |
+
static_out_x = self._forward_impl_up_encoder(
|
327 |
+
static_inputs[0],
|
328 |
+
static_inputs[1],
|
329 |
+
static_inputs[2]
|
330 |
+
)
|
331 |
+
self.graph_up_encoder[l] = graph
|
332 |
+
static_outputs = [
|
333 |
+
static_out_x,
|
334 |
+
]
|
335 |
+
self.inference_buffers_up_encoder[l] = {
|
336 |
+
'static_inputs': static_inputs,
|
337 |
+
'static_outputs': static_outputs
|
338 |
+
}
|
339 |
+
|
340 |
+
self.use_cuda_graph = True
|
341 |
+
print("CUDA Graph initialized successfully for encoder and up_encoder")
|
342 |
+
|
343 |
+
# @torch.compile(dynamic=True,backend="eager")
|
344 |
+
def _forward_impl_encoder(self,
|
345 |
+
x: torch.Tensor,
|
346 |
+
mask: torch.Tensor,
|
347 |
+
pos_emb: torch.Tensor):
|
348 |
+
for layer in self.encoders:
|
349 |
+
x, _, _, _ = layer(x, mask, pos_emb)
|
350 |
+
return x
|
351 |
+
|
352 |
+
# @torch.compile(dynamic=True,backend="eager")
|
353 |
+
def _forward_impl_up_encoder(self,
|
354 |
+
x: torch.Tensor,
|
355 |
+
mask: torch.Tensor,
|
356 |
+
pos_emb: torch.Tensor):
|
357 |
+
for layer in self.up_encoders:
|
358 |
+
x, _, _, _ = layer(x, mask, pos_emb)
|
359 |
+
return x
|
360 |
+
|
361 |
+
def output_size(self) -> int:
|
362 |
+
return self._output_size
|
363 |
+
|
364 |
+
# @torch.compile(dynamic=True,backend="eager")
|
365 |
+
def forward(
|
366 |
+
self,
|
367 |
+
xs: torch.Tensor,
|
368 |
+
xs_lens: torch.Tensor,
|
369 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
370 |
+
# (sfy) chunk training strategy should not be open-sourced
|
371 |
+
T = xs.size(1)
|
372 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
373 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
374 |
+
|
375 |
+
# lookahead
|
376 |
+
xs = self.pre_lookahead_layer(xs)
|
377 |
+
# conformer block
|
378 |
+
if self.enable_cuda_graph and xs.shape[1] in self.graph_encoder:
|
379 |
+
self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs)
|
380 |
+
self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks)
|
381 |
+
self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb)
|
382 |
+
self.graph_encoder[xs.shape[1]].replay()
|
383 |
+
xs = self.inference_buffers_encoder[xs.shape[1]]['static_outputs'][0]
|
384 |
+
else:
|
385 |
+
xs = self._forward_impl_encoder(xs, masks, pos_emb)
|
386 |
+
# upsample
|
387 |
+
xs = xs.transpose(1, 2).contiguous()
|
388 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
389 |
+
xs = xs.transpose(1, 2).contiguous()
|
390 |
+
|
391 |
+
# 2nd conformer block
|
392 |
+
T = xs.size(1)
|
393 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
394 |
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
395 |
+
if self.enable_cuda_graph and xs.shape[1] in self.graph_up_encoder:
|
396 |
+
self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs)
|
397 |
+
self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks)
|
398 |
+
self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb)
|
399 |
+
self.graph_up_encoder[xs.shape[1]].replay()
|
400 |
+
xs = self.inference_buffers_up_encoder[xs.shape[1]]['static_outputs'][0]
|
401 |
+
else:
|
402 |
+
xs = self._forward_impl_up_encoder(xs, masks, pos_emb)
|
403 |
+
# post norm
|
404 |
+
if self.normalize_before:
|
405 |
+
xs = self.after_norm(xs)
|
406 |
+
return xs, masks
|
407 |
+
|
408 |
+
@torch.compile(dynamic=True,backend="eager")
|
409 |
+
def forward_chunk(self,
|
410 |
+
xs: torch.Tensor,
|
411 |
+
last_chunk: bool = False,
|
412 |
+
cnn_cache: torch.Tensor = None,
|
413 |
+
att_cache: torch.Tensor = None,
|
414 |
+
):
|
415 |
+
"""
|
416 |
+
Args:
|
417 |
+
xs: shape (b, dt, c)
|
418 |
+
last_chunk: bool. If last chunk, will pad input with lookaheads
|
419 |
+
att_cache: shape (depth1+depth2, b, nh, 2*t1, c).
|
420 |
+
cnn_cache: shape (b, c, t1+t2). Where t1=2 (pre_lookahead_layer), t2=4 (up_layer)
|
421 |
+
"""
|
422 |
+
if att_cache is not None:
|
423 |
+
assert att_cache.shape[3] % 2 == 0, att_cache.shape
|
424 |
+
if cnn_cache is not None:
|
425 |
+
assert cnn_cache.shape[2] == 2+self.up_layer.stride*2, cnn_cache.shape
|
426 |
+
|
427 |
+
# unpack caches
|
428 |
+
offset1 = att_cache.shape[3] // 2 if att_cache is not None else 0
|
429 |
+
att_cache1 = att_cache[:len(self.encoders), :, :, :offset1] if att_cache is not None else [None] * len(self.encoders)
|
430 |
+
att_cache2 = att_cache[len(self.encoders):] if att_cache is not None else [None] * len(self.encoders)
|
431 |
+
cnn_cache1 = cnn_cache[:, :, :2] if cnn_cache is not None else None
|
432 |
+
cnn_cache2 = cnn_cache[:, :, 2:] if cnn_cache is not None else None
|
433 |
+
xs, _, _ = self.embed(xs, None)
|
434 |
+
if last_chunk:
|
435 |
+
xs = F.pad(xs, (0, 0, 0, self.pre_lookahead_layer.pre_lookahead_len))
|
436 |
+
|
437 |
+
# this_cnn_cache: shape (b=1, c=512, t=2)
|
438 |
+
xs, new_cnn_cache1 = self.pre_lookahead_layer.forward_chunk(xs, cache=cnn_cache1)
|
439 |
+
|
440 |
+
# remake pos_emb, offset param is ignored by position_encoding
|
441 |
+
pos_emb = self.embed.position_encoding(offset=None, size=offset1 + xs.shape[1])
|
442 |
+
|
443 |
+
# first conformer
|
444 |
+
chunk_masks = torch.zeros((0, 0, 0))
|
445 |
+
new_att_cache1 = []
|
446 |
+
|
447 |
+
for idx, layer in enumerate(self.encoders):
|
448 |
+
# this_att_cache: shape (b, nh, t, c * 2)
|
449 |
+
xs, _, this_new_att_cache1, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache1[idx])
|
450 |
+
new_att_cache1.append(this_new_att_cache1)
|
451 |
+
new_att_cache1 = torch.stack(new_att_cache1, dim=0)
|
452 |
+
|
453 |
+
# upsample + conformer encoder, xs: (b, t, c) -> (b, c, t)
|
454 |
+
xs = xs.transpose(1, 2).contiguous()
|
455 |
+
# this_cnn_cache: shape (b=1, c=512, t=2*2)
|
456 |
+
xs, _, new_cnn_cache2 = self.up_layer.forward_chunk(xs, None, cache=cnn_cache2)
|
457 |
+
xs = xs.transpose(1, 2).contiguous()
|
458 |
+
|
459 |
+
# at this time, xs are doubled in length
|
460 |
+
xs, _, _ = self.up_embed(xs, None)
|
461 |
+
|
462 |
+
# remake pos_emb
|
463 |
+
pos_emb = self.embed.position_encoding(offset=None, size=offset1 * self.up_layer.stride + xs.shape[1])
|
464 |
+
|
465 |
+
# second conformer
|
466 |
+
chunk_masks = torch.zeros((0, 0, 0),dtype=torch.bfloat16)
|
467 |
+
new_att_cache2 = []
|
468 |
+
|
469 |
+
for idx, layer in enumerate(self.up_encoders):
|
470 |
+
xs, _, this_new_att_cache2, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache2[idx])
|
471 |
+
new_att_cache2.append(this_new_att_cache2)
|
472 |
+
new_att_cache2 = torch.stack(new_att_cache2, dim=0)
|
473 |
+
|
474 |
+
if self.normalize_before:
|
475 |
+
xs = self.after_norm(xs)
|
476 |
+
|
477 |
+
# pack new cache
|
478 |
+
new_att_cache = torch.cat([new_att_cache1.repeat(1, 1, 1, 2, 1), new_att_cache2], dim=0)
|
479 |
+
new_cnn_cache = torch.cat([new_cnn_cache1, new_cnn_cache2], dim=2)
|
480 |
+
|
481 |
+
return xs, new_cnn_cache, new_att_cache
|
482 |
+
|
483 |
+
|
cosyvoice2/utils/class_utils.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright [2023-11-28] <[email protected], Xingchen Song>
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from cosyvoice2.transformer.subsampling import LinearNoSubsampling
|
18 |
+
from cosyvoice2.transformer.attention import RelPositionMultiHeadedAttention
|
19 |
+
from cosyvoice2.transformer.embedding import EspnetRelPositionalEncoding
|
20 |
+
|
21 |
+
|
22 |
+
COSYVOICE_ACTIVATION_CLASSES = {
|
23 |
+
"hardtanh": torch.nn.Hardtanh,
|
24 |
+
"tanh": torch.nn.Tanh,
|
25 |
+
"relu": torch.nn.ReLU,
|
26 |
+
"selu": torch.nn.SELU,
|
27 |
+
"swish": torch.nn.SiLU,
|
28 |
+
"gelu": torch.nn.GELU,
|
29 |
+
}
|
30 |
+
|
31 |
+
COSYVOICE_SUBSAMPLE_CLASSES = {
|
32 |
+
"linear": LinearNoSubsampling,
|
33 |
+
}
|
34 |
+
|
35 |
+
COSYVOICE_EMB_CLASSES = {
|
36 |
+
"rel_pos_espnet": EspnetRelPositionalEncoding,
|
37 |
+
}
|
38 |
+
|
39 |
+
COSYVOICE_ATTENTION_CLASSES = {
|
40 |
+
"rel_selfattn": RelPositionMultiHeadedAttention,
|
41 |
+
}
|
cosyvoice2/utils/common.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
2 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# Modified from ESPnet(https://github.com/espnet/espnet)
|
16 |
+
"""Unility functions for Transformer."""
|
17 |
+
|
18 |
+
import random
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
IGNORE_ID = -1
|
25 |
+
|
26 |
+
|
27 |
+
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
28 |
+
"""Perform padding for the list of tensors.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
32 |
+
pad_value (float): Value for padding.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
Tensor: Padded tensor (B, Tmax, `*`).
|
36 |
+
|
37 |
+
Examples:
|
38 |
+
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
39 |
+
>>> x
|
40 |
+
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
41 |
+
>>> pad_list(x, 0)
|
42 |
+
tensor([[1., 1., 1., 1.],
|
43 |
+
[1., 1., 0., 0.],
|
44 |
+
[1., 0., 0., 0.]])
|
45 |
+
|
46 |
+
"""
|
47 |
+
max_len = max([len(item) for item in xs])
|
48 |
+
batchs = len(xs)
|
49 |
+
ndim = xs[0].ndim
|
50 |
+
if ndim == 1:
|
51 |
+
pad_res = torch.zeros(batchs,
|
52 |
+
max_len,
|
53 |
+
dtype=xs[0].dtype,
|
54 |
+
device=xs[0].device)
|
55 |
+
elif ndim == 2:
|
56 |
+
pad_res = torch.zeros(batchs,
|
57 |
+
max_len,
|
58 |
+
xs[0].shape[1],
|
59 |
+
dtype=xs[0].dtype,
|
60 |
+
device=xs[0].device)
|
61 |
+
elif ndim == 3:
|
62 |
+
pad_res = torch.zeros(batchs,
|
63 |
+
max_len,
|
64 |
+
xs[0].shape[1],
|
65 |
+
xs[0].shape[2],
|
66 |
+
dtype=xs[0].dtype,
|
67 |
+
device=xs[0].device)
|
68 |
+
else:
|
69 |
+
raise ValueError(f"Unsupported ndim: {ndim}")
|
70 |
+
pad_res.fill_(pad_value)
|
71 |
+
for i in range(batchs):
|
72 |
+
pad_res[i, :len(xs[i])] = xs[i]
|
73 |
+
return pad_res
|
74 |
+
|
75 |
+
|
76 |
+
def get_padding(kernel_size, dilation=1):
|
77 |
+
return int((kernel_size * dilation - dilation) / 2)
|
78 |
+
|
79 |
+
|
80 |
+
def init_weights(m, mean=0.0, std=0.01):
|
81 |
+
classname = m.__class__.__name__
|
82 |
+
if classname.find("Conv") != -1:
|
83 |
+
m.weight.data.normal_(mean, std)
|
84 |
+
|
85 |
+
|
86 |
+
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
87 |
+
device = fade_in_mel.device
|
88 |
+
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
89 |
+
mel_overlap_len = int(window.shape[0] / 2)
|
90 |
+
if fade_in_mel.device == torch.device('cpu'):
|
91 |
+
fade_in_mel = fade_in_mel.clone()
|
92 |
+
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
93 |
+
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
94 |
+
return fade_in_mel.to(device)
|
95 |
+
|
96 |
+
|
97 |
+
def set_all_random_seed(seed):
|
98 |
+
random.seed(seed)
|
99 |
+
np.random.seed(seed)
|
100 |
+
torch.manual_seed(seed)
|
101 |
+
torch.cuda.manual_seed_all(seed)
|
cosyvoice2/utils/mask.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019 Shigeki Karita
|
2 |
+
# 2020 Mobvoi Inc (Binbin Zhang)
|
3 |
+
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
from typing import List
|
20 |
+
|
21 |
+
|
22 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
23 |
+
"""Make mask tensor containing indices of padded part.
|
24 |
+
|
25 |
+
See description of make_non_pad_mask.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
29 |
+
Returns:
|
30 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
31 |
+
|
32 |
+
Examples:
|
33 |
+
>>> lengths = [5, 3, 2]
|
34 |
+
>>> make_pad_mask(lengths)
|
35 |
+
masks = [[0, 0, 0, 0 ,0],
|
36 |
+
[0, 0, 0, 1, 1],
|
37 |
+
[0, 0, 1, 1, 1]]
|
38 |
+
"""
|
39 |
+
batch_size = lengths.size(0)
|
40 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
41 |
+
seq_range = torch.arange(0,
|
42 |
+
max_len,
|
43 |
+
dtype=torch.int64,
|
44 |
+
device=lengths.device)
|
45 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
46 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
47 |
+
mask = seq_range_expand >= seq_length_expand
|
48 |
+
return mask
|
49 |
+
|
flashcosyvoice/__init__.py
ADDED
File without changes
|
flashcosyvoice/cli.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
""" Example Usage: see README.md
|
15 |
+
"""
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import json
|
19 |
+
import os
|
20 |
+
import random
|
21 |
+
import sys
|
22 |
+
import time
|
23 |
+
from concurrent.futures import ThreadPoolExecutor
|
24 |
+
from datetime import datetime
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import onnxruntime
|
28 |
+
import s3tokenizer
|
29 |
+
import torch
|
30 |
+
import torch.distributed as dist
|
31 |
+
import torchaudio
|
32 |
+
import torchaudio.compliance.kaldi as kaldi
|
33 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
34 |
+
from tqdm import tqdm
|
35 |
+
|
36 |
+
from flashcosyvoice.config import Config, CosyVoice2LLMConfig, SamplingParams
|
37 |
+
from flashcosyvoice.cosyvoice2 import CosyVoice2
|
38 |
+
from flashcosyvoice.utils.audio import mel_spectrogram
|
39 |
+
|
40 |
+
|
41 |
+
def set_all_random_seed(seed):
|
42 |
+
random.seed(seed)
|
43 |
+
np.random.seed(seed)
|
44 |
+
torch.manual_seed(seed)
|
45 |
+
torch.cuda.manual_seed_all(seed)
|
46 |
+
|
47 |
+
|
48 |
+
def save_file_async(
|
49 |
+
wav, prompt_speech_tokens, generated_speech_tokens,
|
50 |
+
info, timing_stats
|
51 |
+
):
|
52 |
+
"""Save audio asynchronously."""
|
53 |
+
try:
|
54 |
+
os.makedirs(os.path.dirname(info['wav']), exist_ok=True)
|
55 |
+
if wav is not None:
|
56 |
+
wav = wav.cpu()
|
57 |
+
torchaudio.save(info['wav'], wav, 24000)
|
58 |
+
duration = wav.shape[-1] / 24000.0
|
59 |
+
rtf = ((timing_stats['dataloader_time'] + timing_stats['model_inference_time']) / timing_stats['batch_size']) / duration
|
60 |
+
timing_stats['rtf'] = rtf
|
61 |
+
else:
|
62 |
+
duration = 0.0
|
63 |
+
info['timing_stats'] = timing_stats
|
64 |
+
info['prompt_speech_tokens'] = prompt_speech_tokens
|
65 |
+
info['generated_speech_tokens'] = generated_speech_tokens
|
66 |
+
with open(f"{info['wav'].replace('.wav', '.json')}", "w") as f:
|
67 |
+
json.dump(info, f, ensure_ascii=False, indent=4)
|
68 |
+
return duration
|
69 |
+
except Exception as e:
|
70 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
71 |
+
tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}")
|
72 |
+
return 0.0
|
73 |
+
|
74 |
+
|
75 |
+
class AudioDataset(Dataset):
|
76 |
+
|
77 |
+
def __init__(self, text_norm, text_tokenizer, data_list, model_config: Config):
|
78 |
+
self.datas = []
|
79 |
+
self.text_norm = text_norm
|
80 |
+
self.model_config = model_config
|
81 |
+
|
82 |
+
"""Example data_list:
|
83 |
+
```
|
84 |
+
{"key": "uttid_1", "prompt_text": "你好,我是小明。", "text": "你好,我是小红。", "prompt_wav": "/mnt/data/audio/00000000.wav", "wav": "/mnt/data/audio_synthetic/uttid_1.wav"}
|
85 |
+
{"key": "uttid_2", "prompt_text": "你好,我是小红。", "text": "你好,我是小明。", "prompt_wav": "/mnt/data/audio/00000001.wav", "wav": "/mnt/data/audio_synthetic/uttid_2.wav"}
|
86 |
+
```
|
87 |
+
Note:
|
88 |
+
- `key` is the key of this sample.
|
89 |
+
- `prompt_text` is the text used for prompt.
|
90 |
+
- `text` is the text used for generating real audio.
|
91 |
+
- `prompt_wav` is the audio used for prompt.
|
92 |
+
- `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script).
|
93 |
+
"""
|
94 |
+
missing = 0
|
95 |
+
with open(data_list, 'r', encoding='utf-8') as f:
|
96 |
+
lines = f.readlines()
|
97 |
+
total_lines = len(lines)
|
98 |
+
if torch.distributed.get_node_local_rank() == 0:
|
99 |
+
iterator = tqdm(lines, desc='Loading data')
|
100 |
+
else:
|
101 |
+
iterator = lines
|
102 |
+
for line in iterator:
|
103 |
+
data = json.loads(line.strip())
|
104 |
+
valid = True
|
105 |
+
for k in ['key', 'prompt_text', 'text', 'prompt_wav']:
|
106 |
+
if k not in data:
|
107 |
+
valid = False
|
108 |
+
break
|
109 |
+
if data[k] is None:
|
110 |
+
valid = False
|
111 |
+
break
|
112 |
+
if not os.path.exists(data['prompt_wav']):
|
113 |
+
valid = False
|
114 |
+
if valid:
|
115 |
+
self.datas.append(data)
|
116 |
+
else:
|
117 |
+
missing += 1
|
118 |
+
if torch.distributed.get_node_local_rank() == 0:
|
119 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
120 |
+
tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.')
|
121 |
+
|
122 |
+
self.text_tokenizer = text_tokenizer
|
123 |
+
|
124 |
+
option = onnxruntime.SessionOptions()
|
125 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
126 |
+
option.intra_op_num_threads = 1
|
127 |
+
self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option,
|
128 |
+
providers=["CPUExecutionProvider"])
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.datas)
|
132 |
+
|
133 |
+
def __getitem__(self, idx):
|
134 |
+
data = self.datas[idx]
|
135 |
+
|
136 |
+
try:
|
137 |
+
# 1. feature for s3tokenizer
|
138 |
+
audio = s3tokenizer.load_audio(data['prompt_wav'], sr=16000) # [T]
|
139 |
+
log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T]
|
140 |
+
|
141 |
+
# 2. feature for speaker embedding
|
142 |
+
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
143 |
+
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
144 |
+
spk_emb = self.spk_model.run(
|
145 |
+
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
146 |
+
)[0].flatten().tolist()
|
147 |
+
|
148 |
+
# 3. feature for flow
|
149 |
+
audio, sample_rate = torchaudio.load(data['prompt_wav'], backend='soundfile')
|
150 |
+
audio = audio.mean(dim=0, keepdim=True) # [1, T]
|
151 |
+
if sample_rate != 24000:
|
152 |
+
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
153 |
+
mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
154 |
+
mel_len = mel.shape[0]
|
155 |
+
|
156 |
+
# 4. feature for llm
|
157 |
+
if self.text_norm is not None:
|
158 |
+
prompt_texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['prompt_text'].strip()))["sentences"]]
|
159 |
+
prompt_text = ''.join(prompt_texts)
|
160 |
+
texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['text'].strip()))["sentences"]]
|
161 |
+
text = ''.join(texts)
|
162 |
+
else:
|
163 |
+
prompt_text = data['prompt_text']
|
164 |
+
text = data['text']
|
165 |
+
prompt_text_ids = self.text_tokenizer.encode(prompt_text)
|
166 |
+
prompt_text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in prompt_text_ids]
|
167 |
+
text_ids = self.text_tokenizer.encode(text)
|
168 |
+
text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in text_ids]
|
169 |
+
item = {
|
170 |
+
"prompt_text_tokens": prompt_text_ids, "text_tokens": text_ids,
|
171 |
+
"spk_emb": spk_emb, "mel": mel, "mel_len": mel_len, "log_mel": log_mel, "info": data,
|
172 |
+
"min_tokens": len(text_ids) * self.model_config.min_token_text_ratio,
|
173 |
+
"max_tokens": len(text_ids) * self.model_config.max_token_text_ratio,
|
174 |
+
}
|
175 |
+
except Exception as e:
|
176 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
177 |
+
tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}")
|
178 |
+
return None
|
179 |
+
return item
|
180 |
+
|
181 |
+
|
182 |
+
def collate_fn(batch):
|
183 |
+
prompt_mels_for_llm = [item["log_mel"] for item in batch if item is not None]
|
184 |
+
prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_mels_for_llm) # [B, num_mels=128, T]
|
185 |
+
prompt_text_tokens_for_llm = [item["prompt_text_tokens"] for item in batch if item is not None]
|
186 |
+
text_tokens_for_llm = [item["text_tokens"] for item in batch if item is not None]
|
187 |
+
prompt_mels_for_flow = [item["mel"] for item in batch if item is not None]
|
188 |
+
prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80]
|
189 |
+
prompt_mels_lens_for_flow = [item["mel_len"] for item in batch if item is not None]
|
190 |
+
prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow)
|
191 |
+
spk_emb_for_flow = [item["spk_emb"] for item in batch if item is not None]
|
192 |
+
spk_emb_for_flow = torch.tensor(spk_emb_for_flow)
|
193 |
+
sampling_params = [SamplingParams(min_tokens=item["min_tokens"], max_tokens=item["max_tokens"], use_ras=True) for item in batch if item is not None]
|
194 |
+
infos = [item["info"] for item in batch if item is not None]
|
195 |
+
return {
|
196 |
+
"prompt_mels_for_llm": prompt_mels_for_llm,
|
197 |
+
"prompt_mels_lens_for_llm": prompt_mels_lens_for_llm,
|
198 |
+
"prompt_text_tokens_for_llm": prompt_text_tokens_for_llm,
|
199 |
+
"text_tokens_for_llm": text_tokens_for_llm,
|
200 |
+
"prompt_mels_for_flow": prompt_mels_for_flow,
|
201 |
+
"prompt_mels_lens_for_flow": prompt_mels_lens_for_flow,
|
202 |
+
"spk_emb_for_flow": spk_emb_for_flow,
|
203 |
+
"sampling_params": sampling_params,
|
204 |
+
"infos": infos,
|
205 |
+
}
|
206 |
+
|
207 |
+
|
208 |
+
def init_distributed():
|
209 |
+
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
210 |
+
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
211 |
+
rank = int(os.environ.get('RANK', 0))
|
212 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
213 |
+
tqdm.write(f'[{timestamp}] - [INFO] - Inference on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}')
|
214 |
+
torch.cuda.set_device(local_rank)
|
215 |
+
dist.init_process_group("nccl")
|
216 |
+
return world_size, local_rank, rank
|
217 |
+
|
218 |
+
|
219 |
+
def get_args():
|
220 |
+
parser = argparse.ArgumentParser(description='FlashCosyVoice')
|
221 |
+
parser.add_argument('--model_path',
|
222 |
+
required=True,
|
223 |
+
type=str,
|
224 |
+
help='model path')
|
225 |
+
parser.add_argument('--data_list',
|
226 |
+
required=True,
|
227 |
+
type=str,
|
228 |
+
help='data list')
|
229 |
+
parser.add_argument('--batch_size_dataloader',
|
230 |
+
required=True,
|
231 |
+
type=int,
|
232 |
+
help='batch size (per-device) for dataloading')
|
233 |
+
parser.add_argument('--batch_size_flow',
|
234 |
+
required=True,
|
235 |
+
type=int,
|
236 |
+
help='batch size (per-device) for flow-matching')
|
237 |
+
parser.add_argument('--num_workers',
|
238 |
+
type=int,
|
239 |
+
default=4,
|
240 |
+
help='workers for dataloader')
|
241 |
+
parser.add_argument('--prefetch',
|
242 |
+
type=int,
|
243 |
+
default=5,
|
244 |
+
help='prefetch for dataloader')
|
245 |
+
parser.add_argument('--enable_tn',
|
246 |
+
action='store_true',
|
247 |
+
help='enable text normalization')
|
248 |
+
parser.add_argument('--only_llm',
|
249 |
+
action='store_true',
|
250 |
+
help='only generate speech tokens from llm')
|
251 |
+
parser.add_argument('--fp16_flow',
|
252 |
+
action='store_true',
|
253 |
+
help='enable fp16 flow')
|
254 |
+
parser.add_argument('--seed',
|
255 |
+
type=int,
|
256 |
+
default=1986,
|
257 |
+
help='random seed for generation')
|
258 |
+
args = parser.parse_args()
|
259 |
+
return args
|
260 |
+
|
261 |
+
|
262 |
+
def main():
|
263 |
+
args = get_args()
|
264 |
+
|
265 |
+
if args.enable_tn:
|
266 |
+
# Check python version, if == 3.10, use ttsfrd
|
267 |
+
if sys.version_info.major == 3 and sys.version_info.minor == 10:
|
268 |
+
# Check if ttsfrd is installed
|
269 |
+
try:
|
270 |
+
import ttsfrd
|
271 |
+
from cosyvoice_ttsfrd import get_resource_path
|
272 |
+
except ImportError as e:
|
273 |
+
raise ImportError("ttsfrd is not installed, please install it first, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for installation guide.") from e
|
274 |
+
text_norm = ttsfrd.TtsFrontendEngine()
|
275 |
+
text_norm.initialize(get_resource_path())
|
276 |
+
text_norm.set_lang_type('pinyinvg')
|
277 |
+
else:
|
278 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
279 |
+
tqdm.write(f"[{timestamp}] - [WARNING] - Only python 3.10 is supported for ttsfrd, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for more info. Setting enable_tn to False...")
|
280 |
+
# TODO: maybe we should use wetext if python version is not 3.10?
|
281 |
+
args.enable_tn = False
|
282 |
+
text_norm = None
|
283 |
+
else:
|
284 |
+
text_norm = None
|
285 |
+
|
286 |
+
assert (torch.cuda.is_available())
|
287 |
+
world_size, local_rank, rank = init_distributed()
|
288 |
+
config = Config(model=args.model_path, enforce_eager=True, tensor_parallel_size=1,
|
289 |
+
max_num_seqs=args.batch_size_dataloader,
|
290 |
+
hf_config=CosyVoice2LLMConfig(fp16_flow=args.fp16_flow), rank=local_rank)
|
291 |
+
model = CosyVoice2(config)
|
292 |
+
|
293 |
+
set_all_random_seed(args.seed)
|
294 |
+
|
295 |
+
dataset = AudioDataset(text_norm, model.llm.tokenizer, args.data_list, config)
|
296 |
+
sampler = DistributedSampler(dataset,
|
297 |
+
num_replicas=world_size,
|
298 |
+
rank=rank)
|
299 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size_dataloader, num_workers=args.num_workers, pin_memory=True,
|
300 |
+
sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn)
|
301 |
+
total_steps = len(dataset)
|
302 |
+
|
303 |
+
if local_rank == 0:
|
304 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
305 |
+
tqdm.write(f"[{timestamp}] - [INFO] - {args}")
|
306 |
+
progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav",
|
307 |
+
position=0, leave=True, dynamic_ncols=True)
|
308 |
+
|
309 |
+
cpu_counts = os.cpu_count()
|
310 |
+
executor = ThreadPoolExecutor(max_workers=min(args.batch_size_dataloader, cpu_counts // 8))
|
311 |
+
pending_futures = []
|
312 |
+
dataloader_iter = iter(dataloader)
|
313 |
+
succeed_duration = 0.01 # avoid division by zero
|
314 |
+
start_time = time.time()
|
315 |
+
estimated_total_wavs = 0
|
316 |
+
succeed_wavs = 0
|
317 |
+
failed_wavs = 0
|
318 |
+
last_print_time = start_time
|
319 |
+
|
320 |
+
while True:
|
321 |
+
try:
|
322 |
+
dataloader_start = time.time()
|
323 |
+
batch = next(dataloader_iter)
|
324 |
+
dataloader_time = time.time() - dataloader_start
|
325 |
+
|
326 |
+
if len(batch['infos']) == 0:
|
327 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
328 |
+
tqdm.write(f"[{timestamp}] - [WARNING] - rank {rank} of {world_size}: No valid batch found, skipping this batch...")
|
329 |
+
continue
|
330 |
+
|
331 |
+
model_start = time.time()
|
332 |
+
results_dict, timing_stats = model(**batch, batch_size_flow=args.batch_size_flow,
|
333 |
+
only_llm=args.only_llm)
|
334 |
+
model_time = time.time() - model_start
|
335 |
+
|
336 |
+
estimated_total_wavs += len(results_dict['generated_wavs'])
|
337 |
+
|
338 |
+
timing_stats['dataloader_time'] = dataloader_time
|
339 |
+
timing_stats['model_inference_time'] = model_time
|
340 |
+
|
341 |
+
if args.only_llm:
|
342 |
+
results_dict['generated_wavs'] = [None] * len(results_dict['prompt_speech_tokens'])
|
343 |
+
|
344 |
+
for i in range(len(results_dict['generated_wavs'])):
|
345 |
+
future = executor.submit(
|
346 |
+
save_file_async, results_dict['generated_wavs'][i],
|
347 |
+
results_dict['prompt_speech_tokens'][i],
|
348 |
+
results_dict['generated_speech_tokens'][i],
|
349 |
+
batch['infos'][i].copy(), timing_stats.copy()
|
350 |
+
)
|
351 |
+
pending_futures.append(future)
|
352 |
+
|
353 |
+
completed_futures = []
|
354 |
+
for future in pending_futures:
|
355 |
+
if future.done():
|
356 |
+
try:
|
357 |
+
duration = future.result()
|
358 |
+
succeed_duration += duration
|
359 |
+
succeed_wavs += 1
|
360 |
+
except Exception as e:
|
361 |
+
failed_wavs += 1
|
362 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
363 |
+
tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in async save task: {e}")
|
364 |
+
completed_futures.append(future)
|
365 |
+
|
366 |
+
for future in completed_futures:
|
367 |
+
pending_futures.remove(future)
|
368 |
+
|
369 |
+
if local_rank == 0:
|
370 |
+
update_n = world_size * len(batch["prompt_text_tokens_for_llm"])
|
371 |
+
if progress_bar.n + update_n > progress_bar.total:
|
372 |
+
progress_bar.update(progress_bar.total - progress_bar.n)
|
373 |
+
else:
|
374 |
+
progress_bar.update(update_n)
|
375 |
+
|
376 |
+
current_time = time.time()
|
377 |
+
if current_time - last_print_time >= 120 and not args.only_llm:
|
378 |
+
elapsed_time = current_time - start_time
|
379 |
+
avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0
|
380 |
+
estimated_total_duration = avg_duration * estimated_total_wavs
|
381 |
+
current_rtf = elapsed_time / estimated_total_duration if estimated_total_duration > 0.01 else 0
|
382 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
383 |
+
tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Estimated RTF: {current_rtf:.5f}, Elapsed time: {elapsed_time:.2f}s") # noqa
|
384 |
+
last_print_time = current_time
|
385 |
+
except StopIteration:
|
386 |
+
break
|
387 |
+
except Exception as e:
|
388 |
+
failed_wavs += 1
|
389 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
390 |
+
tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in main loop: {e}")
|
391 |
+
continue
|
392 |
+
|
393 |
+
total_time = time.time() - start_time
|
394 |
+
|
395 |
+
if local_rank == 0:
|
396 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
397 |
+
tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...")
|
398 |
+
|
399 |
+
for future in pending_futures:
|
400 |
+
try:
|
401 |
+
duration = future.result(timeout=60)
|
402 |
+
succeed_duration += duration
|
403 |
+
succeed_wavs += 1
|
404 |
+
except Exception as e:
|
405 |
+
failed_wavs += 1
|
406 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
407 |
+
tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in final async save task: {e}")
|
408 |
+
executor.shutdown(wait=True)
|
409 |
+
|
410 |
+
if local_rank == 0:
|
411 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
412 |
+
tqdm.write(f"[{timestamp}] - [INFO] - All async save tasks completed.")
|
413 |
+
progress_bar.close()
|
414 |
+
|
415 |
+
if not args.only_llm:
|
416 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
417 |
+
tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h), RTF: {total_time / succeed_duration:.5f}") # noqa
|
418 |
+
|
419 |
+
dist.barrier()
|
420 |
+
dist.destroy_process_group()
|
421 |
+
|
422 |
+
|
423 |
+
if __name__ == "__main__":
|
424 |
+
main()
|
flashcosyvoice/config.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from transformers import AutoConfig
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class CosyVoice2LLMConfig:
|
10 |
+
architectures: list[str] = field(default_factory=lambda: ["Qwen2ForCausalLM"])
|
11 |
+
attention_dropout: float = 0.0
|
12 |
+
bos_token_id: int = 151643
|
13 |
+
eos_token_id: int = 6561 # speech eos
|
14 |
+
hidden_act: str = "silu"
|
15 |
+
hidden_size: int = 896
|
16 |
+
initializer_range: float = 0.02
|
17 |
+
intermediate_size: int = 4864
|
18 |
+
max_position_embeddings: int = 32768
|
19 |
+
max_window_layers: int = 24
|
20 |
+
model_type: str = "qwen2"
|
21 |
+
num_attention_heads: int = 14
|
22 |
+
num_hidden_layers: int = 24
|
23 |
+
num_key_value_heads: int = 2
|
24 |
+
head_dim: int = 64
|
25 |
+
rms_norm_eps: float = 1e-06
|
26 |
+
rope_scaling: dict | None = None
|
27 |
+
rope_theta: float = 1000000.0
|
28 |
+
sliding_window: int = 32768
|
29 |
+
tie_word_embeddings: bool = False
|
30 |
+
torch_dtype: torch.dtype = torch.bfloat16
|
31 |
+
transformers_version: str = "4.52.0.dev0"
|
32 |
+
use_cache: bool = True
|
33 |
+
use_sliding_window: bool = False
|
34 |
+
vocab_size: int = 158500 # text_vocab_size + speech_vocab_size + 2 (eos and task_id)
|
35 |
+
text_vocab_size: int = 151936
|
36 |
+
speech_vocab_size: int = 6562 # actually 6564, we only care about non-streaming inference, so cut off tokens (6562, 6563) that are only used for streaming TTS
|
37 |
+
lm_head_bias: bool = True
|
38 |
+
qkv_bias: bool = True
|
39 |
+
fp16_flow: bool = True
|
40 |
+
|
41 |
+
|
42 |
+
@dataclass
|
43 |
+
class SamplingParams:
|
44 |
+
temperature: float = 1.0
|
45 |
+
min_tokens: int = 2
|
46 |
+
max_tokens: int = 64
|
47 |
+
ignore_eos: bool = False
|
48 |
+
top_k: int = 25
|
49 |
+
# RasSampler parameters
|
50 |
+
use_ras: bool = False
|
51 |
+
win_size: int = 10
|
52 |
+
tau_r: float = 0.1
|
53 |
+
top_p: float = 0.8
|
54 |
+
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class Config:
|
58 |
+
model: str
|
59 |
+
max_num_batched_tokens: int = 1572864
|
60 |
+
max_num_seqs: int = 1024
|
61 |
+
max_model_len: int = 1536 # 15s prompt + 30s generated audio for 25hz audio tokenizer
|
62 |
+
gpu_memory_utilization: float = 0.9
|
63 |
+
tensor_parallel_size: int = 1
|
64 |
+
enforce_eager: bool = False
|
65 |
+
hf_config: CosyVoice2LLMConfig | AutoConfig = field(default_factory=CosyVoice2LLMConfig)
|
66 |
+
eos: int = -1
|
67 |
+
kvcache_block_size: int = 256
|
68 |
+
num_kvcache_blocks: int = -1
|
69 |
+
min_token_text_ratio: int = 2
|
70 |
+
max_token_text_ratio: int = 20
|
71 |
+
rank: int = 0
|
72 |
+
|
73 |
+
def __post_init__(self):
|
74 |
+
assert os.path.isdir(self.model)
|
75 |
+
assert self.kvcache_block_size % 256 == 0
|
76 |
+
assert 1 <= self.tensor_parallel_size <= 8
|
77 |
+
|
78 |
+
max_pos = getattr(self.hf_config, "max_position_embeddings", 4096)
|
79 |
+
self.max_model_len = min(self.max_model_len, max_pos)
|
80 |
+
assert self.max_num_batched_tokens >= self.max_model_len
|
flashcosyvoice/cosyvoice2.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import time
|
15 |
+
from datetime import datetime
|
16 |
+
|
17 |
+
import s3tokenizer
|
18 |
+
import torch
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from flashcosyvoice.config import Config, SamplingParams
|
22 |
+
from flashcosyvoice.engine.llm_engine import LLMEngine
|
23 |
+
from flashcosyvoice.modules.flow import CausalMaskedDiffWithXvec
|
24 |
+
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
25 |
+
|
26 |
+
|
27 |
+
class CosyVoice2(torch.nn.Module):
|
28 |
+
def __init__(self, config: Config = None):
|
29 |
+
super().__init__()
|
30 |
+
self.config = Config() if config is None else config
|
31 |
+
|
32 |
+
self.audio_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").cuda().eval()
|
33 |
+
|
34 |
+
self.llm = LLMEngine(**self.config.__dict__)
|
35 |
+
|
36 |
+
self.use_tqdm = torch.distributed.get_node_local_rank() == 0
|
37 |
+
|
38 |
+
self.flow = CausalMaskedDiffWithXvec()
|
39 |
+
if self.config.hf_config.fp16_flow:
|
40 |
+
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
|
41 |
+
tqdm.write(f"[{timestamp}] - [INFO] - Casting flow to fp16")
|
42 |
+
self.flow.half()
|
43 |
+
self.flow.load_state_dict(torch.load(f"{self.config.model}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
44 |
+
self.flow.cuda().eval()
|
45 |
+
|
46 |
+
self.hift = HiFTGenerator()
|
47 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{self.config.model}/hift.pt", map_location="cpu", weights_only=True).items()}
|
48 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
49 |
+
self.hift.cuda().eval()
|
50 |
+
|
51 |
+
@torch.inference_mode()
|
52 |
+
def forward(
|
53 |
+
self, prompt_mels_for_llm: torch.Tensor, prompt_mels_lens_for_llm: torch.Tensor,
|
54 |
+
prompt_text_tokens_for_llm: list[list[int]], text_tokens_for_llm: list[list[int]],
|
55 |
+
prompt_mels_for_flow: torch.Tensor, prompt_mels_lens_for_flow: torch.Tensor,
|
56 |
+
spk_emb_for_flow: torch.Tensor,
|
57 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
58 |
+
batch_size_flow: int,
|
59 |
+
only_llm: bool,
|
60 |
+
**kwargs, # for compatibility
|
61 |
+
):
|
62 |
+
timing_stats = {}
|
63 |
+
|
64 |
+
# Audio tokenization
|
65 |
+
start_time = time.time()
|
66 |
+
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(
|
67 |
+
prompt_mels_for_llm.cuda(), prompt_mels_lens_for_llm.cuda()
|
68 |
+
)
|
69 |
+
timing_stats['audio_tokenization'] = time.time() - start_time
|
70 |
+
|
71 |
+
batch_size = prompt_speech_tokens.shape[0]
|
72 |
+
assert len(prompt_text_tokens_for_llm) == batch_size
|
73 |
+
|
74 |
+
# Prepare LLM inputs
|
75 |
+
start_time = time.time()
|
76 |
+
valid_prompt_speech_tokens = []
|
77 |
+
inputs = []
|
78 |
+
for i in range(batch_size):
|
79 |
+
speech_tokens_i = prompt_speech_tokens[i, :prompt_speech_tokens_lens[i].item()].tolist()
|
80 |
+
valid_prompt_speech_tokens.append(speech_tokens_i)
|
81 |
+
inputs.append([self.config.hf_config.speech_vocab_size] + prompt_text_tokens_for_llm[i] + text_tokens_for_llm[i] + [self.config.hf_config.speech_vocab_size + 1] + speech_tokens_i)
|
82 |
+
timing_stats['prepare_llm_inputs'] = time.time() - start_time
|
83 |
+
|
84 |
+
# LLM generation
|
85 |
+
start_time = time.time()
|
86 |
+
llm_outputs = self.llm.generate(inputs, sampling_params, use_tqdm=self.use_tqdm)
|
87 |
+
timing_stats['llm_generation'] = time.time() - start_time
|
88 |
+
|
89 |
+
results_dict = {
|
90 |
+
"prompt_speech_tokens": valid_prompt_speech_tokens,
|
91 |
+
"generated_speech_tokens": [o['token_ids'][:-1] for o in llm_outputs],
|
92 |
+
}
|
93 |
+
if only_llm:
|
94 |
+
return results_dict, timing_stats
|
95 |
+
|
96 |
+
# Prepare Flow inputs
|
97 |
+
start_time = time.time()
|
98 |
+
flow_inputs = []
|
99 |
+
flow_inputs_lens = []
|
100 |
+
for i, o in enumerate(llm_outputs):
|
101 |
+
generated_speech_tokens = o['token_ids'][:-1] # ignore last eos
|
102 |
+
prompt_speech_tokens = valid_prompt_speech_tokens[i]
|
103 |
+
flow_inputs.append(torch.tensor(prompt_speech_tokens + generated_speech_tokens))
|
104 |
+
flow_inputs_lens.append(len(prompt_speech_tokens) + len(generated_speech_tokens))
|
105 |
+
flow_inputs = torch.nn.utils.rnn.pad_sequence(flow_inputs, batch_first=True, padding_value=0)
|
106 |
+
flow_inputs_lens = torch.tensor(flow_inputs_lens)
|
107 |
+
timing_stats['prepare_flow_inputs'] = time.time() - start_time
|
108 |
+
|
109 |
+
# Flow generation and HiFi-GAN generation (with batching)
|
110 |
+
total_batch_size = flow_inputs.shape[0]
|
111 |
+
generated_wavs = []
|
112 |
+
flow_total_time = 0.0
|
113 |
+
hifigan_total_time = 0.0
|
114 |
+
|
115 |
+
# Process in batches according to batch_size_flow, batch_size_flow <= total_batch_size
|
116 |
+
# NOTE(xcsong): When executing both LLM and Flow on the same GPU,
|
117 |
+
# Flow can easily fill up the SM and memory. Therefore, batch processing is required to avoid OOM.
|
118 |
+
num_batches = (total_batch_size + batch_size_flow - 1) // batch_size_flow
|
119 |
+
batch_iterator = range(0, total_batch_size, batch_size_flow)
|
120 |
+
if self.use_tqdm:
|
121 |
+
batch_iterator = tqdm(batch_iterator, desc="Generating wavs (Flow+HiFi-GAN)", leave=False, unit="batch",
|
122 |
+
total=num_batches, dynamic_ncols=True, position=self.config.rank + 1)
|
123 |
+
|
124 |
+
for start_idx in batch_iterator:
|
125 |
+
end_idx = min(start_idx + batch_size_flow, total_batch_size)
|
126 |
+
batch_flow_inputs = flow_inputs[start_idx:end_idx]
|
127 |
+
batch_flow_inputs_lens = flow_inputs_lens[start_idx:end_idx]
|
128 |
+
batch_prompt_mels = prompt_mels_for_flow[start_idx:end_idx]
|
129 |
+
batch_prompt_mels_lens = prompt_mels_lens_for_flow[start_idx:end_idx]
|
130 |
+
batch_spk_emb = spk_emb_for_flow[start_idx:end_idx]
|
131 |
+
|
132 |
+
# Flow generation for this batch
|
133 |
+
flow_start_time = time.time()
|
134 |
+
with torch.amp.autocast("cuda", dtype=torch.float16 if self.config.hf_config.fp16_flow else torch.float32):
|
135 |
+
batch_generated_mels, batch_generated_mels_lens = self.flow(
|
136 |
+
batch_flow_inputs.cuda(), batch_flow_inputs_lens.cuda(),
|
137 |
+
batch_prompt_mels.cuda(), batch_prompt_mels_lens.cuda(), batch_spk_emb.cuda(),
|
138 |
+
streaming=False, finalize=True
|
139 |
+
)
|
140 |
+
flow_total_time += time.time() - flow_start_time
|
141 |
+
|
142 |
+
# HiFi-GAN generation for this batch
|
143 |
+
hifigan_start_time = time.time()
|
144 |
+
batch_size_current = end_idx - start_idx
|
145 |
+
for i in range(batch_size_current):
|
146 |
+
mel = batch_generated_mels[i, :, batch_prompt_mels_lens[i].item():batch_generated_mels_lens[i].item()].unsqueeze(0)
|
147 |
+
wav, _ = self.hift(speech_feat=mel)
|
148 |
+
generated_wavs.append(wav)
|
149 |
+
hifigan_total_time += time.time() - hifigan_start_time
|
150 |
+
|
151 |
+
timing_stats['flow_generation'] = flow_total_time
|
152 |
+
timing_stats['hifigan_generation'] = hifigan_total_time
|
153 |
+
|
154 |
+
# Calculate total time and batch statistics
|
155 |
+
timing_stats['model.forward_total'] = sum(timing_stats.values())
|
156 |
+
timing_stats['batch_size'] = len(generated_wavs)
|
157 |
+
timing_stats['batch_size_flow'] = batch_size_flow
|
158 |
+
|
159 |
+
results_dict['generated_wavs'] = generated_wavs
|
160 |
+
return results_dict, timing_stats
|
flashcosyvoice/cosyvoice3.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# TODO(xcsong): Implement CosyVoice3 when it is released
|
flashcosyvoice/engine/__init__.py
ADDED
File without changes
|
flashcosyvoice/engine/block_manager.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import xxhash
|
5 |
+
|
6 |
+
from flashcosyvoice.engine.sequence import Sequence
|
7 |
+
|
8 |
+
|
9 |
+
class Block:
|
10 |
+
|
11 |
+
def __init__(self, block_id):
|
12 |
+
self.block_id = block_id
|
13 |
+
self.ref_count = 0
|
14 |
+
self.hash = -1
|
15 |
+
self.token_ids = []
|
16 |
+
|
17 |
+
def update(self, hash: int, token_ids: list[int]):
|
18 |
+
self.hash = hash
|
19 |
+
self.token_ids = token_ids
|
20 |
+
|
21 |
+
def reset(self):
|
22 |
+
self.ref_count = 1
|
23 |
+
self.hash = -1
|
24 |
+
self.token_ids = []
|
25 |
+
|
26 |
+
|
27 |
+
class BlockManager:
|
28 |
+
|
29 |
+
def __init__(self, num_blocks: int, block_size: int):
|
30 |
+
assert num_blocks > 0
|
31 |
+
self.block_size = block_size
|
32 |
+
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
33 |
+
self.hash_to_block_id: dict[int, int] = dict()
|
34 |
+
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
35 |
+
self.used_block_ids: set[int] = set()
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
39 |
+
h = xxhash.xxh64()
|
40 |
+
if prefix != -1:
|
41 |
+
h.update(prefix.to_bytes(8, "little"))
|
42 |
+
h.update(np.array(token_ids).tobytes())
|
43 |
+
return h.intdigest()
|
44 |
+
|
45 |
+
def _allocate_block(self, block_id: int) -> Block:
|
46 |
+
block = self.blocks[block_id]
|
47 |
+
assert block.ref_count == 0
|
48 |
+
block.reset()
|
49 |
+
self.free_block_ids.remove(block_id)
|
50 |
+
self.used_block_ids.add(block_id)
|
51 |
+
return self.blocks[block_id]
|
52 |
+
|
53 |
+
def _deallocate_block(self, block_id: int) -> Block:
|
54 |
+
assert self.blocks[block_id].ref_count == 0
|
55 |
+
self.used_block_ids.remove(block_id)
|
56 |
+
self.free_block_ids.append(block_id)
|
57 |
+
|
58 |
+
def can_allocate(self, seq: Sequence) -> bool:
|
59 |
+
return len(self.free_block_ids) >= seq.num_blocks
|
60 |
+
|
61 |
+
def allocate(self, seq: Sequence):
|
62 |
+
assert not seq.block_table
|
63 |
+
h = -1
|
64 |
+
cache_miss = False
|
65 |
+
for i in range(seq.num_blocks):
|
66 |
+
token_ids = seq.block(i)
|
67 |
+
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
68 |
+
block_id = self.hash_to_block_id.get(h, -1)
|
69 |
+
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
70 |
+
cache_miss = True
|
71 |
+
if cache_miss:
|
72 |
+
block_id = self.free_block_ids[0]
|
73 |
+
block = self._allocate_block(block_id)
|
74 |
+
else:
|
75 |
+
seq.num_cached_tokens += self.block_size
|
76 |
+
if block_id in self.used_block_ids:
|
77 |
+
block = self.blocks[block_id]
|
78 |
+
block.ref_count += 1
|
79 |
+
else:
|
80 |
+
block = self._allocate_block(block_id)
|
81 |
+
if h != -1:
|
82 |
+
block.update(h, token_ids)
|
83 |
+
self.hash_to_block_id[h] = block_id
|
84 |
+
seq.block_table.append(block_id)
|
85 |
+
|
86 |
+
def deallocate(self, seq: Sequence):
|
87 |
+
for block_id in reversed(seq.block_table):
|
88 |
+
block = self.blocks[block_id]
|
89 |
+
block.ref_count -= 1
|
90 |
+
if block.ref_count == 0:
|
91 |
+
self._deallocate_block(block_id)
|
92 |
+
seq.num_cached_tokens = 0
|
93 |
+
seq.block_table.clear()
|
94 |
+
|
95 |
+
def can_append(self, seq: Sequence) -> bool:
|
96 |
+
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
97 |
+
|
98 |
+
def may_append(self, seq: Sequence):
|
99 |
+
block_table = seq.block_table
|
100 |
+
last_block = self.blocks[block_table[-1]]
|
101 |
+
if len(seq) % self.block_size == 1:
|
102 |
+
assert last_block.hash != -1
|
103 |
+
block_id = self.free_block_ids[0]
|
104 |
+
self._allocate_block(block_id)
|
105 |
+
block_table.append(block_id)
|
106 |
+
elif len(seq) % self.block_size == 0:
|
107 |
+
assert last_block.hash == -1
|
108 |
+
token_ids = seq.block(seq.num_blocks - 1)
|
109 |
+
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
110 |
+
h = self.compute_hash(token_ids, prefix)
|
111 |
+
last_block.update(h, token_ids)
|
112 |
+
self.hash_to_block_id[h] = last_block.block_id
|
113 |
+
else:
|
114 |
+
assert last_block.hash == -1
|
flashcosyvoice/engine/llm_engine.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import atexit
|
2 |
+
from dataclasses import fields
|
3 |
+
from time import perf_counter
|
4 |
+
|
5 |
+
import torch.multiprocessing as mp
|
6 |
+
from tqdm.auto import tqdm
|
7 |
+
from transformers import AutoTokenizer
|
8 |
+
|
9 |
+
from flashcosyvoice.config import Config, SamplingParams
|
10 |
+
from flashcosyvoice.engine.model_runner import ModelRunner
|
11 |
+
from flashcosyvoice.engine.scheduler import Scheduler
|
12 |
+
from flashcosyvoice.engine.sequence import Sequence
|
13 |
+
|
14 |
+
|
15 |
+
class LLMEngine:
|
16 |
+
|
17 |
+
def __init__(self, model, **kwargs):
|
18 |
+
config_fields = {field.name for field in fields(Config)}
|
19 |
+
config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields}
|
20 |
+
config = Config(model, **config_kwargs)
|
21 |
+
self.ps = []
|
22 |
+
self.events = []
|
23 |
+
ctx = mp.get_context("spawn")
|
24 |
+
assert config.tensor_parallel_size == 1, "NOTE(xcsong): Currently only support tp=1"
|
25 |
+
for i in range(1, config.tensor_parallel_size):
|
26 |
+
event = ctx.Event()
|
27 |
+
process = ctx.Process(target=ModelRunner, args=(config, i, event))
|
28 |
+
process.start()
|
29 |
+
self.ps.append(process)
|
30 |
+
self.events.append(event)
|
31 |
+
if hasattr(config.hf_config, "speech_vocab_size"):
|
32 |
+
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
33 |
+
special_tokens = {
|
34 |
+
'eos_token': '<|endoftext|>',
|
35 |
+
'pad_token': '<|endoftext|>',
|
36 |
+
'additional_special_tokens': [
|
37 |
+
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
38 |
+
'[breath]', '<strong>', '</strong>', '[noise]',
|
39 |
+
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
40 |
+
'[quick_breath]',
|
41 |
+
"<laughter>", "</laughter>",
|
42 |
+
"[hissing]", "[sigh]", "[vocalized-noise]",
|
43 |
+
"[lipsmack]", "[mn]"
|
44 |
+
]
|
45 |
+
}
|
46 |
+
self.tokenizer = AutoTokenizer.from_pretrained(f"{config.model}/CosyVoice-BlankEN")
|
47 |
+
self.tokenizer.add_special_tokens(special_tokens)
|
48 |
+
self.skip_special_tokens = True
|
49 |
+
else:
|
50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
51 |
+
if hasattr(config.hf_config, "eos_token_id"):
|
52 |
+
config.eos = config.hf_config.eos_token_id
|
53 |
+
else:
|
54 |
+
config.eos = self.tokenizer.eos_token_id
|
55 |
+
self.model_runner = ModelRunner(config, config.rank, self.events)
|
56 |
+
self.scheduler = Scheduler(config)
|
57 |
+
self.config = config
|
58 |
+
atexit.register(self.exit)
|
59 |
+
|
60 |
+
def exit(self):
|
61 |
+
self.model_runner.call("exit")
|
62 |
+
del self.model_runner
|
63 |
+
for p in self.ps:
|
64 |
+
p.join()
|
65 |
+
|
66 |
+
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
67 |
+
if isinstance(prompt, str):
|
68 |
+
prompt = self.tokenizer.encode(prompt)
|
69 |
+
seq = Sequence(prompt, sampling_params)
|
70 |
+
self.scheduler.add(seq)
|
71 |
+
|
72 |
+
def step(self):
|
73 |
+
seqs, is_prefill = self.scheduler.schedule()
|
74 |
+
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
75 |
+
self.scheduler.postprocess(seqs, token_ids)
|
76 |
+
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
77 |
+
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
78 |
+
return outputs, num_tokens
|
79 |
+
|
80 |
+
def is_finished(self):
|
81 |
+
return self.scheduler.is_finished()
|
82 |
+
|
83 |
+
def generate(
|
84 |
+
self,
|
85 |
+
prompts: list[str] | list[list[int]],
|
86 |
+
sampling_params: SamplingParams | list[SamplingParams],
|
87 |
+
use_tqdm: bool = True,
|
88 |
+
) -> list[str]:
|
89 |
+
if use_tqdm:
|
90 |
+
pbar = tqdm(total=len(prompts), desc="Generating tokens (LLM)", leave=False,
|
91 |
+
dynamic_ncols=True, position=self.config.rank + 1)
|
92 |
+
if not isinstance(sampling_params, list):
|
93 |
+
sampling_params = [sampling_params] * len(prompts)
|
94 |
+
for prompt, sp in zip(prompts, sampling_params):
|
95 |
+
self.add_request(prompt, sp)
|
96 |
+
outputs = {}
|
97 |
+
prefill_throughput = decode_throughput = instant_decode_throughput = 0.
|
98 |
+
total_decode_tokens = 0
|
99 |
+
total_decode_time = 0.
|
100 |
+
while not self.is_finished():
|
101 |
+
t = perf_counter()
|
102 |
+
output, num_tokens = self.step()
|
103 |
+
step_time = perf_counter() - t
|
104 |
+
if use_tqdm:
|
105 |
+
if num_tokens > 0:
|
106 |
+
prefill_throughput = num_tokens / step_time
|
107 |
+
else:
|
108 |
+
instant_decode_throughput = -num_tokens / step_time
|
109 |
+
total_decode_tokens += -num_tokens
|
110 |
+
total_decode_time += step_time
|
111 |
+
decode_throughput = total_decode_tokens / total_decode_time if total_decode_time > 0 else 0
|
112 |
+
pbar.set_postfix({
|
113 |
+
"Prefill": f"{int(prefill_throughput)}tok/s",
|
114 |
+
"AvgDecode": f"{int(decode_throughput)}tok/s",
|
115 |
+
"InstDecode": f"{int(instant_decode_throughput)}tok/s",
|
116 |
+
})
|
117 |
+
for seq_id, token_ids in output:
|
118 |
+
outputs[seq_id] = token_ids
|
119 |
+
if use_tqdm:
|
120 |
+
pbar.update(1)
|
121 |
+
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
122 |
+
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
123 |
+
if use_tqdm:
|
124 |
+
pbar.close()
|
125 |
+
return outputs
|
flashcosyvoice/engine/model_runner.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from multiprocessing.shared_memory import SharedMemory
|
3 |
+
from multiprocessing.synchronize import Event
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
from flashcosyvoice.config import Config
|
9 |
+
from flashcosyvoice.engine.sequence import Sequence
|
10 |
+
from flashcosyvoice.modules.qwen2 import Qwen2ForCausalLM
|
11 |
+
from flashcosyvoice.modules.sampler import RasSampler, Sampler
|
12 |
+
from flashcosyvoice.utils.context import (get_context, reset_context,
|
13 |
+
set_context)
|
14 |
+
from flashcosyvoice.utils.loader import load_model
|
15 |
+
|
16 |
+
|
17 |
+
class ModelRunner:
|
18 |
+
|
19 |
+
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
|
20 |
+
self.config = config
|
21 |
+
hf_config = config.hf_config
|
22 |
+
self.block_size = config.kvcache_block_size
|
23 |
+
self.enforce_eager = config.enforce_eager
|
24 |
+
self.world_size = config.tensor_parallel_size
|
25 |
+
self.rank = rank
|
26 |
+
self.event = event
|
27 |
+
|
28 |
+
# TODO(xcsong): support tp > 1
|
29 |
+
if self.world_size > 1:
|
30 |
+
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
|
31 |
+
torch.cuda.set_device(rank)
|
32 |
+
default_dtype = torch.get_default_dtype()
|
33 |
+
torch.set_default_dtype(hf_config.torch_dtype)
|
34 |
+
torch.set_default_device("cuda")
|
35 |
+
self.model = Qwen2ForCausalLM(hf_config)
|
36 |
+
load_model(self.model, config.model, hf_config)
|
37 |
+
self.sampler = Sampler()
|
38 |
+
self.ras_sampler = RasSampler()
|
39 |
+
self.warmup_model()
|
40 |
+
self.allocate_kv_cache()
|
41 |
+
if not self.enforce_eager:
|
42 |
+
self.capture_cudagraph()
|
43 |
+
torch.set_default_device("cpu")
|
44 |
+
torch.set_default_dtype(default_dtype)
|
45 |
+
|
46 |
+
if self.world_size > 1:
|
47 |
+
if rank == 0:
|
48 |
+
self.shm = SharedMemory(name="flashcosyvoice", create=True, size=2**20)
|
49 |
+
dist.barrier()
|
50 |
+
else:
|
51 |
+
dist.barrier()
|
52 |
+
self.shm = SharedMemory(name="flashcosyvoice")
|
53 |
+
self.loop()
|
54 |
+
|
55 |
+
def exit(self):
|
56 |
+
if self.world_size > 1:
|
57 |
+
self.shm.close()
|
58 |
+
dist.barrier()
|
59 |
+
if self.rank == 0:
|
60 |
+
self.shm.unlink()
|
61 |
+
if not self.enforce_eager:
|
62 |
+
del self.graphs, self.graph_pool
|
63 |
+
torch.cuda.synchronize()
|
64 |
+
if self.world_size > 1:
|
65 |
+
dist.destroy_process_group()
|
66 |
+
|
67 |
+
def loop(self):
|
68 |
+
while True:
|
69 |
+
method_name, args = self.read_shm()
|
70 |
+
self.call(method_name, *args)
|
71 |
+
if method_name == "exit":
|
72 |
+
break
|
73 |
+
|
74 |
+
def read_shm(self):
|
75 |
+
assert self.world_size > 1 and self.rank
|
76 |
+
self.event.wait()
|
77 |
+
n = int.from_bytes(self.shm.buf[0:4], "little")
|
78 |
+
method_name, *args = pickle.loads(self.shm.buf[4:n + 4])
|
79 |
+
self.event.clear()
|
80 |
+
return method_name, args
|
81 |
+
|
82 |
+
def write_shm(self, method_name, *args):
|
83 |
+
assert self.world_size > 1 and not self.rank
|
84 |
+
data = pickle.dumps([method_name, *args])
|
85 |
+
n = len(data)
|
86 |
+
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
87 |
+
self.shm.buf[4:n + 4] = data
|
88 |
+
for event in self.event:
|
89 |
+
event.set()
|
90 |
+
|
91 |
+
def call(self, method_name, *args):
|
92 |
+
if self.world_size > 1 and self.rank == 0:
|
93 |
+
self.write_shm(method_name, *args)
|
94 |
+
method = getattr(self, method_name, None)
|
95 |
+
return method(*args)
|
96 |
+
|
97 |
+
def warmup_model(self):
|
98 |
+
torch.cuda.empty_cache()
|
99 |
+
torch.cuda.reset_peak_memory_stats()
|
100 |
+
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
|
101 |
+
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
|
102 |
+
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
|
103 |
+
self.run(seqs, True)
|
104 |
+
torch.cuda.empty_cache()
|
105 |
+
|
106 |
+
def allocate_kv_cache(self):
|
107 |
+
config = self.config
|
108 |
+
hf_config = config.hf_config
|
109 |
+
free, total = torch.cuda.mem_get_info()
|
110 |
+
used = total - free
|
111 |
+
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
|
112 |
+
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
|
113 |
+
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
114 |
+
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
115 |
+
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
116 |
+
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
117 |
+
assert config.num_kvcache_blocks > 0, "try to **increase** gpu_memory_utilization"
|
118 |
+
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
119 |
+
layer_id = 0
|
120 |
+
for module in self.model.modules():
|
121 |
+
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
122 |
+
module.k_cache = self.kv_cache[0, layer_id]
|
123 |
+
module.v_cache = self.kv_cache[1, layer_id]
|
124 |
+
layer_id += 1
|
125 |
+
|
126 |
+
def prepare_block_tables(self, seqs: list[Sequence]):
|
127 |
+
max_len = max(len(seq.block_table) for seq in seqs)
|
128 |
+
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
|
129 |
+
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
130 |
+
return block_tables
|
131 |
+
|
132 |
+
def prepare_prefill(self, seqs: list[Sequence]):
|
133 |
+
input_ids = []
|
134 |
+
positions = []
|
135 |
+
cu_seqlens_q = [0]
|
136 |
+
cu_seqlens_k = [0]
|
137 |
+
max_seqlen_q = 0
|
138 |
+
max_seqlen_k = 0
|
139 |
+
slot_mapping = []
|
140 |
+
block_tables = None
|
141 |
+
for seq in seqs:
|
142 |
+
seqlen = len(seq)
|
143 |
+
input_ids.extend(seq[seq.num_cached_tokens:])
|
144 |
+
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
145 |
+
seqlen_q = seqlen - seq.num_cached_tokens
|
146 |
+
seqlen_k = seqlen
|
147 |
+
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
148 |
+
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
149 |
+
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
150 |
+
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
151 |
+
if not seq.block_table:
|
152 |
+
continue
|
153 |
+
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
154 |
+
start = seq.block_table[i] * self.block_size
|
155 |
+
if i != seq.num_blocks - 1:
|
156 |
+
end = start + self.block_size
|
157 |
+
else:
|
158 |
+
end = start + seq.last_block_num_tokens
|
159 |
+
slot_mapping.extend(list(range(start, end)))
|
160 |
+
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
161 |
+
block_tables = self.prepare_block_tables(seqs)
|
162 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
163 |
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
164 |
+
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
165 |
+
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
166 |
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
167 |
+
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
168 |
+
return input_ids, positions
|
169 |
+
|
170 |
+
def prepare_decode(self, seqs: list[Sequence]):
|
171 |
+
input_ids = []
|
172 |
+
positions = []
|
173 |
+
slot_mapping = []
|
174 |
+
context_lens = []
|
175 |
+
for seq in seqs:
|
176 |
+
input_ids.append(seq.last_token)
|
177 |
+
positions.append(len(seq))
|
178 |
+
context_lens.append(len(seq))
|
179 |
+
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
180 |
+
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
181 |
+
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
182 |
+
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
183 |
+
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
184 |
+
block_tables = self.prepare_block_tables(seqs)
|
185 |
+
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
186 |
+
return input_ids, positions
|
187 |
+
|
188 |
+
def prepare_sample(self, seqs: list[Sequence]):
|
189 |
+
temperatures = []
|
190 |
+
top_ks = []
|
191 |
+
win_sizes = []
|
192 |
+
tau_rs = []
|
193 |
+
top_ps = []
|
194 |
+
min_tokens_list = []
|
195 |
+
use_ras_list = []
|
196 |
+
|
197 |
+
for seq in seqs:
|
198 |
+
temperatures.append(seq.temperature)
|
199 |
+
top_ks.append(seq.top_k)
|
200 |
+
win_sizes.append(seq.win_size)
|
201 |
+
tau_rs.append(seq.tau_r)
|
202 |
+
top_ps.append(seq.top_p)
|
203 |
+
min_tokens_list.append(seq.min_tokens)
|
204 |
+
use_ras_list.append(seq.use_ras)
|
205 |
+
|
206 |
+
temperatures_tensor = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
|
207 |
+
# check all items equal
|
208 |
+
assert all(item == temperatures[0] for item in temperatures)
|
209 |
+
assert all(item == top_ks[0] for item in top_ks)
|
210 |
+
assert all(item == win_sizes[0] for item in win_sizes)
|
211 |
+
assert all(item == tau_rs[0] for item in tau_rs)
|
212 |
+
assert all(item == top_ps[0] for item in top_ps)
|
213 |
+
assert all(item == use_ras_list[0] for item in use_ras_list)
|
214 |
+
|
215 |
+
return {
|
216 |
+
'temperatures': temperatures_tensor,
|
217 |
+
'top_k': top_ks[0],
|
218 |
+
'win_size': win_sizes[0],
|
219 |
+
'tau_r': tau_rs[0],
|
220 |
+
'top_p': top_ps[0],
|
221 |
+
'eos_token': self.config.eos,
|
222 |
+
'min_tokens': min_tokens_list,
|
223 |
+
'use_ras': use_ras_list[0]
|
224 |
+
}
|
225 |
+
|
226 |
+
@torch.inference_mode()
|
227 |
+
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
|
228 |
+
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
|
229 |
+
return self.model.compute_logits(self.model(input_ids, positions))
|
230 |
+
else:
|
231 |
+
bs = input_ids.size(0)
|
232 |
+
context = get_context()
|
233 |
+
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
234 |
+
graph_vars = self.graph_vars
|
235 |
+
for k, v in graph_vars.items():
|
236 |
+
if k != "outputs":
|
237 |
+
v.zero_()
|
238 |
+
graph_vars["input_ids"][:bs] = input_ids
|
239 |
+
graph_vars["positions"][:bs] = positions
|
240 |
+
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
241 |
+
graph_vars["context_lens"][:bs] = context.context_lens
|
242 |
+
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
243 |
+
graph.replay()
|
244 |
+
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
245 |
+
|
246 |
+
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
247 |
+
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
248 |
+
if self.rank == 0 or self.world_size == 1:
|
249 |
+
sample_params = self.prepare_sample(seqs)
|
250 |
+
logits = self.run_model(input_ids, positions, is_prefill)
|
251 |
+
|
252 |
+
if sample_params['use_ras']:
|
253 |
+
# Prepare decoded tokens list for RasSampler
|
254 |
+
decoded_tokens_list = [seq.completion_token_ids for seq in seqs]
|
255 |
+
# Pass all parameters as lists to RasSampler
|
256 |
+
token_ids = self.ras_sampler(
|
257 |
+
logits,
|
258 |
+
decoded_tokens_list,
|
259 |
+
win_size=sample_params['win_size'],
|
260 |
+
tau_r=sample_params['tau_r'],
|
261 |
+
top_p=sample_params['top_p'],
|
262 |
+
top_k=sample_params['top_k'],
|
263 |
+
eos_token=sample_params['eos_token'],
|
264 |
+
min_tokens=sample_params['min_tokens']
|
265 |
+
).tolist()
|
266 |
+
else:
|
267 |
+
# Use the default sampler with list form of top_ks
|
268 |
+
token_ids = self.sampler(logits, sample_params['temperatures'], sample_params['top_k']).tolist()
|
269 |
+
else:
|
270 |
+
logits = self.run_model(input_ids, positions, is_prefill)
|
271 |
+
token_ids = None
|
272 |
+
reset_context()
|
273 |
+
return token_ids
|
274 |
+
|
275 |
+
@torch.inference_mode()
|
276 |
+
def capture_cudagraph(self):
|
277 |
+
config = self.config
|
278 |
+
hf_config = config.hf_config
|
279 |
+
max_bs = min(self.config.max_num_seqs, 512)
|
280 |
+
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
281 |
+
input_ids = torch.zeros(max_bs, dtype=torch.int64)
|
282 |
+
positions = torch.zeros(max_bs, dtype=torch.int64)
|
283 |
+
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
|
284 |
+
context_lens = torch.zeros(max_bs, dtype=torch.int32)
|
285 |
+
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
|
286 |
+
outputs = torch.zeros(max_bs, hf_config.hidden_size)
|
287 |
+
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
|
288 |
+
self.graphs = {}
|
289 |
+
self.graph_pool = None
|
290 |
+
|
291 |
+
for bs in reversed(self.graph_bs):
|
292 |
+
graph = torch.cuda.CUDAGraph()
|
293 |
+
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
294 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
295 |
+
with torch.cuda.graph(graph, self.graph_pool):
|
296 |
+
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
297 |
+
if self.graph_pool is None:
|
298 |
+
self.graph_pool = graph.pool()
|
299 |
+
self.graphs[bs] = graph
|
300 |
+
torch.cuda.synchronize()
|
301 |
+
reset_context()
|
302 |
+
|
303 |
+
self.graph_vars = dict(
|
304 |
+
input_ids=input_ids,
|
305 |
+
positions=positions,
|
306 |
+
slot_mapping=slot_mapping,
|
307 |
+
context_lens=context_lens,
|
308 |
+
block_tables=block_tables,
|
309 |
+
outputs=outputs,
|
310 |
+
)
|
flashcosyvoice/engine/scheduler.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
|
3 |
+
from flashcosyvoice.config import Config
|
4 |
+
from flashcosyvoice.engine.block_manager import BlockManager
|
5 |
+
from flashcosyvoice.engine.sequence import Sequence, SequenceStatus
|
6 |
+
|
7 |
+
|
8 |
+
class Scheduler:
|
9 |
+
|
10 |
+
def __init__(self, config: Config):
|
11 |
+
self.max_num_seqs = config.max_num_seqs
|
12 |
+
self.max_num_batched_tokens = config.max_num_batched_tokens
|
13 |
+
self.eos = config.eos
|
14 |
+
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
15 |
+
self.waiting: deque[Sequence] = deque()
|
16 |
+
self.running: deque[Sequence] = deque()
|
17 |
+
|
18 |
+
def is_finished(self):
|
19 |
+
return not self.waiting and not self.running
|
20 |
+
|
21 |
+
def add(self, seq: Sequence):
|
22 |
+
self.waiting.append(seq)
|
23 |
+
|
24 |
+
def schedule(self) -> tuple[list[Sequence], bool]:
|
25 |
+
# prefill
|
26 |
+
scheduled_seqs = []
|
27 |
+
num_seqs = 0
|
28 |
+
num_batched_tokens = 0
|
29 |
+
while self.waiting and num_seqs < self.max_num_seqs:
|
30 |
+
seq = self.waiting[0]
|
31 |
+
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
32 |
+
break
|
33 |
+
num_seqs += 1
|
34 |
+
self.block_manager.allocate(seq)
|
35 |
+
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
36 |
+
seq.status = SequenceStatus.RUNNING
|
37 |
+
self.waiting.popleft()
|
38 |
+
self.running.append(seq)
|
39 |
+
scheduled_seqs.append(seq)
|
40 |
+
if scheduled_seqs:
|
41 |
+
return scheduled_seqs, True
|
42 |
+
|
43 |
+
# decode
|
44 |
+
while self.running and num_seqs < self.max_num_seqs:
|
45 |
+
seq = self.running.popleft()
|
46 |
+
while not self.block_manager.can_append(seq):
|
47 |
+
if self.running:
|
48 |
+
self.preempt(self.running.pop())
|
49 |
+
else:
|
50 |
+
self.preempt(seq)
|
51 |
+
break
|
52 |
+
else:
|
53 |
+
num_seqs += 1
|
54 |
+
self.block_manager.may_append(seq)
|
55 |
+
scheduled_seqs.append(seq)
|
56 |
+
assert scheduled_seqs
|
57 |
+
self.running.extendleft(reversed(scheduled_seqs))
|
58 |
+
return scheduled_seqs, False
|
59 |
+
|
60 |
+
def preempt(self, seq: Sequence):
|
61 |
+
seq.status = SequenceStatus.WAITING
|
62 |
+
self.block_manager.deallocate(seq)
|
63 |
+
self.waiting.appendleft(seq)
|
64 |
+
|
65 |
+
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
66 |
+
for seq, token_id in zip(seqs, token_ids):
|
67 |
+
seq.append_token(token_id)
|
68 |
+
# Check if the sequence has reached the maximum number of tokens
|
69 |
+
reached_max_tokens = seq.num_completion_tokens == seq.max_tokens
|
70 |
+
# Check if the sequence has reached EOS and has generated enough tokens (satisfying min_tokens requirements)
|
71 |
+
eos_with_min_tokens = (not seq.ignore_eos and token_id == self.eos and
|
72 |
+
seq.num_completion_tokens >= seq.min_tokens)
|
73 |
+
|
74 |
+
if reached_max_tokens or eos_with_min_tokens:
|
75 |
+
seq.status = SequenceStatus.FINISHED
|
76 |
+
self.block_manager.deallocate(seq)
|
77 |
+
self.running.remove(seq)
|
flashcosyvoice/engine/sequence.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import copy
|
2 |
+
from enum import Enum, auto
|
3 |
+
from itertools import count
|
4 |
+
|
5 |
+
from flashcosyvoice.config import SamplingParams
|
6 |
+
|
7 |
+
|
8 |
+
class SequenceStatus(Enum):
|
9 |
+
WAITING = auto()
|
10 |
+
RUNNING = auto()
|
11 |
+
FINISHED = auto()
|
12 |
+
|
13 |
+
|
14 |
+
class Sequence:
|
15 |
+
block_size = 256
|
16 |
+
counter = count()
|
17 |
+
|
18 |
+
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
19 |
+
self.seq_id = next(Sequence.counter)
|
20 |
+
self.status = SequenceStatus.WAITING
|
21 |
+
self.token_ids = copy(token_ids)
|
22 |
+
self.last_token = token_ids[-1]
|
23 |
+
self.num_tokens = len(self.token_ids)
|
24 |
+
self.num_prompt_tokens = len(token_ids)
|
25 |
+
self.num_cached_tokens = 0
|
26 |
+
self.block_table = []
|
27 |
+
self.temperature = sampling_params.temperature
|
28 |
+
self.min_tokens = sampling_params.min_tokens
|
29 |
+
self.max_tokens = sampling_params.max_tokens
|
30 |
+
self.ignore_eos = sampling_params.ignore_eos
|
31 |
+
self.top_k = sampling_params.top_k
|
32 |
+
# RasSampler parameters
|
33 |
+
self.use_ras = sampling_params.use_ras
|
34 |
+
self.win_size = sampling_params.win_size
|
35 |
+
self.tau_r = sampling_params.tau_r
|
36 |
+
self.top_p = sampling_params.top_p
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return self.num_tokens
|
40 |
+
|
41 |
+
def __getitem__(self, key):
|
42 |
+
return self.token_ids[key]
|
43 |
+
|
44 |
+
@property
|
45 |
+
def is_finished(self):
|
46 |
+
return self.status == SequenceStatus.FINISHED
|
47 |
+
|
48 |
+
@property
|
49 |
+
def num_completion_tokens(self):
|
50 |
+
return self.num_tokens - self.num_prompt_tokens
|
51 |
+
|
52 |
+
@property
|
53 |
+
def prompt_token_ids(self):
|
54 |
+
return self.token_ids[:self.num_prompt_tokens]
|
55 |
+
|
56 |
+
@property
|
57 |
+
def completion_token_ids(self):
|
58 |
+
return self.token_ids[self.num_prompt_tokens:]
|
59 |
+
|
60 |
+
@property
|
61 |
+
def num_cached_blocks(self):
|
62 |
+
return self.num_cached_tokens // self.block_size
|
63 |
+
|
64 |
+
@property
|
65 |
+
def num_blocks(self):
|
66 |
+
return (self.num_tokens + self.block_size - 1) // self.block_size
|
67 |
+
|
68 |
+
@property
|
69 |
+
def last_block_num_tokens(self):
|
70 |
+
return self.num_tokens - (self.num_blocks - 1) * self.block_size
|
71 |
+
|
72 |
+
def block(self, i):
|
73 |
+
assert 0 <= i < self.num_blocks
|
74 |
+
return self.token_ids[i*self.block_size: (i+1)*self.block_size]
|
75 |
+
|
76 |
+
def append_token(self, token_id: int):
|
77 |
+
self.token_ids.append(token_id)
|
78 |
+
self.last_token = token_id
|
79 |
+
self.num_tokens += 1
|
80 |
+
|
81 |
+
def __getstate__(self):
|
82 |
+
return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table,
|
83 |
+
self.token_ids if self.num_completion_tokens == 0 else self.last_token)
|
84 |
+
|
85 |
+
def __setstate__(self, state):
|
86 |
+
self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1]
|
87 |
+
if self.num_completion_tokens == 0:
|
88 |
+
self.token_ids = state[-1]
|
89 |
+
else:
|
90 |
+
self.last_token = state[-1]
|
flashcosyvoice/modules/__init__.py
ADDED
File without changes
|
flashcosyvoice/modules/flow.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from flashcosyvoice.modules.flow_components.estimator import \
|
8 |
+
CausalConditionalDecoder
|
9 |
+
from flashcosyvoice.modules.flow_components.upsample_encoder import (
|
10 |
+
UpsampleConformerEncoder, make_pad_mask)
|
11 |
+
|
12 |
+
|
13 |
+
# TODO(xcsong): make it configurable
|
14 |
+
@dataclass
|
15 |
+
class CfmParams:
|
16 |
+
sigma_min: float = 1e-6
|
17 |
+
solver: str = "euler"
|
18 |
+
t_scheduler: str = "cosine"
|
19 |
+
training_cfg_rate: float = 0.2
|
20 |
+
inference_cfg_rate: float = 0.7
|
21 |
+
|
22 |
+
|
23 |
+
class CausalConditionalCFM(torch.nn.Module):
|
24 |
+
def __init__(self, in_channels=320, cfm_params=CfmParams(), n_spks=1, spk_emb_dim=80, estimator: torch.nn.Module = None):
|
25 |
+
super().__init__()
|
26 |
+
self.n_feats = in_channels
|
27 |
+
self.n_spks = n_spks
|
28 |
+
self.spk_emb_dim = spk_emb_dim
|
29 |
+
self.solver = cfm_params.solver
|
30 |
+
if hasattr(cfm_params, "sigma_min"):
|
31 |
+
self.sigma_min = cfm_params.sigma_min
|
32 |
+
else:
|
33 |
+
self.sigma_min = 1e-4
|
34 |
+
self.t_scheduler = cfm_params.t_scheduler
|
35 |
+
self.training_cfg_rate = cfm_params.training_cfg_rate
|
36 |
+
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
37 |
+
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
38 |
+
# Just change the architecture of the estimator here
|
39 |
+
self.estimator = CausalConditionalDecoder() if estimator is None else estimator
|
40 |
+
|
41 |
+
@torch.inference_mode()
|
42 |
+
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
43 |
+
"""Forward diffusion
|
44 |
+
|
45 |
+
Args:
|
46 |
+
mu (torch.Tensor): output of encoder
|
47 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
48 |
+
mask (torch.Tensor): output_mask
|
49 |
+
shape: (batch_size, 1, mel_timesteps)
|
50 |
+
n_timesteps (int): number of diffusion steps
|
51 |
+
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
52 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
53 |
+
shape: (batch_size, spk_emb_dim)
|
54 |
+
cond: Not used but kept for future purposes
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
sample: generated mel-spectrogram
|
58 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
59 |
+
"""
|
60 |
+
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
61 |
+
# fix prompt and overlap part mu and z
|
62 |
+
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
63 |
+
if self.t_scheduler == 'cosine':
|
64 |
+
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
65 |
+
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
66 |
+
|
67 |
+
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
68 |
+
"""
|
69 |
+
Fixed euler solver for ODEs.
|
70 |
+
Args:
|
71 |
+
x (torch.Tensor): random noise
|
72 |
+
t_span (torch.Tensor): n_timesteps interpolated
|
73 |
+
shape: (n_timesteps + 1,)
|
74 |
+
mu (torch.Tensor): output of encoder
|
75 |
+
shape: (batch_size, n_feats, mel_timesteps)
|
76 |
+
mask (torch.Tensor): output_mask
|
77 |
+
shape: (batch_size, 1, mel_timesteps)
|
78 |
+
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
79 |
+
shape: (batch_size, spk_emb_dim)
|
80 |
+
cond: Not used but kept for future purposes
|
81 |
+
"""
|
82 |
+
batch_size = x.size(0)
|
83 |
+
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
84 |
+
|
85 |
+
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
86 |
+
# Or in future might add like a return_all_steps flag
|
87 |
+
sol = []
|
88 |
+
|
89 |
+
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
90 |
+
# Create tensors with double batch size for CFG (conditional + unconditional)
|
91 |
+
x_in = torch.zeros([batch_size * 2, x.size(1), x.size(2)], device=x.device, dtype=x.dtype)
|
92 |
+
mask_in = torch.zeros([batch_size * 2, mask.size(1), mask.size(2)], device=x.device, dtype=x.dtype)
|
93 |
+
mu_in = torch.zeros([batch_size * 2, mu.size(1), mu.size(2)], device=x.device, dtype=x.dtype)
|
94 |
+
t_in = torch.zeros([batch_size * 2], device=x.device, dtype=x.dtype)
|
95 |
+
spks_in = torch.zeros([batch_size * 2, spks.size(1)], device=x.device, dtype=x.dtype)
|
96 |
+
cond_in = torch.zeros([batch_size * 2, cond.size(1), cond.size(2)], device=x.device, dtype=x.dtype)
|
97 |
+
|
98 |
+
for step in range(1, len(t_span)):
|
99 |
+
# Classifier-Free Guidance inference introduced in VoiceBox
|
100 |
+
# Copy conditional and unconditional input
|
101 |
+
x_in[:batch_size] = x
|
102 |
+
x_in[batch_size:] = x
|
103 |
+
mask_in[:batch_size] = mask
|
104 |
+
mask_in[batch_size:] = mask
|
105 |
+
mu_in[:batch_size] = mu
|
106 |
+
# Unconditional part remains 0
|
107 |
+
t_in.fill_(t)
|
108 |
+
spks_in[:batch_size] = spks
|
109 |
+
cond_in[:batch_size] = cond
|
110 |
+
|
111 |
+
dphi_dt = self.estimator(
|
112 |
+
x_in, mask_in,
|
113 |
+
mu_in, t_in,
|
114 |
+
spks_in,
|
115 |
+
cond_in,
|
116 |
+
streaming
|
117 |
+
)
|
118 |
+
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [batch_size, batch_size], dim=0)
|
119 |
+
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
120 |
+
x = x + dt * dphi_dt
|
121 |
+
t = t + dt
|
122 |
+
sol.append(x)
|
123 |
+
if step < len(t_span) - 1:
|
124 |
+
dt = t_span[step + 1] - t
|
125 |
+
|
126 |
+
return sol[-1].float()
|
127 |
+
|
128 |
+
|
129 |
+
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
input_size: int = 512,
|
133 |
+
output_size: int = 80,
|
134 |
+
spk_embed_dim: int = 192,
|
135 |
+
output_type: str = "mel",
|
136 |
+
vocab_size: int = 6561,
|
137 |
+
input_frame_rate: int = 25,
|
138 |
+
token_mel_ratio: int = 2,
|
139 |
+
pre_lookahead_len: int = 3,
|
140 |
+
encoder: torch.nn.Module = None,
|
141 |
+
decoder: torch.nn.Module = None,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.input_size = input_size
|
145 |
+
self.output_size = output_size
|
146 |
+
self.vocab_size = vocab_size
|
147 |
+
self.output_type = output_type
|
148 |
+
self.input_frame_rate = input_frame_rate
|
149 |
+
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
150 |
+
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
151 |
+
self.encoder = UpsampleConformerEncoder() if encoder is None else encoder
|
152 |
+
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
153 |
+
self.decoder = CausalConditionalCFM() if decoder is None else decoder
|
154 |
+
self.token_mel_ratio = token_mel_ratio
|
155 |
+
self.pre_lookahead_len = pre_lookahead_len
|
156 |
+
|
157 |
+
@torch.inference_mode()
|
158 |
+
def forward(self,
|
159 |
+
token,
|
160 |
+
token_len,
|
161 |
+
prompt_feat,
|
162 |
+
prompt_feat_len,
|
163 |
+
embedding,
|
164 |
+
streaming,
|
165 |
+
finalize):
|
166 |
+
# xvec projection
|
167 |
+
embedding = F.normalize(embedding, dim=1)
|
168 |
+
embedding = self.spk_embed_affine_layer(embedding)
|
169 |
+
|
170 |
+
# concat text and prompt_text
|
171 |
+
mask = (~make_pad_mask(token_len, max_len=token.shape[1])).unsqueeze(-1).to(embedding)
|
172 |
+
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
173 |
+
|
174 |
+
# text encode
|
175 |
+
if finalize is True:
|
176 |
+
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
177 |
+
else:
|
178 |
+
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
179 |
+
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
180 |
+
h = self.encoder_proj(h)
|
181 |
+
|
182 |
+
# get conditions
|
183 |
+
conds = torch.zeros_like(h, device=token.device)
|
184 |
+
for i, j in enumerate(prompt_feat_len):
|
185 |
+
conds[i, :j] = prompt_feat[i, :j]
|
186 |
+
conds = conds.transpose(1, 2)
|
187 |
+
|
188 |
+
h_lengths = h_lengths.sum(dim=-1).squeeze(dim=1)
|
189 |
+
mask = (~make_pad_mask(h_lengths, max_len=h.shape[1])).to(h)
|
190 |
+
feat, _ = self.decoder(
|
191 |
+
mu=h.transpose(1, 2).contiguous(),
|
192 |
+
mask=mask.unsqueeze(1),
|
193 |
+
spks=embedding,
|
194 |
+
cond=conds,
|
195 |
+
n_timesteps=10,
|
196 |
+
streaming=streaming
|
197 |
+
) # [B, num_mels, T]
|
198 |
+
return feat.float(), h_lengths
|
flashcosyvoice/modules/flow_components/__init__.py
ADDED
File without changes
|
flashcosyvoice/modules/flow_components/estimator.py
ADDED
@@ -0,0 +1,974 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Dict, Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from diffusers.models.attention import (GEGLU, GELU, AdaLayerNorm,
|
8 |
+
AdaLayerNormZero, ApproximateGELU)
|
9 |
+
from diffusers.models.attention_processor import Attention
|
10 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
11 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
12 |
+
from einops import pack, rearrange, repeat
|
13 |
+
|
14 |
+
from flashcosyvoice.modules.flow_components.upsample_encoder import \
|
15 |
+
add_optional_chunk_mask
|
16 |
+
|
17 |
+
|
18 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
19 |
+
assert mask.dtype == torch.bool
|
20 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
21 |
+
mask = mask.to(dtype)
|
22 |
+
# attention mask bias
|
23 |
+
# NOTE(Mddct): torch.finfo jit issues
|
24 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
25 |
+
mask = (1.0 - mask) * -1.0e+10
|
26 |
+
return mask
|
27 |
+
|
28 |
+
|
29 |
+
class SnakeBeta(nn.Module):
|
30 |
+
"""
|
31 |
+
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
32 |
+
Shape:
|
33 |
+
- Input: (B, C, T)
|
34 |
+
- Output: (B, C, T), same shape as the input
|
35 |
+
Parameters:
|
36 |
+
- alpha - trainable parameter that controls frequency
|
37 |
+
- beta - trainable parameter that controls magnitude
|
38 |
+
References:
|
39 |
+
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
40 |
+
https://arxiv.org/abs/2006.08195
|
41 |
+
Examples:
|
42 |
+
>>> a1 = snakebeta(256)
|
43 |
+
>>> x = torch.randn(256)
|
44 |
+
>>> x = a1(x)
|
45 |
+
|
46 |
+
Args:
|
47 |
+
in_features: shape of the input
|
48 |
+
out_features: shape of the output
|
49 |
+
alpha: trainable parameter that controls frequency
|
50 |
+
alpha_trainable: whether alpha is trainable
|
51 |
+
alpha_logscale: whether to use log scale for alpha
|
52 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
53 |
+
beta is initialized to 1 by default, higher values = higher-magnitude.
|
54 |
+
alpha will be trained along with the rest of your model.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self, in_features, out_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
|
58 |
+
super().__init__()
|
59 |
+
self.in_features = out_features if isinstance(out_features, list) else [out_features]
|
60 |
+
self.proj = LoRACompatibleLinear(in_features, out_features)
|
61 |
+
|
62 |
+
# initialize alpha
|
63 |
+
self.alpha_logscale = alpha_logscale
|
64 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
65 |
+
self.alpha = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
66 |
+
self.beta = nn.Parameter(torch.zeros(self.in_features) * alpha)
|
67 |
+
else: # linear scale alphas initialized to ones
|
68 |
+
self.alpha = nn.Parameter(torch.ones(self.in_features) * alpha)
|
69 |
+
self.beta = nn.Parameter(torch.ones(self.in_features) * alpha)
|
70 |
+
|
71 |
+
self.alpha.requires_grad = alpha_trainable
|
72 |
+
self.beta.requires_grad = alpha_trainable
|
73 |
+
|
74 |
+
self.no_div_by_zero = 0.000000001
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
"""
|
78 |
+
Forward pass of the function.
|
79 |
+
Applies the function to the input elementwise.
|
80 |
+
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
81 |
+
"""
|
82 |
+
x = self.proj(x)
|
83 |
+
if self.alpha_logscale:
|
84 |
+
alpha = torch.exp(self.alpha)
|
85 |
+
beta = torch.exp(self.beta)
|
86 |
+
else:
|
87 |
+
alpha = self.alpha
|
88 |
+
beta = self.beta
|
89 |
+
|
90 |
+
x = x + (1.0 / (beta + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
91 |
+
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class FeedForward(nn.Module):
|
96 |
+
r"""
|
97 |
+
A feed-forward layer.
|
98 |
+
|
99 |
+
Parameters:
|
100 |
+
dim (`int`): The number of channels in the input.
|
101 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
102 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
103 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
104 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
105 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
106 |
+
"""
|
107 |
+
|
108 |
+
def __init__(
|
109 |
+
self,
|
110 |
+
dim: int,
|
111 |
+
dim_out: Optional[int] = None,
|
112 |
+
mult: int = 4,
|
113 |
+
dropout: float = 0.0,
|
114 |
+
activation_fn: str = "geglu",
|
115 |
+
final_dropout: bool = False,
|
116 |
+
):
|
117 |
+
super().__init__()
|
118 |
+
inner_dim = int(dim * mult)
|
119 |
+
dim_out = dim_out if dim_out is not None else dim
|
120 |
+
|
121 |
+
if activation_fn == "gelu":
|
122 |
+
act_fn = GELU(dim, inner_dim)
|
123 |
+
if activation_fn == "gelu-approximate":
|
124 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
125 |
+
elif activation_fn == "geglu":
|
126 |
+
act_fn = GEGLU(dim, inner_dim)
|
127 |
+
elif activation_fn == "geglu-approximate":
|
128 |
+
act_fn = ApproximateGELU(dim, inner_dim)
|
129 |
+
elif activation_fn == "snakebeta":
|
130 |
+
act_fn = SnakeBeta(dim, inner_dim)
|
131 |
+
|
132 |
+
self.net = nn.ModuleList([])
|
133 |
+
# project in
|
134 |
+
self.net.append(act_fn)
|
135 |
+
# project dropout
|
136 |
+
self.net.append(nn.Dropout(dropout))
|
137 |
+
# project out
|
138 |
+
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
139 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
140 |
+
if final_dropout:
|
141 |
+
self.net.append(nn.Dropout(dropout))
|
142 |
+
|
143 |
+
def forward(self, hidden_states):
|
144 |
+
for module in self.net:
|
145 |
+
hidden_states = module(hidden_states)
|
146 |
+
return hidden_states
|
147 |
+
|
148 |
+
|
149 |
+
@maybe_allow_in_graph
|
150 |
+
class BasicTransformerBlock(nn.Module):
|
151 |
+
r"""
|
152 |
+
A basic Transformer block.
|
153 |
+
|
154 |
+
Parameters:
|
155 |
+
dim (`int`): The number of channels in the input and output.
|
156 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
157 |
+
attention_head_dim (`int`): The number of channels in each head.
|
158 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
159 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
160 |
+
only_cross_attention (`bool`, *optional*):
|
161 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
162 |
+
double_self_attention (`bool`, *optional*):
|
163 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
164 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
165 |
+
num_embeds_ada_norm (:
|
166 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
167 |
+
attention_bias (:
|
168 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
169 |
+
"""
|
170 |
+
|
171 |
+
def __init__(
|
172 |
+
self,
|
173 |
+
dim: int,
|
174 |
+
num_attention_heads: int,
|
175 |
+
attention_head_dim: int,
|
176 |
+
dropout=0.0,
|
177 |
+
cross_attention_dim: Optional[int] = None,
|
178 |
+
activation_fn: str = "geglu",
|
179 |
+
num_embeds_ada_norm: Optional[int] = None,
|
180 |
+
attention_bias: bool = False,
|
181 |
+
only_cross_attention: bool = False,
|
182 |
+
double_self_attention: bool = False,
|
183 |
+
upcast_attention: bool = False,
|
184 |
+
norm_elementwise_affine: bool = True,
|
185 |
+
norm_type: str = "layer_norm",
|
186 |
+
final_dropout: bool = False,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
self.only_cross_attention = only_cross_attention
|
190 |
+
|
191 |
+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
192 |
+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
193 |
+
|
194 |
+
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
195 |
+
raise ValueError(
|
196 |
+
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
197 |
+
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
198 |
+
)
|
199 |
+
|
200 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
201 |
+
# 1. Self-Attn
|
202 |
+
if self.use_ada_layer_norm:
|
203 |
+
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
204 |
+
elif self.use_ada_layer_norm_zero:
|
205 |
+
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
206 |
+
else:
|
207 |
+
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
208 |
+
self.attn1 = Attention(
|
209 |
+
query_dim=dim,
|
210 |
+
heads=num_attention_heads,
|
211 |
+
dim_head=attention_head_dim,
|
212 |
+
dropout=dropout,
|
213 |
+
bias=attention_bias,
|
214 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
215 |
+
upcast_attention=upcast_attention,
|
216 |
+
)
|
217 |
+
|
218 |
+
# 2. Cross-Attn
|
219 |
+
if cross_attention_dim is not None or double_self_attention:
|
220 |
+
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
221 |
+
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
222 |
+
# the second cross attention block.
|
223 |
+
self.norm2 = (
|
224 |
+
AdaLayerNorm(dim, num_embeds_ada_norm)
|
225 |
+
if self.use_ada_layer_norm
|
226 |
+
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
227 |
+
)
|
228 |
+
self.attn2 = Attention(
|
229 |
+
query_dim=dim,
|
230 |
+
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
|
231 |
+
heads=num_attention_heads,
|
232 |
+
dim_head=attention_head_dim,
|
233 |
+
dropout=dropout,
|
234 |
+
bias=attention_bias,
|
235 |
+
upcast_attention=upcast_attention,
|
236 |
+
# scale_qk=False, # uncomment this to not to use flash attention
|
237 |
+
) # is self-attn if encoder_hidden_states is none
|
238 |
+
else:
|
239 |
+
self.norm2 = None
|
240 |
+
self.attn2 = None
|
241 |
+
|
242 |
+
# 3. Feed-forward
|
243 |
+
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
244 |
+
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
245 |
+
|
246 |
+
# let chunk size default to None
|
247 |
+
self._chunk_size = None
|
248 |
+
self._chunk_dim = 0
|
249 |
+
|
250 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
|
251 |
+
# Sets chunk feed-forward
|
252 |
+
self._chunk_size = chunk_size
|
253 |
+
self._chunk_dim = dim
|
254 |
+
|
255 |
+
def forward(
|
256 |
+
self,
|
257 |
+
hidden_states: torch.FloatTensor,
|
258 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
259 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
260 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
261 |
+
timestep: Optional[torch.LongTensor] = None,
|
262 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
263 |
+
class_labels: Optional[torch.LongTensor] = None,
|
264 |
+
):
|
265 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
266 |
+
# 1. Self-Attention
|
267 |
+
if self.use_ada_layer_norm:
|
268 |
+
norm_hidden_states = self.norm1(hidden_states, timestep)
|
269 |
+
elif self.use_ada_layer_norm_zero:
|
270 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
271 |
+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
norm_hidden_states = self.norm1(hidden_states)
|
275 |
+
|
276 |
+
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
277 |
+
|
278 |
+
attn_output = self.attn1(
|
279 |
+
norm_hidden_states,
|
280 |
+
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
|
281 |
+
attention_mask=encoder_attention_mask if self.only_cross_attention else attention_mask,
|
282 |
+
**cross_attention_kwargs,
|
283 |
+
)
|
284 |
+
if self.use_ada_layer_norm_zero:
|
285 |
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
286 |
+
hidden_states = attn_output + hidden_states
|
287 |
+
|
288 |
+
# 2. Cross-Attention
|
289 |
+
if self.attn2 is not None:
|
290 |
+
norm_hidden_states = (
|
291 |
+
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
292 |
+
)
|
293 |
+
|
294 |
+
attn_output = self.attn2(
|
295 |
+
norm_hidden_states,
|
296 |
+
encoder_hidden_states=encoder_hidden_states,
|
297 |
+
attention_mask=encoder_attention_mask,
|
298 |
+
**cross_attention_kwargs,
|
299 |
+
)
|
300 |
+
hidden_states = attn_output + hidden_states
|
301 |
+
|
302 |
+
# 3. Feed-forward
|
303 |
+
norm_hidden_states = self.norm3(hidden_states)
|
304 |
+
|
305 |
+
if self.use_ada_layer_norm_zero:
|
306 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
307 |
+
|
308 |
+
if self._chunk_size is not None:
|
309 |
+
# "feed_forward_chunk_size" can be used to save memory
|
310 |
+
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
|
311 |
+
raise ValueError(
|
312 |
+
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
|
313 |
+
)
|
314 |
+
|
315 |
+
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
|
316 |
+
ff_output = torch.cat(
|
317 |
+
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
|
318 |
+
dim=self._chunk_dim,
|
319 |
+
)
|
320 |
+
else:
|
321 |
+
ff_output = self.ff(norm_hidden_states)
|
322 |
+
|
323 |
+
if self.use_ada_layer_norm_zero:
|
324 |
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
325 |
+
|
326 |
+
hidden_states = ff_output + hidden_states
|
327 |
+
|
328 |
+
return hidden_states
|
329 |
+
|
330 |
+
|
331 |
+
class SinusoidalPosEmb(torch.nn.Module):
|
332 |
+
def __init__(self, dim):
|
333 |
+
super().__init__()
|
334 |
+
self.dim = dim
|
335 |
+
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
336 |
+
|
337 |
+
def forward(self, x, scale=1000):
|
338 |
+
if x.ndim < 1:
|
339 |
+
x = x.unsqueeze(0)
|
340 |
+
device = x.device
|
341 |
+
half_dim = self.dim // 2
|
342 |
+
emb = math.log(10000) / (half_dim - 1)
|
343 |
+
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
344 |
+
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
345 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
346 |
+
return emb
|
347 |
+
|
348 |
+
|
349 |
+
class Block1D(torch.nn.Module):
|
350 |
+
def __init__(self, dim, dim_out, groups=8):
|
351 |
+
super().__init__()
|
352 |
+
self.block = torch.nn.Sequential(
|
353 |
+
torch.nn.Conv1d(dim, dim_out, 3, padding=1),
|
354 |
+
torch.nn.GroupNorm(groups, dim_out),
|
355 |
+
nn.Mish(),
|
356 |
+
)
|
357 |
+
|
358 |
+
def forward(self, x, mask):
|
359 |
+
output = self.block(x * mask)
|
360 |
+
return output * mask
|
361 |
+
|
362 |
+
|
363 |
+
class ResnetBlock1D(torch.nn.Module):
|
364 |
+
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
|
365 |
+
super().__init__()
|
366 |
+
self.mlp = torch.nn.Sequential(nn.Mish(), torch.nn.Linear(time_emb_dim, dim_out))
|
367 |
+
|
368 |
+
self.block1 = Block1D(dim, dim_out, groups=groups)
|
369 |
+
self.block2 = Block1D(dim_out, dim_out, groups=groups)
|
370 |
+
|
371 |
+
self.res_conv = torch.nn.Conv1d(dim, dim_out, 1)
|
372 |
+
|
373 |
+
def forward(self, x, mask, time_emb):
|
374 |
+
h = self.block1(x, mask)
|
375 |
+
h += self.mlp(time_emb).unsqueeze(-1)
|
376 |
+
h = self.block2(h, mask)
|
377 |
+
output = h + self.res_conv(x * mask)
|
378 |
+
return output
|
379 |
+
|
380 |
+
|
381 |
+
class Downsample1D(nn.Module):
|
382 |
+
def __init__(self, dim):
|
383 |
+
super().__init__()
|
384 |
+
self.conv = torch.nn.Conv1d(dim, dim, 3, 2, 1)
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
return self.conv(x)
|
388 |
+
|
389 |
+
|
390 |
+
class TimestepEmbedding(nn.Module):
|
391 |
+
def __init__(
|
392 |
+
self,
|
393 |
+
in_channels: int,
|
394 |
+
time_embed_dim: int,
|
395 |
+
act_fn: str = "silu",
|
396 |
+
out_dim: int = None,
|
397 |
+
post_act_fn: Optional[str] = None,
|
398 |
+
cond_proj_dim=None,
|
399 |
+
):
|
400 |
+
super().__init__()
|
401 |
+
assert act_fn == "silu", "act_fn must be silu"
|
402 |
+
|
403 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
404 |
+
|
405 |
+
if cond_proj_dim is not None:
|
406 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
407 |
+
else:
|
408 |
+
self.cond_proj = None
|
409 |
+
|
410 |
+
self.act = nn.SiLU()
|
411 |
+
|
412 |
+
if out_dim is not None:
|
413 |
+
time_embed_dim_out = out_dim
|
414 |
+
else:
|
415 |
+
time_embed_dim_out = time_embed_dim
|
416 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
417 |
+
|
418 |
+
if post_act_fn is None:
|
419 |
+
self.post_act = None
|
420 |
+
else:
|
421 |
+
self.post_act = nn.SiLU()
|
422 |
+
|
423 |
+
def forward(self, sample, condition=None):
|
424 |
+
if condition is not None:
|
425 |
+
sample = sample + self.cond_proj(condition)
|
426 |
+
sample = self.linear_1(sample)
|
427 |
+
|
428 |
+
if self.act is not None:
|
429 |
+
sample = self.act(sample)
|
430 |
+
|
431 |
+
sample = self.linear_2(sample)
|
432 |
+
|
433 |
+
if self.post_act is not None:
|
434 |
+
sample = self.post_act(sample)
|
435 |
+
return sample
|
436 |
+
|
437 |
+
|
438 |
+
class Upsample1D(nn.Module):
|
439 |
+
"""A 1D upsampling layer with an optional convolution.
|
440 |
+
|
441 |
+
Parameters:
|
442 |
+
channels (`int`):
|
443 |
+
number of channels in the inputs and outputs.
|
444 |
+
use_conv (`bool`, default `False`):
|
445 |
+
option to use a convolution.
|
446 |
+
use_conv_transpose (`bool`, default `False`):
|
447 |
+
option to use a convolution transpose.
|
448 |
+
out_channels (`int`, optional):
|
449 |
+
number of output channels. Defaults to `channels`.
|
450 |
+
"""
|
451 |
+
|
452 |
+
def __init__(self, channels, use_conv=False, use_conv_transpose=True, out_channels=None, name="conv"):
|
453 |
+
super().__init__()
|
454 |
+
self.channels = channels
|
455 |
+
self.out_channels = out_channels or channels
|
456 |
+
self.use_conv = use_conv
|
457 |
+
self.use_conv_transpose = use_conv_transpose
|
458 |
+
self.name = name
|
459 |
+
|
460 |
+
self.conv = None
|
461 |
+
if use_conv_transpose:
|
462 |
+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
463 |
+
elif use_conv:
|
464 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
465 |
+
|
466 |
+
def forward(self, inputs):
|
467 |
+
assert inputs.shape[1] == self.channels
|
468 |
+
if self.use_conv_transpose:
|
469 |
+
return self.conv(inputs)
|
470 |
+
|
471 |
+
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
472 |
+
|
473 |
+
if self.use_conv:
|
474 |
+
outputs = self.conv(outputs)
|
475 |
+
|
476 |
+
return outputs
|
477 |
+
|
478 |
+
|
479 |
+
class Transpose(torch.nn.Module):
|
480 |
+
def __init__(self, dim0: int, dim1: int):
|
481 |
+
super().__init__()
|
482 |
+
self.dim0 = dim0
|
483 |
+
self.dim1 = dim1
|
484 |
+
|
485 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
486 |
+
x = torch.transpose(x, self.dim0, self.dim1)
|
487 |
+
return x
|
488 |
+
|
489 |
+
|
490 |
+
class CausalConv1d(torch.nn.Conv1d):
|
491 |
+
def __init__(
|
492 |
+
self,
|
493 |
+
in_channels: int,
|
494 |
+
out_channels: int,
|
495 |
+
kernel_size: int,
|
496 |
+
stride: int = 1,
|
497 |
+
dilation: int = 1,
|
498 |
+
groups: int = 1,
|
499 |
+
bias: bool = True,
|
500 |
+
padding_mode: str = 'zeros',
|
501 |
+
device=None,
|
502 |
+
dtype=None
|
503 |
+
) -> None:
|
504 |
+
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
505 |
+
kernel_size, stride,
|
506 |
+
padding=0, dilation=dilation,
|
507 |
+
groups=groups, bias=bias,
|
508 |
+
padding_mode=padding_mode,
|
509 |
+
device=device, dtype=dtype)
|
510 |
+
assert stride == 1
|
511 |
+
self.causal_padding = kernel_size - 1
|
512 |
+
|
513 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
514 |
+
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
515 |
+
x = super(CausalConv1d, self).forward(x)
|
516 |
+
return x
|
517 |
+
|
518 |
+
|
519 |
+
class CausalBlock1D(Block1D):
|
520 |
+
def __init__(self, dim: int, dim_out: int):
|
521 |
+
super(CausalBlock1D, self).__init__(dim, dim_out)
|
522 |
+
self.block = torch.nn.Sequential(
|
523 |
+
CausalConv1d(dim, dim_out, 3),
|
524 |
+
Transpose(1, 2),
|
525 |
+
nn.LayerNorm(dim_out),
|
526 |
+
Transpose(1, 2),
|
527 |
+
nn.Mish(),
|
528 |
+
)
|
529 |
+
|
530 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
531 |
+
output = self.block(x * mask)
|
532 |
+
return output * mask
|
533 |
+
|
534 |
+
|
535 |
+
class CausalResnetBlock1D(ResnetBlock1D):
|
536 |
+
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
537 |
+
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
538 |
+
self.block1 = CausalBlock1D(dim, dim_out)
|
539 |
+
self.block2 = CausalBlock1D(dim_out, dim_out)
|
540 |
+
|
541 |
+
|
542 |
+
class ConditionalDecoder(nn.Module):
|
543 |
+
"""
|
544 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
545 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
546 |
+
|
547 |
+
Args:
|
548 |
+
in_channels: number of input channels
|
549 |
+
out_channels: number of output channels
|
550 |
+
channels: tuple of channel dimensions
|
551 |
+
dropout: dropout rate
|
552 |
+
attention_head_dim: dimension of attention heads
|
553 |
+
n_blocks: number of transformer blocks
|
554 |
+
num_mid_blocks: number of middle blocks
|
555 |
+
num_heads: number of attention heads
|
556 |
+
act_fn: activation function name
|
557 |
+
"""
|
558 |
+
|
559 |
+
def __init__(
|
560 |
+
self,
|
561 |
+
in_channels,
|
562 |
+
out_channels,
|
563 |
+
channels=(256, 256),
|
564 |
+
dropout=0.05,
|
565 |
+
attention_head_dim=64,
|
566 |
+
n_blocks=1,
|
567 |
+
num_mid_blocks=2,
|
568 |
+
num_heads=4,
|
569 |
+
act_fn="snake",
|
570 |
+
):
|
571 |
+
super().__init__()
|
572 |
+
channels = tuple(channels)
|
573 |
+
self.in_channels = in_channels
|
574 |
+
self.out_channels = out_channels
|
575 |
+
|
576 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
577 |
+
time_embed_dim = channels[0] * 4
|
578 |
+
self.time_mlp = TimestepEmbedding(
|
579 |
+
in_channels=in_channels,
|
580 |
+
time_embed_dim=time_embed_dim,
|
581 |
+
act_fn="silu",
|
582 |
+
)
|
583 |
+
self.down_blocks = nn.ModuleList([])
|
584 |
+
self.mid_blocks = nn.ModuleList([])
|
585 |
+
self.up_blocks = nn.ModuleList([])
|
586 |
+
|
587 |
+
output_channel = in_channels
|
588 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
589 |
+
input_channel = output_channel
|
590 |
+
output_channel = channels[i]
|
591 |
+
is_last = i == len(channels) - 1
|
592 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
593 |
+
transformer_blocks = nn.ModuleList(
|
594 |
+
[
|
595 |
+
BasicTransformerBlock(
|
596 |
+
dim=output_channel,
|
597 |
+
num_attention_heads=num_heads,
|
598 |
+
attention_head_dim=attention_head_dim,
|
599 |
+
dropout=dropout,
|
600 |
+
activation_fn=act_fn,
|
601 |
+
)
|
602 |
+
for _ in range(n_blocks)
|
603 |
+
]
|
604 |
+
)
|
605 |
+
downsample = (
|
606 |
+
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
607 |
+
)
|
608 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
609 |
+
|
610 |
+
for _ in range(num_mid_blocks):
|
611 |
+
input_channel = channels[-1]
|
612 |
+
out_channels = channels[-1]
|
613 |
+
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
614 |
+
|
615 |
+
transformer_blocks = nn.ModuleList(
|
616 |
+
[
|
617 |
+
BasicTransformerBlock(
|
618 |
+
dim=output_channel,
|
619 |
+
num_attention_heads=num_heads,
|
620 |
+
attention_head_dim=attention_head_dim,
|
621 |
+
dropout=dropout,
|
622 |
+
activation_fn=act_fn,
|
623 |
+
)
|
624 |
+
for _ in range(n_blocks)
|
625 |
+
]
|
626 |
+
)
|
627 |
+
|
628 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
629 |
+
|
630 |
+
channels = channels[::-1] + (channels[0],)
|
631 |
+
for i in range(len(channels) - 1):
|
632 |
+
input_channel = channels[i] * 2
|
633 |
+
output_channel = channels[i + 1]
|
634 |
+
is_last = i == len(channels) - 2
|
635 |
+
resnet = ResnetBlock1D(
|
636 |
+
dim=input_channel,
|
637 |
+
dim_out=output_channel,
|
638 |
+
time_emb_dim=time_embed_dim,
|
639 |
+
)
|
640 |
+
transformer_blocks = nn.ModuleList(
|
641 |
+
[
|
642 |
+
BasicTransformerBlock(
|
643 |
+
dim=output_channel,
|
644 |
+
num_attention_heads=num_heads,
|
645 |
+
attention_head_dim=attention_head_dim,
|
646 |
+
dropout=dropout,
|
647 |
+
activation_fn=act_fn,
|
648 |
+
)
|
649 |
+
for _ in range(n_blocks)
|
650 |
+
]
|
651 |
+
)
|
652 |
+
upsample = (
|
653 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
654 |
+
if not is_last
|
655 |
+
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
656 |
+
)
|
657 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
658 |
+
self.final_block = Block1D(channels[-1], channels[-1])
|
659 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
660 |
+
self.initialize_weights()
|
661 |
+
|
662 |
+
def initialize_weights(self):
|
663 |
+
for m in self.modules():
|
664 |
+
if isinstance(m, nn.Conv1d):
|
665 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
666 |
+
if m.bias is not None:
|
667 |
+
nn.init.constant_(m.bias, 0)
|
668 |
+
elif isinstance(m, nn.GroupNorm):
|
669 |
+
nn.init.constant_(m.weight, 1)
|
670 |
+
nn.init.constant_(m.bias, 0)
|
671 |
+
elif isinstance(m, nn.Linear):
|
672 |
+
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
673 |
+
if m.bias is not None:
|
674 |
+
nn.init.constant_(m.bias, 0)
|
675 |
+
|
676 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
677 |
+
"""Forward pass of the UNet1DConditional model.
|
678 |
+
|
679 |
+
Args:
|
680 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
681 |
+
mask (_type_): shape (batch_size, 1, time)
|
682 |
+
t (_type_): shape (batch_size)
|
683 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
684 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
685 |
+
|
686 |
+
Raises:
|
687 |
+
ValueError: _description_
|
688 |
+
ValueError: _description_
|
689 |
+
|
690 |
+
Returns:
|
691 |
+
_type_: _description_
|
692 |
+
"""
|
693 |
+
|
694 |
+
t = self.time_embeddings(t).to(t.dtype)
|
695 |
+
t = self.time_mlp(t)
|
696 |
+
|
697 |
+
x = pack([x, mu], "b * t")[0]
|
698 |
+
|
699 |
+
if spks is not None:
|
700 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
701 |
+
x = pack([x, spks], "b * t")[0]
|
702 |
+
if cond is not None:
|
703 |
+
x = pack([x, cond], "b * t")[0]
|
704 |
+
|
705 |
+
hiddens = []
|
706 |
+
masks = [mask]
|
707 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
708 |
+
mask_down = masks[-1]
|
709 |
+
x = resnet(x, mask_down, t)
|
710 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
711 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
712 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
713 |
+
for transformer_block in transformer_blocks:
|
714 |
+
x = transformer_block(
|
715 |
+
hidden_states=x,
|
716 |
+
attention_mask=attn_mask,
|
717 |
+
timestep=t,
|
718 |
+
)
|
719 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
720 |
+
hiddens.append(x) # Save hidden states for skip connections
|
721 |
+
x = downsample(x * mask_down)
|
722 |
+
masks.append(mask_down[:, :, ::2])
|
723 |
+
masks = masks[:-1]
|
724 |
+
mask_mid = masks[-1]
|
725 |
+
|
726 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
727 |
+
x = resnet(x, mask_mid, t)
|
728 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
729 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
730 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
731 |
+
for transformer_block in transformer_blocks:
|
732 |
+
x = transformer_block(
|
733 |
+
hidden_states=x,
|
734 |
+
attention_mask=attn_mask,
|
735 |
+
timestep=t,
|
736 |
+
)
|
737 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
738 |
+
|
739 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
740 |
+
mask_up = masks.pop()
|
741 |
+
skip = hiddens.pop()
|
742 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
743 |
+
x = resnet(x, mask_up, t)
|
744 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
745 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
746 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
747 |
+
for transformer_block in transformer_blocks:
|
748 |
+
x = transformer_block(
|
749 |
+
hidden_states=x,
|
750 |
+
attention_mask=attn_mask,
|
751 |
+
timestep=t,
|
752 |
+
)
|
753 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
754 |
+
x = upsample(x * mask_up)
|
755 |
+
x = self.final_block(x, mask_up)
|
756 |
+
output = self.final_proj(x * mask_up)
|
757 |
+
return output * mask
|
758 |
+
|
759 |
+
|
760 |
+
class CausalConditionalDecoder(ConditionalDecoder):
|
761 |
+
"""
|
762 |
+
This decoder requires an input with the same shape of the target. So, if your text content
|
763 |
+
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
764 |
+
|
765 |
+
Args:
|
766 |
+
in_channels: number of input channels
|
767 |
+
out_channels: number of output channels
|
768 |
+
channels: list of channel dimensions
|
769 |
+
dropout: dropout rate
|
770 |
+
attention_head_dim: dimension of attention heads
|
771 |
+
n_blocks: number of transformer blocks
|
772 |
+
num_mid_blocks: number of middle blocks
|
773 |
+
num_heads: number of attention heads
|
774 |
+
act_fn: activation function name
|
775 |
+
static_chunk_size: size of static chunks
|
776 |
+
num_decoding_left_chunks: number of left chunks for decoding
|
777 |
+
"""
|
778 |
+
|
779 |
+
def __init__(
|
780 |
+
self,
|
781 |
+
in_channels=320,
|
782 |
+
out_channels=80,
|
783 |
+
channels=[256], # noqa
|
784 |
+
dropout=0.0,
|
785 |
+
attention_head_dim=64,
|
786 |
+
n_blocks=4,
|
787 |
+
num_mid_blocks=12,
|
788 |
+
num_heads=8,
|
789 |
+
act_fn="gelu",
|
790 |
+
static_chunk_size=50,
|
791 |
+
num_decoding_left_chunks=-1,
|
792 |
+
):
|
793 |
+
torch.nn.Module.__init__(self)
|
794 |
+
channels = tuple(channels)
|
795 |
+
self.in_channels = in_channels
|
796 |
+
self.out_channels = out_channels
|
797 |
+
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
798 |
+
time_embed_dim = channels[0] * 4
|
799 |
+
self.time_mlp = TimestepEmbedding(
|
800 |
+
in_channels=in_channels,
|
801 |
+
time_embed_dim=time_embed_dim,
|
802 |
+
act_fn="silu",
|
803 |
+
)
|
804 |
+
self.static_chunk_size = static_chunk_size
|
805 |
+
self.num_decoding_left_chunks = num_decoding_left_chunks
|
806 |
+
self.down_blocks = nn.ModuleList([])
|
807 |
+
self.mid_blocks = nn.ModuleList([])
|
808 |
+
self.up_blocks = nn.ModuleList([])
|
809 |
+
|
810 |
+
output_channel = in_channels
|
811 |
+
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
812 |
+
input_channel = output_channel
|
813 |
+
output_channel = channels[i]
|
814 |
+
is_last = i == len(channels) - 1
|
815 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
816 |
+
transformer_blocks = nn.ModuleList(
|
817 |
+
[
|
818 |
+
BasicTransformerBlock(
|
819 |
+
dim=output_channel,
|
820 |
+
num_attention_heads=num_heads,
|
821 |
+
attention_head_dim=attention_head_dim,
|
822 |
+
dropout=dropout,
|
823 |
+
activation_fn=act_fn,
|
824 |
+
)
|
825 |
+
for _ in range(n_blocks)
|
826 |
+
]
|
827 |
+
)
|
828 |
+
downsample = (
|
829 |
+
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
830 |
+
)
|
831 |
+
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
832 |
+
|
833 |
+
for _ in range(num_mid_blocks):
|
834 |
+
input_channel = channels[-1]
|
835 |
+
out_channels = channels[-1]
|
836 |
+
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
837 |
+
|
838 |
+
transformer_blocks = nn.ModuleList(
|
839 |
+
[
|
840 |
+
BasicTransformerBlock(
|
841 |
+
dim=output_channel,
|
842 |
+
num_attention_heads=num_heads,
|
843 |
+
attention_head_dim=attention_head_dim,
|
844 |
+
dropout=dropout,
|
845 |
+
activation_fn=act_fn,
|
846 |
+
)
|
847 |
+
for _ in range(n_blocks)
|
848 |
+
]
|
849 |
+
)
|
850 |
+
|
851 |
+
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
852 |
+
|
853 |
+
channels = channels[::-1] + (channels[0],)
|
854 |
+
for i in range(len(channels) - 1):
|
855 |
+
input_channel = channels[i] * 2
|
856 |
+
output_channel = channels[i + 1]
|
857 |
+
is_last = i == len(channels) - 2
|
858 |
+
resnet = CausalResnetBlock1D(
|
859 |
+
dim=input_channel,
|
860 |
+
dim_out=output_channel,
|
861 |
+
time_emb_dim=time_embed_dim,
|
862 |
+
)
|
863 |
+
transformer_blocks = nn.ModuleList(
|
864 |
+
[
|
865 |
+
BasicTransformerBlock(
|
866 |
+
dim=output_channel,
|
867 |
+
num_attention_heads=num_heads,
|
868 |
+
attention_head_dim=attention_head_dim,
|
869 |
+
dropout=dropout,
|
870 |
+
activation_fn=act_fn,
|
871 |
+
)
|
872 |
+
for _ in range(n_blocks)
|
873 |
+
]
|
874 |
+
)
|
875 |
+
upsample = (
|
876 |
+
Upsample1D(output_channel, use_conv_transpose=True)
|
877 |
+
if not is_last
|
878 |
+
else CausalConv1d(output_channel, output_channel, 3)
|
879 |
+
)
|
880 |
+
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
881 |
+
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
882 |
+
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
883 |
+
self.initialize_weights()
|
884 |
+
|
885 |
+
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
886 |
+
"""Forward pass of the UNet1DConditional model.
|
887 |
+
|
888 |
+
Args:
|
889 |
+
x (torch.Tensor): shape (batch_size, in_channels, time)
|
890 |
+
mask (_type_): shape (batch_size, 1, time)
|
891 |
+
t (_type_): shape (batch_size)
|
892 |
+
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
893 |
+
cond (_type_, optional): placeholder for future use. Defaults to None.
|
894 |
+
|
895 |
+
Raises:
|
896 |
+
ValueError: _description_
|
897 |
+
ValueError: _description_
|
898 |
+
|
899 |
+
Returns:
|
900 |
+
_type_: _description_
|
901 |
+
"""
|
902 |
+
t = self.time_embeddings(t).to(t.dtype)
|
903 |
+
t = self.time_mlp(t)
|
904 |
+
|
905 |
+
x = pack([x, mu], "b * t")[0]
|
906 |
+
|
907 |
+
if spks is not None:
|
908 |
+
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
909 |
+
x = pack([x, spks], "b * t")[0]
|
910 |
+
if cond is not None:
|
911 |
+
x = pack([x, cond], "b * t")[0]
|
912 |
+
|
913 |
+
hiddens = []
|
914 |
+
masks = [mask]
|
915 |
+
for resnet, transformer_blocks, downsample in self.down_blocks:
|
916 |
+
mask_down = masks[-1]
|
917 |
+
x = resnet(x, mask_down, t)
|
918 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
919 |
+
if streaming is True:
|
920 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
921 |
+
else:
|
922 |
+
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
923 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
924 |
+
for transformer_block in transformer_blocks:
|
925 |
+
x = transformer_block(
|
926 |
+
hidden_states=x,
|
927 |
+
attention_mask=attn_mask,
|
928 |
+
timestep=t,
|
929 |
+
)
|
930 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
931 |
+
hiddens.append(x) # Save hidden states for skip connections
|
932 |
+
x = downsample(x * mask_down)
|
933 |
+
masks.append(mask_down[:, :, ::2])
|
934 |
+
masks = masks[:-1]
|
935 |
+
mask_mid = masks[-1]
|
936 |
+
|
937 |
+
for resnet, transformer_blocks in self.mid_blocks:
|
938 |
+
x = resnet(x, mask_mid, t)
|
939 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
940 |
+
if streaming is True:
|
941 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
942 |
+
else:
|
943 |
+
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
944 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
945 |
+
for transformer_block in transformer_blocks:
|
946 |
+
x = transformer_block(
|
947 |
+
hidden_states=x,
|
948 |
+
attention_mask=attn_mask,
|
949 |
+
timestep=t,
|
950 |
+
)
|
951 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
952 |
+
|
953 |
+
for resnet, transformer_blocks, upsample in self.up_blocks:
|
954 |
+
mask_up = masks.pop()
|
955 |
+
skip = hiddens.pop()
|
956 |
+
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
957 |
+
x = resnet(x, mask_up, t)
|
958 |
+
x = rearrange(x, "b c t -> b t c").contiguous()
|
959 |
+
if streaming is True:
|
960 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
961 |
+
else:
|
962 |
+
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
963 |
+
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
964 |
+
for transformer_block in transformer_blocks:
|
965 |
+
x = transformer_block(
|
966 |
+
hidden_states=x,
|
967 |
+
attention_mask=attn_mask,
|
968 |
+
timestep=t,
|
969 |
+
)
|
970 |
+
x = rearrange(x, "b t c -> b c t").contiguous()
|
971 |
+
x = upsample(x * mask_up)
|
972 |
+
x = self.final_block(x, mask_up)
|
973 |
+
output = self.final_proj(x * mask_up)
|
974 |
+
return output * mask
|
flashcosyvoice/modules/flow_components/upsample_encoder.py
ADDED
@@ -0,0 +1,998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def subsequent_chunk_mask(
|
10 |
+
size: int,
|
11 |
+
chunk_size: int,
|
12 |
+
num_left_chunks: int = -1,
|
13 |
+
device: torch.device = torch.device("cpu"),
|
14 |
+
) -> torch.Tensor:
|
15 |
+
"""Create mask for subsequent steps (size, size) with chunk size,
|
16 |
+
this is for streaming encoder
|
17 |
+
|
18 |
+
Args:
|
19 |
+
size (int): size of mask
|
20 |
+
chunk_size (int): size of chunk
|
21 |
+
num_left_chunks (int): number of left chunks
|
22 |
+
<0: use full chunk
|
23 |
+
>=0: use num_left_chunks
|
24 |
+
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: mask
|
28 |
+
|
29 |
+
Examples:
|
30 |
+
>>> subsequent_chunk_mask(4, 2)
|
31 |
+
[[1, 1, 0, 0],
|
32 |
+
[1, 1, 0, 0],
|
33 |
+
[1, 1, 1, 1],
|
34 |
+
[1, 1, 1, 1]]
|
35 |
+
"""
|
36 |
+
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
37 |
+
pos_idx = torch.arange(size, device=device)
|
38 |
+
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
39 |
+
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
40 |
+
return ret
|
41 |
+
|
42 |
+
|
43 |
+
def add_optional_chunk_mask(xs: torch.Tensor,
|
44 |
+
masks: torch.Tensor,
|
45 |
+
use_dynamic_chunk: bool,
|
46 |
+
use_dynamic_left_chunk: bool,
|
47 |
+
decoding_chunk_size: int,
|
48 |
+
static_chunk_size: int,
|
49 |
+
num_decoding_left_chunks: int,
|
50 |
+
enable_full_context: bool = True):
|
51 |
+
""" Apply optional mask for encoder.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
55 |
+
mask (torch.Tensor): mask for xs, (B, 1, L)
|
56 |
+
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
57 |
+
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
58 |
+
training.
|
59 |
+
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
60 |
+
0: default for training, use random dynamic chunk.
|
61 |
+
<0: for decoding, use full chunk.
|
62 |
+
>0: for decoding, use fixed chunk size as set.
|
63 |
+
static_chunk_size (int): chunk size for static chunk training/decoding
|
64 |
+
if it's greater than 0, if use_dynamic_chunk is true,
|
65 |
+
this parameter will be ignored
|
66 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
67 |
+
the chunk size is decoding_chunk_size.
|
68 |
+
>=0: use num_decoding_left_chunks
|
69 |
+
<0: use all left chunks
|
70 |
+
enable_full_context (bool):
|
71 |
+
True: chunk size is either [1, 25] or full context(max_len)
|
72 |
+
False: chunk size ~ U[1, 25]
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
torch.Tensor: chunk mask of the input xs.
|
76 |
+
"""
|
77 |
+
# Whether to use chunk mask or not
|
78 |
+
if use_dynamic_chunk:
|
79 |
+
max_len = xs.size(1)
|
80 |
+
if decoding_chunk_size < 0:
|
81 |
+
chunk_size = max_len
|
82 |
+
num_left_chunks = -1
|
83 |
+
elif decoding_chunk_size > 0:
|
84 |
+
chunk_size = decoding_chunk_size
|
85 |
+
num_left_chunks = num_decoding_left_chunks
|
86 |
+
else:
|
87 |
+
# chunk size is either [1, 25] or full context(max_len).
|
88 |
+
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
89 |
+
# delay, the maximum frame is 100 / 4 = 25.
|
90 |
+
chunk_size = torch.randint(1, max_len, (1, )).item()
|
91 |
+
num_left_chunks = -1
|
92 |
+
if chunk_size > max_len // 2 and enable_full_context:
|
93 |
+
chunk_size = max_len
|
94 |
+
else:
|
95 |
+
chunk_size = chunk_size % 25 + 1
|
96 |
+
if use_dynamic_left_chunk:
|
97 |
+
max_left_chunks = (max_len - 1) // chunk_size
|
98 |
+
num_left_chunks = torch.randint(0, max_left_chunks,
|
99 |
+
(1, )).item()
|
100 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
101 |
+
num_left_chunks,
|
102 |
+
xs.device) # (L, L)
|
103 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
104 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
105 |
+
elif static_chunk_size > 0:
|
106 |
+
num_left_chunks = num_decoding_left_chunks
|
107 |
+
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
108 |
+
num_left_chunks,
|
109 |
+
xs.device) # (L, L)
|
110 |
+
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
111 |
+
chunk_masks = masks & chunk_masks # (B, L, L)
|
112 |
+
else:
|
113 |
+
chunk_masks = masks
|
114 |
+
assert chunk_masks.dtype == torch.bool
|
115 |
+
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
116 |
+
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
117 |
+
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
118 |
+
return chunk_masks
|
119 |
+
|
120 |
+
|
121 |
+
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
122 |
+
"""Make mask tensor containing indices of padded part.
|
123 |
+
|
124 |
+
See description of make_non_pad_mask.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
128 |
+
Returns:
|
129 |
+
torch.Tensor: Mask tensor containing indices of padded part.
|
130 |
+
|
131 |
+
Examples:
|
132 |
+
>>> lengths = [5, 3, 2]
|
133 |
+
>>> make_pad_mask(lengths)
|
134 |
+
masks = [[0, 0, 0, 0 ,0],
|
135 |
+
[0, 0, 0, 1, 1],
|
136 |
+
[0, 0, 1, 1, 1]]
|
137 |
+
"""
|
138 |
+
batch_size = lengths.size(0)
|
139 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
140 |
+
seq_range = torch.arange(0,
|
141 |
+
max_len,
|
142 |
+
dtype=torch.int64,
|
143 |
+
device=lengths.device)
|
144 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
145 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
146 |
+
mask = seq_range_expand >= seq_length_expand
|
147 |
+
return mask
|
148 |
+
|
149 |
+
|
150 |
+
class EspnetRelPositionalEncoding(torch.nn.Module):
|
151 |
+
"""Relative positional encoding module (new implementation).
|
152 |
+
|
153 |
+
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
154 |
+
|
155 |
+
See : Appendix B in https://arxiv.org/abs/1901.02860
|
156 |
+
|
157 |
+
Args:
|
158 |
+
d_model (int): Embedding dimension.
|
159 |
+
max_len (int): Maximum input length.
|
160 |
+
|
161 |
+
"""
|
162 |
+
|
163 |
+
def __init__(self, d_model: int, max_len: int = 5000):
|
164 |
+
super(EspnetRelPositionalEncoding, self).__init__()
|
165 |
+
self.d_model = d_model
|
166 |
+
self.xscale = math.sqrt(self.d_model)
|
167 |
+
self.pe = None
|
168 |
+
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
169 |
+
|
170 |
+
def extend_pe(self, x: torch.Tensor):
|
171 |
+
"""Reset the positional encodings."""
|
172 |
+
if self.pe is not None:
|
173 |
+
# self.pe contains both positive and negative parts
|
174 |
+
# the length of self.pe is 2 * input_len - 1
|
175 |
+
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
176 |
+
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
177 |
+
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
178 |
+
return
|
179 |
+
# Suppose `i` means to the position of query vecotr and `j` means the
|
180 |
+
# position of key vector. We use position relative positions when keys
|
181 |
+
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
182 |
+
pe_positive = torch.zeros(x.size(1), self.d_model)
|
183 |
+
pe_negative = torch.zeros(x.size(1), self.d_model)
|
184 |
+
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
185 |
+
div_term = torch.exp(
|
186 |
+
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
187 |
+
* -(math.log(10000.0) / self.d_model)
|
188 |
+
)
|
189 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
190 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
191 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
192 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
193 |
+
|
194 |
+
# Reserve the order of positive indices and concat both positive and
|
195 |
+
# negative indices. This is used to support the shifting trick
|
196 |
+
# as in https://arxiv.org/abs/1901.02860
|
197 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
198 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
199 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
200 |
+
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
203 |
+
-> Tuple[torch.Tensor, torch.Tensor]:
|
204 |
+
"""Add positional encoding.
|
205 |
+
|
206 |
+
Args:
|
207 |
+
x (torch.Tensor): Input tensor (batch, time, `*`).
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
torch.Tensor: Encoded tensor (batch, time, `*`).
|
211 |
+
|
212 |
+
"""
|
213 |
+
self.extend_pe(x)
|
214 |
+
x = x * self.xscale
|
215 |
+
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
216 |
+
return x, pos_emb
|
217 |
+
|
218 |
+
def position_encoding(self,
|
219 |
+
offset: Union[int, torch.Tensor],
|
220 |
+
size: int) -> torch.Tensor:
|
221 |
+
""" For getting encoding in a streaming fashion
|
222 |
+
|
223 |
+
Attention!!!!!
|
224 |
+
we apply dropout only once at the whole utterance level in a none
|
225 |
+
streaming way, but will call this function several times with
|
226 |
+
increasing input size in a streaming scenario, so the dropout will
|
227 |
+
be applied several times.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
offset (int or torch.tensor): start offset
|
231 |
+
size (int): required size of position encoding
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
torch.Tensor: Corresponding encoding
|
235 |
+
"""
|
236 |
+
# How to subscript a Union type:
|
237 |
+
# https://github.com/pytorch/pytorch/issues/69434
|
238 |
+
if isinstance(offset, int):
|
239 |
+
pos_emb = self.pe[
|
240 |
+
:,
|
241 |
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
242 |
+
]
|
243 |
+
elif isinstance(offset, torch.Tensor):
|
244 |
+
pos_emb = self.pe[
|
245 |
+
:,
|
246 |
+
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
247 |
+
]
|
248 |
+
return pos_emb
|
249 |
+
|
250 |
+
|
251 |
+
class LinearNoSubsampling(torch.nn.Module):
|
252 |
+
"""Linear transform the input without subsampling
|
253 |
+
|
254 |
+
Args:
|
255 |
+
idim (int): Input dimension.
|
256 |
+
odim (int): Output dimension.
|
257 |
+
pos_enc_class (torch.nn.Module): Positional encoding class.
|
258 |
+
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(self, idim: int, odim: int,
|
262 |
+
pos_enc_class: torch.nn.Module):
|
263 |
+
super().__init__()
|
264 |
+
self.out = torch.nn.Sequential(
|
265 |
+
torch.nn.Linear(idim, odim),
|
266 |
+
torch.nn.LayerNorm(odim, eps=1e-5),
|
267 |
+
)
|
268 |
+
self.pos_enc = pos_enc_class
|
269 |
+
self.right_context = 0
|
270 |
+
self.subsampling_rate = 1
|
271 |
+
|
272 |
+
def forward(
|
273 |
+
self,
|
274 |
+
x: torch.Tensor,
|
275 |
+
x_mask: torch.Tensor,
|
276 |
+
offset: Union[int, torch.Tensor] = 0
|
277 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
278 |
+
"""Input x.
|
279 |
+
|
280 |
+
Args:
|
281 |
+
x (torch.Tensor): Input tensor (#batch, time, idim).
|
282 |
+
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
283 |
+
|
284 |
+
Returns:
|
285 |
+
torch.Tensor: linear input tensor (#batch, time', odim),
|
286 |
+
where time' = time .
|
287 |
+
torch.Tensor: linear input mask (#batch, 1, time'),
|
288 |
+
where time' = time .
|
289 |
+
|
290 |
+
"""
|
291 |
+
x = self.out(x)
|
292 |
+
x, pos_emb = self.pos_enc(x, offset)
|
293 |
+
return x, pos_emb, x_mask
|
294 |
+
|
295 |
+
def position_encoding(self, offset: Union[int, torch.Tensor],
|
296 |
+
size: int) -> torch.Tensor:
|
297 |
+
return self.pos_enc.position_encoding(offset, size)
|
298 |
+
|
299 |
+
|
300 |
+
class Upsample1D(nn.Module):
|
301 |
+
"""A 1D upsampling layer with an optional convolution.
|
302 |
+
|
303 |
+
Parameters:
|
304 |
+
channels (`int`):
|
305 |
+
number of channels in the inputs and outputs.
|
306 |
+
use_conv (`bool`, default `False`):
|
307 |
+
option to use a convolution.
|
308 |
+
use_conv_transpose (`bool`, default `False`):
|
309 |
+
option to use a convolution transpose.
|
310 |
+
out_channels (`int`, optional):
|
311 |
+
number of output channels. Defaults to `channels`.
|
312 |
+
"""
|
313 |
+
|
314 |
+
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
315 |
+
super().__init__()
|
316 |
+
self.channels = channels
|
317 |
+
self.out_channels = out_channels
|
318 |
+
self.stride = stride
|
319 |
+
# In this mode, first repeat interpolate, than conv with stride=1
|
320 |
+
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
321 |
+
|
322 |
+
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
323 |
+
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
324 |
+
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
325 |
+
outputs = self.conv(outputs)
|
326 |
+
return outputs, input_lengths * self.stride
|
327 |
+
|
328 |
+
|
329 |
+
class PreLookaheadLayer(nn.Module):
|
330 |
+
def __init__(self, channels: int, pre_lookahead_len: int = 1):
|
331 |
+
super().__init__()
|
332 |
+
self.channels = channels
|
333 |
+
self.pre_lookahead_len = pre_lookahead_len
|
334 |
+
self.conv1 = nn.Conv1d(
|
335 |
+
channels, channels,
|
336 |
+
kernel_size=pre_lookahead_len + 1,
|
337 |
+
stride=1, padding=0,
|
338 |
+
)
|
339 |
+
self.conv2 = nn.Conv1d(
|
340 |
+
channels, channels,
|
341 |
+
kernel_size=3, stride=1, padding=0,
|
342 |
+
)
|
343 |
+
|
344 |
+
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
345 |
+
"""
|
346 |
+
inputs: (batch_size, seq_len, channels)
|
347 |
+
"""
|
348 |
+
outputs = inputs.transpose(1, 2).contiguous()
|
349 |
+
context = context.transpose(1, 2).contiguous()
|
350 |
+
# look ahead
|
351 |
+
if context.size(2) == 0:
|
352 |
+
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
353 |
+
else:
|
354 |
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
355 |
+
assert context.size(2) == self.pre_lookahead_len
|
356 |
+
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
357 |
+
outputs = F.leaky_relu(self.conv1(outputs))
|
358 |
+
# outputs
|
359 |
+
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
360 |
+
outputs = self.conv2(outputs)
|
361 |
+
outputs = outputs.transpose(1, 2).contiguous()
|
362 |
+
|
363 |
+
# residual connection
|
364 |
+
outputs = outputs + inputs
|
365 |
+
return outputs
|
366 |
+
|
367 |
+
|
368 |
+
class MultiHeadedAttention(nn.Module):
|
369 |
+
"""Multi-Head Attention layer.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
n_head (int): The number of heads.
|
373 |
+
n_feat (int): The number of features.
|
374 |
+
dropout_rate (float): Dropout rate.
|
375 |
+
key_bias (bool): Whether to use bias in key linear layer.
|
376 |
+
|
377 |
+
"""
|
378 |
+
|
379 |
+
def __init__(self,
|
380 |
+
n_head: int,
|
381 |
+
n_feat: int,
|
382 |
+
dropout_rate: float,
|
383 |
+
key_bias: bool = True):
|
384 |
+
super().__init__()
|
385 |
+
assert n_feat % n_head == 0
|
386 |
+
# We assume d_v always equals d_k
|
387 |
+
self.d_k = n_feat // n_head
|
388 |
+
self.h = n_head
|
389 |
+
self.linear_q = nn.Linear(n_feat, n_feat)
|
390 |
+
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
391 |
+
self.linear_v = nn.Linear(n_feat, n_feat)
|
392 |
+
self.linear_out = nn.Linear(n_feat, n_feat)
|
393 |
+
self.dropout = nn.Dropout(p=dropout_rate)
|
394 |
+
|
395 |
+
def forward_qkv(
|
396 |
+
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
397 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
398 |
+
"""Transform query, key and value.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
402 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
403 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
torch.Tensor: Transformed query tensor, size
|
407 |
+
(#batch, n_head, time1, d_k).
|
408 |
+
torch.Tensor: Transformed key tensor, size
|
409 |
+
(#batch, n_head, time2, d_k).
|
410 |
+
torch.Tensor: Transformed value tensor, size
|
411 |
+
(#batch, n_head, time2, d_k).
|
412 |
+
|
413 |
+
"""
|
414 |
+
n_batch = query.size(0)
|
415 |
+
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
416 |
+
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
417 |
+
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
418 |
+
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
419 |
+
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
420 |
+
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
421 |
+
|
422 |
+
return q, k, v
|
423 |
+
|
424 |
+
def forward_attention(
|
425 |
+
self,
|
426 |
+
value: torch.Tensor,
|
427 |
+
scores: torch.Tensor,
|
428 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
429 |
+
) -> torch.Tensor:
|
430 |
+
"""Compute attention context vector.
|
431 |
+
|
432 |
+
Args:
|
433 |
+
value (torch.Tensor): Transformed value, size
|
434 |
+
(#batch, n_head, time2, d_k).
|
435 |
+
scores (torch.Tensor): Attention score, size
|
436 |
+
(#batch, n_head, time1, time2).
|
437 |
+
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
438 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
439 |
+
|
440 |
+
Returns:
|
441 |
+
torch.Tensor: Transformed value (#batch, time1, d_model)
|
442 |
+
weighted by the attention score (#batch, time1, time2).
|
443 |
+
|
444 |
+
"""
|
445 |
+
n_batch = value.size(0)
|
446 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
447 |
+
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
448 |
+
# 1st chunk to ease the onnx export.]
|
449 |
+
# 2. pytorch training
|
450 |
+
if mask.size(2) > 0: # time2 > 0
|
451 |
+
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
452 |
+
# For last chunk, time2 might be larger than scores.size(-1)
|
453 |
+
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
454 |
+
scores = scores.masked_fill(mask, -float('inf'))
|
455 |
+
attn = torch.softmax(scores, dim=-1).masked_fill(
|
456 |
+
mask, 0.0) # (batch, head, time1, time2)
|
457 |
+
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
458 |
+
# 1. onnx(16/-1, -1/-1, 16/0)
|
459 |
+
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
460 |
+
else:
|
461 |
+
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
462 |
+
|
463 |
+
p_attn = self.dropout(attn)
|
464 |
+
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
465 |
+
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
466 |
+
self.h * self.d_k)
|
467 |
+
) # (batch, time1, d_model)
|
468 |
+
|
469 |
+
return self.linear_out(x) # (batch, time1, d_model)
|
470 |
+
|
471 |
+
def forward(
|
472 |
+
self,
|
473 |
+
query: torch.Tensor,
|
474 |
+
key: torch.Tensor,
|
475 |
+
value: torch.Tensor,
|
476 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
477 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
478 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
479 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
480 |
+
"""Compute scaled dot product attention.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
484 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
485 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
486 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
487 |
+
(#batch, time1, time2).
|
488 |
+
1.When applying cross attention between decoder and encoder,
|
489 |
+
the batch padding mask for input is in (#batch, 1, T) shape.
|
490 |
+
2.When applying self attention of encoder,
|
491 |
+
the mask is in (#batch, T, T) shape.
|
492 |
+
3.When applying self attention of decoder,
|
493 |
+
the mask is in (#batch, L, L) shape.
|
494 |
+
4.If the different position in decoder see different block
|
495 |
+
of the encoder, such as Mocha, the passed in mask could be
|
496 |
+
in (#batch, L, T) shape. But there is no such case in current
|
497 |
+
CosyVoice.
|
498 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
499 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
500 |
+
and `head * d_k == size`
|
501 |
+
|
502 |
+
|
503 |
+
Returns:
|
504 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
505 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
506 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
507 |
+
and `head * d_k == size`
|
508 |
+
|
509 |
+
"""
|
510 |
+
q, k, v = self.forward_qkv(query, key, value)
|
511 |
+
|
512 |
+
# NOTE(xcsong):
|
513 |
+
# when export onnx model, for 1st chunk, we feed
|
514 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
515 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
516 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
517 |
+
# and we will always do splitting and
|
518 |
+
# concatnation(this will simplify onnx export). Note that
|
519 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
520 |
+
# when export jit model, for 1st chunk, we always feed
|
521 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
522 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
523 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
524 |
+
# >>> c = torch.cat((a, b), dim=2)
|
525 |
+
# >>> torch.equal(b, c) # True
|
526 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
527 |
+
# >>> torch.equal(d[0], d[1]) # True
|
528 |
+
if cache.size(0) > 0:
|
529 |
+
key_cache, value_cache = torch.split(cache,
|
530 |
+
cache.size(-1) // 2,
|
531 |
+
dim=-1)
|
532 |
+
k = torch.cat([key_cache, k], dim=2)
|
533 |
+
v = torch.cat([value_cache, v], dim=2)
|
534 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
535 |
+
# non-trivial to calculate `next_cache_start` here.
|
536 |
+
new_cache = torch.cat((k, v), dim=-1)
|
537 |
+
|
538 |
+
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
539 |
+
return self.forward_attention(v, scores, mask), new_cache
|
540 |
+
|
541 |
+
|
542 |
+
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
543 |
+
"""Multi-Head Attention layer with relative position encoding.
|
544 |
+
Paper: https://arxiv.org/abs/1901.02860
|
545 |
+
Args:
|
546 |
+
n_head (int): The number of heads.
|
547 |
+
n_feat (int): The number of features.
|
548 |
+
dropout_rate (float): Dropout rate.
|
549 |
+
key_bias (bool): Whether to use bias in key linear layer.
|
550 |
+
"""
|
551 |
+
|
552 |
+
def __init__(self,
|
553 |
+
n_head: int,
|
554 |
+
n_feat: int,
|
555 |
+
dropout_rate: float,
|
556 |
+
key_bias: bool = True):
|
557 |
+
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
558 |
+
# linear transformation for positional encoding
|
559 |
+
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
560 |
+
# these two learnable bias are used in matrix c and matrix d
|
561 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
562 |
+
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
563 |
+
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
564 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
565 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
566 |
+
|
567 |
+
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
568 |
+
"""Compute relative positional encoding.
|
569 |
+
|
570 |
+
Args:
|
571 |
+
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
572 |
+
time1 means the length of query vector.
|
573 |
+
|
574 |
+
Returns:
|
575 |
+
torch.Tensor: Output tensor.
|
576 |
+
|
577 |
+
"""
|
578 |
+
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
579 |
+
device=x.device,
|
580 |
+
dtype=x.dtype)
|
581 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
582 |
+
|
583 |
+
x_padded = x_padded.view(x.size()[0],
|
584 |
+
x.size()[1],
|
585 |
+
x.size(3) + 1, x.size(2))
|
586 |
+
x = x_padded[:, :, 1:].view_as(x)[
|
587 |
+
:, :, :, : x.size(-1) // 2 + 1
|
588 |
+
] # only keep the positions from 0 to time2
|
589 |
+
return x
|
590 |
+
|
591 |
+
def forward(
|
592 |
+
self,
|
593 |
+
query: torch.Tensor,
|
594 |
+
key: torch.Tensor,
|
595 |
+
value: torch.Tensor,
|
596 |
+
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
597 |
+
pos_emb: torch.Tensor = torch.empty(0),
|
598 |
+
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
599 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
600 |
+
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
601 |
+
Args:
|
602 |
+
query (torch.Tensor): Query tensor (#batch, time1, size).
|
603 |
+
key (torch.Tensor): Key tensor (#batch, time2, size).
|
604 |
+
value (torch.Tensor): Value tensor (#batch, time2, size).
|
605 |
+
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
606 |
+
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
607 |
+
pos_emb (torch.Tensor): Positional embedding tensor
|
608 |
+
(#batch, time2, size).
|
609 |
+
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
610 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
611 |
+
and `head * d_k == size`
|
612 |
+
Returns:
|
613 |
+
torch.Tensor: Output tensor (#batch, time1, d_model).
|
614 |
+
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
615 |
+
where `cache_t == chunk_size * num_decoding_left_chunks`
|
616 |
+
and `head * d_k == size`
|
617 |
+
"""
|
618 |
+
q, k, v = self.forward_qkv(query, key, value)
|
619 |
+
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
620 |
+
|
621 |
+
# NOTE(xcsong):
|
622 |
+
# when export onnx model, for 1st chunk, we feed
|
623 |
+
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
624 |
+
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
625 |
+
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
626 |
+
# and we will always do splitting and
|
627 |
+
# concatnation(this will simplify onnx export). Note that
|
628 |
+
# it's OK to concat & split zero-shaped tensors(see code below).
|
629 |
+
# when export jit model, for 1st chunk, we always feed
|
630 |
+
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
631 |
+
# >>> a = torch.ones((1, 2, 0, 4))
|
632 |
+
# >>> b = torch.ones((1, 2, 3, 4))
|
633 |
+
# >>> c = torch.cat((a, b), dim=2)
|
634 |
+
# >>> torch.equal(b, c) # True
|
635 |
+
# >>> d = torch.split(a, 2, dim=-1)
|
636 |
+
# >>> torch.equal(d[0], d[1]) # True
|
637 |
+
if cache.size(0) > 0:
|
638 |
+
key_cache, value_cache = torch.split(cache,
|
639 |
+
cache.size(-1) // 2,
|
640 |
+
dim=-1)
|
641 |
+
k = torch.cat([key_cache, k], dim=2)
|
642 |
+
v = torch.cat([value_cache, v], dim=2)
|
643 |
+
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
644 |
+
# non-trivial to calculate `next_cache_start` here.
|
645 |
+
new_cache = torch.cat((k, v), dim=-1)
|
646 |
+
|
647 |
+
n_batch_pos = pos_emb.size(0)
|
648 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
649 |
+
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
650 |
+
|
651 |
+
# (batch, head, time1, d_k)
|
652 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
653 |
+
# (batch, head, time1, d_k)
|
654 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
655 |
+
|
656 |
+
# compute attention score
|
657 |
+
# first compute matrix a and matrix c
|
658 |
+
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
659 |
+
# (batch, head, time1, time2)
|
660 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
661 |
+
|
662 |
+
# compute matrix b and matrix d
|
663 |
+
# (batch, head, time1, time2)
|
664 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
665 |
+
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
666 |
+
if matrix_ac.shape != matrix_bd.shape:
|
667 |
+
matrix_bd = self.rel_shift(matrix_bd)
|
668 |
+
|
669 |
+
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
670 |
+
self.d_k) # (batch, head, time1, time2)
|
671 |
+
|
672 |
+
return self.forward_attention(v, scores, mask), new_cache
|
673 |
+
|
674 |
+
|
675 |
+
class PositionwiseFeedForward(torch.nn.Module):
|
676 |
+
"""Positionwise feed forward layer.
|
677 |
+
|
678 |
+
FeedForward are appied on each position of the sequence.
|
679 |
+
The output dim is same with the input dim.
|
680 |
+
|
681 |
+
Args:
|
682 |
+
idim (int): Input dimenstion.
|
683 |
+
hidden_units (int): The number of hidden units.
|
684 |
+
dropout_rate (float): Dropout rate.
|
685 |
+
activation (torch.nn.Module): Activation function
|
686 |
+
"""
|
687 |
+
|
688 |
+
def __init__(
|
689 |
+
self,
|
690 |
+
idim: int,
|
691 |
+
hidden_units: int,
|
692 |
+
dropout_rate: float,
|
693 |
+
activation: torch.nn.Module = torch.nn.ReLU(),
|
694 |
+
):
|
695 |
+
super(PositionwiseFeedForward, self).__init__()
|
696 |
+
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
697 |
+
self.activation = activation
|
698 |
+
self.dropout = torch.nn.Dropout(dropout_rate)
|
699 |
+
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
700 |
+
|
701 |
+
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
702 |
+
"""Forward function.
|
703 |
+
|
704 |
+
Args:
|
705 |
+
xs: input tensor (B, L, D)
|
706 |
+
Returns:
|
707 |
+
output tensor, (B, L, D)
|
708 |
+
"""
|
709 |
+
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
710 |
+
|
711 |
+
|
712 |
+
class ConformerEncoderLayer(nn.Module):
|
713 |
+
"""Encoder layer module.
|
714 |
+
Args:
|
715 |
+
size (int): Input dimension.
|
716 |
+
self_attn (torch.nn.Module): Self-attention module instance.
|
717 |
+
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
718 |
+
instance can be used as the argument.
|
719 |
+
feed_forward (torch.nn.Module): Feed-forward module instance.
|
720 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
721 |
+
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
722 |
+
instance.
|
723 |
+
`PositionwiseFeedForward` instance can be used as the argument.
|
724 |
+
conv_module (torch.nn.Module): Convolution module instance.
|
725 |
+
`ConvlutionModule` instance can be used as the argument.
|
726 |
+
dropout_rate (float): Dropout rate.
|
727 |
+
normalize_before (bool):
|
728 |
+
True: use layer_norm before each sub-block.
|
729 |
+
False: use layer_norm after each sub-block.
|
730 |
+
"""
|
731 |
+
|
732 |
+
def __init__(
|
733 |
+
self,
|
734 |
+
size: int,
|
735 |
+
self_attn: torch.nn.Module,
|
736 |
+
feed_forward: Optional[nn.Module] = None,
|
737 |
+
feed_forward_macaron: Optional[nn.Module] = None,
|
738 |
+
conv_module: Optional[nn.Module] = None,
|
739 |
+
dropout_rate: float = 0.0,
|
740 |
+
normalize_before: bool = True,
|
741 |
+
):
|
742 |
+
super().__init__()
|
743 |
+
self.self_attn = self_attn
|
744 |
+
self.feed_forward = feed_forward
|
745 |
+
self.feed_forward_macaron = feed_forward_macaron
|
746 |
+
self.conv_module = conv_module
|
747 |
+
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
748 |
+
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
749 |
+
if feed_forward_macaron is not None:
|
750 |
+
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
751 |
+
self.ff_scale = 0.5
|
752 |
+
else:
|
753 |
+
self.ff_scale = 1.0
|
754 |
+
if self.conv_module is not None:
|
755 |
+
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
756 |
+
self.norm_final = nn.LayerNorm(
|
757 |
+
size, eps=1e-12) # for the final output of the block
|
758 |
+
self.dropout = nn.Dropout(dropout_rate)
|
759 |
+
self.size = size
|
760 |
+
self.normalize_before = normalize_before
|
761 |
+
|
762 |
+
def forward(
|
763 |
+
self,
|
764 |
+
x: torch.Tensor,
|
765 |
+
mask: torch.Tensor,
|
766 |
+
pos_emb: torch.Tensor,
|
767 |
+
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
768 |
+
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
769 |
+
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
770 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
771 |
+
"""Compute encoded features.
|
772 |
+
|
773 |
+
Args:
|
774 |
+
x (torch.Tensor): (#batch, time, size)
|
775 |
+
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
776 |
+
(0, 0, 0) means fake mask.
|
777 |
+
pos_emb (torch.Tensor): positional encoding, must not be None
|
778 |
+
for ConformerEncoderLayer.
|
779 |
+
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
780 |
+
(#batch, 1,time), (0, 0, 0) means fake mask.
|
781 |
+
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
782 |
+
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
783 |
+
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
784 |
+
(#batch=1, size, cache_t2)
|
785 |
+
Returns:
|
786 |
+
torch.Tensor: Output tensor (#batch, time, size).
|
787 |
+
torch.Tensor: Mask tensor (#batch, time, time).
|
788 |
+
torch.Tensor: att_cache tensor,
|
789 |
+
(#batch=1, head, cache_t1 + time, d_k * 2).
|
790 |
+
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
791 |
+
"""
|
792 |
+
|
793 |
+
# whether to use macaron style
|
794 |
+
if self.feed_forward_macaron is not None:
|
795 |
+
residual = x
|
796 |
+
if self.normalize_before:
|
797 |
+
x = self.norm_ff_macaron(x)
|
798 |
+
x = residual + self.ff_scale * self.dropout(
|
799 |
+
self.feed_forward_macaron(x))
|
800 |
+
if not self.normalize_before:
|
801 |
+
x = self.norm_ff_macaron(x)
|
802 |
+
|
803 |
+
# multi-headed self-attention module
|
804 |
+
residual = x
|
805 |
+
if self.normalize_before:
|
806 |
+
x = self.norm_mha(x)
|
807 |
+
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
808 |
+
att_cache)
|
809 |
+
x = residual + self.dropout(x_att)
|
810 |
+
if not self.normalize_before:
|
811 |
+
x = self.norm_mha(x)
|
812 |
+
|
813 |
+
# convolution module
|
814 |
+
# Fake new cnn cache here, and then change it in conv_module
|
815 |
+
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
816 |
+
if self.conv_module is not None:
|
817 |
+
residual = x
|
818 |
+
if self.normalize_before:
|
819 |
+
x = self.norm_conv(x)
|
820 |
+
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
821 |
+
x = residual + self.dropout(x)
|
822 |
+
|
823 |
+
if not self.normalize_before:
|
824 |
+
x = self.norm_conv(x)
|
825 |
+
|
826 |
+
# feed forward module
|
827 |
+
residual = x
|
828 |
+
if self.normalize_before:
|
829 |
+
x = self.norm_ff(x)
|
830 |
+
|
831 |
+
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
832 |
+
if not self.normalize_before:
|
833 |
+
x = self.norm_ff(x)
|
834 |
+
|
835 |
+
if self.conv_module is not None:
|
836 |
+
x = self.norm_final(x)
|
837 |
+
|
838 |
+
return x, mask, new_att_cache, new_cnn_cache
|
839 |
+
|
840 |
+
|
841 |
+
class UpsampleConformerEncoder(torch.nn.Module):
|
842 |
+
"""
|
843 |
+
Args:
|
844 |
+
input_size (int): input dim
|
845 |
+
output_size (int): dimension of attention
|
846 |
+
attention_heads (int): the number of heads of multi head attention
|
847 |
+
linear_units (int): the hidden units number of position-wise feed
|
848 |
+
forward
|
849 |
+
num_blocks (int): the number of decoder blocks
|
850 |
+
static_chunk_size (int): chunk size for static chunk training and
|
851 |
+
decoding
|
852 |
+
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
853 |
+
training or not, You can only use fixed chunk(chunk_size > 0)
|
854 |
+
or dyanmic chunk size(use_dynamic_chunk = True)
|
855 |
+
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
856 |
+
dynamic chunk training
|
857 |
+
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
858 |
+
"""
|
859 |
+
|
860 |
+
def __init__(
|
861 |
+
self,
|
862 |
+
input_size: int = 512,
|
863 |
+
output_size: int = 512,
|
864 |
+
attention_heads: int = 8,
|
865 |
+
linear_units: int = 2048,
|
866 |
+
num_blocks: int = 6,
|
867 |
+
static_chunk_size: int = 25,
|
868 |
+
use_dynamic_chunk: bool = False,
|
869 |
+
use_dynamic_left_chunk: bool = False,
|
870 |
+
key_bias: bool = True,
|
871 |
+
):
|
872 |
+
super().__init__()
|
873 |
+
self._output_size = output_size
|
874 |
+
|
875 |
+
self.embed = LinearNoSubsampling(
|
876 |
+
input_size, output_size,
|
877 |
+
EspnetRelPositionalEncoding(output_size),
|
878 |
+
)
|
879 |
+
|
880 |
+
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
881 |
+
self.static_chunk_size = static_chunk_size
|
882 |
+
self.use_dynamic_chunk = use_dynamic_chunk
|
883 |
+
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
884 |
+
activation = torch.nn.SiLU()
|
885 |
+
# self-attention module definition
|
886 |
+
encoder_selfattn_layer_args = (
|
887 |
+
attention_heads,
|
888 |
+
output_size,
|
889 |
+
0.0,
|
890 |
+
key_bias,
|
891 |
+
)
|
892 |
+
# feed-forward module definition
|
893 |
+
positionwise_layer_args = (
|
894 |
+
output_size,
|
895 |
+
linear_units,
|
896 |
+
0.0,
|
897 |
+
activation,
|
898 |
+
)
|
899 |
+
# convolution module definition
|
900 |
+
self.pre_lookahead_layer = PreLookaheadLayer(channels=512, pre_lookahead_len=3)
|
901 |
+
self.encoders = torch.nn.ModuleList([
|
902 |
+
ConformerEncoderLayer(
|
903 |
+
output_size,
|
904 |
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
905 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
906 |
+
) for _ in range(num_blocks)
|
907 |
+
])
|
908 |
+
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
909 |
+
self.up_embed = LinearNoSubsampling(
|
910 |
+
input_size, output_size,
|
911 |
+
EspnetRelPositionalEncoding(output_size),
|
912 |
+
)
|
913 |
+
self.up_encoders = torch.nn.ModuleList([
|
914 |
+
ConformerEncoderLayer(
|
915 |
+
output_size,
|
916 |
+
RelPositionMultiHeadedAttention(*encoder_selfattn_layer_args),
|
917 |
+
PositionwiseFeedForward(*positionwise_layer_args),
|
918 |
+
) for _ in range(4)
|
919 |
+
])
|
920 |
+
|
921 |
+
def output_size(self) -> int:
|
922 |
+
return self._output_size
|
923 |
+
|
924 |
+
def forward(
|
925 |
+
self,
|
926 |
+
xs: torch.Tensor,
|
927 |
+
xs_lens: torch.Tensor,
|
928 |
+
context: torch.Tensor = torch.zeros(0, 0, 0),
|
929 |
+
decoding_chunk_size: int = 0,
|
930 |
+
num_decoding_left_chunks: int = -1,
|
931 |
+
streaming: bool = False,
|
932 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
933 |
+
"""Embed positions in tensor.
|
934 |
+
|
935 |
+
Args:
|
936 |
+
xs: padded input tensor (B, T, D)
|
937 |
+
xs_lens: input length (B)
|
938 |
+
decoding_chunk_size: decoding chunk size for dynamic chunk
|
939 |
+
0: default for training, use random dynamic chunk.
|
940 |
+
<0: for decoding, use full chunk.
|
941 |
+
>0: for decoding, use fixed chunk size as set.
|
942 |
+
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
943 |
+
the chunk size is decoding_chunk_size.
|
944 |
+
>=0: use num_decoding_left_chunks
|
945 |
+
<0: use all left chunks
|
946 |
+
Returns:
|
947 |
+
encoder output tensor xs, and subsampled masks
|
948 |
+
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
949 |
+
masks: torch.Tensor batch padding mask after subsample
|
950 |
+
(B, 1, T' ~= T/subsample_rate)
|
951 |
+
NOTE(xcsong):
|
952 |
+
We pass the `__call__` method of the modules instead of `forward` to the
|
953 |
+
checkpointing API because `__call__` attaches all the hooks of the module.
|
954 |
+
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
955 |
+
"""
|
956 |
+
T = xs.size(1)
|
957 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
958 |
+
xs, pos_emb, masks = self.embed(xs, masks)
|
959 |
+
if context.size(1) != 0:
|
960 |
+
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
961 |
+
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
962 |
+
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
963 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
964 |
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
965 |
+
# lookahead + conformer encoder
|
966 |
+
xs = self.pre_lookahead_layer(xs, context=context)
|
967 |
+
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
968 |
+
|
969 |
+
# upsample + conformer encoder
|
970 |
+
xs = xs.transpose(1, 2).contiguous()
|
971 |
+
xs, xs_lens = self.up_layer(xs, xs_lens)
|
972 |
+
xs = xs.transpose(1, 2).contiguous()
|
973 |
+
T = xs.size(1)
|
974 |
+
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
975 |
+
xs, pos_emb, masks = self.up_embed(xs, masks)
|
976 |
+
mask_pad = masks # (B, 1, T/subsample_rate)
|
977 |
+
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
978 |
+
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
979 |
+
|
980 |
+
xs = self.after_norm(xs)
|
981 |
+
# Here we assume the mask is not changed in encoder layers, so just
|
982 |
+
# return the masks before encoder layers, and the masks will be used
|
983 |
+
# for cross attention with decoder later
|
984 |
+
return xs, masks
|
985 |
+
|
986 |
+
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
987 |
+
pos_emb: torch.Tensor,
|
988 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
989 |
+
for layer in self.encoders:
|
990 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
991 |
+
return xs
|
992 |
+
|
993 |
+
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
994 |
+
pos_emb: torch.Tensor,
|
995 |
+
mask_pad: torch.Tensor) -> torch.Tensor:
|
996 |
+
for layer in self.up_encoders:
|
997 |
+
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
998 |
+
return xs
|
flashcosyvoice/modules/hifigan.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""HIFI-GAN"""
|
16 |
+
|
17 |
+
from typing import Dict, List
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from scipy.signal import get_window
|
24 |
+
from torch.nn import Conv1d, ConvTranspose1d
|
25 |
+
from torch.nn.utils import remove_weight_norm
|
26 |
+
|
27 |
+
try:
|
28 |
+
from torch.nn.utils.parametrizations import weight_norm
|
29 |
+
except ImportError:
|
30 |
+
from torch.nn.utils import weight_norm # noqa
|
31 |
+
|
32 |
+
from flashcosyvoice.modules.hifigan_components.layers import (
|
33 |
+
ResBlock, SourceModuleHnNSF, SourceModuleHnNSF2, init_weights)
|
34 |
+
|
35 |
+
|
36 |
+
class ConvRNNF0Predictor(nn.Module):
|
37 |
+
def __init__(self,
|
38 |
+
num_class: int = 1,
|
39 |
+
in_channels: int = 80,
|
40 |
+
cond_channels: int = 512
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.num_class = num_class
|
45 |
+
self.condnet = nn.Sequential(
|
46 |
+
weight_norm( # noqa
|
47 |
+
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
48 |
+
),
|
49 |
+
nn.ELU(),
|
50 |
+
weight_norm( # noqa
|
51 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
52 |
+
),
|
53 |
+
nn.ELU(),
|
54 |
+
weight_norm( # noqa
|
55 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
56 |
+
),
|
57 |
+
nn.ELU(),
|
58 |
+
weight_norm( # noqa
|
59 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
60 |
+
),
|
61 |
+
nn.ELU(),
|
62 |
+
weight_norm( # noqa
|
63 |
+
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
64 |
+
),
|
65 |
+
nn.ELU(),
|
66 |
+
)
|
67 |
+
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
68 |
+
|
69 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
70 |
+
x = self.condnet(x)
|
71 |
+
x = x.transpose(1, 2)
|
72 |
+
return torch.abs(self.classifier(x).squeeze(-1))
|
73 |
+
|
74 |
+
|
75 |
+
class HiFTGenerator(nn.Module):
|
76 |
+
"""
|
77 |
+
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
78 |
+
https://arxiv.org/abs/2309.09493
|
79 |
+
"""
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
in_channels: int = 80,
|
83 |
+
base_channels: int = 512,
|
84 |
+
nb_harmonics: int = 8,
|
85 |
+
sampling_rate: int = 24000,
|
86 |
+
nsf_alpha: float = 0.1,
|
87 |
+
nsf_sigma: float = 0.003,
|
88 |
+
nsf_voiced_threshold: float = 10,
|
89 |
+
upsample_rates: List[int] = [8, 5, 3], # noqa
|
90 |
+
upsample_kernel_sizes: List[int] = [16, 11, 7], # noqa
|
91 |
+
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4}, # noqa
|
92 |
+
resblock_kernel_sizes: List[int] = [3, 7, 11], # noqa
|
93 |
+
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
|
94 |
+
source_resblock_kernel_sizes: List[int] = [7, 7, 11], # noqa
|
95 |
+
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]], # noqa
|
96 |
+
lrelu_slope: float = 0.1,
|
97 |
+
audio_limit: float = 0.99,
|
98 |
+
f0_predictor: torch.nn.Module = None,
|
99 |
+
):
|
100 |
+
super(HiFTGenerator, self).__init__()
|
101 |
+
|
102 |
+
self.out_channels = 1
|
103 |
+
self.nb_harmonics = nb_harmonics
|
104 |
+
self.sampling_rate = sampling_rate
|
105 |
+
self.istft_params = istft_params
|
106 |
+
self.lrelu_slope = lrelu_slope
|
107 |
+
self.audio_limit = audio_limit
|
108 |
+
|
109 |
+
self.num_kernels = len(resblock_kernel_sizes)
|
110 |
+
self.num_upsamples = len(upsample_rates)
|
111 |
+
# NOTE in CosyVoice2, we use the original SourceModuleHnNSF implementation
|
112 |
+
this_SourceModuleHnNSF = SourceModuleHnNSF if self.sampling_rate == 22050 else SourceModuleHnNSF2
|
113 |
+
self.m_source = this_SourceModuleHnNSF(
|
114 |
+
sampling_rate=sampling_rate,
|
115 |
+
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
116 |
+
harmonic_num=nb_harmonics,
|
117 |
+
sine_amp=nsf_alpha,
|
118 |
+
add_noise_std=nsf_sigma,
|
119 |
+
voiced_threshod=nsf_voiced_threshold)
|
120 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
121 |
+
|
122 |
+
self.conv_pre = weight_norm( # noqa
|
123 |
+
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
124 |
+
)
|
125 |
+
|
126 |
+
# Up
|
127 |
+
self.ups = nn.ModuleList()
|
128 |
+
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
129 |
+
self.ups.append(
|
130 |
+
weight_norm( # noqa
|
131 |
+
ConvTranspose1d(
|
132 |
+
base_channels // (2**i),
|
133 |
+
base_channels // (2**(i + 1)),
|
134 |
+
k,
|
135 |
+
u,
|
136 |
+
padding=(k - u) // 2,
|
137 |
+
)
|
138 |
+
)
|
139 |
+
)
|
140 |
+
|
141 |
+
# Down
|
142 |
+
self.source_downs = nn.ModuleList()
|
143 |
+
self.source_resblocks = nn.ModuleList()
|
144 |
+
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
145 |
+
downsample_cum_rates = np.cumprod(downsample_rates)
|
146 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
147 |
+
if u == 1:
|
148 |
+
self.source_downs.append(
|
149 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
150 |
+
)
|
151 |
+
else:
|
152 |
+
self.source_downs.append(
|
153 |
+
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
154 |
+
)
|
155 |
+
|
156 |
+
self.source_resblocks.append(
|
157 |
+
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
158 |
+
)
|
159 |
+
|
160 |
+
self.resblocks = nn.ModuleList()
|
161 |
+
for i in range(len(self.ups)):
|
162 |
+
ch = base_channels // (2**(i + 1))
|
163 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
164 |
+
self.resblocks.append(ResBlock(ch, k, d))
|
165 |
+
|
166 |
+
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3)) # noqa
|
167 |
+
self.ups.apply(init_weights)
|
168 |
+
self.conv_post.apply(init_weights)
|
169 |
+
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
170 |
+
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
171 |
+
self.f0_predictor = ConvRNNF0Predictor() if f0_predictor is None else f0_predictor
|
172 |
+
|
173 |
+
def remove_weight_norm(self):
|
174 |
+
print('Removing weight norm...')
|
175 |
+
for up in self.ups:
|
176 |
+
remove_weight_norm(up)
|
177 |
+
for resblock in self.resblocks:
|
178 |
+
resblock.remove_weight_norm()
|
179 |
+
remove_weight_norm(self.conv_pre)
|
180 |
+
remove_weight_norm(self.conv_post)
|
181 |
+
self.m_source.remove_weight_norm()
|
182 |
+
for source_down in self.source_downs:
|
183 |
+
remove_weight_norm(source_down)
|
184 |
+
for source_resblock in self.source_resblocks:
|
185 |
+
source_resblock.remove_weight_norm()
|
186 |
+
|
187 |
+
def _stft(self, x):
|
188 |
+
spec = torch.stft(
|
189 |
+
x,
|
190 |
+
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
191 |
+
return_complex=True)
|
192 |
+
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
193 |
+
return spec[..., 0], spec[..., 1]
|
194 |
+
|
195 |
+
def _istft(self, magnitude, phase):
|
196 |
+
magnitude = torch.clip(magnitude, max=1e2)
|
197 |
+
real = magnitude * torch.cos(phase)
|
198 |
+
img = magnitude * torch.sin(phase)
|
199 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
200 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
201 |
+
return inverse_transform
|
202 |
+
|
203 |
+
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
204 |
+
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
205 |
+
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
206 |
+
|
207 |
+
x = self.conv_pre(x)
|
208 |
+
for i in range(self.num_upsamples):
|
209 |
+
x = F.leaky_relu(x, self.lrelu_slope)
|
210 |
+
x = self.ups[i](x)
|
211 |
+
|
212 |
+
if i == self.num_upsamples - 1:
|
213 |
+
x = self.reflection_pad(x)
|
214 |
+
|
215 |
+
# fusion
|
216 |
+
si = self.source_downs[i](s_stft)
|
217 |
+
si = self.source_resblocks[i](si)
|
218 |
+
x = x + si
|
219 |
+
|
220 |
+
xs = None
|
221 |
+
for j in range(self.num_kernels):
|
222 |
+
if xs is None:
|
223 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
224 |
+
else:
|
225 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
226 |
+
x = xs / self.num_kernels
|
227 |
+
|
228 |
+
x = F.leaky_relu(x)
|
229 |
+
x = self.conv_post(x)
|
230 |
+
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
231 |
+
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
232 |
+
|
233 |
+
x = self._istft(magnitude, phase)
|
234 |
+
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
235 |
+
return x
|
236 |
+
|
237 |
+
@torch.inference_mode()
|
238 |
+
def forward(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
239 |
+
# mel->f0
|
240 |
+
f0 = self.f0_predictor(speech_feat)
|
241 |
+
# f0->source
|
242 |
+
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
243 |
+
s, _, _ = self.m_source(s)
|
244 |
+
s = s.transpose(1, 2)
|
245 |
+
# use cache_source to avoid glitch
|
246 |
+
if cache_source.shape[2] != 0:
|
247 |
+
s[:, :, :cache_source.shape[2]] = cache_source
|
248 |
+
generated_speech = self.decode(x=speech_feat, s=s)
|
249 |
+
return generated_speech, s
|
flashcosyvoice/modules/hifigan_components/__init__.py
ADDED
File without changes
|
flashcosyvoice/modules/hifigan_components/layers.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.distributions.uniform import Uniform
|
7 |
+
from torch.nn import Conv1d
|
8 |
+
from torch.nn.utils import remove_weight_norm
|
9 |
+
|
10 |
+
try:
|
11 |
+
from torch.nn.utils.parametrizations import weight_norm
|
12 |
+
except ImportError:
|
13 |
+
from torch.nn.utils import weight_norm # noqa
|
14 |
+
|
15 |
+
|
16 |
+
def get_padding(kernel_size, dilation=1):
|
17 |
+
return int((kernel_size * dilation - dilation) / 2)
|
18 |
+
|
19 |
+
|
20 |
+
def init_weights(m, mean=0.0, std=0.01):
|
21 |
+
classname = m.__class__.__name__
|
22 |
+
if classname.find("Conv") != -1:
|
23 |
+
m.weight.data.normal_(mean, std)
|
24 |
+
|
25 |
+
|
26 |
+
"""hifigan based generator implementation.
|
27 |
+
|
28 |
+
This code is modified from https://github.com/jik876/hifi-gan
|
29 |
+
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
30 |
+
https://github.com/NVIDIA/BigVGAN
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
36 |
+
# LICENSE is in incl_licenses directory.
|
37 |
+
class Snake(nn.Module):
|
38 |
+
'''
|
39 |
+
Implementation of a sine-based periodic activation function
|
40 |
+
Shape:
|
41 |
+
- Input: (B, C, T)
|
42 |
+
- Output: (B, C, T), same shape as the input
|
43 |
+
Parameters:
|
44 |
+
- alpha - trainable parameter
|
45 |
+
References:
|
46 |
+
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
47 |
+
https://arxiv.org/abs/2006.08195
|
48 |
+
Examples:
|
49 |
+
>>> a1 = snake(256)
|
50 |
+
>>> x = torch.randn(256)
|
51 |
+
>>> x = a1(x)
|
52 |
+
|
53 |
+
Args:
|
54 |
+
in_features: shape of the input
|
55 |
+
alpha: trainable parameter
|
56 |
+
alpha_trainable: whether alpha is trainable
|
57 |
+
alpha_logscale: whether to use log scale for alpha
|
58 |
+
alpha is initialized to 1 by default, higher values = higher-frequency.
|
59 |
+
alpha will be trained along with the rest of your model.
|
60 |
+
'''
|
61 |
+
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
62 |
+
super(Snake, self).__init__()
|
63 |
+
self.in_features = in_features
|
64 |
+
|
65 |
+
# initialize alpha
|
66 |
+
self.alpha_logscale = alpha_logscale
|
67 |
+
if self.alpha_logscale: # log scale alphas initialized to zeros
|
68 |
+
self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
|
69 |
+
else: # linear scale alphas initialized to ones
|
70 |
+
self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
|
71 |
+
|
72 |
+
self.alpha.requires_grad = alpha_trainable
|
73 |
+
|
74 |
+
self.no_div_by_zero = 0.000000001
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
'''
|
78 |
+
Forward pass of the function.
|
79 |
+
Applies the function to the input elementwise.
|
80 |
+
Snake ∶= x + 1/a * sin^2 (xa)
|
81 |
+
'''
|
82 |
+
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
83 |
+
if self.alpha_logscale:
|
84 |
+
alpha = torch.exp(alpha)
|
85 |
+
x = x + (1.0 / (alpha + self.no_div_by_zero)) * torch.pow(torch.sin(x * alpha), 2)
|
86 |
+
|
87 |
+
return x
|
88 |
+
|
89 |
+
|
90 |
+
class ResBlock(torch.nn.Module):
|
91 |
+
"""Residual block module in HiFiGAN/BigVGAN."""
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
channels: int = 512,
|
95 |
+
kernel_size: int = 3,
|
96 |
+
dilations: List[int] = [1, 3, 5], # noqa
|
97 |
+
):
|
98 |
+
super(ResBlock, self).__init__()
|
99 |
+
self.convs1 = nn.ModuleList()
|
100 |
+
self.convs2 = nn.ModuleList()
|
101 |
+
|
102 |
+
for dilation in dilations:
|
103 |
+
self.convs1.append(
|
104 |
+
weight_norm( # noqa
|
105 |
+
Conv1d(
|
106 |
+
channels,
|
107 |
+
channels,
|
108 |
+
kernel_size,
|
109 |
+
1,
|
110 |
+
dilation=dilation,
|
111 |
+
padding=get_padding(kernel_size, dilation)
|
112 |
+
)
|
113 |
+
)
|
114 |
+
)
|
115 |
+
self.convs2.append(
|
116 |
+
weight_norm( # noqa
|
117 |
+
Conv1d(
|
118 |
+
channels,
|
119 |
+
channels,
|
120 |
+
kernel_size,
|
121 |
+
1,
|
122 |
+
dilation=1,
|
123 |
+
padding=get_padding(kernel_size, 1)
|
124 |
+
)
|
125 |
+
)
|
126 |
+
)
|
127 |
+
self.convs1.apply(init_weights)
|
128 |
+
self.convs2.apply(init_weights)
|
129 |
+
self.activations1 = nn.ModuleList([
|
130 |
+
Snake(channels, alpha_logscale=False)
|
131 |
+
for _ in range(len(self.convs1))
|
132 |
+
])
|
133 |
+
self.activations2 = nn.ModuleList([
|
134 |
+
Snake(channels, alpha_logscale=False)
|
135 |
+
for _ in range(len(self.convs2))
|
136 |
+
])
|
137 |
+
|
138 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
139 |
+
for idx in range(len(self.convs1)):
|
140 |
+
xt = self.activations1[idx](x)
|
141 |
+
xt = self.convs1[idx](xt)
|
142 |
+
xt = self.activations2[idx](xt)
|
143 |
+
xt = self.convs2[idx](xt)
|
144 |
+
x = xt + x
|
145 |
+
return x
|
146 |
+
|
147 |
+
def remove_weight_norm(self):
|
148 |
+
for idx in range(len(self.convs1)):
|
149 |
+
remove_weight_norm(self.convs1[idx])
|
150 |
+
remove_weight_norm(self.convs2[idx])
|
151 |
+
|
152 |
+
|
153 |
+
class SineGen(torch.nn.Module):
|
154 |
+
""" Definition of sine generator
|
155 |
+
SineGen(samp_rate, harmonic_num = 0,
|
156 |
+
sine_amp = 0.1, noise_std = 0.003,
|
157 |
+
voiced_threshold = 0,
|
158 |
+
flag_for_pulse=False)
|
159 |
+
samp_rate: sampling rate in Hz
|
160 |
+
harmonic_num: number of harmonic overtones (default 0)
|
161 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
162 |
+
noise_std: std of Gaussian noise (default 0.003)
|
163 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
164 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
165 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
166 |
+
segment is always sin(np.pi) or cos(0)
|
167 |
+
"""
|
168 |
+
|
169 |
+
def __init__(self, samp_rate, harmonic_num=0,
|
170 |
+
sine_amp=0.1, noise_std=0.003,
|
171 |
+
voiced_threshold=0):
|
172 |
+
super(SineGen, self).__init__()
|
173 |
+
self.sine_amp = sine_amp
|
174 |
+
self.noise_std = noise_std
|
175 |
+
self.harmonic_num = harmonic_num
|
176 |
+
self.sampling_rate = samp_rate
|
177 |
+
self.voiced_threshold = voiced_threshold
|
178 |
+
|
179 |
+
def _f02uv(self, f0):
|
180 |
+
# generate uv signal
|
181 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
182 |
+
return uv
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def forward(self, f0):
|
186 |
+
"""
|
187 |
+
:param f0: [B, 1, sample_len], Hz
|
188 |
+
:return: [B, 1, sample_len]
|
189 |
+
"""
|
190 |
+
|
191 |
+
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
192 |
+
for i in range(self.harmonic_num + 1):
|
193 |
+
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
194 |
+
|
195 |
+
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
196 |
+
u_dist = Uniform(low=-np.pi, high=np.pi)
|
197 |
+
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
198 |
+
phase_vec[:, 0, :] = 0
|
199 |
+
|
200 |
+
# generate sine waveforms
|
201 |
+
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
202 |
+
|
203 |
+
# generate uv signal
|
204 |
+
uv = self._f02uv(f0)
|
205 |
+
|
206 |
+
# noise: for unvoiced should be similar to sine_amp
|
207 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
208 |
+
# . for voiced regions is self.noise_std
|
209 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
210 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
211 |
+
|
212 |
+
# first: set the unvoiced part to 0 by uv
|
213 |
+
# then: additive noise
|
214 |
+
sine_waves = sine_waves * uv + noise
|
215 |
+
return sine_waves, uv, noise
|
216 |
+
|
217 |
+
|
218 |
+
class SourceModuleHnNSF(torch.nn.Module):
|
219 |
+
""" SourceModule for hn-nsf
|
220 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
221 |
+
add_noise_std=0.003, voiced_threshod=0)
|
222 |
+
sampling_rate: sampling_rate in Hz
|
223 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
224 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
225 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
226 |
+
note that amplitude of noise in unvoiced is decided
|
227 |
+
by sine_amp
|
228 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
229 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
230 |
+
F0_sampled (batchsize, length, 1)
|
231 |
+
Sine_source (batchsize, length, 1)
|
232 |
+
noise_source (batchsize, length 1)
|
233 |
+
uv (batchsize, length, 1)
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
237 |
+
add_noise_std=0.003, voiced_threshod=0):
|
238 |
+
super(SourceModuleHnNSF, self).__init__()
|
239 |
+
|
240 |
+
self.sine_amp = sine_amp
|
241 |
+
self.noise_std = add_noise_std
|
242 |
+
|
243 |
+
# to produce sine waveforms
|
244 |
+
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
|
245 |
+
sine_amp, add_noise_std, voiced_threshod)
|
246 |
+
|
247 |
+
# to merge source harmonics into a single excitation
|
248 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
249 |
+
self.l_tanh = torch.nn.Tanh()
|
250 |
+
|
251 |
+
def forward(self, x):
|
252 |
+
"""
|
253 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
254 |
+
F0_sampled (batchsize, length, 1)
|
255 |
+
Sine_source (batchsize, length, 1)
|
256 |
+
noise_source (batchsize, length 1)
|
257 |
+
"""
|
258 |
+
# source for harmonic branch
|
259 |
+
with torch.no_grad():
|
260 |
+
sine_wavs, uv, _ = self.l_sin_gen(x.transpose(1, 2))
|
261 |
+
sine_wavs = sine_wavs.transpose(1, 2)
|
262 |
+
uv = uv.transpose(1, 2)
|
263 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
264 |
+
|
265 |
+
# source for noise branch, in the same shape as uv
|
266 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
267 |
+
return sine_merge, noise, uv
|
268 |
+
|
269 |
+
|
270 |
+
class SineGen2(torch.nn.Module):
|
271 |
+
""" Definition of sine generator
|
272 |
+
SineGen(samp_rate, harmonic_num = 0,
|
273 |
+
sine_amp = 0.1, noise_std = 0.003,
|
274 |
+
voiced_threshold = 0,
|
275 |
+
flag_for_pulse=False)
|
276 |
+
samp_rate: sampling rate in Hz
|
277 |
+
harmonic_num: number of harmonic overtones (default 0)
|
278 |
+
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
279 |
+
noise_std: std of Gaussian noise (default 0.003)
|
280 |
+
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
281 |
+
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
282 |
+
Note: when flag_for_pulse is True, the first time step of a voiced
|
283 |
+
segment is always sin(np.pi) or cos(0)
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
287 |
+
sine_amp=0.1, noise_std=0.003,
|
288 |
+
voiced_threshold=0,
|
289 |
+
flag_for_pulse=False):
|
290 |
+
super(SineGen2, self).__init__()
|
291 |
+
self.sine_amp = sine_amp
|
292 |
+
self.noise_std = noise_std
|
293 |
+
self.harmonic_num = harmonic_num
|
294 |
+
self.dim = self.harmonic_num + 1
|
295 |
+
self.sampling_rate = samp_rate
|
296 |
+
self.voiced_threshold = voiced_threshold
|
297 |
+
self.flag_for_pulse = flag_for_pulse
|
298 |
+
self.upsample_scale = upsample_scale
|
299 |
+
|
300 |
+
def _f02uv(self, f0):
|
301 |
+
# generate uv signal
|
302 |
+
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
303 |
+
return uv
|
304 |
+
|
305 |
+
def _f02sine(self, f0_values):
|
306 |
+
""" f0_values: (batchsize, length, dim)
|
307 |
+
where dim indicates fundamental tone and overtones
|
308 |
+
"""
|
309 |
+
# convert to F0 in rad. The interger part n can be ignored
|
310 |
+
# because 2 * np.pi * n doesn't affect phase
|
311 |
+
rad_values = (f0_values / self.sampling_rate) % 1
|
312 |
+
|
313 |
+
# initial phase noise (no noise for fundamental component)
|
314 |
+
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
315 |
+
rand_ini[:, 0] = 0
|
316 |
+
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
317 |
+
|
318 |
+
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
319 |
+
if not self.flag_for_pulse:
|
320 |
+
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
321 |
+
scale_factor=1 / self.upsample_scale,
|
322 |
+
mode="linear").transpose(1, 2)
|
323 |
+
|
324 |
+
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
325 |
+
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
326 |
+
scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
|
327 |
+
sines = torch.sin(phase)
|
328 |
+
else:
|
329 |
+
# If necessary, make sure that the first time step of every
|
330 |
+
# voiced segments is sin(pi) or cos(0)
|
331 |
+
# This is used for pulse-train generation
|
332 |
+
|
333 |
+
# identify the last time step in unvoiced segments
|
334 |
+
uv = self._f02uv(f0_values)
|
335 |
+
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
336 |
+
uv_1[:, -1, :] = 1
|
337 |
+
u_loc = (uv < 1) * (uv_1 > 0)
|
338 |
+
|
339 |
+
# get the instantanouse phase
|
340 |
+
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
341 |
+
# different batch needs to be processed differently
|
342 |
+
for idx in range(f0_values.shape[0]):
|
343 |
+
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
344 |
+
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
345 |
+
# stores the accumulation of i.phase within
|
346 |
+
# each voiced segments
|
347 |
+
tmp_cumsum[idx, :, :] = 0
|
348 |
+
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
349 |
+
|
350 |
+
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
351 |
+
# within the previous voiced segment.
|
352 |
+
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
353 |
+
|
354 |
+
# get the sines
|
355 |
+
sines = torch.cos(i_phase * 2 * np.pi)
|
356 |
+
return sines
|
357 |
+
|
358 |
+
def forward(self, f0):
|
359 |
+
""" sine_tensor, uv = forward(f0)
|
360 |
+
input F0: tensor(batchsize=1, length, dim=1)
|
361 |
+
f0 for unvoiced steps should be 0
|
362 |
+
output sine_tensor: tensor(batchsize=1, length, dim)
|
363 |
+
output uv: tensor(batchsize=1, length, 1)
|
364 |
+
"""
|
365 |
+
# fundamental component
|
366 |
+
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
367 |
+
|
368 |
+
# generate sine waveforms
|
369 |
+
sine_waves = self._f02sine(fn) * self.sine_amp
|
370 |
+
|
371 |
+
# generate uv signal
|
372 |
+
uv = self._f02uv(f0)
|
373 |
+
|
374 |
+
# noise: for unvoiced should be similar to sine_amp
|
375 |
+
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
376 |
+
# . for voiced regions is self.noise_std
|
377 |
+
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
378 |
+
noise = noise_amp * torch.randn_like(sine_waves)
|
379 |
+
|
380 |
+
# first: set the unvoiced part to 0 by uv
|
381 |
+
# then: additive noise
|
382 |
+
sine_waves = sine_waves * uv + noise
|
383 |
+
return sine_waves, uv, noise
|
384 |
+
|
385 |
+
|
386 |
+
class SourceModuleHnNSF2(torch.nn.Module):
|
387 |
+
""" SourceModule for hn-nsf
|
388 |
+
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
389 |
+
add_noise_std=0.003, voiced_threshod=0)
|
390 |
+
sampling_rate: sampling_rate in Hz
|
391 |
+
harmonic_num: number of harmonic above F0 (default: 0)
|
392 |
+
sine_amp: amplitude of sine source signal (default: 0.1)
|
393 |
+
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
394 |
+
note that amplitude of noise in unvoiced is decided
|
395 |
+
by sine_amp
|
396 |
+
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
397 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
398 |
+
F0_sampled (batchsize, length, 1)
|
399 |
+
Sine_source (batchsize, length, 1)
|
400 |
+
noise_source (batchsize, length 1)
|
401 |
+
uv (batchsize, length, 1)
|
402 |
+
"""
|
403 |
+
|
404 |
+
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
405 |
+
add_noise_std=0.003, voiced_threshod=0):
|
406 |
+
super(SourceModuleHnNSF2, self).__init__()
|
407 |
+
|
408 |
+
self.sine_amp = sine_amp
|
409 |
+
self.noise_std = add_noise_std
|
410 |
+
|
411 |
+
# to produce sine waveforms
|
412 |
+
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num,
|
413 |
+
sine_amp, add_noise_std, voiced_threshod)
|
414 |
+
|
415 |
+
# to merge source harmonics into a single excitation
|
416 |
+
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
417 |
+
self.l_tanh = torch.nn.Tanh()
|
418 |
+
|
419 |
+
def forward(self, x):
|
420 |
+
"""
|
421 |
+
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
422 |
+
F0_sampled (batchsize, length, 1)
|
423 |
+
Sine_source (batchsize, length, 1)
|
424 |
+
noise_source (batchsize, length 1)
|
425 |
+
"""
|
426 |
+
# source for harmonic branch
|
427 |
+
with torch.no_grad():
|
428 |
+
sine_wavs, uv, _ = self.l_sin_gen(x)
|
429 |
+
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
430 |
+
|
431 |
+
# source for noise branch, in the same shape as uv
|
432 |
+
noise = torch.randn_like(uv) * self.sine_amp / 3
|
433 |
+
return sine_merge, noise, uv
|
flashcosyvoice/modules/qwen2.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song)
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from transformers import AutoConfig
|
17 |
+
|
18 |
+
from flashcosyvoice.config import CosyVoice2LLMConfig
|
19 |
+
from flashcosyvoice.modules.qwen2_components.layers import (
|
20 |
+
ParallelLMHead, Qwen2DecoderLayer, RMSNorm, VocabParallelEmbedding)
|
21 |
+
|
22 |
+
|
23 |
+
class Qwen2Model(nn.Module):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
config: CosyVoice2LLMConfig,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.vocab_size = config.vocab_size
|
31 |
+
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
32 |
+
self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
33 |
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
34 |
+
|
35 |
+
def forward(
|
36 |
+
self,
|
37 |
+
input_ids: torch.Tensor,
|
38 |
+
positions: torch.Tensor,
|
39 |
+
) -> torch.Tensor:
|
40 |
+
hidden_states = self.embed_tokens(input_ids)
|
41 |
+
residual = None
|
42 |
+
for layer in self.layers:
|
43 |
+
hidden_states, residual = layer(
|
44 |
+
positions,
|
45 |
+
hidden_states,
|
46 |
+
residual,
|
47 |
+
)
|
48 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
49 |
+
return hidden_states
|
50 |
+
|
51 |
+
|
52 |
+
class Qwen2ForCausalLM(nn.Module):
|
53 |
+
packed_modules_mapping = {
|
54 |
+
"q_proj": ("qkv_proj", "q"),
|
55 |
+
"k_proj": ("qkv_proj", "k"),
|
56 |
+
"v_proj": ("qkv_proj", "v"),
|
57 |
+
"gate_proj": ("gate_up_proj", 0),
|
58 |
+
"up_proj": ("gate_up_proj", 1),
|
59 |
+
}
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
config: CosyVoice2LLMConfig | AutoConfig
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
self.model = Qwen2Model(config)
|
67 |
+
if hasattr(config, "speech_vocab_size"):
|
68 |
+
self.lm_head = ParallelLMHead(config.speech_vocab_size, config.hidden_size, bias=getattr(config, "lm_head_bias", True))
|
69 |
+
self.model_type = "speech_llm"
|
70 |
+
else:
|
71 |
+
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False)
|
72 |
+
self.model_type = "text_llm"
|
73 |
+
self.tie_word_embeddings = config.tie_word_embeddings
|
74 |
+
if self.tie_word_embeddings:
|
75 |
+
if self.model_type == "speech_llm":
|
76 |
+
assert config.vocab_size == config.speech_vocab_size, "vocab_size and speech_vocab_size must be the same when tie_word_embeddings is True"
|
77 |
+
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
78 |
+
|
79 |
+
def forward(
|
80 |
+
self,
|
81 |
+
input_ids: torch.Tensor,
|
82 |
+
positions: torch.Tensor,
|
83 |
+
) -> torch.Tensor:
|
84 |
+
hidden_states = self.model(input_ids, positions)
|
85 |
+
return hidden_states
|
86 |
+
|
87 |
+
def compute_logits(
|
88 |
+
self,
|
89 |
+
hidden_states: torch.Tensor,
|
90 |
+
) -> torch.Tensor:
|
91 |
+
logits = self.lm_head(hidden_states)
|
92 |
+
return logits
|
flashcosyvoice/modules/qwen2_components/__init__.py
ADDED
File without changes
|
flashcosyvoice/modules/qwen2_components/layers.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
10 |
+
|
11 |
+
from flashcosyvoice.config import CosyVoice2LLMConfig
|
12 |
+
from flashcosyvoice.utils.context import get_context
|
13 |
+
|
14 |
+
|
15 |
+
class SiluAndMul(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
@torch.compile
|
21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
22 |
+
x, y = x.chunk(2, -1)
|
23 |
+
return F.silu(x) * y
|
24 |
+
|
25 |
+
|
26 |
+
class RMSNorm(nn.Module):
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
hidden_size: int,
|
31 |
+
eps: float = 1e-6,
|
32 |
+
) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.hidden_size = hidden_size
|
35 |
+
self.eps = eps
|
36 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
37 |
+
|
38 |
+
@torch.compile
|
39 |
+
def rms_forward(
|
40 |
+
self,
|
41 |
+
x: torch.Tensor,
|
42 |
+
) -> torch.Tensor:
|
43 |
+
orig_dtype = x.dtype
|
44 |
+
x = x.to(torch.float32)
|
45 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
46 |
+
x.mul_(torch.rsqrt(var + self.eps))
|
47 |
+
x = x.to(orig_dtype).mul_(self.weight)
|
48 |
+
return x
|
49 |
+
|
50 |
+
@torch.compile
|
51 |
+
def add_rms_forward(
|
52 |
+
self,
|
53 |
+
x: torch.Tensor,
|
54 |
+
residual: torch.Tensor,
|
55 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
56 |
+
orig_dtype = x.dtype
|
57 |
+
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
58 |
+
residual = x.to(orig_dtype)
|
59 |
+
var = x.pow(2).mean(dim=-1, keepdim=True)
|
60 |
+
x.mul_(torch.rsqrt(var + self.eps))
|
61 |
+
x = x.to(orig_dtype).mul_(self.weight)
|
62 |
+
return x, residual
|
63 |
+
|
64 |
+
def forward(
|
65 |
+
self,
|
66 |
+
x: torch.Tensor,
|
67 |
+
residual: torch.Tensor | None = None,
|
68 |
+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
69 |
+
if residual is None:
|
70 |
+
return self.rms_forward(x)
|
71 |
+
else:
|
72 |
+
return self.add_rms_forward(x, residual)
|
73 |
+
|
74 |
+
|
75 |
+
@triton.jit
|
76 |
+
def store_kvcache_kernel(
|
77 |
+
key_ptr,
|
78 |
+
key_stride,
|
79 |
+
value_ptr,
|
80 |
+
value_stride,
|
81 |
+
k_cache_ptr,
|
82 |
+
v_cache_ptr,
|
83 |
+
slot_mapping_ptr,
|
84 |
+
D: tl.constexpr,
|
85 |
+
):
|
86 |
+
idx = tl.program_id(0)
|
87 |
+
key_offsets = idx * key_stride + tl.arange(0, D)
|
88 |
+
value_offsets = idx * value_stride + tl.arange(0, D)
|
89 |
+
key = tl.load(key_ptr + key_offsets)
|
90 |
+
value = tl.load(value_ptr + value_offsets)
|
91 |
+
slot = tl.load(slot_mapping_ptr + idx)
|
92 |
+
cache_offsets = slot * D + tl.arange(0, D)
|
93 |
+
tl.store(k_cache_ptr + cache_offsets, key)
|
94 |
+
tl.store(v_cache_ptr + cache_offsets, value)
|
95 |
+
|
96 |
+
|
97 |
+
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
98 |
+
N, num_heads, head_dim = key.shape
|
99 |
+
D = num_heads * head_dim
|
100 |
+
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
101 |
+
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
102 |
+
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
103 |
+
assert slot_mapping.numel() == N
|
104 |
+
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
105 |
+
|
106 |
+
|
107 |
+
class Attention(nn.Module):
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
num_heads,
|
112 |
+
head_dim,
|
113 |
+
scale,
|
114 |
+
num_kv_heads,
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.num_heads = num_heads
|
118 |
+
self.head_dim = head_dim
|
119 |
+
self.scale = scale
|
120 |
+
self.num_kv_heads = num_kv_heads
|
121 |
+
self.k_cache = self.v_cache = torch.tensor([])
|
122 |
+
|
123 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
124 |
+
o: torch.Tensor
|
125 |
+
q = q.view(-1, self.num_heads, self.head_dim)
|
126 |
+
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
127 |
+
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
128 |
+
context = get_context()
|
129 |
+
k_cache, v_cache = self.k_cache, self.v_cache
|
130 |
+
if k_cache.numel() and v_cache.numel():
|
131 |
+
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
132 |
+
if context.is_prefill:
|
133 |
+
if context.block_tables is not None: # prefix cache
|
134 |
+
k, v = k_cache, v_cache
|
135 |
+
o = flash_attn_varlen_func(q, k, v,
|
136 |
+
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
137 |
+
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
138 |
+
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
139 |
+
else: # decode
|
140 |
+
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
141 |
+
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
142 |
+
softmax_scale=self.scale, causal=True)
|
143 |
+
o = o.view(-1, self.num_heads * self.head_dim)
|
144 |
+
return o
|
145 |
+
|
146 |
+
|
147 |
+
class VocabParallelEmbedding(nn.Module):
|
148 |
+
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
num_embeddings: int,
|
152 |
+
embedding_dim: int,
|
153 |
+
):
|
154 |
+
super().__init__()
|
155 |
+
# TODO(xcsong): support tp > 1
|
156 |
+
self.tp_rank = 0 # dist.get_rank()
|
157 |
+
self.tp_size = 1 # dist.get_world_size()
|
158 |
+
assert num_embeddings % self.tp_size == 0
|
159 |
+
self.num_embeddings = num_embeddings
|
160 |
+
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
161 |
+
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
162 |
+
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
163 |
+
self.embedding_dim = embedding_dim
|
164 |
+
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
165 |
+
self.weight.weight_loader = self.weight_loader
|
166 |
+
|
167 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
168 |
+
param_data = param.data
|
169 |
+
shard_size = param_data.size(0)
|
170 |
+
start_idx = self.tp_rank * shard_size
|
171 |
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
172 |
+
assert param_data.size() == loaded_weight.size()
|
173 |
+
param_data.copy_(loaded_weight)
|
174 |
+
|
175 |
+
def forward(self, x: torch.Tensor):
|
176 |
+
if self.tp_size > 1:
|
177 |
+
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
178 |
+
x = mask * (x - self.vocab_start_idx)
|
179 |
+
y = F.embedding(x, self.weight)
|
180 |
+
if self.tp_size > 1:
|
181 |
+
y = mask.unsqueeze(1) * y
|
182 |
+
dist.all_reduce(y)
|
183 |
+
return y
|
184 |
+
|
185 |
+
|
186 |
+
class ParallelLMHead(VocabParallelEmbedding):
|
187 |
+
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
num_embeddings: int,
|
191 |
+
embedding_dim: int,
|
192 |
+
bias: bool = False,
|
193 |
+
):
|
194 |
+
super().__init__(num_embeddings, embedding_dim)
|
195 |
+
if bias:
|
196 |
+
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
197 |
+
self.bias.weight_loader = self.weight_loader
|
198 |
+
else:
|
199 |
+
self.register_parameter("bias", None)
|
200 |
+
|
201 |
+
def forward(self, x: torch.Tensor):
|
202 |
+
context = get_context()
|
203 |
+
if context.is_prefill:
|
204 |
+
last_indices = context.cu_seqlens_q[1:] - 1
|
205 |
+
x = x[last_indices].contiguous()
|
206 |
+
logits = F.linear(x, self.weight, self.bias)
|
207 |
+
if self.tp_size > 1:
|
208 |
+
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
209 |
+
dist.gather(logits, all_logits, 0)
|
210 |
+
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
211 |
+
return logits
|
212 |
+
|
213 |
+
|
214 |
+
def divide(numerator, denominator):
|
215 |
+
assert numerator % denominator == 0
|
216 |
+
return numerator // denominator
|
217 |
+
|
218 |
+
|
219 |
+
class LinearBase(nn.Module):
|
220 |
+
|
221 |
+
def __init__(
|
222 |
+
self,
|
223 |
+
input_size: int,
|
224 |
+
output_size: int,
|
225 |
+
tp_dim: int | None = None,
|
226 |
+
):
|
227 |
+
super().__init__()
|
228 |
+
self.input_size = input_size
|
229 |
+
self.output_size = output_size
|
230 |
+
self.tp_dim = tp_dim
|
231 |
+
# TODO(xcsong): support tp > 1
|
232 |
+
self.tp_rank = 0 # dist.get_rank()
|
233 |
+
self.tp_size = 1 # dist.get_world_size()
|
234 |
+
|
235 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
236 |
+
raise NotImplementedError
|
237 |
+
|
238 |
+
|
239 |
+
class ReplicatedLinear(LinearBase):
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
input_size: int,
|
244 |
+
output_size: int,
|
245 |
+
bias: bool = False,
|
246 |
+
):
|
247 |
+
super().__init__(input_size, output_size)
|
248 |
+
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
|
249 |
+
self.weight.weight_loader = self.weight_loader
|
250 |
+
if bias:
|
251 |
+
self.bias = nn.Parameter(torch.empty(self.output_size))
|
252 |
+
self.bias.weight_loader = self.weight_loader
|
253 |
+
else:
|
254 |
+
self.register_parameter("bias", None)
|
255 |
+
|
256 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
257 |
+
assert param.size() == loaded_weight.size()
|
258 |
+
param.data.copy_(loaded_weight)
|
259 |
+
|
260 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
261 |
+
return F.linear(x, self.weight, self.bias)
|
262 |
+
|
263 |
+
|
264 |
+
class ColumnParallelLinear(LinearBase):
|
265 |
+
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
input_size: int,
|
269 |
+
output_size: int,
|
270 |
+
bias: bool = False,
|
271 |
+
):
|
272 |
+
super().__init__(input_size, output_size, 0)
|
273 |
+
self.input_size_per_partition = input_size
|
274 |
+
self.output_size_per_partition = divide(output_size, self.tp_size)
|
275 |
+
self.output_partition_sizes = [self.output_size_per_partition]
|
276 |
+
if hasattr(self, "output_sizes"):
|
277 |
+
self.output_partition_sizes = [
|
278 |
+
divide(output_size, self.tp_size)
|
279 |
+
for output_size in self.output_sizes
|
280 |
+
]
|
281 |
+
|
282 |
+
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
283 |
+
self.weight.weight_loader = self.weight_loader
|
284 |
+
if bias:
|
285 |
+
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
|
286 |
+
self.bias.weight_loader = self.weight_loader
|
287 |
+
else:
|
288 |
+
self.register_parameter("bias", None)
|
289 |
+
|
290 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
291 |
+
param_data = param.data
|
292 |
+
shard_size = param_data.size(self.tp_dim)
|
293 |
+
start_idx = self.tp_rank * shard_size
|
294 |
+
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
295 |
+
assert param_data.size() == loaded_weight.size()
|
296 |
+
param_data.copy_(loaded_weight)
|
297 |
+
|
298 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
299 |
+
return F.linear(x, self.weight, self.bias)
|
300 |
+
|
301 |
+
|
302 |
+
class MergedColumnParallelLinear(ColumnParallelLinear):
|
303 |
+
|
304 |
+
def __init__(
|
305 |
+
self,
|
306 |
+
input_size: int,
|
307 |
+
output_sizes: list[int],
|
308 |
+
bias: bool = False,
|
309 |
+
):
|
310 |
+
self.output_sizes = output_sizes
|
311 |
+
super().__init__(input_size, sum(output_sizes), bias=bias)
|
312 |
+
|
313 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
314 |
+
param_data = param.data
|
315 |
+
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
316 |
+
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
317 |
+
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
318 |
+
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
319 |
+
assert param_data.size() == loaded_weight.size()
|
320 |
+
param_data.copy_(loaded_weight)
|
321 |
+
|
322 |
+
|
323 |
+
class QKVParallelLinear(ColumnParallelLinear):
|
324 |
+
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
hidden_size: int,
|
328 |
+
head_size: int,
|
329 |
+
total_num_heads: int,
|
330 |
+
total_num_kv_heads: int | None = None,
|
331 |
+
bias: bool = False,
|
332 |
+
):
|
333 |
+
self.hidden_size = hidden_size
|
334 |
+
self.head_size = head_size
|
335 |
+
self.total_num_heads = total_num_heads
|
336 |
+
if total_num_kv_heads is None:
|
337 |
+
total_num_kv_heads = total_num_heads
|
338 |
+
self.total_num_kv_heads = total_num_kv_heads
|
339 |
+
# TODO(xcsong): support tp > 1
|
340 |
+
tp_size = 1 # dist.get_world_size()
|
341 |
+
self.num_heads = divide(self.total_num_heads, tp_size)
|
342 |
+
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
343 |
+
input_size = self.hidden_size
|
344 |
+
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
345 |
+
self.output_sizes = [
|
346 |
+
self.num_heads * self.head_size * tp_size, # q_proj
|
347 |
+
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
348 |
+
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
349 |
+
]
|
350 |
+
|
351 |
+
super().__init__(input_size, output_size, bias)
|
352 |
+
|
353 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
354 |
+
param_data = param.data
|
355 |
+
assert loaded_shard_id in ["q", "k", "v"]
|
356 |
+
if loaded_shard_id == "q":
|
357 |
+
shard_size = self.num_heads * self.head_size
|
358 |
+
shard_offset = 0
|
359 |
+
elif loaded_shard_id == "k":
|
360 |
+
shard_size = self.num_kv_heads * self.head_size
|
361 |
+
shard_offset = self.num_heads * self.head_size
|
362 |
+
else:
|
363 |
+
shard_size = self.num_kv_heads * self.head_size
|
364 |
+
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
365 |
+
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
366 |
+
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
367 |
+
assert param_data.size() == loaded_weight.size()
|
368 |
+
param_data.copy_(loaded_weight)
|
369 |
+
|
370 |
+
|
371 |
+
class RowParallelLinear(LinearBase):
|
372 |
+
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
input_size: int,
|
376 |
+
output_size: int,
|
377 |
+
bias: bool = False,
|
378 |
+
):
|
379 |
+
super().__init__(input_size, output_size, 1)
|
380 |
+
self.input_size_per_partition = divide(input_size, self.tp_size)
|
381 |
+
self.output_size_per_partition = output_size
|
382 |
+
self.output_partition_sizes = [output_size]
|
383 |
+
|
384 |
+
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
385 |
+
self.weight.weight_loader = self.weight_loader
|
386 |
+
if bias:
|
387 |
+
self.bias = nn.Parameter(torch.empty(self.output_size))
|
388 |
+
self.bias.weight_loader = self.weight_loader
|
389 |
+
else:
|
390 |
+
self.register_parameter("bias", None)
|
391 |
+
|
392 |
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
393 |
+
param_data = param.data
|
394 |
+
shard_size = param_data.size(self.tp_dim)
|
395 |
+
start_idx = self.tp_rank * shard_size
|
396 |
+
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
397 |
+
assert param_data.size() == loaded_weight.size()
|
398 |
+
param_data.copy_(loaded_weight)
|
399 |
+
|
400 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
401 |
+
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
402 |
+
if self.tp_size > 1:
|
403 |
+
dist.all_reduce(y)
|
404 |
+
return y
|
405 |
+
|
406 |
+
|
407 |
+
def apply_rotary_emb(
|
408 |
+
x: torch.Tensor,
|
409 |
+
cos: torch.Tensor,
|
410 |
+
sin: torch.Tensor,
|
411 |
+
) -> torch.Tensor:
|
412 |
+
cos = cos.unsqueeze(-2)
|
413 |
+
sin = sin.unsqueeze(-2)
|
414 |
+
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
|
415 |
+
y1 = x1 * cos - x2 * sin
|
416 |
+
y2 = x2 * cos + x1 * sin
|
417 |
+
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
418 |
+
|
419 |
+
|
420 |
+
class RotaryEmbedding(nn.Module):
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
head_size: int,
|
425 |
+
rotary_dim: int,
|
426 |
+
max_position_embeddings: int,
|
427 |
+
base: float,
|
428 |
+
) -> None:
|
429 |
+
super().__init__()
|
430 |
+
self.head_size = head_size
|
431 |
+
assert rotary_dim == head_size
|
432 |
+
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
433 |
+
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
434 |
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
435 |
+
cos = freqs.cos()
|
436 |
+
sin = freqs.sin()
|
437 |
+
cache = torch.cat((cos, sin), dim=-1)
|
438 |
+
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
439 |
+
|
440 |
+
@torch.compile
|
441 |
+
def forward(
|
442 |
+
self,
|
443 |
+
positions: torch.Tensor,
|
444 |
+
query: torch.Tensor,
|
445 |
+
key: torch.Tensor,
|
446 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
447 |
+
positions = positions.flatten()
|
448 |
+
num_tokens = positions.shape[0]
|
449 |
+
cos_sin = self.cos_sin_cache[positions]
|
450 |
+
cos, sin = cos_sin.chunk(2, dim=-1)
|
451 |
+
query_shape = query.shape
|
452 |
+
query = query.view(num_tokens, -1, self.head_size)
|
453 |
+
query = apply_rotary_emb(query, cos, sin).view(query_shape)
|
454 |
+
key_shape = key.shape
|
455 |
+
key = key.view(num_tokens, -1, self.head_size)
|
456 |
+
key = apply_rotary_emb(key, cos, sin).view(key_shape)
|
457 |
+
return query, key
|
458 |
+
|
459 |
+
|
460 |
+
@lru_cache(1)
|
461 |
+
def get_rope(
|
462 |
+
head_size: int,
|
463 |
+
rotary_dim: int,
|
464 |
+
max_position: int,
|
465 |
+
base: float,
|
466 |
+
rope_scaling: dict | None = None,
|
467 |
+
):
|
468 |
+
assert rope_scaling is None
|
469 |
+
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
470 |
+
return rotary_emb
|
471 |
+
|
472 |
+
|
473 |
+
class Qwen2Attention(nn.Module):
|
474 |
+
|
475 |
+
def __init__(
|
476 |
+
self,
|
477 |
+
hidden_size: int,
|
478 |
+
num_heads: int,
|
479 |
+
num_kv_heads: int,
|
480 |
+
max_position: int = 4096 * 32,
|
481 |
+
head_dim: int | None = None,
|
482 |
+
rms_norm_eps: float = 1e-06,
|
483 |
+
qkv_bias: bool = True,
|
484 |
+
rope_theta: float = 1000000.0,
|
485 |
+
rope_scaling: tuple | None = None,
|
486 |
+
) -> None:
|
487 |
+
super().__init__()
|
488 |
+
self.hidden_size = hidden_size
|
489 |
+
# TODO(xcsong): support tp > 1
|
490 |
+
tp_size = 1 # dist.get_world_size()
|
491 |
+
self.total_num_heads = num_heads
|
492 |
+
assert self.total_num_heads % tp_size == 0
|
493 |
+
self.num_heads = self.total_num_heads // tp_size
|
494 |
+
self.total_num_kv_heads = num_kv_heads
|
495 |
+
assert self.total_num_kv_heads % tp_size == 0
|
496 |
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
497 |
+
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
498 |
+
self.q_size = self.num_heads * self.head_dim
|
499 |
+
self.kv_size = self.num_kv_heads * self.head_dim
|
500 |
+
self.scaling = self.head_dim**-0.5
|
501 |
+
self.rope_theta = rope_theta
|
502 |
+
|
503 |
+
self.qkv_proj = QKVParallelLinear(
|
504 |
+
hidden_size,
|
505 |
+
self.head_dim,
|
506 |
+
self.total_num_heads,
|
507 |
+
self.total_num_kv_heads,
|
508 |
+
bias=qkv_bias,
|
509 |
+
)
|
510 |
+
self.o_proj = RowParallelLinear(
|
511 |
+
self.total_num_heads * self.head_dim,
|
512 |
+
hidden_size,
|
513 |
+
bias=False,
|
514 |
+
)
|
515 |
+
|
516 |
+
self.rotary_emb = get_rope(
|
517 |
+
self.head_dim,
|
518 |
+
rotary_dim=self.head_dim,
|
519 |
+
max_position=max_position,
|
520 |
+
base=self.rope_theta,
|
521 |
+
rope_scaling=rope_scaling,
|
522 |
+
)
|
523 |
+
self.attn = Attention(self.num_heads,
|
524 |
+
self.head_dim,
|
525 |
+
self.scaling,
|
526 |
+
num_kv_heads=self.num_kv_heads)
|
527 |
+
|
528 |
+
def forward(
|
529 |
+
self,
|
530 |
+
positions: torch.Tensor,
|
531 |
+
hidden_states: torch.Tensor,
|
532 |
+
) -> torch.Tensor:
|
533 |
+
qkv = self.qkv_proj(hidden_states)
|
534 |
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
535 |
+
q, k = self.rotary_emb(positions, q, k)
|
536 |
+
o = self.attn(q, k, v)
|
537 |
+
output = self.o_proj(o)
|
538 |
+
return output
|
539 |
+
|
540 |
+
|
541 |
+
class Qwen2MLP(nn.Module):
|
542 |
+
|
543 |
+
def __init__(
|
544 |
+
self,
|
545 |
+
hidden_size: int,
|
546 |
+
intermediate_size: int,
|
547 |
+
hidden_act: str,
|
548 |
+
) -> None:
|
549 |
+
super().__init__()
|
550 |
+
self.gate_up_proj = MergedColumnParallelLinear(
|
551 |
+
hidden_size,
|
552 |
+
[intermediate_size] * 2,
|
553 |
+
bias=False,
|
554 |
+
)
|
555 |
+
self.down_proj = RowParallelLinear(
|
556 |
+
intermediate_size,
|
557 |
+
hidden_size,
|
558 |
+
bias=False,
|
559 |
+
)
|
560 |
+
assert hidden_act == "silu"
|
561 |
+
self.act_fn = SiluAndMul()
|
562 |
+
|
563 |
+
def forward(self, x):
|
564 |
+
gate_up = self.gate_up_proj(x)
|
565 |
+
x = self.act_fn(gate_up)
|
566 |
+
x = self.down_proj(x)
|
567 |
+
return x
|
568 |
+
|
569 |
+
|
570 |
+
class Qwen2DecoderLayer(nn.Module):
|
571 |
+
|
572 |
+
def __init__(
|
573 |
+
self,
|
574 |
+
config: CosyVoice2LLMConfig,
|
575 |
+
) -> None:
|
576 |
+
super().__init__()
|
577 |
+
self.hidden_size = config.hidden_size
|
578 |
+
self.self_attn = Qwen2Attention(
|
579 |
+
hidden_size=self.hidden_size,
|
580 |
+
num_heads=config.num_attention_heads,
|
581 |
+
num_kv_heads=config.num_key_value_heads,
|
582 |
+
max_position=config.max_position_embeddings,
|
583 |
+
rms_norm_eps=config.rms_norm_eps,
|
584 |
+
qkv_bias=getattr(config, "qkv_bias", True),
|
585 |
+
head_dim=getattr(config, "head_dim", None),
|
586 |
+
rope_theta=getattr(config, "rope_theta", 1000000.0),
|
587 |
+
rope_scaling=getattr(config, "rope_scaling", None),
|
588 |
+
)
|
589 |
+
self.mlp = Qwen2MLP(
|
590 |
+
hidden_size=config.hidden_size,
|
591 |
+
intermediate_size=config.intermediate_size,
|
592 |
+
hidden_act=config.hidden_act,
|
593 |
+
)
|
594 |
+
self.input_layernorm = RMSNorm(config.hidden_size,
|
595 |
+
eps=config.rms_norm_eps)
|
596 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
597 |
+
eps=config.rms_norm_eps)
|
598 |
+
|
599 |
+
def forward(
|
600 |
+
self,
|
601 |
+
positions: torch.Tensor,
|
602 |
+
hidden_states: torch.Tensor,
|
603 |
+
residual: torch.Tensor | None,
|
604 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
605 |
+
if residual is None:
|
606 |
+
residual = hidden_states
|
607 |
+
hidden_states = self.input_layernorm(hidden_states)
|
608 |
+
else:
|
609 |
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
610 |
+
hidden_states = self.self_attn(
|
611 |
+
positions=positions,
|
612 |
+
hidden_states=hidden_states,
|
613 |
+
)
|
614 |
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
615 |
+
hidden_states = self.mlp(hidden_states)
|
616 |
+
return hidden_states, residual
|
flashcosyvoice/modules/sampler.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class Sampler(nn.Module):
|
6 |
+
"""
|
7 |
+
Optimized sampler implementation using vectorized operations instead of loops, significantly improving performance
|
8 |
+
|
9 |
+
Performance optimizations:
|
10 |
+
1. Using batch processing instead of sequence loops, reducing Python loop overhead
|
11 |
+
2. Using PyTorch's vectorized operations (like torch.sort, torch.gather) for parallel computation
|
12 |
+
3. Using mask operations to apply top-k filtering at once, avoiding per-sequence processing
|
13 |
+
"""
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor, top_k: int = None):
|
18 |
+
"""
|
19 |
+
Perform sampling operation using vectorized method for top-k filtering
|
20 |
+
|
21 |
+
Args:
|
22 |
+
logits: Logits tensor with shape [batch_size, vocab_size]
|
23 |
+
temperatures: Temperature parameters with shape [batch_size]
|
24 |
+
top_k: Top-k value for filtering (uniform across all sequences)
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
Sampled token IDs
|
28 |
+
"""
|
29 |
+
logits = logits.to(torch.float)
|
30 |
+
greedy_tokens = logits.argmax(dim=-1) # Greedy decoding result, used when temperature=0
|
31 |
+
logits.div_(temperatures.unsqueeze(dim=1)) # Apply temperature scaling
|
32 |
+
|
33 |
+
# Apply uniform top-k filtering if top_k is provided
|
34 |
+
if top_k is not None and top_k > 0:
|
35 |
+
vocab_size = logits.size(-1)
|
36 |
+
|
37 |
+
# Create a mask to store which positions should be kept
|
38 |
+
mask = torch.zeros_like(logits, dtype=torch.bool)
|
39 |
+
|
40 |
+
# Batch sorting for all sequences at once
|
41 |
+
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
|
42 |
+
|
43 |
+
# Get threshold for each sequence (the k-th largest value)
|
44 |
+
k_value = min(top_k, vocab_size) # Ensure k doesn't exceed vocab size
|
45 |
+
thresholds = sorted_logits[:, k_value-1:k_value] # Shape [batch_size, 1]
|
46 |
+
thresholds = thresholds.expand(-1, vocab_size) # Expand to match logits shape
|
47 |
+
|
48 |
+
# Create mask: only keep logits greater than or equal to threshold
|
49 |
+
mask = logits >= thresholds
|
50 |
+
|
51 |
+
# Apply mask: set logits not in top-k to negative infinity
|
52 |
+
logits = torch.where(mask, logits, torch.tensor(float('-inf'), device=logits.device))
|
53 |
+
|
54 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
55 |
+
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
56 |
+
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
57 |
+
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
58 |
+
|
59 |
+
|
60 |
+
class RasSampler(nn.Module):
|
61 |
+
"""
|
62 |
+
Optimized Repetition Aware Sampling implementation
|
63 |
+
|
64 |
+
Performance optimizations:
|
65 |
+
1. Using vectorized nucleus sampling instead of loop implementation, improving sampling efficiency
|
66 |
+
2. Using tensor operations to calculate repetition rate, reducing Python loop overhead
|
67 |
+
3. Optimizing EOS handling logic, reducing unnecessary resampling
|
68 |
+
4. Using PyTorch's vectorized operations for parallel computation
|
69 |
+
5. Batch processing for all sequences, dramatically improving throughput
|
70 |
+
6. Robust handling for sequences of any length, including empty sequences
|
71 |
+
"""
|
72 |
+
def __init__(self):
|
73 |
+
super().__init__()
|
74 |
+
|
75 |
+
def forward(self, logits: torch.Tensor, decoded_tokens_list: list,
|
76 |
+
win_size: int = 10, tau_r: float = 0.1,
|
77 |
+
top_p: float = 0.8, top_k: int = 25,
|
78 |
+
eos_token: int = 6561, min_tokens: list[int] = None):
|
79 |
+
"""
|
80 |
+
Execute repetition-aware sampling using optimized vectorized operations with batch processing
|
81 |
+
|
82 |
+
Args:
|
83 |
+
logits: Input logits with shape [batch_size, vocab_size]
|
84 |
+
decoded_tokens_list: List of decoded tokens, each element is a token list for a batch
|
85 |
+
win_size: Window size for repetition detection (uniform across all batch items)
|
86 |
+
tau_r: Repetition threshold (uniform across all batch items)
|
87 |
+
top_p: Nucleus sampling probability threshold (uniform across all batch items)
|
88 |
+
top_k: Nucleus sampling top-k threshold (uniform across all batch items)
|
89 |
+
eos_token: End of sequence token ID (uniform across all batch items)
|
90 |
+
min_tokens: List of minimum tokens to generate before allowing EOS, one per batch item
|
91 |
+
Returns:
|
92 |
+
Selected token IDs
|
93 |
+
"""
|
94 |
+
batch_size = logits.size(0)
|
95 |
+
device = logits.device
|
96 |
+
result = torch.zeros(batch_size, dtype=torch.long, device=device)
|
97 |
+
|
98 |
+
# Set default values if not provided
|
99 |
+
if min_tokens is None:
|
100 |
+
min_tokens = [2] * batch_size
|
101 |
+
|
102 |
+
# Ensure min_tokens list has the correct length
|
103 |
+
assert len(min_tokens) == batch_size, f"min_tokens length {len(min_tokens)} != batch_size {batch_size}"
|
104 |
+
|
105 |
+
# Force continue decode first token
|
106 |
+
for i in range(batch_size):
|
107 |
+
if i < len(decoded_tokens_list) and len(decoded_tokens_list[i]) == 0:
|
108 |
+
logits[i, eos_token] = -float('inf')
|
109 |
+
|
110 |
+
# 1. First, perform nucleus sampling for all sequences
|
111 |
+
probs = torch.softmax(logits, dim=-1)
|
112 |
+
|
113 |
+
# Use vectorized nucleus sampling for all sequences
|
114 |
+
# This can be done in batch since top_p and top_k are uniform
|
115 |
+
sorted_probs, sorted_indices = probs.sort(dim=-1, descending=True)
|
116 |
+
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
|
117 |
+
|
118 |
+
# Create masks for top-p and top-k filtering
|
119 |
+
top_p_mask = cumulative_probs <= top_p
|
120 |
+
|
121 |
+
# Create top-k mask (first top_k positions are True)
|
122 |
+
top_k_mask = torch.zeros_like(top_p_mask)
|
123 |
+
top_k_mask[:, :top_k] = True
|
124 |
+
|
125 |
+
# Combine masks
|
126 |
+
mask = top_p_mask & top_k_mask
|
127 |
+
|
128 |
+
# Ensure at least one token is selected per sequence
|
129 |
+
first_token_mask = torch.zeros_like(mask)
|
130 |
+
first_token_mask[:, 0] = True
|
131 |
+
mask = mask | first_token_mask
|
132 |
+
|
133 |
+
# Sample from the filtered distribution
|
134 |
+
sample_probs = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs))
|
135 |
+
sample_probs = sample_probs / sample_probs.sum(dim=-1, keepdim=True)
|
136 |
+
|
137 |
+
# Sample indices from the filtered distribution
|
138 |
+
sampled_indices = torch.multinomial(sample_probs, 1).squeeze(-1)
|
139 |
+
top_ids = torch.gather(sorted_indices, -1, sampled_indices.unsqueeze(-1)).squeeze(-1)
|
140 |
+
|
141 |
+
# 2. Check for repetitions and apply random sampling if needed
|
142 |
+
# Extract recent tokens for each sequence, handling empty or short sequences
|
143 |
+
recent_tokens_list = []
|
144 |
+
for i in range(batch_size):
|
145 |
+
# Handle index out of range or empty tokens
|
146 |
+
if i < len(decoded_tokens_list):
|
147 |
+
tokens = decoded_tokens_list[i]
|
148 |
+
if len(tokens) > 0:
|
149 |
+
start_idx = max(0, len(tokens) - win_size)
|
150 |
+
recent_tokens_list.append(tokens[start_idx:])
|
151 |
+
else:
|
152 |
+
recent_tokens_list.append([]) # Empty list for empty tokens
|
153 |
+
else:
|
154 |
+
recent_tokens_list.append([]) # Empty list for missing batch items
|
155 |
+
|
156 |
+
# Check if we have any tokens to process for repetition detection
|
157 |
+
if any(len(tokens) > 0 for tokens in recent_tokens_list):
|
158 |
+
# Convert to padded tensor for batch processing
|
159 |
+
max_recent_len = max(len(tokens) for tokens in recent_tokens_list)
|
160 |
+
if max_recent_len > 0: # Only proceed if we have tokens
|
161 |
+
recent_tokens_tensor = torch.zeros((batch_size, max_recent_len), dtype=torch.long, device=device) - 1
|
162 |
+
for i, tokens in enumerate(recent_tokens_list):
|
163 |
+
if len(tokens) > 0:
|
164 |
+
recent_tokens_tensor[i, -len(tokens):] = torch.tensor(tokens, device=device)
|
165 |
+
|
166 |
+
# Create a mask for valid positions and to avoid division by zero
|
167 |
+
valid_positions_mask = torch.zeros_like(recent_tokens_tensor, dtype=torch.bool)
|
168 |
+
for i, tokens in enumerate(recent_tokens_list):
|
169 |
+
if len(tokens) > 0:
|
170 |
+
valid_positions_mask[i, -len(tokens):] = True
|
171 |
+
|
172 |
+
# Check repetition rates
|
173 |
+
repetition_counts = torch.zeros(batch_size, device=device)
|
174 |
+
for i in range(batch_size):
|
175 |
+
if len(recent_tokens_list[i]) > 0:
|
176 |
+
repetition_counts[i] = (recent_tokens_tensor[i] == top_ids[i]).sum()
|
177 |
+
|
178 |
+
# Calculate repetition rates, avoiding division by zero
|
179 |
+
recent_lengths = torch.tensor([max(1, len(tokens)) for tokens in recent_tokens_list], device=device)
|
180 |
+
repetition_rates = repetition_counts / recent_lengths
|
181 |
+
|
182 |
+
# Identify sequences needing random sampling
|
183 |
+
need_random = repetition_rates >= tau_r
|
184 |
+
|
185 |
+
# Apply random sampling where needed
|
186 |
+
if need_random.any():
|
187 |
+
random_indices = torch.multinomial(probs[need_random], 1).squeeze(-1)
|
188 |
+
top_ids[need_random] = random_indices
|
189 |
+
|
190 |
+
# 3. Handle EOS tokens
|
191 |
+
# Create mask for sequences that should ignore EOS tokens
|
192 |
+
ignore_eos_mask = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
193 |
+
for i in range(batch_size):
|
194 |
+
if i < len(decoded_tokens_list):
|
195 |
+
ignore_eos_mask[i] = len(decoded_tokens_list[i]) < min_tokens[i]
|
196 |
+
else:
|
197 |
+
ignore_eos_mask[i] = True # Default to ignoring EOS for missing sequences
|
198 |
+
|
199 |
+
is_eos_mask = top_ids == eos_token
|
200 |
+
need_resample = ignore_eos_mask & is_eos_mask
|
201 |
+
|
202 |
+
# Resample for sequences that need it
|
203 |
+
if need_resample.any():
|
204 |
+
max_trials = 100
|
205 |
+
for attempt in range(max_trials):
|
206 |
+
# Break if no more resampling needed
|
207 |
+
if not need_resample.any():
|
208 |
+
break
|
209 |
+
|
210 |
+
# Sample new tokens for sequences that need resampling
|
211 |
+
new_samples = torch.multinomial(probs[need_resample], 1).squeeze(-1)
|
212 |
+
|
213 |
+
# Update top_ids with new samples
|
214 |
+
top_ids[need_resample] = new_samples
|
215 |
+
|
216 |
+
# Update which sequences still need resampling
|
217 |
+
is_eos_mask = top_ids == eos_token
|
218 |
+
need_resample = ignore_eos_mask & is_eos_mask
|
219 |
+
|
220 |
+
# If still have EOS tokens that should be ignored, force them to be non-EOS
|
221 |
+
if need_resample.any():
|
222 |
+
# Force to a non-EOS token (e.g., the second most likely token)
|
223 |
+
for i in range(batch_size):
|
224 |
+
if need_resample[i]:
|
225 |
+
# Get second most likely token (or first if only one token)
|
226 |
+
second_best_idx = 1 if sorted_indices.size(1) > 1 else 0
|
227 |
+
top_ids[i] = sorted_indices[i, second_best_idx]
|
228 |
+
|
229 |
+
result = top_ids
|
230 |
+
|
231 |
+
return result
|
flashcosyvoice/utils/__init__.py
ADDED
File without changes
|
flashcosyvoice/utils/audio.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from librosa.filters import mel as librosa_mel_fn
|
4 |
+
from scipy.io.wavfile import read
|
5 |
+
|
6 |
+
MAX_WAV_VALUE = 32768.0
|
7 |
+
|
8 |
+
|
9 |
+
def load_wav(full_path):
|
10 |
+
sampling_rate, data = read(full_path)
|
11 |
+
return data, sampling_rate
|
12 |
+
|
13 |
+
|
14 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
15 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
16 |
+
|
17 |
+
|
18 |
+
def dynamic_range_decompression(x, C=1):
|
19 |
+
return np.exp(x) / C
|
20 |
+
|
21 |
+
|
22 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
23 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
24 |
+
|
25 |
+
|
26 |
+
def dynamic_range_decompression_torch(x, C=1):
|
27 |
+
return torch.exp(x) / C
|
28 |
+
|
29 |
+
|
30 |
+
def spectral_normalize_torch(magnitudes):
|
31 |
+
output = dynamic_range_compression_torch(magnitudes)
|
32 |
+
return output
|
33 |
+
|
34 |
+
|
35 |
+
def spectral_de_normalize_torch(magnitudes):
|
36 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
37 |
+
return output
|
38 |
+
|
39 |
+
|
40 |
+
mel_basis = {}
|
41 |
+
hann_window = {}
|
42 |
+
|
43 |
+
|
44 |
+
def mel_spectrogram(y, n_fft=1920, num_mels=80, sampling_rate=24000, hop_size=480,
|
45 |
+
win_size=1920, fmin=0, fmax=8000, center=False):
|
46 |
+
global mel_basis, hann_window # pylint: disable=global-statement
|
47 |
+
if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
|
48 |
+
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
49 |
+
mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
50 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
51 |
+
|
52 |
+
y = torch.nn.functional.pad(
|
53 |
+
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
54 |
+
)
|
55 |
+
y = y.squeeze(1)
|
56 |
+
|
57 |
+
spec = torch.view_as_real(
|
58 |
+
torch.stft(
|
59 |
+
y,
|
60 |
+
n_fft,
|
61 |
+
hop_length=hop_size,
|
62 |
+
win_length=win_size,
|
63 |
+
window=hann_window[str(y.device)],
|
64 |
+
center=center,
|
65 |
+
pad_mode="reflect",
|
66 |
+
normalized=False,
|
67 |
+
onesided=True,
|
68 |
+
return_complex=True,
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
73 |
+
|
74 |
+
spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
|
75 |
+
spec = spectral_normalize_torch(spec)
|
76 |
+
|
77 |
+
return spec
|
flashcosyvoice/utils/context.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Context:
|
8 |
+
is_prefill: bool = False
|
9 |
+
cu_seqlens_q: torch.Tensor | None = None
|
10 |
+
cu_seqlens_k: torch.Tensor | None = None
|
11 |
+
max_seqlen_q: int = 0
|
12 |
+
max_seqlen_k: int = 0
|
13 |
+
slot_mapping: torch.Tensor | None = None
|
14 |
+
context_lens: torch.Tensor | None = None
|
15 |
+
block_tables: torch.Tensor | None = None
|
16 |
+
|
17 |
+
_CONTEXT = Context()
|
18 |
+
|
19 |
+
def get_context():
|
20 |
+
return _CONTEXT
|
21 |
+
|
22 |
+
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
23 |
+
global _CONTEXT
|
24 |
+
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
25 |
+
|
26 |
+
def reset_context():
|
27 |
+
global _CONTEXT
|
28 |
+
_CONTEXT = Context()
|
flashcosyvoice/utils/loader.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from glob import glob
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from safetensors import safe_open
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from flashcosyvoice.config import CosyVoice2LLMConfig
|
9 |
+
|
10 |
+
|
11 |
+
def default_weight_loader(param: nn.Parameter, loaded_weight: torch.Tensor):
|
12 |
+
param.data.copy_(loaded_weight)
|
13 |
+
|
14 |
+
|
15 |
+
def load_text_llm(model: nn.Module, path: str):
|
16 |
+
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
17 |
+
for file in glob(os.path.join(path, "*.safetensors")):
|
18 |
+
with safe_open(file, "pt", "cpu") as f:
|
19 |
+
for weight_name in f.keys():
|
20 |
+
for k in packed_modules_mapping:
|
21 |
+
if k in weight_name:
|
22 |
+
v, shard_id = packed_modules_mapping[k]
|
23 |
+
param_name = weight_name.replace(k, v)
|
24 |
+
param = model.get_parameter(param_name)
|
25 |
+
weight_loader = param.weight_loader
|
26 |
+
weight_loader(param, f.get_tensor(weight_name), shard_id)
|
27 |
+
break
|
28 |
+
else:
|
29 |
+
param = model.get_parameter(weight_name)
|
30 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
31 |
+
weight_loader(param, f.get_tensor(weight_name))
|
32 |
+
|
33 |
+
|
34 |
+
def load_speech_llm(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig):
|
35 |
+
packed_modules_mapping = getattr(model, "packed_modules_mapping", {})
|
36 |
+
|
37 |
+
# NOTE(xcsong): 1. load speech embedding + sos/taskid embedding + lm head
|
38 |
+
embedding_weights = {}
|
39 |
+
tmp_weights = torch.load(f"{path}/llm.pt", map_location="cpu", weights_only=True)
|
40 |
+
missed, missed_names = 0, []
|
41 |
+
for k, v in tmp_weights.items():
|
42 |
+
if k == "speech_embedding.weight": # torch.Size([6564, 896])
|
43 |
+
speech_embedding_size = hf_config.speech_vocab_size # 6562
|
44 |
+
# NOTE(xcsong): padding to 6592 for vllm tensor parallel
|
45 |
+
if speech_embedding_size != v.shape[0]: # [6564, 896] -> [6562, 896]
|
46 |
+
assert speech_embedding_size <= v.shape[0], f"speech_embedding_size should be less than or equal to {v.shape[0]}, but got {speech_embedding_size}"
|
47 |
+
v = v[:speech_embedding_size, :]
|
48 |
+
embedding_weights["speech_embedding.weight"] = v
|
49 |
+
elif k == "llm_embedding.weight": # torch.Size([2, 896]), eos and task_id
|
50 |
+
assert v.shape[0] == 2, f"llm_embedding.weight should be of shape [2, 896], but got {v.shape}"
|
51 |
+
embedding_weights["llm_embedding.weight"] = v
|
52 |
+
elif k == "llm.model.model.embed_tokens.weight": # torch.Size([151936, 896])
|
53 |
+
embedding_weights["model.embed_tokens.weight"] = v
|
54 |
+
elif k == "llm_decoder.weight": # torch.Size([6564, 896])
|
55 |
+
lm_head_size = hf_config.speech_vocab_size # 6562
|
56 |
+
if lm_head_size != v.shape[0]: # [6564, 896] -> [6562, 896]
|
57 |
+
assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
|
58 |
+
v = v[:lm_head_size, :]
|
59 |
+
param = model.get_parameter("lm_head.weight")
|
60 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
61 |
+
weight_loader(param, v)
|
62 |
+
elif k == "llm_decoder.bias": # torch.Size([6564])
|
63 |
+
lm_head_size = hf_config.speech_vocab_size # 6562
|
64 |
+
if lm_head_size != v.shape[0]: # [6564] -> [6562]
|
65 |
+
assert lm_head_size <= v.shape[0], f"lm_head_size should be less than or equal to {v.shape[0]}, but got {lm_head_size}"
|
66 |
+
v = v[:lm_head_size]
|
67 |
+
param = model.get_parameter("lm_head.bias")
|
68 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
69 |
+
weight_loader(param, v)
|
70 |
+
elif "llm.model." in k:
|
71 |
+
weight_name = k.replace("llm.model.", "")
|
72 |
+
for kk in packed_modules_mapping:
|
73 |
+
if kk in weight_name:
|
74 |
+
vv, shard_id = packed_modules_mapping[kk]
|
75 |
+
param_name = weight_name.replace(kk, vv)
|
76 |
+
try:
|
77 |
+
param = model.get_parameter(param_name)
|
78 |
+
weight_loader = param.weight_loader
|
79 |
+
weight_loader(param, v, shard_id)
|
80 |
+
break
|
81 |
+
except Exception as e:
|
82 |
+
print(e)
|
83 |
+
print(f"skip parameter (1): {weight_name}")
|
84 |
+
continue
|
85 |
+
else:
|
86 |
+
try:
|
87 |
+
param = model.get_parameter(weight_name)
|
88 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
89 |
+
weight_loader(param, v)
|
90 |
+
except Exception as e:
|
91 |
+
print(e)
|
92 |
+
print(f"skip parameter (2): {weight_name}")
|
93 |
+
continue
|
94 |
+
else:
|
95 |
+
missed += 1
|
96 |
+
missed_names.append(weight_name)
|
97 |
+
continue
|
98 |
+
print(f"missed {missed} parameters: {missed_names}")
|
99 |
+
|
100 |
+
# NOTE(xcsong): 2. merge text embedding, sos/taskid embedding, and speech embedding
|
101 |
+
text_embedding_weight = embedding_weights["model.embed_tokens.weight"].cpu() # [151936, 896]
|
102 |
+
sos_taskid_embedding_weight = embedding_weights["llm_embedding.weight"].cpu() # [2, 896]
|
103 |
+
speech_embedding_weight = embedding_weights["speech_embedding.weight"].cpu() # [6562, 896]
|
104 |
+
final_embedding_weight = torch.cat([speech_embedding_weight, sos_taskid_embedding_weight, text_embedding_weight], dim=0) # [158500, 896]
|
105 |
+
param = model.get_parameter("model.embed_tokens.weight")
|
106 |
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
107 |
+
weight_loader(param, final_embedding_weight)
|
108 |
+
|
109 |
+
|
110 |
+
def load_model(model: nn.Module, path: str, hf_config: CosyVoice2LLMConfig | None = None):
|
111 |
+
if model.model_type == "speech_llm":
|
112 |
+
load_speech_llm(model, path, hf_config)
|
113 |
+
elif model.model_type == "text_llm":
|
114 |
+
load_text_llm(model, path)
|
115 |
+
else:
|
116 |
+
raise ValueError(f"Unsupported model type: {model.model_type}")
|
flashcosyvoice/utils/memory.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from pynvml import * # noqa
|
5 |
+
|
6 |
+
|
7 |
+
def get_gpu_memory():
|
8 |
+
torch.cuda.synchronize()
|
9 |
+
nvmlInit()
|
10 |
+
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
|
11 |
+
cuda_device_idx = torch.cuda.current_device()
|
12 |
+
cuda_device_idx = visible_device[cuda_device_idx]
|
13 |
+
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
|
14 |
+
mem_info = nvmlDeviceGetMemoryInfo(handle)
|
15 |
+
total_memory = mem_info.total
|
16 |
+
used_memory = mem_info.used
|
17 |
+
free_memory = mem_info.free
|
18 |
+
nvmlShutdown()
|
19 |
+
return total_memory, used_memory, free_memory
|
stepaudio2.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
3 |
+
|
4 |
+
from utils import compute_token_num, load_audio, log_mel_spectrogram, padding_mels
|
5 |
+
|
6 |
+
|
7 |
+
class StepAudio2Base:
|
8 |
+
|
9 |
+
def __init__(self, model_path: str):
|
10 |
+
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side="right")
|
11 |
+
self.llm = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
|
12 |
+
self.eos_token_id = self.llm_tokenizer.eos_token_id
|
13 |
+
|
14 |
+
def __call__(self, messages: list, **kwargs):
|
15 |
+
messages, mels = self.apply_chat_template(messages)
|
16 |
+
|
17 |
+
# Tokenize prompts
|
18 |
+
prompt_ids = []
|
19 |
+
for msg in messages:
|
20 |
+
if isinstance(msg, str):
|
21 |
+
prompt_ids.append(self.llm_tokenizer(text=msg, return_tensors="pt", padding=True)["input_ids"])
|
22 |
+
elif isinstance(msg, list):
|
23 |
+
prompt_ids.append(torch.tensor([msg], dtype=torch.int32))
|
24 |
+
else:
|
25 |
+
raise ValueError(f"Unsupported content type: {type(msg)}")
|
26 |
+
prompt_ids = torch.cat(prompt_ids, dim=-1).cuda()
|
27 |
+
attention_mask = torch.ones_like(prompt_ids)
|
28 |
+
|
29 |
+
#mels = None if len(mels) == 0 else torch.stack(mels).cuda()
|
30 |
+
#mel_lengths = None if mels is None else torch.tensor([mel.shape[1] - 2 for mel in mels], dtype=torch.int32, device='cuda')
|
31 |
+
if len(mels)==0:
|
32 |
+
mels = None
|
33 |
+
mel_lengths = None
|
34 |
+
else:
|
35 |
+
mels, mel_lengths = padding_mels(mels)
|
36 |
+
mels = mels.cuda()
|
37 |
+
mel_lengths = mel_lengths.cuda()
|
38 |
+
|
39 |
+
generate_inputs = {
|
40 |
+
"input_ids": prompt_ids,
|
41 |
+
"wavs": mels,
|
42 |
+
"wav_lens": mel_lengths,
|
43 |
+
"attention_mask":attention_mask
|
44 |
+
}
|
45 |
+
|
46 |
+
generation_config = dict(max_new_tokens=2048,
|
47 |
+
pad_token_id=self.llm_tokenizer.pad_token_id,
|
48 |
+
eos_token_id=self.eos_token_id,
|
49 |
+
)
|
50 |
+
generation_config.update(kwargs)
|
51 |
+
generation_config = GenerationConfig(**generation_config)
|
52 |
+
|
53 |
+
outputs = self.llm.generate(**generate_inputs, generation_config=generation_config)
|
54 |
+
output_token_ids = outputs[0, prompt_ids.shape[-1] : -1].tolist()
|
55 |
+
output_text_tokens = [i for i in output_token_ids if i < 151688]
|
56 |
+
output_audio_tokens = [i - 151696 for i in output_token_ids if i > 151695]
|
57 |
+
output_text = self.llm_tokenizer.decode(output_text_tokens)
|
58 |
+
return output_token_ids, output_text, output_audio_tokens
|
59 |
+
|
60 |
+
def apply_chat_template(self, messages: list):
|
61 |
+
results = []
|
62 |
+
mels = []
|
63 |
+
for msg in messages:
|
64 |
+
content = msg
|
65 |
+
if isinstance(content, str):
|
66 |
+
text_with_audio = content
|
67 |
+
results.append(text_with_audio)
|
68 |
+
elif isinstance(content, dict):
|
69 |
+
if content["type"] == "text":
|
70 |
+
results.append(f"{content['text']}")
|
71 |
+
elif content["type"] == "audio":
|
72 |
+
audio = load_audio(content['audio'])
|
73 |
+
for i in range(0, audio.shape[0], 16000 * 25):
|
74 |
+
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
|
75 |
+
mels.append(mel)
|
76 |
+
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
|
77 |
+
results.append(f"<audio_start>{audio_tokens}<audio_end>")
|
78 |
+
elif content["type"] == "token":
|
79 |
+
results.append(content["token"])
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Unsupported content type: {type(content)}")
|
82 |
+
# print(results)
|
83 |
+
return results, mels
|
84 |
+
|
85 |
+
|
86 |
+
class StepAudio2(StepAudio2Base):
|
87 |
+
|
88 |
+
def __init__(self, model_path: str):
|
89 |
+
super().__init__(model_path)
|
90 |
+
self.llm_tokenizer.eos_token = "<|EOT|>"
|
91 |
+
self.llm.config.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
|
92 |
+
self.eos_token_id = self.llm_tokenizer.convert_tokens_to_ids("<|EOT|>")
|
93 |
+
|
94 |
+
def apply_chat_template(self, messages: list):
|
95 |
+
results = []
|
96 |
+
mels = []
|
97 |
+
for msg in messages:
|
98 |
+
role = msg["role"]
|
99 |
+
content = msg["content"]
|
100 |
+
if role == "user":
|
101 |
+
role = "human"
|
102 |
+
if isinstance(content, str):
|
103 |
+
text_with_audio = f"<|BOT|>{role}\n{content}"
|
104 |
+
text_with_audio += '<|EOT|>' if msg.get('eot', True) else ''
|
105 |
+
results.append(text_with_audio)
|
106 |
+
elif isinstance(content, list):
|
107 |
+
results.append(f"<|BOT|>{role}\n")
|
108 |
+
for item in content:
|
109 |
+
if item["type"] == "text":
|
110 |
+
results.append(f"{item['text']}")
|
111 |
+
elif item["type"] == "audio":
|
112 |
+
audio = load_audio(item['audio'])
|
113 |
+
for i in range(0, audio.shape[0], 16000 * 25):
|
114 |
+
mel = log_mel_spectrogram(audio[i:i+16000*25], n_mels=128, padding=479)
|
115 |
+
mels.append(mel)
|
116 |
+
audio_tokens = "<audio_patch>" * compute_token_num(mel.shape[1])
|
117 |
+
results.append(f"<audio_start>{audio_tokens}<audio_end>")
|
118 |
+
elif item["type"] == "token":
|
119 |
+
results.append(item["token"])
|
120 |
+
if msg.get('eot', True):
|
121 |
+
results.append('<|EOT|>')
|
122 |
+
elif content is None:
|
123 |
+
results.append(f"<|BOT|>{role}\n")
|
124 |
+
else:
|
125 |
+
raise ValueError(f"Unsupported content type: {type(content)}")
|
126 |
+
# print(results)
|
127 |
+
return results, mels
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
from token2wav import Token2wav
|
131 |
+
|
132 |
+
model = StepAudio2('/mnt/gpfs/lijingbei/Step-Audio-2-mini')
|
133 |
+
token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav')
|
134 |
+
|
135 |
+
# Text-to-text conversation
|
136 |
+
print()
|
137 |
+
messages = [
|
138 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
139 |
+
{"role": "human", "content": "Give me a brief introduction to the Great Wall."},
|
140 |
+
{"role": "assistant", "content": None}
|
141 |
+
]
|
142 |
+
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
143 |
+
print(text)
|
144 |
+
|
145 |
+
# Text-to-speech conversation
|
146 |
+
print()
|
147 |
+
messages = [
|
148 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
149 |
+
{"role": "human", "content": "Give me a brief introduction to the Great Wall."},
|
150 |
+
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
|
151 |
+
]
|
152 |
+
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
153 |
+
print(text)
|
154 |
+
print(tokens)
|
155 |
+
audio = token2wav(audio, prompt_wav='assets/default_male.wav')
|
156 |
+
with open('output-male.wav', 'wb') as f:
|
157 |
+
f.write(audio)
|
158 |
+
|
159 |
+
# Speech-to-text conversation
|
160 |
+
print()
|
161 |
+
messages = [
|
162 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
163 |
+
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
|
164 |
+
{"role": "assistant", "content": None}
|
165 |
+
]
|
166 |
+
tokens, text, _ = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
167 |
+
print(text)
|
168 |
+
|
169 |
+
# Speech-to-speech conversation
|
170 |
+
print()
|
171 |
+
messages = [
|
172 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
173 |
+
{"role": "human", "content": [{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
|
174 |
+
{"role": "assistant", "content": "<tts_start>", "eot": False}, # Insert <tts_start> for speech response
|
175 |
+
]
|
176 |
+
tokens, text, audio = model(messages, max_new_tokens=4096, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
177 |
+
print(text)
|
178 |
+
print(tokens)
|
179 |
+
audio = token2wav(audio, prompt_wav='assets/default_female.wav')
|
180 |
+
with open('output-female.wav', 'wb') as f:
|
181 |
+
f.write(audio)
|
182 |
+
|
183 |
+
# Multi-turn conversation
|
184 |
+
print()
|
185 |
+
messages.pop(-1)
|
186 |
+
messages += [
|
187 |
+
{"role": "assistant", "content": [{"type": "text", "text": "<tts_start>"},
|
188 |
+
{"type": "token", "token": tokens}]},
|
189 |
+
{"role": "human", "content": "Now write a 4-line poem about it."},
|
190 |
+
{"role": "assistant", "content": None}
|
191 |
+
]
|
192 |
+
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
193 |
+
print(text)
|
194 |
+
|
195 |
+
# Multi-modal inputs
|
196 |
+
print()
|
197 |
+
messages = [
|
198 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
199 |
+
{"role": "human", "content": [{"type": "text", "text": "Translate the speech into Chinese."},
|
200 |
+
{"type": "audio", "audio": "assets/give_me_a_brief_introduction_to_the_great_wall.wav"}]},
|
201 |
+
{"role": "assistant", "content": None}
|
202 |
+
]
|
203 |
+
tokens, text, audio = model(messages, max_new_tokens=256, temperature=0.7, repetition_penalty=1.05, top_p=0.9, do_sample=True)
|
204 |
+
print(text)
|
token2wav.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torchaudio
|
5 |
+
import s3tokenizer
|
6 |
+
import onnxruntime
|
7 |
+
|
8 |
+
import torchaudio.compliance.kaldi as kaldi
|
9 |
+
from flashcosyvoice.modules.hifigan import HiFTGenerator
|
10 |
+
from flashcosyvoice.utils.audio import mel_spectrogram
|
11 |
+
from hyperpyyaml import load_hyperpyyaml
|
12 |
+
|
13 |
+
|
14 |
+
class Token2wav():
|
15 |
+
|
16 |
+
def __init__(self, model_path, float16=False):
|
17 |
+
self.float16 = float16
|
18 |
+
|
19 |
+
self.audio_tokenizer = s3tokenizer.load_model(f"{model_path}/speech_tokenizer_v2_25hz.onnx").cuda().eval()
|
20 |
+
|
21 |
+
option = onnxruntime.SessionOptions()
|
22 |
+
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
23 |
+
option.intra_op_num_threads = 1
|
24 |
+
self.spk_model = onnxruntime.InferenceSession(f"{model_path}/campplus.onnx", sess_options=option, providers=["CPUExecutionProvider"])
|
25 |
+
|
26 |
+
with open(f"{model_path}/flow.yaml", "r") as f:
|
27 |
+
configs = load_hyperpyyaml(f)
|
28 |
+
self.flow = configs['flow']
|
29 |
+
if float16:
|
30 |
+
self.flow.half()
|
31 |
+
self.flow.load_state_dict(torch.load(f"{model_path}/flow.pt", map_location="cpu", weights_only=True), strict=True)
|
32 |
+
self.flow.cuda().eval()
|
33 |
+
|
34 |
+
self.hift = HiFTGenerator()
|
35 |
+
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(f"{model_path}/hift.pt", map_location="cpu", weights_only=True).items()}
|
36 |
+
self.hift.load_state_dict(hift_state_dict, strict=True)
|
37 |
+
self.hift.cuda().eval()
|
38 |
+
|
39 |
+
def __call__(self, generated_speech_tokens, prompt_wav):
|
40 |
+
audio = s3tokenizer.load_audio(prompt_wav, sr=16000) # [T]
|
41 |
+
mels = s3tokenizer.log_mel_spectrogram(audio)
|
42 |
+
mels, mels_lens = s3tokenizer.padding([mels])
|
43 |
+
prompt_speech_tokens, prompt_speech_tokens_lens = self.audio_tokenizer.quantize(mels.cuda(), mels_lens.cuda())
|
44 |
+
|
45 |
+
spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000)
|
46 |
+
spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True)
|
47 |
+
spk_emb = torch.tensor(self.spk_model.run(
|
48 |
+
None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}
|
49 |
+
)[0], device='cuda')
|
50 |
+
|
51 |
+
audio, sample_rate = torchaudio.load(prompt_wav, backend='soundfile')
|
52 |
+
audio = audio.mean(dim=0, keepdim=True) # [1, T]
|
53 |
+
if sample_rate != 24000:
|
54 |
+
audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio)
|
55 |
+
prompt_mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels]
|
56 |
+
prompt_mels = prompt_mel.unsqueeze(0).cuda()
|
57 |
+
prompt_mels_lens = torch.tensor([prompt_mels.shape[1]], dtype=torch.int32, device='cuda')
|
58 |
+
|
59 |
+
generated_speech_tokens = torch.tensor([generated_speech_tokens], dtype=torch.int32, device='cuda')
|
60 |
+
generated_speech_tokens_lens = torch.tensor([generated_speech_tokens.shape[1]], dtype=torch.int32, device='cuda')
|
61 |
+
|
62 |
+
with torch.amp.autocast("cuda", dtype=torch.float16 if self.float16 else torch.float32):
|
63 |
+
mel = self.flow.inference(generated_speech_tokens, generated_speech_tokens_lens,
|
64 |
+
prompt_speech_tokens, prompt_speech_tokens_lens,
|
65 |
+
prompt_mels, prompt_mels_lens, spk_emb, 10)
|
66 |
+
|
67 |
+
wav, _ = self.hift(speech_feat=mel)
|
68 |
+
output = io.BytesIO()
|
69 |
+
torchaudio.save(output, wav.cpu(), sample_rate=24000, format='wav')
|
70 |
+
|
71 |
+
return output.getvalue()
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
token2wav = Token2wav('/mnt/gpfs/lijingbei/Step-Audio-2-mini/token2wav')
|
75 |
+
|
76 |
+
tokens = [1493, 4299, 4218, 2049, 528, 2752, 4850, 4569, 4575, 6372, 2127, 4068, 2312, 4993, 4769, 2300, 226, 2175, 2160, 2152, 6311, 6065, 4859, 5102, 4615, 6534, 6426, 1763, 2249, 2209, 5938, 1725, 6048, 3816, 6058, 958, 63, 4460, 5914, 2379, 735, 5319, 4593, 2328, 890, 35, 751, 1483, 1484, 1483, 2112, 303, 4753, 2301, 5507, 5588, 5261, 5744, 5501, 2341, 2001, 2252, 2344, 1860, 2031, 414, 4366, 4366, 6059, 5300, 4814, 5092, 5100, 1923, 3054, 4320, 4296, 2148, 4371, 5831, 5084, 5027, 4946, 4946, 2678, 575, 575, 521, 518, 638, 1367, 2804, 3402, 4299]
|
77 |
+
audio = token2wav(tokens, 'assets/default_male.wav')
|
78 |
+
with open('assets/give_me_a_brief_introduction_to_the_great_wall.wav', 'wb') as f:
|
79 |
+
f.write(audio)
|
utils.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils.rnn import pad_sequence
|
5 |
+
import torchaudio
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
|
9 |
+
def _mel_filters(n_mels: int) -> torch.Tensor:
|
10 |
+
"""Load the mel filterbank matrix for projecting STFT into a Mel spectrogram."""
|
11 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
12 |
+
if n_mels == 128:
|
13 |
+
return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128))
|
14 |
+
else:
|
15 |
+
return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))
|
16 |
+
|
17 |
+
def load_audio(file_path, target_rate=16000, max_length=None):
|
18 |
+
"""
|
19 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
20 |
+
If max_length is provided, truncate the audio to that length
|
21 |
+
"""
|
22 |
+
waveform, sample_rate = torchaudio.load(file_path)
|
23 |
+
if sample_rate != target_rate:
|
24 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
|
25 |
+
audio = waveform[0] # get the first channel
|
26 |
+
|
27 |
+
# Truncate audio if it exceeds max_length
|
28 |
+
if max_length is not None and audio.shape[0] > max_length:
|
29 |
+
audio = audio[:max_length]
|
30 |
+
|
31 |
+
return audio
|
32 |
+
|
33 |
+
def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None):
|
34 |
+
"""
|
35 |
+
Compute the log-Mel spectrogram with specific padding for StepAudio
|
36 |
+
"""
|
37 |
+
if not torch.is_tensor(audio):
|
38 |
+
if isinstance(audio, str):
|
39 |
+
audio = load_audio(audio)
|
40 |
+
audio = torch.from_numpy(audio)
|
41 |
+
if device is not None:
|
42 |
+
audio = audio.to(device)
|
43 |
+
if padding > 0:
|
44 |
+
audio = F.pad(audio, (0, padding))
|
45 |
+
window = torch.hann_window(400).to(audio.device)
|
46 |
+
stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
|
47 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
48 |
+
filters = _mel_filters(n_mels)
|
49 |
+
mel_spec = filters @ magnitudes
|
50 |
+
|
51 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
52 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
53 |
+
log_spec = (log_spec + 4.0) / 4.0
|
54 |
+
return log_spec
|
55 |
+
|
56 |
+
def compute_token_num(max_feature_len):
|
57 |
+
# First, audio goes through encoder:
|
58 |
+
# 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged
|
59 |
+
# 2. conv2: kernel=3, stride=2, padding=1 -> size/2
|
60 |
+
# 3. avg_pooler: kernel=2, stride=2 -> size/2
|
61 |
+
max_feature_len = max_feature_len - 2 # remove padding
|
62 |
+
encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler
|
63 |
+
|
64 |
+
# Then through adaptor (parameters from config file):
|
65 |
+
padding = 1
|
66 |
+
kernel_size = 3 # from config: audio_encoder_config.kernel_size
|
67 |
+
stride = 2 # from config: audio_encoder_config.adapter_stride
|
68 |
+
adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1
|
69 |
+
return adapter_output_dim
|
70 |
+
|
71 |
+
def padding_mels(data: List[torch.Tensor]):
|
72 |
+
""" Padding the data into batch data
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
data: List[Tensor], shape of Tensor (128, T)
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
-------
|
80 |
+
feats, feats lengths
|
81 |
+
"""
|
82 |
+
sample = data
|
83 |
+
assert isinstance(sample, list)
|
84 |
+
feats_lengths = torch.tensor([s.size(1)-2 for s in sample],
|
85 |
+
dtype=torch.int32)
|
86 |
+
feats = [s.t() for s in sample]
|
87 |
+
padded_feats = pad_sequence(feats,
|
88 |
+
batch_first=True,
|
89 |
+
padding_value=0)
|
90 |
+
|
91 |
+
return padded_feats.transpose(1, 2), feats_lengths
|