import os |
import warnings |
import io |
import torch |
import torchaudio |
import numpy as np |
from typing import Union, Callable, Optional |
from .audio import load_audio |
from .result import WhisperResult |
AUDIO_TYPES = ('str', 'byte', 'torch', 'numpy') |
def transcribe_any( |
inference_func: Callable, |
audio: Union[str, np.ndarray, torch.Tensor, bytes], |
audio_type: str = None, |
input_sr: int = None, |
model_sr: int = None, |
inference_kwargs: dict = None, |
temp_file: str = None, |
verbose: Optional[bool] = False, |
regroup: Union[bool, str] = True, |
suppress_silence: bool = True, |
suppress_word_ts: bool = True, |
q_levels: int = 20, |
k_size: int = 5, |
demucs: bool = False, |
demucs_device: str = None, |
demucs_output: str = None, |
demucs_options: dict = None, |
vad: bool = False, |
vad_threshold: float = 0.35, |
vad_onnx: bool = False, |
min_word_dur: float = 0.1, |
nonspeech_error: float = 0.3, |
use_word_position: bool = True, |
only_voice_freq: bool = False, |
only_ffmpeg: bool = False, |
force_order: bool = False, |
check_sorted: bool = True |
) -> WhisperResult: |
""" |
Transcribe ``audio`` using any ASR system. |
Parameters |
---------- |
inference_func : Callable |
Function that runs ASR when provided the [audio] and return data in the appropriate format. |
For format examples see, https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb. |
audio : str or numpy.ndarray or torch.Tensor or bytes |
Path/URL to the audio file, the audio waveform, or bytes of audio file. |
audio_type : {'str', 'byte', 'torch', 'numpy', None}, default None, meaning same type as ``audio`` |
The type that ``audio`` needs to be for ``inference_func``. |
'str' is a path to the file. |
'byte' is bytes (used for APIs or to avoid writing any data to hard drive). |
'torch' is an instance of :class:`torch.Tensor` containing the audio waveform, in float32 dtype, on CPU. |
'numpy' is an instance of :class:`numpy.ndarray` containing the audio waveform, in float32 dtype. |
input_sr : int, default None, meaning auto-detected if ``audio`` is ``str`` or ``bytes`` |
The sample rate of ``audio``. |
model_sr : int, default None, meaning same sample rate as ``input_sr`` |
The sample rate to resample the audio into for ``inference_func``. |
inference_kwargs : dict, optional |
Dictionary of arguments to pass into ``inference_func``. |
temp_file : str, default './_temp_stable-ts_audio_.wav' |
Temporary path for the preprocessed audio when ``audio_type = 'str'``. |
verbose: bool, False |
Whether to displays all the details during transcription, If ``False``, displays progressbar. If ``None``, does |
not display anything. |
regroup: str or bool, default True |
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'. Only |
applies if ``word_timestamps = False``. |
suppress_silence : bool, default True |
Whether to enable timestamps adjustments based on the detected silence. |
suppress_word_ts : bool, default True |
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``. |
q_levels : int, default 20 |
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``. |
Acts as a threshold to marking sound as silent. |
Fewer levels will increase the threshold of volume at which to mark a sound as silent. |
k_size : int, default 5 |
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``. |
Recommend 5 or 3; higher sizes will reduce detection of silence. |
demucs : bool or torch.nn.Module, default False |
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of |
a Demucs model to avoid reloading the model for each run. |
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs. |
demucs_output : str, optional |
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``. |
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs. |
demucs_options : dict, optional |
Options to use for :func:`stable_whisper.audio.demucs_audio`. |
demucs_device : str, default None, meaning 'cuda' if cuda is available with ``torch`` else 'cpu' |
Device to use for demucs. |
vad : bool, default False |
Whether to use Silero VAD to generate timestamp suppression mask. |
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad. |
vad_threshold : float, default 0.35 |
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection. |
vad_onnx : bool, default False |
Whether to use ONNX for Silero VAD. |
min_word_dur : float, default 0.1 |
Shortest duration each word is allowed to reach for silence suppression. |
nonspeech_error : float, default 0.3 |
Relative error of non-speech sections that appear in between a word for silence suppression. |
use_word_position : bool, default True |
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if |
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start. |
only_voice_freq : bool, default False |
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are. |
only_ffmpeg : bool, default False |
Whether to use only FFmpeg (instead of not yt-dlp) for URls |
force_order : bool, default False |
Whether to use adjacent timestamps to replace timestamps that are out of order. Use this parameter only if |
the words/segments returned by ``inference_func`` are expected to be in chronological order. |
check_sorted : bool, default True |
Whether to raise an error when timestamps returned by ``inference_func`` are not in ascending order. |
Returns |
------- |
stable_whisper.result.WhisperResult |
All timestamps, words, probabilities, and other data from the transcription of ``audio``. |
Notes |
----- |
For ``audio_type = 'str'``: |
If ``audio`` is a file and no audio preprocessing is set, ``audio`` will be directly passed into |
``inference_func``. |
If audio preprocessing is ``demucs`` or ``only_voice_freq``, the processed audio will be encoded into |
``temp_file`` and then passed into ``inference_func``. |
For ``audio_type = 'byte'``: |
If ``audio`` is file, the bytes of file will be passed into ``inference_func``. |
If ``audio`` is :class:`torch.Tensor` or :class:`numpy.ndarray`, the bytes of the ``audio`` will be encoded |
into WAV format then passed into ``inference_func``. |
Resampling is only performed on ``audio`` when ``model_sr`` does not match the sample rate of the ``audio`` before |
passing into ``inference_func`` due to ``input_sr`` not matching ``model_sr``, or sample rate changes due to |
audio preprocessing from ``demucs = True``. |
""" |
if demucs_options is None: |
demucs_options = {} |
if demucs_output: |
if 'save_path' not in demucs_options: |
demucs_options['save_path'] = demucs_output |
warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. ' |
'E.g. demucs_options=dict(save_path="demucs_output.mp3")', |
DeprecationWarning, stacklevel=2) |
if demucs_device: |
if 'device' not in demucs_options: |
demucs_options['device'] = demucs_device |
warnings.warn('``demucs_device`` is deprecated. Use ``demucs_options`` with ``device`` instead. ' |
'E.g. demucs_options=dict(device="cpu")', |
DeprecationWarning, stacklevel=2) |
if audio_type is not None and (audio_type := audio_type.lower()) not in AUDIO_TYPES: |
raise NotImplementedError(f'[audio_type]={audio_type} is not supported. Types: {AUDIO_TYPES}') |
if audio_type is None: |
if isinstance(audio, str): |
audio_type = 'str' |
elif isinstance(audio, bytes): |
audio_type = 'byte' |
elif isinstance(audio, torch.Tensor): |
audio_type = 'pytorch' |
elif isinstance(audio, np.ndarray): |
audio_type = 'numpy' |
else: |
raise TypeError(f'{type(audio)} is not supported for [audio].') |
if ( |
input_sr is None and |
isinstance(audio, (np.ndarray, torch.Tensor)) and |
(demucs or only_voice_freq or suppress_silence or model_sr) |
): |
raise ValueError('[input_sr] is required when [audio] is a PyTorch tensor or NumPy array.') |
if ( |
model_sr is None and |
isinstance(audio, (str, bytes)) and |
audio_type in ('torch', 'numpy') |
): |
raise ValueError('[model_sr] is required when [audio_type] is a "pytorch" or "numpy".') |
if isinstance(audio, str): |
from .audio import _load_file |
audio = _load_file(audio, verbose=verbose, only_ffmpeg=only_ffmpeg) |
if inference_kwargs is None: |
inference_kwargs = {} |
temp_file = os.path.abspath(temp_file or './_temp_stable-ts_audio_.wav') |
temp_audio_file = None |
curr_sr = input_sr |
if demucs: |
if demucs is True: |
from .audio import load_demucs_model |
demucs_model = load_demucs_model() |
else: |
demucs_model = demucs |
demucs = True |
else: |
demucs_model = None |
def get_input_sr(): |
nonlocal input_sr |
if not input_sr and isinstance(audio, (str, bytes)): |
from .audio import get_samplerate |
input_sr = get_samplerate(audio) |
return input_sr |
if only_voice_freq: |
from .audio import voice_freq_filter |
if demucs_model is None: |
curr_sr = model_sr or get_input_sr() |
else: |
curr_sr = demucs_model.samplerate |
if model_sr is None: |
model_sr = get_input_sr() |
audio = load_audio(audio, sr=curr_sr, verbose=verbose, only_ffmpeg=only_ffmpeg) |
audio = voice_freq_filter(audio, curr_sr) |
if demucs: |
from .audio import demucs_audio |
if demucs_device is None: |
demucs_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
demucs_kwargs = dict( |
audio=audio, |
input_sr=curr_sr, |
model=demucs_model, |
save_path=demucs_output, |
device=demucs_device, |
verbose=verbose |
) |
demucs_kwargs.update(demucs_options or {}) |
audio = demucs_audio( |
**demucs_kwargs |
) |
curr_sr = demucs_model.samplerate |
if demucs_output and audio_type == 'str': |
audio = demucs_output |
final_audio = audio |
if model_sr is not None: |
if curr_sr is None: |
curr_sr = get_input_sr() |
if curr_sr != model_sr: |
if isinstance(final_audio, (str, bytes)): |
final_audio = load_audio( |
final_audio, |
sr=model_sr, |
verbose=verbose, |
only_ffmpeg=only_ffmpeg |
) |
else: |
if isinstance(final_audio, np.ndarray): |
final_audio = torch.from_numpy(final_audio) |
if isinstance(final_audio, torch.Tensor): |
final_audio = torchaudio.functional.resample( |
final_audio, |
orig_freq=curr_sr, |
new_freq=model_sr, |
resampling_method="kaiser_window" |
) |
if audio_type in ('torch', 'numpy'): |
if isinstance(final_audio, (str, bytes)): |
final_audio = load_audio( |
final_audio, |
sr=model_sr, |
verbose=verbose, |
only_ffmpeg=only_ffmpeg |
) |
else: |
if audio_type == 'torch': |
if isinstance(final_audio, np.ndarray): |
final_audio = torch.from_numpy(final_audio) |
elif audio_type == 'numpy' and isinstance(final_audio, torch.Tensor): |
final_audio = final_audio.cpu().numpy() |
elif audio_type == 'str': |
if isinstance(final_audio, (torch.Tensor, np.ndarray)): |
if isinstance(final_audio, np.ndarray): |
final_audio = torch.from_numpy(final_audio) |
if final_audio.ndim < 2: |
final_audio = final_audio[None] |
torchaudio.save(temp_file, final_audio, model_sr) |
final_audio = temp_audio_file = temp_file |
elif isinstance(final_audio, bytes): |
with open(temp_file, 'wb') as f: |
f.write(final_audio) |
final_audio = temp_audio_file = temp_file |
else: |
if isinstance(final_audio, (torch.Tensor, np.ndarray)): |
if isinstance(final_audio, np.ndarray): |
final_audio = torch.from_numpy(final_audio) |
if final_audio.ndim < 2: |
final_audio = final_audio[None] |
with io.BytesIO() as f: |
torchaudio.save(f, final_audio, model_sr, format="wav") |
f.seek(0) |
final_audio = f.read() |
elif isinstance(final_audio, str): |
with open(final_audio, 'rb') as f: |
final_audio = f.read() |
inference_kwargs['audio'] = final_audio |
result = None |
try: |
result = inference_func(**inference_kwargs) |
if not isinstance(result, WhisperResult): |
result = WhisperResult(result, force_order=force_order, check_sorted=check_sorted) |
if suppress_silence: |
result.adjust_by_silence( |
audio, vad, |
vad_onnx=vad_onnx, vad_threshold=vad_threshold, |
q_levels=q_levels, k_size=k_size, |
sample_rate=curr_sr, min_word_dur=min_word_dur, |
word_level=suppress_word_ts, verbose=True, |
nonspeech_error=nonspeech_error, |
use_word_position=use_word_position |
) |
if result.has_words and regroup: |
result.regroup(regroup) |
finally: |
if temp_audio_file is not None: |
try: |
os.unlink(temp_audio_file) |
except Exception as e: |
warnings.warn(f'Failed to remove temporary audio file {temp_audio_file}. {e}') |
return result |