File size: 14,958 Bytes
8718761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import os
import warnings
import io
import torch
import torchaudio
import numpy as np
from typing import Union, Callable, Optional

from .audio import load_audio
from .result import WhisperResult

AUDIO_TYPES = ('str', 'byte', 'torch', 'numpy')


def transcribe_any(
        inference_func: Callable,
        audio: Union[str, np.ndarray, torch.Tensor, bytes],
        audio_type: str = None,
        input_sr: int = None,
        model_sr: int = None,
        inference_kwargs: dict = None,
        temp_file: str = None,
        verbose: Optional[bool] = False,
        regroup: Union[bool, str] = True,
        suppress_silence: bool = True,
        suppress_word_ts: bool = True,
        q_levels: int = 20,
        k_size: int = 5,
        demucs: bool = False,
        demucs_device: str = None,
        demucs_output: str = None,
        demucs_options: dict = None,
        vad: bool = False,
        vad_threshold: float = 0.35,
        vad_onnx: bool = False,
        min_word_dur: float = 0.1,
        nonspeech_error: float = 0.3,
        use_word_position: bool = True,
        only_voice_freq: bool = False,
        only_ffmpeg: bool = False,
        force_order: bool = False,
        check_sorted: bool = True
) -> WhisperResult:
    """
    Transcribe ``audio`` using any ASR system.

    Parameters
    ----------
    inference_func : Callable
        Function that runs ASR when provided the [audio] and return data in the appropriate format.
        For format examples see, https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb.
    audio : str or numpy.ndarray or torch.Tensor or bytes
        Path/URL to the audio file, the audio waveform, or bytes of audio file.
    audio_type : {'str', 'byte', 'torch', 'numpy', None}, default None, meaning same type as ``audio``
        The type that ``audio`` needs to be for ``inference_func``.
        'str' is a path to the file.
        'byte' is bytes (used for APIs or to avoid writing any data to hard drive).
        'torch' is an instance of :class:`torch.Tensor` containing the audio waveform, in float32 dtype, on CPU.
        'numpy' is an instance of :class:`numpy.ndarray` containing the audio waveform, in float32 dtype.
    input_sr : int, default None, meaning auto-detected if ``audio`` is ``str`` or ``bytes``
        The sample rate of ``audio``.
    model_sr : int, default None, meaning same sample rate as ``input_sr``
        The sample rate to resample the audio into for ``inference_func``.
    inference_kwargs : dict, optional
        Dictionary of arguments to pass into ``inference_func``.
    temp_file : str, default './_temp_stable-ts_audio_.wav'
        Temporary path for the preprocessed audio when ``audio_type = 'str'``.
    verbose: bool, False
        Whether to displays all the details during transcription, If ``False``, displays progressbar. If ``None``, does
        not display anything.
    regroup: str or bool, default True
         String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'. Only
         applies if ``word_timestamps = False``.
    suppress_silence : bool, default True
        Whether to enable timestamps adjustments based on the detected silence.
    suppress_word_ts : bool, default True
        Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
    q_levels : int, default 20
        Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
        Acts as a threshold to marking sound as silent.
        Fewer levels will increase the threshold of volume at which to mark a sound as silent.
    k_size : int, default 5
        Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
        Recommend 5 or 3; higher sizes will reduce detection of silence.
    demucs : bool or torch.nn.Module, default False
        Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
        a Demucs model to avoid reloading the model for each run.
        Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
    demucs_output : str, optional
        Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
        Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
    demucs_options : dict, optional
        Options to use for :func:`stable_whisper.audio.demucs_audio`.
    demucs_device : str, default None, meaning 'cuda' if cuda is available with ``torch`` else 'cpu'
        Device to use for demucs.
    vad : bool, default False
        Whether to use Silero VAD to generate timestamp suppression mask.
        Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
    vad_threshold : float, default 0.35
        Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
    vad_onnx : bool, default False
        Whether to use ONNX for Silero VAD.
    min_word_dur : float, default 0.1
        Shortest duration each word is allowed to reach for silence suppression.
    nonspeech_error : float, default 0.3
        Relative error of non-speech sections that appear in between a word for silence suppression.
    use_word_position : bool, default True
        Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
        adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
    only_voice_freq : bool, default False
        Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
    only_ffmpeg : bool, default False
        Whether to use only FFmpeg (instead of not yt-dlp) for URls
    force_order : bool, default False
        Whether to use adjacent timestamps to replace timestamps that are out of order. Use this parameter only if
        the words/segments returned by ``inference_func`` are expected to be in chronological order.
    check_sorted : bool, default True
        Whether to raise an error when timestamps returned by ``inference_func`` are not in ascending order.

    Returns
    -------
    stable_whisper.result.WhisperResult
        All timestamps, words, probabilities, and other data from the transcription of ``audio``.

    Notes
    -----
    For ``audio_type = 'str'``:
        If ``audio`` is a file and no audio preprocessing is set, ``audio`` will be directly passed into
            ``inference_func``.
        If audio preprocessing is ``demucs`` or ``only_voice_freq``, the processed audio will be encoded into
            ``temp_file`` and then passed into ``inference_func``.

    For ``audio_type = 'byte'``:
        If ``audio`` is file, the bytes of file will be passed into ``inference_func``.
        If ``audio`` is :class:`torch.Tensor` or :class:`numpy.ndarray`, the bytes of the ``audio`` will be encoded
            into WAV format then passed into ``inference_func``.

    Resampling is only performed on ``audio`` when ``model_sr`` does not match the sample rate of the ``audio`` before
        passing into ``inference_func`` due to ``input_sr`` not matching ``model_sr``, or sample rate changes due to
        audio preprocessing from ``demucs = True``.
    """
    if demucs_options is None:
        demucs_options = {}
    if demucs_output:
        if 'save_path' not in demucs_options:
            demucs_options['save_path'] = demucs_output
        warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
                      'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
                      DeprecationWarning, stacklevel=2)
    if demucs_device:
        if 'device' not in demucs_options:
            demucs_options['device'] = demucs_device
        warnings.warn('``demucs_device`` is deprecated. Use ``demucs_options`` with ``device`` instead. '
                      'E.g. demucs_options=dict(device="cpu")',
                      DeprecationWarning, stacklevel=2)

    if audio_type is not None and (audio_type := audio_type.lower()) not in AUDIO_TYPES:
        raise NotImplementedError(f'[audio_type]={audio_type} is not supported. Types: {AUDIO_TYPES}')

    if audio_type is None:
        if isinstance(audio, str):
            audio_type = 'str'
        elif isinstance(audio, bytes):
            audio_type = 'byte'
        elif isinstance(audio, torch.Tensor):
            audio_type = 'pytorch'
        elif isinstance(audio, np.ndarray):
            audio_type = 'numpy'
        else:
            raise TypeError(f'{type(audio)} is not supported for [audio].')

    if (
            input_sr is None and
            isinstance(audio, (np.ndarray, torch.Tensor)) and
            (demucs or only_voice_freq or suppress_silence or model_sr)
    ):
        raise ValueError('[input_sr] is required when [audio] is a PyTorch tensor or NumPy array.')

    if (
            model_sr is None and
            isinstance(audio, (str, bytes)) and
            audio_type in ('torch', 'numpy')
    ):
        raise ValueError('[model_sr] is required when [audio_type] is a "pytorch" or "numpy".')

    if isinstance(audio, str):
        from .audio import _load_file
        audio = _load_file(audio, verbose=verbose, only_ffmpeg=only_ffmpeg)

    if inference_kwargs is None:
        inference_kwargs = {}

    temp_file = os.path.abspath(temp_file or './_temp_stable-ts_audio_.wav')
    temp_audio_file = None

    curr_sr = input_sr

    if demucs:
        if demucs is True:
            from .audio import load_demucs_model
            demucs_model = load_demucs_model()
        else:
            demucs_model = demucs
            demucs = True
    else:
        demucs_model = None

    def get_input_sr():
        nonlocal input_sr
        if not input_sr and isinstance(audio, (str, bytes)):
            from .audio import get_samplerate
            input_sr = get_samplerate(audio)
        return input_sr

    if only_voice_freq:
        from .audio import voice_freq_filter
        if demucs_model is None:
            curr_sr = model_sr or get_input_sr()
        else:
            curr_sr = demucs_model.samplerate
            if model_sr is None:
                model_sr = get_input_sr()
        audio = load_audio(audio, sr=curr_sr, verbose=verbose, only_ffmpeg=only_ffmpeg)
        audio = voice_freq_filter(audio, curr_sr)

    if demucs:
        from .audio import demucs_audio
        if demucs_device is None:
            demucs_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        demucs_kwargs = dict(
            audio=audio,
            input_sr=curr_sr,
            model=demucs_model,
            save_path=demucs_output,
            device=demucs_device,
            verbose=verbose
        )
        demucs_kwargs.update(demucs_options or {})
        audio = demucs_audio(
            **demucs_kwargs
        )
        curr_sr = demucs_model.samplerate
        if demucs_output and audio_type == 'str':
            audio = demucs_output

    final_audio = audio

    if model_sr is not None:

        if curr_sr is None:
            curr_sr = get_input_sr()

        if curr_sr != model_sr:
            if isinstance(final_audio, (str, bytes)):
                final_audio = load_audio(
                    final_audio,
                    sr=model_sr,
                    verbose=verbose,
                    only_ffmpeg=only_ffmpeg
                )
            else:
                if isinstance(final_audio, np.ndarray):
                    final_audio = torch.from_numpy(final_audio)
                if isinstance(final_audio, torch.Tensor):
                    final_audio = torchaudio.functional.resample(
                        final_audio,
                        orig_freq=curr_sr,
                        new_freq=model_sr,
                        resampling_method="kaiser_window"
                    )

    if audio_type in ('torch', 'numpy'):

        if isinstance(final_audio, (str, bytes)):
            final_audio = load_audio(
                final_audio,
                sr=model_sr,
                verbose=verbose,
                only_ffmpeg=only_ffmpeg
            )

        else:
            if audio_type == 'torch':
                if isinstance(final_audio, np.ndarray):
                    final_audio = torch.from_numpy(final_audio)
            elif audio_type == 'numpy' and isinstance(final_audio, torch.Tensor):
                final_audio = final_audio.cpu().numpy()

    elif audio_type == 'str':

        if isinstance(final_audio, (torch.Tensor, np.ndarray)):
            if isinstance(final_audio, np.ndarray):
                final_audio = torch.from_numpy(final_audio)
            if final_audio.ndim < 2:
                final_audio = final_audio[None]
            torchaudio.save(temp_file, final_audio, model_sr)
            final_audio = temp_audio_file = temp_file

        elif isinstance(final_audio, bytes):
            with open(temp_file, 'wb') as f:
                f.write(final_audio)
            final_audio = temp_audio_file = temp_file

    else:  # audio_type == 'byte'

        if isinstance(final_audio, (torch.Tensor, np.ndarray)):
            if isinstance(final_audio, np.ndarray):
                final_audio = torch.from_numpy(final_audio)
            if final_audio.ndim < 2:
                final_audio = final_audio[None]
            with io.BytesIO() as f:
                torchaudio.save(f, final_audio, model_sr, format="wav")
                f.seek(0)
                final_audio = f.read()

        elif isinstance(final_audio, str):
            with open(final_audio, 'rb') as f:
                final_audio = f.read()

    inference_kwargs['audio'] = final_audio

    result = None
    try:
        result = inference_func(**inference_kwargs)
        if not isinstance(result, WhisperResult):
            result = WhisperResult(result, force_order=force_order, check_sorted=check_sorted)
        if suppress_silence:
            result.adjust_by_silence(
                audio, vad,
                vad_onnx=vad_onnx, vad_threshold=vad_threshold,
                q_levels=q_levels, k_size=k_size,
                sample_rate=curr_sr, min_word_dur=min_word_dur,
                word_level=suppress_word_ts, verbose=True,
                nonspeech_error=nonspeech_error,
                use_word_position=use_word_position
            )

        if result.has_words and regroup:
            result.regroup(regroup)

    finally:
        if temp_audio_file is not None:
            try:
                os.unlink(temp_audio_file)
            except Exception as e:
                warnings.warn(f'Failed to remove temporary audio file {temp_audio_file}. {e}')

    return result