Update moonshine to support batch decoding
Browse files- modeling_moonshine.py +4 -3
modeling_moonshine.py
CHANGED
|
@@ -398,7 +398,7 @@ class AudioPreprocessor(nn.Module):
|
|
| 398 |
assert (
|
| 399 |
src.shape[-1] >= 1023
|
| 400 |
), f"src shape[-1] {src.shape[-1]} should be at least 1023"
|
| 401 |
-
src = src.
|
| 402 |
return self.audio_preprocess(src)
|
| 403 |
|
| 404 |
|
|
@@ -435,7 +435,8 @@ class MoonshineModelTorch(nn.Module):
|
|
| 435 |
sot_token = 1
|
| 436 |
eot_token = 2
|
| 437 |
|
| 438 |
-
|
|
|
|
| 439 |
|
| 440 |
vals = self.decoder_initial(x=seq, enc_src=enc)
|
| 441 |
logits = vals[0]
|
|
@@ -448,7 +449,7 @@ class MoonshineModelTorch(nn.Module):
|
|
| 448 |
seq = torch.cat((seq, sample), dim=-1)
|
| 449 |
|
| 450 |
seq_len = int(src.shape[-1] * 6.5 / 16000)
|
| 451 |
-
while
|
| 452 |
vals = self.decoder(
|
| 453 |
seq,
|
| 454 |
*k_cache,
|
|
|
|
| 398 |
assert (
|
| 399 |
src.shape[-1] >= 1023
|
| 400 |
), f"src shape[-1] {src.shape[-1]} should be at least 1023"
|
| 401 |
+
src = src.reshape((-1, 1, src.shape[-1]))
|
| 402 |
return self.audio_preprocess(src)
|
| 403 |
|
| 404 |
|
|
|
|
| 435 |
sot_token = 1
|
| 436 |
eot_token = 2
|
| 437 |
|
| 438 |
+
sot_array = [[sot_token] for _ in range(enc.shape[0])]
|
| 439 |
+
seq = torch.as_tensor(sot_array).to(src.device)
|
| 440 |
|
| 441 |
vals = self.decoder_initial(x=seq, enc_src=enc)
|
| 442 |
logits = vals[0]
|
|
|
|
| 449 |
seq = torch.cat((seq, sample), dim=-1)
|
| 450 |
|
| 451 |
seq_len = int(src.shape[-1] * 6.5 / 16000)
|
| 452 |
+
while any([eot_token not in sub_seq for sub_seq in seq]) and seq.shape[-1] <= seq_len:
|
| 453 |
vals = self.decoder(
|
| 454 |
seq,
|
| 455 |
*k_cache,
|