import os import warnings import io import torch import torchaudio import numpy as np from typing import Union, Callable, Optional from .audio import load_audio from .result import WhisperResult AUDIO_TYPES = ('str', 'byte', 'torch', 'numpy') def transcribe_any( inference_func: Callable, audio: Union[str, np.ndarray, torch.Tensor, bytes], audio_type: str = None, input_sr: int = None, model_sr: int = None, inference_kwargs: dict = None, temp_file: str = None, verbose: Optional[bool] = False, regroup: Union[bool, str] = True, suppress_silence: bool = True, suppress_word_ts: bool = True, q_levels: int = 20, k_size: int = 5, demucs: bool = False, demucs_device: str = None, demucs_output: str = None, demucs_options: dict = None, vad: bool = False, vad_threshold: float = 0.35, vad_onnx: bool = False, min_word_dur: float = 0.1, nonspeech_error: float = 0.3, use_word_position: bool = True, only_voice_freq: bool = False, only_ffmpeg: bool = False, force_order: bool = False, check_sorted: bool = True ) -> WhisperResult: """ Transcribe ``audio`` using any ASR system. Parameters ---------- inference_func : Callable Function that runs ASR when provided the [audio] and return data in the appropriate format. For format examples see, https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb. audio : str or numpy.ndarray or torch.Tensor or bytes Path/URL to the audio file, the audio waveform, or bytes of audio file. audio_type : {'str', 'byte', 'torch', 'numpy', None}, default None, meaning same type as ``audio`` The type that ``audio`` needs to be for ``inference_func``. 'str' is a path to the file. 'byte' is bytes (used for APIs or to avoid writing any data to hard drive). 'torch' is an instance of :class:`torch.Tensor` containing the audio waveform, in float32 dtype, on CPU. 'numpy' is an instance of :class:`numpy.ndarray` containing the audio waveform, in float32 dtype. input_sr : int, default None, meaning auto-detected if ``audio`` is ``str`` or ``bytes`` The sample rate of ``audio``. model_sr : int, default None, meaning same sample rate as ``input_sr`` The sample rate to resample the audio into for ``inference_func``. inference_kwargs : dict, optional Dictionary of arguments to pass into ``inference_func``. temp_file : str, default './_temp_stable-ts_audio_.wav' Temporary path for the preprocessed audio when ``audio_type = 'str'``. verbose: bool, False Whether to displays all the details during transcription, If ``False``, displays progressbar. If ``None``, does not display anything. regroup: str or bool, default True String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'. Only applies if ``word_timestamps = False``. suppress_silence : bool, default True Whether to enable timestamps adjustments based on the detected silence. suppress_word_ts : bool, default True Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``. q_levels : int, default 20 Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``. Acts as a threshold to marking sound as silent. Fewer levels will increase the threshold of volume at which to mark a sound as silent. k_size : int, default 5 Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``. Recommend 5 or 3; higher sizes will reduce detection of silence. demucs : bool or torch.nn.Module, default False Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of a Demucs model to avoid reloading the model for each run. Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs. demucs_output : str, optional Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``. Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs. demucs_options : dict, optional Options to use for :func:`stable_whisper.audio.demucs_audio`. demucs_device : str, default None, meaning 'cuda' if cuda is available with ``torch`` else 'cpu' Device to use for demucs. vad : bool, default False Whether to use Silero VAD to generate timestamp suppression mask. Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad. vad_threshold : float, default 0.35 Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection. vad_onnx : bool, default False Whether to use ONNX for Silero VAD. min_word_dur : float, default 0.1 Shortest duration each word is allowed to reach for silence suppression. nonspeech_error : float, default 0.3 Relative error of non-speech sections that appear in between a word for silence suppression. use_word_position : bool, default True Whether to use position of the word in its segment to determine whether to keep end or start timestamps if adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start. only_voice_freq : bool, default False Whether to only use sound between 200 - 5000 Hz, where majority of human speech are. only_ffmpeg : bool, default False Whether to use only FFmpeg (instead of not yt-dlp) for URls force_order : bool, default False Whether to use adjacent timestamps to replace timestamps that are out of order. Use this parameter only if the words/segments returned by ``inference_func`` are expected to be in chronological order. check_sorted : bool, default True Whether to raise an error when timestamps returned by ``inference_func`` are not in ascending order. Returns ------- stable_whisper.result.WhisperResult All timestamps, words, probabilities, and other data from the transcription of ``audio``. Notes ----- For ``audio_type = 'str'``: If ``audio`` is a file and no audio preprocessing is set, ``audio`` will be directly passed into ``inference_func``. If audio preprocessing is ``demucs`` or ``only_voice_freq``, the processed audio will be encoded into ``temp_file`` and then passed into ``inference_func``. For ``audio_type = 'byte'``: If ``audio`` is file, the bytes of file will be passed into ``inference_func``. If ``audio`` is :class:`torch.Tensor` or :class:`numpy.ndarray`, the bytes of the ``audio`` will be encoded into WAV format then passed into ``inference_func``. Resampling is only performed on ``audio`` when ``model_sr`` does not match the sample rate of the ``audio`` before passing into ``inference_func`` due to ``input_sr`` not matching ``model_sr``, or sample rate changes due to audio preprocessing from ``demucs = True``. """ if demucs_options is None: demucs_options = {} if demucs_output: if 'save_path' not in demucs_options: demucs_options['save_path'] = demucs_output warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. ' 'E.g. demucs_options=dict(save_path="demucs_output.mp3")', DeprecationWarning, stacklevel=2) if demucs_device: if 'device' not in demucs_options: demucs_options['device'] = demucs_device warnings.warn('``demucs_device`` is deprecated. Use ``demucs_options`` with ``device`` instead. ' 'E.g. demucs_options=dict(device="cpu")', DeprecationWarning, stacklevel=2) if audio_type is not None and (audio_type := audio_type.lower()) not in AUDIO_TYPES: raise NotImplementedError(f'[audio_type]={audio_type} is not supported. Types: {AUDIO_TYPES}') if audio_type is None: if isinstance(audio, str): audio_type = 'str' elif isinstance(audio, bytes): audio_type = 'byte' elif isinstance(audio, torch.Tensor): audio_type = 'pytorch' elif isinstance(audio, np.ndarray): audio_type = 'numpy' else: raise TypeError(f'{type(audio)} is not supported for [audio].') if ( input_sr is None and isinstance(audio, (np.ndarray, torch.Tensor)) and (demucs or only_voice_freq or suppress_silence or model_sr) ): raise ValueError('[input_sr] is required when [audio] is a PyTorch tensor or NumPy array.') if ( model_sr is None and isinstance(audio, (str, bytes)) and audio_type in ('torch', 'numpy') ): raise ValueError('[model_sr] is required when [audio_type] is a "pytorch" or "numpy".') if isinstance(audio, str): from .audio import _load_file audio = _load_file(audio, verbose=verbose, only_ffmpeg=only_ffmpeg) if inference_kwargs is None: inference_kwargs = {} temp_file = os.path.abspath(temp_file or './_temp_stable-ts_audio_.wav') temp_audio_file = None curr_sr = input_sr if demucs: if demucs is True: from .audio import load_demucs_model demucs_model = load_demucs_model() else: demucs_model = demucs demucs = True else: demucs_model = None def get_input_sr(): nonlocal input_sr if not input_sr and isinstance(audio, (str, bytes)): from .audio import get_samplerate input_sr = get_samplerate(audio) return input_sr if only_voice_freq: from .audio import voice_freq_filter if demucs_model is None: curr_sr = model_sr or get_input_sr() else: curr_sr = demucs_model.samplerate if model_sr is None: model_sr = get_input_sr() audio = load_audio(audio, sr=curr_sr, verbose=verbose, only_ffmpeg=only_ffmpeg) audio = voice_freq_filter(audio, curr_sr) if demucs: from .audio import demucs_audio if demucs_device is None: demucs_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') demucs_kwargs = dict( audio=audio, input_sr=curr_sr, model=demucs_model, save_path=demucs_output, device=demucs_device, verbose=verbose ) demucs_kwargs.update(demucs_options or {}) audio = demucs_audio( **demucs_kwargs ) curr_sr = demucs_model.samplerate if demucs_output and audio_type == 'str': audio = demucs_output final_audio = audio if model_sr is not None: if curr_sr is None: curr_sr = get_input_sr() if curr_sr != model_sr: if isinstance(final_audio, (str, bytes)): final_audio = load_audio( final_audio, sr=model_sr, verbose=verbose, only_ffmpeg=only_ffmpeg ) else: if isinstance(final_audio, np.ndarray): final_audio = torch.from_numpy(final_audio) if isinstance(final_audio, torch.Tensor): final_audio = torchaudio.functional.resample( final_audio, orig_freq=curr_sr, new_freq=model_sr, resampling_method="kaiser_window" ) if audio_type in ('torch', 'numpy'): if isinstance(final_audio, (str, bytes)): final_audio = load_audio( final_audio, sr=model_sr, verbose=verbose, only_ffmpeg=only_ffmpeg ) else: if audio_type == 'torch': if isinstance(final_audio, np.ndarray): final_audio = torch.from_numpy(final_audio) elif audio_type == 'numpy' and isinstance(final_audio, torch.Tensor): final_audio = final_audio.cpu().numpy() elif audio_type == 'str': if isinstance(final_audio, (torch.Tensor, np.ndarray)): if isinstance(final_audio, np.ndarray): final_audio = torch.from_numpy(final_audio) if final_audio.ndim < 2: final_audio = final_audio[None] torchaudio.save(temp_file, final_audio, model_sr) final_audio = temp_audio_file = temp_file elif isinstance(final_audio, bytes): with open(temp_file, 'wb') as f: f.write(final_audio) final_audio = temp_audio_file = temp_file else: # audio_type == 'byte' if isinstance(final_audio, (torch.Tensor, np.ndarray)): if isinstance(final_audio, np.ndarray): final_audio = torch.from_numpy(final_audio) if final_audio.ndim < 2: final_audio = final_audio[None] with io.BytesIO() as f: torchaudio.save(f, final_audio, model_sr, format="wav") f.seek(0) final_audio = f.read() elif isinstance(final_audio, str): with open(final_audio, 'rb') as f: final_audio = f.read() inference_kwargs['audio'] = final_audio result = None try: result = inference_func(**inference_kwargs) if not isinstance(result, WhisperResult): result = WhisperResult(result, force_order=force_order, check_sorted=check_sorted) if suppress_silence: result.adjust_by_silence( audio, vad, vad_onnx=vad_onnx, vad_threshold=vad_threshold, q_levels=q_levels, k_size=k_size, sample_rate=curr_sr, min_word_dur=min_word_dur, word_level=suppress_word_ts, verbose=True, nonspeech_error=nonspeech_error, use_word_position=use_word_position ) if result.has_words and regroup: result.regroup(regroup) finally: if temp_audio_file is not None: try: os.unlink(temp_audio_file) except Exception as e: warnings.warn(f'Failed to remove temporary audio file {temp_audio_file}. {e}') return result