|
from typing import TYPE_CHECKING, List, Union |
|
from dataclasses import replace |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult |
|
|
|
|
|
if TYPE_CHECKING: |
|
from whisper.model import Whisper |
|
|
|
|
|
def _suppress_ts(ts_logits: torch.Tensor, ts_token_mask: torch.Tensor = None): |
|
if ts_token_mask is not None: |
|
ts_logits[:, ts_token_mask] = -np.inf |
|
|
|
|
|
|
|
class DecodingTaskStable(DecodingTask): |
|
|
|
def __init__(self, *args, **kwargs): |
|
self.ts_token_mask: torch.Tensor = kwargs.pop('ts_token_mask', None) |
|
self.audio_features: torch.Tensor = kwargs.pop('audio_features', None) |
|
super(DecodingTaskStable, self).__init__(*args, **kwargs) |
|
|
|
def _get_audio_features(self, mel: torch.Tensor): |
|
if self.audio_features is None: |
|
audio_features = super()._get_audio_features(mel) |
|
self.audio_features = audio_features.detach().clone() |
|
return audio_features |
|
return self.audio_features.clone() |
|
|
|
|
|
def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor): |
|
n_batch = tokens.shape[0] |
|
sum_logprobs: torch.Tensor = torch.zeros(n_batch, device=audio_features.device) |
|
no_speech_probs = [np.nan] * n_batch |
|
|
|
try: |
|
for i in range(self.sample_len): |
|
logits = self.inference.logits(tokens, audio_features) |
|
|
|
if i == 0 and self.tokenizer.no_speech is not None: |
|
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) |
|
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() |
|
|
|
|
|
logits = logits[:, -1] |
|
|
|
|
|
for logit_filter in self.logit_filters: |
|
logit_filter.apply(logits, tokens) |
|
|
|
|
|
_suppress_ts(logits[:, self.tokenizer.timestamp_begin:], self.ts_token_mask) |
|
|
|
logits.nan_to_num_(-np.inf) |
|
|
|
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) |
|
|
|
if completed or tokens.shape[-1] > self.n_ctx: |
|
break |
|
finally: |
|
self.inference.cleanup_caching() |
|
|
|
return tokens, sum_logprobs, no_speech_probs |
|
|
|
|
|
|
|
@torch.no_grad() |
|
def decode_stable(model: "Whisper", |
|
mel: torch.Tensor, |
|
options: DecodingOptions = DecodingOptions(), |
|
ts_token_mask: torch.Tensor = None, |
|
audio_features: torch.Tensor = None, |
|
**kwargs, ) -> \ |
|
Union[DecodingResult, List[DecodingResult], tuple]: |
|
""" |
|
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). |
|
|
|
Parameters |
|
---------- |
|
model : whisper.model.Whisper |
|
An instance of Whisper ASR model. |
|
mel : torch.Tensor, |
|
A tensor containing the Mel spectrogram(s). ``mel.shape`` must be (80, 3000) or (*, 80, 3000). |
|
options : whisper.decode.DecodingOptions, default whisper.decode.DecodingOptions() |
|
A dataclass that contains all necessary options for decoding 30-second segments |
|
ts_token_mask : torch.Tensor, optional |
|
Mask for suppressing to timestamp token(s) for decoding. |
|
audio_features : torch.Tensor, optional |
|
Reused ``audio_feature`` from encoder for fallback. |
|
|
|
Returns |
|
------- |
|
whisper.decode.DecodingResult or list whisper.decode.DecodingResult |
|
The result(s) of decoding contained in ``whisper.decode.DecodingResult`` dataclass instance(s). |
|
""" |
|
if single := mel.ndim == 2: |
|
mel = mel.unsqueeze(0) |
|
|
|
if kwargs: |
|
options = replace(options, **kwargs) |
|
|
|
task = DecodingTaskStable(model, options, ts_token_mask=ts_token_mask, audio_features=audio_features) |
|
result = task.run(mel) |
|
|
|
return result[0] if single else result, task.audio_features |
|
|