from typing import TYPE_CHECKING, List, Union from dataclasses import replace import torch import numpy as np from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult if TYPE_CHECKING: from whisper.model import Whisper def _suppress_ts(ts_logits: torch.Tensor, ts_token_mask: torch.Tensor = None): if ts_token_mask is not None: ts_logits[:, ts_token_mask] = -np.inf # modified version of whisper.decoding.DecodingTask class DecodingTaskStable(DecodingTask): def __init__(self, *args, **kwargs): self.ts_token_mask: torch.Tensor = kwargs.pop('ts_token_mask', None) self.audio_features: torch.Tensor = kwargs.pop('audio_features', None) super(DecodingTaskStable, self).__init__(*args, **kwargs) def _get_audio_features(self, mel: torch.Tensor): if self.audio_features is None: audio_features = super()._get_audio_features(mel) self.audio_features = audio_features.detach().clone() return audio_features return self.audio_features.clone() # modified version of whisper.DecodingTask._main_loop def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor): n_batch = tokens.shape[0] sum_logprobs: torch.Tensor = torch.zeros(n_batch, device=audio_features.device) no_speech_probs = [np.nan] * n_batch try: for i in range(self.sample_len): logits = self.inference.logits(tokens, audio_features) if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1) no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() # now we need to consider the logits at the last token only logits = logits[:, -1] # apply the logit filters, e.g. for suppressing or applying penalty to for logit_filter in self.logit_filters: logit_filter.apply(logits, tokens) # suppress timestamp tokens where the audio is silent so that decoder ignores those timestamps _suppress_ts(logits[:, self.tokenizer.timestamp_begin:], self.ts_token_mask) logits.nan_to_num_(-np.inf) # expand the tokens tensor with the selected next tokens tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) if completed or tokens.shape[-1] > self.n_ctx: break finally: self.inference.cleanup_caching() return tokens, sum_logprobs, no_speech_probs # modified version of whisper.decoding.decode @torch.no_grad() def decode_stable(model: "Whisper", mel: torch.Tensor, options: DecodingOptions = DecodingOptions(), ts_token_mask: torch.Tensor = None, audio_features: torch.Tensor = None, **kwargs, ) -> \ Union[DecodingResult, List[DecodingResult], tuple]: """ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s). Parameters ---------- model : whisper.model.Whisper An instance of Whisper ASR model. mel : torch.Tensor, A tensor containing the Mel spectrogram(s). ``mel.shape`` must be (80, 3000) or (*, 80, 3000). options : whisper.decode.DecodingOptions, default whisper.decode.DecodingOptions() A dataclass that contains all necessary options for decoding 30-second segments ts_token_mask : torch.Tensor, optional Mask for suppressing to timestamp token(s) for decoding. audio_features : torch.Tensor, optional Reused ``audio_feature`` from encoder for fallback. Returns ------- whisper.decode.DecodingResult or list whisper.decode.DecodingResult The result(s) of decoding contained in ``whisper.decode.DecodingResult`` dataclass instance(s). """ if single := mel.ndim == 2: mel = mel.unsqueeze(0) if kwargs: options = replace(options, **kwargs) task = DecodingTaskStable(model, options, ts_token_mask=ts_token_mask, audio_features=audio_features) result = task.run(mel) return result[0] if single else result, task.audio_features