stable-ts / stable_whisper /non_whisper.py
Rolando
Set it up
8718761
raw
history blame
15 kB
import os
import warnings
import io
import torch
import torchaudio
import numpy as np
from typing import Union, Callable, Optional
from .audio import load_audio
from .result import WhisperResult
AUDIO_TYPES = ('str', 'byte', 'torch', 'numpy')
def transcribe_any(
inference_func: Callable,
audio: Union[str, np.ndarray, torch.Tensor, bytes],
audio_type: str = None,
input_sr: int = None,
model_sr: int = None,
inference_kwargs: dict = None,
temp_file: str = None,
verbose: Optional[bool] = False,
regroup: Union[bool, str] = True,
suppress_silence: bool = True,
suppress_word_ts: bool = True,
q_levels: int = 20,
k_size: int = 5,
demucs: bool = False,
demucs_device: str = None,
demucs_output: str = None,
demucs_options: dict = None,
vad: bool = False,
vad_threshold: float = 0.35,
vad_onnx: bool = False,
min_word_dur: float = 0.1,
nonspeech_error: float = 0.3,
use_word_position: bool = True,
only_voice_freq: bool = False,
only_ffmpeg: bool = False,
force_order: bool = False,
check_sorted: bool = True
) -> WhisperResult:
"""
Transcribe ``audio`` using any ASR system.
Parameters
----------
inference_func : Callable
Function that runs ASR when provided the [audio] and return data in the appropriate format.
For format examples see, https://github.com/jianfch/stable-ts/blob/main/examples/non-whisper.ipynb.
audio : str or numpy.ndarray or torch.Tensor or bytes
Path/URL to the audio file, the audio waveform, or bytes of audio file.
audio_type : {'str', 'byte', 'torch', 'numpy', None}, default None, meaning same type as ``audio``
The type that ``audio`` needs to be for ``inference_func``.
'str' is a path to the file.
'byte' is bytes (used for APIs or to avoid writing any data to hard drive).
'torch' is an instance of :class:`torch.Tensor` containing the audio waveform, in float32 dtype, on CPU.
'numpy' is an instance of :class:`numpy.ndarray` containing the audio waveform, in float32 dtype.
input_sr : int, default None, meaning auto-detected if ``audio`` is ``str`` or ``bytes``
The sample rate of ``audio``.
model_sr : int, default None, meaning same sample rate as ``input_sr``
The sample rate to resample the audio into for ``inference_func``.
inference_kwargs : dict, optional
Dictionary of arguments to pass into ``inference_func``.
temp_file : str, default './_temp_stable-ts_audio_.wav'
Temporary path for the preprocessed audio when ``audio_type = 'str'``.
verbose: bool, False
Whether to displays all the details during transcription, If ``False``, displays progressbar. If ``None``, does
not display anything.
regroup: str or bool, default True
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'. Only
applies if ``word_timestamps = False``.
suppress_silence : bool, default True
Whether to enable timestamps adjustments based on the detected silence.
suppress_word_ts : bool, default True
Whether to adjust word timestamps based on the detected silence. Only enabled if ``suppress_silence = True``.
q_levels : int, default 20
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
Acts as a threshold to marking sound as silent.
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
k_size : int, default 5
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
Recommend 5 or 3; higher sizes will reduce detection of silence.
demucs : bool or torch.nn.Module, default False
Whether to preprocess ``audio`` with Demucs to isolate vocals / remove noise. Set ``demucs`` to an instance of
a Demucs model to avoid reloading the model for each run.
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
demucs_output : str, optional
Path to save the vocals isolated by Demucs as WAV file. Ignored if ``demucs = False``.
Demucs must be installed to use. Official repo, https://github.com/facebookresearch/demucs.
demucs_options : dict, optional
Options to use for :func:`stable_whisper.audio.demucs_audio`.
demucs_device : str, default None, meaning 'cuda' if cuda is available with ``torch`` else 'cpu'
Device to use for demucs.
vad : bool, default False
Whether to use Silero VAD to generate timestamp suppression mask.
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
vad_threshold : float, default 0.35
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
vad_onnx : bool, default False
Whether to use ONNX for Silero VAD.
min_word_dur : float, default 0.1
Shortest duration each word is allowed to reach for silence suppression.
nonspeech_error : float, default 0.3
Relative error of non-speech sections that appear in between a word for silence suppression.
use_word_position : bool, default True
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
only_voice_freq : bool, default False
Whether to only use sound between 200 - 5000 Hz, where majority of human speech are.
only_ffmpeg : bool, default False
Whether to use only FFmpeg (instead of not yt-dlp) for URls
force_order : bool, default False
Whether to use adjacent timestamps to replace timestamps that are out of order. Use this parameter only if
the words/segments returned by ``inference_func`` are expected to be in chronological order.
check_sorted : bool, default True
Whether to raise an error when timestamps returned by ``inference_func`` are not in ascending order.
Returns
-------
stable_whisper.result.WhisperResult
All timestamps, words, probabilities, and other data from the transcription of ``audio``.
Notes
-----
For ``audio_type = 'str'``:
If ``audio`` is a file and no audio preprocessing is set, ``audio`` will be directly passed into
``inference_func``.
If audio preprocessing is ``demucs`` or ``only_voice_freq``, the processed audio will be encoded into
``temp_file`` and then passed into ``inference_func``.
For ``audio_type = 'byte'``:
If ``audio`` is file, the bytes of file will be passed into ``inference_func``.
If ``audio`` is :class:`torch.Tensor` or :class:`numpy.ndarray`, the bytes of the ``audio`` will be encoded
into WAV format then passed into ``inference_func``.
Resampling is only performed on ``audio`` when ``model_sr`` does not match the sample rate of the ``audio`` before
passing into ``inference_func`` due to ``input_sr`` not matching ``model_sr``, or sample rate changes due to
audio preprocessing from ``demucs = True``.
"""
if demucs_options is None:
demucs_options = {}
if demucs_output:
if 'save_path' not in demucs_options:
demucs_options['save_path'] = demucs_output
warnings.warn('``demucs_output`` is deprecated. Use ``demucs_options`` with ``save_path`` instead. '
'E.g. demucs_options=dict(save_path="demucs_output.mp3")',
DeprecationWarning, stacklevel=2)
if demucs_device:
if 'device' not in demucs_options:
demucs_options['device'] = demucs_device
warnings.warn('``demucs_device`` is deprecated. Use ``demucs_options`` with ``device`` instead. '
'E.g. demucs_options=dict(device="cpu")',
DeprecationWarning, stacklevel=2)
if audio_type is not None and (audio_type := audio_type.lower()) not in AUDIO_TYPES:
raise NotImplementedError(f'[audio_type]={audio_type} is not supported. Types: {AUDIO_TYPES}')
if audio_type is None:
if isinstance(audio, str):
audio_type = 'str'
elif isinstance(audio, bytes):
audio_type = 'byte'
elif isinstance(audio, torch.Tensor):
audio_type = 'pytorch'
elif isinstance(audio, np.ndarray):
audio_type = 'numpy'
else:
raise TypeError(f'{type(audio)} is not supported for [audio].')
if (
input_sr is None and
isinstance(audio, (np.ndarray, torch.Tensor)) and
(demucs or only_voice_freq or suppress_silence or model_sr)
):
raise ValueError('[input_sr] is required when [audio] is a PyTorch tensor or NumPy array.')
if (
model_sr is None and
isinstance(audio, (str, bytes)) and
audio_type in ('torch', 'numpy')
):
raise ValueError('[model_sr] is required when [audio_type] is a "pytorch" or "numpy".')
if isinstance(audio, str):
from .audio import _load_file
audio = _load_file(audio, verbose=verbose, only_ffmpeg=only_ffmpeg)
if inference_kwargs is None:
inference_kwargs = {}
temp_file = os.path.abspath(temp_file or './_temp_stable-ts_audio_.wav')
temp_audio_file = None
curr_sr = input_sr
if demucs:
if demucs is True:
from .audio import load_demucs_model
demucs_model = load_demucs_model()
else:
demucs_model = demucs
demucs = True
else:
demucs_model = None
def get_input_sr():
nonlocal input_sr
if not input_sr and isinstance(audio, (str, bytes)):
from .audio import get_samplerate
input_sr = get_samplerate(audio)
return input_sr
if only_voice_freq:
from .audio import voice_freq_filter
if demucs_model is None:
curr_sr = model_sr or get_input_sr()
else:
curr_sr = demucs_model.samplerate
if model_sr is None:
model_sr = get_input_sr()
audio = load_audio(audio, sr=curr_sr, verbose=verbose, only_ffmpeg=only_ffmpeg)
audio = voice_freq_filter(audio, curr_sr)
if demucs:
from .audio import demucs_audio
if demucs_device is None:
demucs_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
demucs_kwargs = dict(
audio=audio,
input_sr=curr_sr,
model=demucs_model,
save_path=demucs_output,
device=demucs_device,
verbose=verbose
)
demucs_kwargs.update(demucs_options or {})
audio = demucs_audio(
**demucs_kwargs
)
curr_sr = demucs_model.samplerate
if demucs_output and audio_type == 'str':
audio = demucs_output
final_audio = audio
if model_sr is not None:
if curr_sr is None:
curr_sr = get_input_sr()
if curr_sr != model_sr:
if isinstance(final_audio, (str, bytes)):
final_audio = load_audio(
final_audio,
sr=model_sr,
verbose=verbose,
only_ffmpeg=only_ffmpeg
)
else:
if isinstance(final_audio, np.ndarray):
final_audio = torch.from_numpy(final_audio)
if isinstance(final_audio, torch.Tensor):
final_audio = torchaudio.functional.resample(
final_audio,
orig_freq=curr_sr,
new_freq=model_sr,
resampling_method="kaiser_window"
)
if audio_type in ('torch', 'numpy'):
if isinstance(final_audio, (str, bytes)):
final_audio = load_audio(
final_audio,
sr=model_sr,
verbose=verbose,
only_ffmpeg=only_ffmpeg
)
else:
if audio_type == 'torch':
if isinstance(final_audio, np.ndarray):
final_audio = torch.from_numpy(final_audio)
elif audio_type == 'numpy' and isinstance(final_audio, torch.Tensor):
final_audio = final_audio.cpu().numpy()
elif audio_type == 'str':
if isinstance(final_audio, (torch.Tensor, np.ndarray)):
if isinstance(final_audio, np.ndarray):
final_audio = torch.from_numpy(final_audio)
if final_audio.ndim < 2:
final_audio = final_audio[None]
torchaudio.save(temp_file, final_audio, model_sr)
final_audio = temp_audio_file = temp_file
elif isinstance(final_audio, bytes):
with open(temp_file, 'wb') as f:
f.write(final_audio)
final_audio = temp_audio_file = temp_file
else: # audio_type == 'byte'
if isinstance(final_audio, (torch.Tensor, np.ndarray)):
if isinstance(final_audio, np.ndarray):
final_audio = torch.from_numpy(final_audio)
if final_audio.ndim < 2:
final_audio = final_audio[None]
with io.BytesIO() as f:
torchaudio.save(f, final_audio, model_sr, format="wav")
f.seek(0)
final_audio = f.read()
elif isinstance(final_audio, str):
with open(final_audio, 'rb') as f:
final_audio = f.read()
inference_kwargs['audio'] = final_audio
result = None
try:
result = inference_func(**inference_kwargs)
if not isinstance(result, WhisperResult):
result = WhisperResult(result, force_order=force_order, check_sorted=check_sorted)
if suppress_silence:
result.adjust_by_silence(
audio, vad,
vad_onnx=vad_onnx, vad_threshold=vad_threshold,
q_levels=q_levels, k_size=k_size,
sample_rate=curr_sr, min_word_dur=min_word_dur,
word_level=suppress_word_ts, verbose=True,
nonspeech_error=nonspeech_error,
use_word_position=use_word_position
)
if result.has_words and regroup:
result.regroup(regroup)
finally:
if temp_audio_file is not None:
try:
os.unlink(temp_audio_file)
except Exception as e:
warnings.warn(f'Failed to remove temporary audio file {temp_audio_file}. {e}')
return result