Upload modeling_moonshine.py
Browse filesAdd full support for batching. Update decoding loop and input mask logic.
- modeling_moonshine.py +47 -22
modeling_moonshine.py
CHANGED
|
@@ -113,11 +113,11 @@ class MultiHeadCrossAttentionWithKVCache(MultiHeadAttention):
|
|
| 113 |
def __init__(self, dim, inner_dim, n_head):
|
| 114 |
super().__init__(dim, inner_dim, n_head)
|
| 115 |
|
| 116 |
-
def forward(self, q, k_cache, v_cache):
|
| 117 |
q = self.to_q(q)
|
| 118 |
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
| 119 |
|
| 120 |
-
return super().sdp_attention(q, k_cache, v_cache)
|
| 121 |
|
| 122 |
|
| 123 |
class FFLinearGelu(nn.Module):
|
|
@@ -162,10 +162,10 @@ class EncoderLayer(nn.Module):
|
|
| 162 |
|
| 163 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
| 164 |
|
| 165 |
-
def forward(self, x, rot_pos_emb):
|
| 166 |
_x = x
|
| 167 |
x = self.norm1(x)
|
| 168 |
-
x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb)
|
| 169 |
x = x + _x
|
| 170 |
|
| 171 |
_x = x
|
|
@@ -187,12 +187,12 @@ class Encoder(nn.Module):
|
|
| 187 |
)
|
| 188 |
self.post_norm = nn.LayerNorm(dim, bias=False)
|
| 189 |
|
| 190 |
-
def forward(self, x):
|
| 191 |
-
pos = torch.arange(x.shape[
|
| 192 |
rot_pos_emb = self.rot_pos_emb(pos)
|
| 193 |
|
| 194 |
-
for layer in self.layers:
|
| 195 |
-
x = layer(x, rot_pos_emb=rot_pos_emb)
|
| 196 |
return self.post_norm(x)
|
| 197 |
|
| 198 |
|
|
@@ -214,7 +214,7 @@ class DecoderLayer(nn.Module):
|
|
| 214 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
| 215 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
| 216 |
|
| 217 |
-
def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb):
|
| 218 |
dim = x.size()[1]
|
| 219 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
| 220 |
_x = x
|
|
@@ -232,7 +232,7 @@ class DecoderLayer(nn.Module):
|
|
| 232 |
|
| 233 |
_x = x
|
| 234 |
x = self.norm2(x)
|
| 235 |
-
x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache)
|
| 236 |
x = x + _x
|
| 237 |
|
| 238 |
_x = x
|
|
@@ -259,7 +259,7 @@ class Decoder(nn.Module):
|
|
| 259 |
self.final_norm = nn.LayerNorm(dim, bias=False)
|
| 260 |
self.token_embedding = nn.Embedding(dec_voc_size, dim)
|
| 261 |
|
| 262 |
-
def forward(self, x, *args):
|
| 263 |
pos = torch.arange(x.shape[1], device=x.device)
|
| 264 |
rot_pos_emb = self.rot_pos_emb(pos)
|
| 265 |
x = self.token_embedding(x)
|
|
@@ -279,6 +279,7 @@ class Decoder(nn.Module):
|
|
| 279 |
x_attn_k_cache=x_attn_k_cache[idx],
|
| 280 |
x_attn_v_cache=x_attn_v_cache[idx],
|
| 281 |
rot_pos_emb=rot_pos_emb,
|
|
|
|
| 282 |
)
|
| 283 |
k_cache_new.append(new_k_line)
|
| 284 |
v_cache_new.append(new_v_line)
|
|
@@ -306,7 +307,7 @@ class InitialDecoderLayer(nn.Module):
|
|
| 306 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
| 307 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
| 308 |
|
| 309 |
-
def forward(self, x, context, rot_pos_emb):
|
| 310 |
dim = x.size()[1]
|
| 311 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
| 312 |
_x = x
|
|
@@ -323,7 +324,7 @@ class InitialDecoderLayer(nn.Module):
|
|
| 323 |
_x = x
|
| 324 |
x = self.norm2(x)
|
| 325 |
x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
|
| 326 |
-
q=x, k=context, v=context
|
| 327 |
)
|
| 328 |
x = x + _x
|
| 329 |
|
|
@@ -345,7 +346,7 @@ class DecoderInitial(Decoder):
|
|
| 345 |
]
|
| 346 |
)
|
| 347 |
|
| 348 |
-
def forward(self, x, enc_src):
|
| 349 |
pos = torch.arange(x.shape[1], device=x.device)
|
| 350 |
rot_pos_emb = self.rot_pos_emb(pos)
|
| 351 |
x = self.token_embedding(x)
|
|
@@ -362,6 +363,7 @@ class DecoderInitial(Decoder):
|
|
| 362 |
x,
|
| 363 |
enc_src,
|
| 364 |
rot_pos_emb,
|
|
|
|
| 365 |
)
|
| 366 |
|
| 367 |
k_cache.append(new_k_line)
|
|
@@ -429,16 +431,34 @@ class MoonshineModelTorch(nn.Module):
|
|
| 429 |
self.n_head = n_head
|
| 430 |
self.d_head = inner_dim // n_head
|
| 431 |
|
| 432 |
-
def generate(self, src):
|
| 433 |
preprocessed = self.preprocessor(src)
|
| 434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
sot_token = 1
|
| 436 |
eot_token = 2
|
| 437 |
|
| 438 |
-
sot_array = [[sot_token] for _ in range(
|
| 439 |
seq = torch.as_tensor(sot_array).to(src.device)
|
| 440 |
|
| 441 |
-
vals = self.decoder_initial(x=seq, enc_src=enc)
|
| 442 |
logits = vals[0]
|
| 443 |
k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
|
| 444 |
vals[i : i + self.dec_depth]
|
|
@@ -448,10 +468,11 @@ class MoonshineModelTorch(nn.Module):
|
|
| 448 |
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
| 449 |
seq = torch.cat((seq, sample), dim=-1)
|
| 450 |
|
| 451 |
-
|
| 452 |
-
while
|
| 453 |
vals = self.decoder(
|
| 454 |
seq,
|
|
|
|
| 455 |
*k_cache,
|
| 456 |
*v_cache,
|
| 457 |
*x_attn_k_cache,
|
|
@@ -462,6 +483,10 @@ class MoonshineModelTorch(nn.Module):
|
|
| 462 |
v_cache = vals[self.dec_depth + 1 :]
|
| 463 |
logits = logits[:, -1] # get last token
|
| 464 |
sample = logits.argmax(dim=-1, keepdim=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
seq = torch.cat((seq, sample), dim=-1)
|
| 466 |
|
| 467 |
return seq
|
|
@@ -483,5 +508,5 @@ class MoonshineModel(PreTrainedModel):
|
|
| 483 |
dec_ff_swiglu = config.dec_ff_swiglu,
|
| 484 |
)
|
| 485 |
|
| 486 |
-
def forward(self, tensor):
|
| 487 |
-
return self.model.generate(tensor)
|
|
|
|
| 113 |
def __init__(self, dim, inner_dim, n_head):
|
| 114 |
super().__init__(dim, inner_dim, n_head)
|
| 115 |
|
| 116 |
+
def forward(self, q, k_cache, v_cache, mask):
|
| 117 |
q = self.to_q(q)
|
| 118 |
q = rearrange(q, "b n (h d) -> b h n d", h=self.n_head)
|
| 119 |
|
| 120 |
+
return super().sdp_attention(q, k_cache, v_cache, mask=mask)
|
| 121 |
|
| 122 |
|
| 123 |
class FFLinearGelu(nn.Module):
|
|
|
|
| 162 |
|
| 163 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
| 164 |
|
| 165 |
+
def forward(self, x, rot_pos_emb, mask):
|
| 166 |
_x = x
|
| 167 |
x = self.norm1(x)
|
| 168 |
+
x, _, _ = self.attention(q=x, k=x, v=x, rot_pos_emb=rot_pos_emb, mask=mask)
|
| 169 |
x = x + _x
|
| 170 |
|
| 171 |
_x = x
|
|
|
|
| 187 |
)
|
| 188 |
self.post_norm = nn.LayerNorm(dim, bias=False)
|
| 189 |
|
| 190 |
+
def forward(self, x, mask):
|
| 191 |
+
pos = torch.arange(x.shape[-2], device=x.device)
|
| 192 |
rot_pos_emb = self.rot_pos_emb(pos)
|
| 193 |
|
| 194 |
+
for idx, layer in enumerate(self.layers):
|
| 195 |
+
x = layer(x, rot_pos_emb=rot_pos_emb, mask=mask)
|
| 196 |
return self.post_norm(x)
|
| 197 |
|
| 198 |
|
|
|
|
| 214 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
| 215 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
| 216 |
|
| 217 |
+
def forward(self, x, k_cache, v_cache, x_attn_k_cache, x_attn_v_cache, rot_pos_emb, input_mask):
|
| 218 |
dim = x.size()[1]
|
| 219 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
| 220 |
_x = x
|
|
|
|
| 232 |
|
| 233 |
_x = x
|
| 234 |
x = self.norm2(x)
|
| 235 |
+
x = self.cross_attention(q=x, k_cache=x_attn_k_cache, v_cache=x_attn_v_cache, mask=input_mask)
|
| 236 |
x = x + _x
|
| 237 |
|
| 238 |
_x = x
|
|
|
|
| 259 |
self.final_norm = nn.LayerNorm(dim, bias=False)
|
| 260 |
self.token_embedding = nn.Embedding(dec_voc_size, dim)
|
| 261 |
|
| 262 |
+
def forward(self, x, input_mask, *args):
|
| 263 |
pos = torch.arange(x.shape[1], device=x.device)
|
| 264 |
rot_pos_emb = self.rot_pos_emb(pos)
|
| 265 |
x = self.token_embedding(x)
|
|
|
|
| 279 |
x_attn_k_cache=x_attn_k_cache[idx],
|
| 280 |
x_attn_v_cache=x_attn_v_cache[idx],
|
| 281 |
rot_pos_emb=rot_pos_emb,
|
| 282 |
+
input_mask=input_mask,
|
| 283 |
)
|
| 284 |
k_cache_new.append(new_k_line)
|
| 285 |
v_cache_new.append(new_v_line)
|
|
|
|
| 307 |
self.norm3 = nn.LayerNorm(dim, bias=False)
|
| 308 |
self.ff = FFSwiGLU(dim, ff_mult) if ff_swiglu else FFLinearGelu(dim, ff_mult)
|
| 309 |
|
| 310 |
+
def forward(self, x, context, rot_pos_emb, input_mask):
|
| 311 |
dim = x.size()[1]
|
| 312 |
causal_mask = torch.ones((dim, dim), dtype=torch.bool).triu(1).to(x.device)
|
| 313 |
_x = x
|
|
|
|
| 324 |
_x = x
|
| 325 |
x = self.norm2(x)
|
| 326 |
x, x_attn_k_cache, x_attn_v_cache = self.cross_attention(
|
| 327 |
+
q=x, k=context, v=context, mask=input_mask,
|
| 328 |
)
|
| 329 |
x = x + _x
|
| 330 |
|
|
|
|
| 346 |
]
|
| 347 |
)
|
| 348 |
|
| 349 |
+
def forward(self, x, enc_src, input_mask):
|
| 350 |
pos = torch.arange(x.shape[1], device=x.device)
|
| 351 |
rot_pos_emb = self.rot_pos_emb(pos)
|
| 352 |
x = self.token_embedding(x)
|
|
|
|
| 363 |
x,
|
| 364 |
enc_src,
|
| 365 |
rot_pos_emb,
|
| 366 |
+
input_mask,
|
| 367 |
)
|
| 368 |
|
| 369 |
k_cache.append(new_k_line)
|
|
|
|
| 431 |
self.n_head = n_head
|
| 432 |
self.d_head = inner_dim // n_head
|
| 433 |
|
| 434 |
+
def generate(self, src, mask):
|
| 435 |
preprocessed = self.preprocessor(src)
|
| 436 |
+
batch_size = preprocessed.shape[0]
|
| 437 |
+
|
| 438 |
+
# Get max sequence length based on number of unmasked inputs for each sample in batch.
|
| 439 |
+
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second.
|
| 440 |
+
if mask is not None:
|
| 441 |
+
seq_lens = torch.sum(mask, dim=-1, keepdim=True) * token_limit_factor
|
| 442 |
+
else:
|
| 443 |
+
token_limit = torch.tensor([src.shape[-1] * token_limit_factor])
|
| 444 |
+
seq_lens = torch.stack([token_limit for _ in range(batch_size)])
|
| 445 |
+
seq_lens = seq_lens.to(torch.int32).to(src.device).squeeze()
|
| 446 |
+
|
| 447 |
+
# Preprocess mask so that it matches preprocessed audio.
|
| 448 |
+
if mask is not None:
|
| 449 |
+
mask = mask[..., :-127:64][..., :-7:3][..., :-3:2].to(torch.bool)
|
| 450 |
+
mask = ~mask.reshape((batch_size, 1, 1, -1))
|
| 451 |
+
mask = torch.nn.functional.pad(mask, (0, preprocessed.shape[-2] - mask.shape[-1]))
|
| 452 |
+
|
| 453 |
+
enc = self.encoder(preprocessed, mask)
|
| 454 |
+
|
| 455 |
sot_token = 1
|
| 456 |
eot_token = 2
|
| 457 |
|
| 458 |
+
sot_array = [[sot_token] for _ in range(batch_size)]
|
| 459 |
seq = torch.as_tensor(sot_array).to(src.device)
|
| 460 |
|
| 461 |
+
vals = self.decoder_initial(x=seq, enc_src=enc, input_mask=mask)
|
| 462 |
logits = vals[0]
|
| 463 |
k_cache, v_cache, x_attn_k_cache, x_attn_v_cache = [
|
| 464 |
vals[i : i + self.dec_depth]
|
|
|
|
| 468 |
sample = logits[:, -1].argmax(dim=-1, keepdim=True)
|
| 469 |
seq = torch.cat((seq, sample), dim=-1)
|
| 470 |
|
| 471 |
+
eot_mask = torch.zeros((batch_size), dtype=torch.bool).to(src.device)
|
| 472 |
+
while not torch.all(eot_mask):
|
| 473 |
vals = self.decoder(
|
| 474 |
seq,
|
| 475 |
+
mask,
|
| 476 |
*k_cache,
|
| 477 |
*v_cache,
|
| 478 |
*x_attn_k_cache,
|
|
|
|
| 483 |
v_cache = vals[self.dec_depth + 1 :]
|
| 484 |
logits = logits[:, -1] # get last token
|
| 485 |
sample = logits.argmax(dim=-1, keepdim=True)
|
| 486 |
+
# For each sample in batch detect EOT or token limit reached.
|
| 487 |
+
eot_mask = eot_mask | (sample.squeeze() == eot_token)
|
| 488 |
+
eot_mask = eot_mask | (seq.shape[-1] >= seq_lens)
|
| 489 |
+
sample = sample.masked_fill(eot_mask.reshape((-1, 1)), eot_token)
|
| 490 |
seq = torch.cat((seq, sample), dim=-1)
|
| 491 |
|
| 492 |
return seq
|
|
|
|
| 508 |
dec_ff_swiglu = config.dec_ff_swiglu,
|
| 509 |
)
|
| 510 |
|
| 511 |
+
def forward(self, tensor, input_mask=None):
|
| 512 |
+
return self.model.generate(tensor, input_mask)
|