File size: 4,347 Bytes
8718761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import TYPE_CHECKING, List, Union
from dataclasses import replace

import torch
import numpy as np

from whisper.decoding import DecodingTask, DecodingOptions, DecodingResult


if TYPE_CHECKING:
    from whisper.model import Whisper


def _suppress_ts(ts_logits: torch.Tensor, ts_token_mask: torch.Tensor = None):
    if ts_token_mask is not None:
        ts_logits[:, ts_token_mask] = -np.inf


# modified version of whisper.decoding.DecodingTask
class DecodingTaskStable(DecodingTask):

    def __init__(self, *args, **kwargs):
        self.ts_token_mask: torch.Tensor = kwargs.pop('ts_token_mask', None)
        self.audio_features: torch.Tensor = kwargs.pop('audio_features', None)
        super(DecodingTaskStable, self).__init__(*args, **kwargs)

    def _get_audio_features(self, mel: torch.Tensor):
        if self.audio_features is None:
            audio_features = super()._get_audio_features(mel)
            self.audio_features = audio_features.detach().clone()
            return audio_features
        return self.audio_features.clone()

    # modified version of whisper.DecodingTask._main_loop
    def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor):
        n_batch = tokens.shape[0]
        sum_logprobs: torch.Tensor = torch.zeros(n_batch, device=audio_features.device)
        no_speech_probs = [np.nan] * n_batch

        try:
            for i in range(self.sample_len):
                logits = self.inference.logits(tokens, audio_features)

                if i == 0 and self.tokenizer.no_speech is not None:  # save no_speech_probs
                    probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
                    no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()

                # now we need to consider the logits at the last token only
                logits = logits[:, -1]

                # apply the logit filters, e.g. for suppressing or applying penalty to
                for logit_filter in self.logit_filters:
                    logit_filter.apply(logits, tokens)

                # suppress timestamp tokens where the audio is silent so that decoder ignores those timestamps
                _suppress_ts(logits[:, self.tokenizer.timestamp_begin:], self.ts_token_mask)

                logits.nan_to_num_(-np.inf)
                # expand the tokens tensor with the selected next tokens
                tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)

                if completed or tokens.shape[-1] > self.n_ctx:
                    break
        finally:
            self.inference.cleanup_caching()

        return tokens, sum_logprobs, no_speech_probs


# modified version of whisper.decoding.decode
@torch.no_grad()
def decode_stable(model: "Whisper",
                  mel: torch.Tensor,
                  options: DecodingOptions = DecodingOptions(),
                  ts_token_mask: torch.Tensor = None,
                  audio_features: torch.Tensor = None,
                  **kwargs, ) -> \
        Union[DecodingResult, List[DecodingResult], tuple]:
    """
    Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).

    Parameters
    ----------
    model : whisper.model.Whisper
        An instance of Whisper ASR model.
    mel : torch.Tensor,
        A tensor containing the Mel spectrogram(s). ``mel.shape`` must be (80, 3000) or (*, 80, 3000).
    options : whisper.decode.DecodingOptions, default whisper.decode.DecodingOptions()
        A dataclass that contains all necessary options for decoding 30-second segments
    ts_token_mask : torch.Tensor, optional
        Mask for suppressing to timestamp token(s) for decoding.
    audio_features : torch.Tensor, optional
        Reused ``audio_feature`` from encoder for fallback.

    Returns
    -------
    whisper.decode.DecodingResult or list whisper.decode.DecodingResult
        The result(s) of decoding contained in ``whisper.decode.DecodingResult`` dataclass instance(s).
    """
    if single := mel.ndim == 2:
        mel = mel.unsqueeze(0)

    if kwargs:
        options = replace(options, **kwargs)

    task = DecodingTaskStable(model, options, ts_token_mask=ts_token_mask, audio_features=audio_features)
    result = task.run(mel)

    return result[0] if single else result, task.audio_features