|
import warnings |
|
from typing import List, Union, Tuple, Optional |
|
from itertools import chain |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
from whisper.audio import TOKENS_PER_SECOND, SAMPLE_RATE, N_SAMPLES_PER_TOKEN |
|
|
|
|
|
NONVAD_SAMPLE_RATES = (16000,) |
|
VAD_SAMPLE_RATES = (16000, 8000) |
|
|
|
|
|
def is_ascending_sequence( |
|
seq: List[Union[int, float]], |
|
verbose=True |
|
) -> bool: |
|
""" |
|
check if a sequence of numbers are in ascending order |
|
""" |
|
is_ascending = True |
|
for idx, (i, j) in enumerate(zip(seq[:-1], seq[1:])): |
|
if i > j: |
|
is_ascending = False |
|
if verbose: |
|
print(f'[Index{idx}]:{i} > [Index{idx + 1}]:{j}') |
|
else: |
|
break |
|
|
|
return is_ascending |
|
|
|
|
|
def valid_ts( |
|
ts: List[dict], |
|
warn=True |
|
) -> bool: |
|
valid = is_ascending_sequence(list(chain.from_iterable([s['start'], s['end']] for s in ts)), False) |
|
if warn and not valid: |
|
warnings.warn(message='Found timestamp(s) jumping backwards in time. ' |
|
'Use word_timestamps=True to avoid the issue.') |
|
return valid |
|
|
|
|
|
def mask2timing( |
|
silence_mask: (np.ndarray, torch.Tensor), |
|
time_offset: float = 0.0, |
|
) -> (Tuple[np.ndarray, np.ndarray], None): |
|
if silence_mask is None or not silence_mask.any(): |
|
return |
|
assert silence_mask.ndim == 1 |
|
if isinstance(silence_mask, torch.Tensor): |
|
silences = silence_mask.cpu().numpy().copy() |
|
elif isinstance(silence_mask, np.ndarray): |
|
silences = silence_mask.copy() |
|
else: |
|
raise NotImplementedError(f'Expected torch.Tensor or numpy.ndarray, but got {type(silence_mask)}') |
|
silences[0] = False |
|
silences[-1] = False |
|
silent_starts = np.logical_and(~silences[:-1], silences[1:]).nonzero()[0] / TOKENS_PER_SECOND |
|
silent_ends = (np.logical_and(silences[:-1], ~silences[1:]).nonzero()[0] + 1) / TOKENS_PER_SECOND |
|
if time_offset: |
|
silent_starts += time_offset |
|
silent_ends += time_offset |
|
return silent_starts, silent_ends |
|
|
|
|
|
def timing2mask( |
|
silent_starts: np.ndarray, |
|
silent_ends: np.ndarray, |
|
size: int, |
|
time_offset: float = None |
|
) -> torch.Tensor: |
|
assert len(silent_starts) == len(silent_ends) |
|
ts_token_mask = torch.zeros(size, dtype=torch.bool) |
|
if time_offset: |
|
silent_starts = (silent_starts - time_offset).clip(min=0) |
|
silent_ends = (silent_ends - time_offset).clip(min=0) |
|
mask_i = (silent_starts * TOKENS_PER_SECOND).round().astype(np.int16) |
|
mask_e = (silent_ends * TOKENS_PER_SECOND).round().astype(np.int16) |
|
for mi, me in zip(mask_i, mask_e): |
|
ts_token_mask[mi:me+1] = True |
|
|
|
return ts_token_mask |
|
|
|
|
|
def suppress_silence( |
|
result_obj, |
|
silent_starts: Union[np.ndarray, List[float]], |
|
silent_ends: Union[np.ndarray, List[float]], |
|
min_word_dur: float, |
|
nonspeech_error: float = 0.3, |
|
keep_end: Optional[bool] = True |
|
): |
|
assert len(silent_starts) == len(silent_ends) |
|
if len(silent_starts) == 0 or (result_obj.end - result_obj.start) <= min_word_dur: |
|
return |
|
if isinstance(silent_starts, list): |
|
silent_starts = np.array(silent_starts) |
|
if isinstance(silent_ends, list): |
|
silent_ends = np.array(silent_ends) |
|
|
|
start_overlaps = np.all( |
|
(silent_starts <= result_obj.start, result_obj.start < silent_ends, silent_ends <= result_obj.end), |
|
axis=0 |
|
).nonzero()[0].tolist() |
|
if start_overlaps: |
|
new_start = silent_ends[start_overlaps[0]] |
|
result_obj.start = min(new_start, round(result_obj.end - min_word_dur, 3)) |
|
if (result_obj.end - result_obj.start) <= min_word_dur: |
|
return |
|
|
|
end_overlaps = np.all( |
|
(result_obj.start <= silent_starts, silent_starts < result_obj.end, result_obj.end <= silent_ends), |
|
axis=0 |
|
).nonzero()[0].tolist() |
|
if end_overlaps: |
|
new_end = silent_starts[end_overlaps[0]] |
|
result_obj.end = max(new_end, round(result_obj.start + min_word_dur, 3)) |
|
if (result_obj.end - result_obj.start) <= min_word_dur: |
|
return |
|
|
|
if nonspeech_error: |
|
matches = np.logical_and( |
|
result_obj.start <= silent_starts, |
|
result_obj.end >= silent_ends, |
|
).nonzero()[0].tolist() |
|
if len(matches) == 0: |
|
return |
|
silence_start = np.min(silent_starts[matches]) |
|
silence_end = np.max(silent_ends[matches]) |
|
start_extra = silence_start - result_obj.start |
|
end_extra = result_obj.end - silence_end |
|
silent_duration = silence_end - silence_start |
|
start_within_error = (start_extra / silent_duration) <= nonspeech_error |
|
end_within_error = (end_extra / silent_duration) <= nonspeech_error |
|
if keep_end is None: |
|
keep_end = start_extra <= end_extra |
|
within_error = start_within_error if keep_end else end_within_error |
|
else: |
|
within_error = start_within_error or end_within_error |
|
|
|
if within_error: |
|
if keep_end: |
|
result_obj.start = min(silence_end, round(result_obj.end - min_word_dur, 3)) |
|
else: |
|
result_obj.end = max(silence_start, round(result_obj.start + min_word_dur, 3)) |
|
|
|
|
|
def standardize_audio( |
|
audio: Union[torch.Tensor, np.ndarray, str, bytes], |
|
resample_sr: Tuple[Optional[int], Union[int, Tuple[int]]] = None |
|
) -> torch.Tensor: |
|
if isinstance(audio, (str, bytes)): |
|
from .audio import load_audio |
|
audio = load_audio(audio) |
|
if isinstance(audio, np.ndarray): |
|
audio = torch.from_numpy(audio) |
|
audio = audio.float() |
|
if resample_sr: |
|
in_sr, out_sr = resample_sr |
|
if in_sr: |
|
if isinstance(out_sr, int): |
|
out_sr = [out_sr] |
|
if in_sr not in out_sr: |
|
from torchaudio.functional import resample |
|
audio = resample(audio, in_sr, out_sr[0]) |
|
|
|
return audio |
|
|
|
|
|
def audio2loudness( |
|
audio_tensor: torch.Tensor |
|
) -> (torch.Tensor, None): |
|
assert audio_tensor.dim() == 1, f'waveform must be 1D, but got {audio_tensor.dim()}D' |
|
audio_tensor = audio_tensor.abs() |
|
k = int(audio_tensor.numel() * 0.001) |
|
if k: |
|
top_values, _ = torch.topk(audio_tensor, k) |
|
threshold = top_values[-1] |
|
else: |
|
threshold = audio_tensor.quantile(0.999, dim=-1) |
|
if (token_count := round(audio_tensor.shape[-1] / N_SAMPLES_PER_TOKEN)+1) > 2: |
|
if threshold < 1e-5: |
|
return torch.zeros(token_count, dtype=audio_tensor.dtype, device=audio_tensor.device) |
|
audio_tensor = audio_tensor / min(1., threshold * 1.75) |
|
audio_tensor = F.interpolate( |
|
audio_tensor[None, None], |
|
size=token_count, |
|
mode='linear', |
|
align_corners=False |
|
)[0, 0] |
|
return audio_tensor |
|
|
|
|
|
def visualize_mask( |
|
loudness_tensor: torch.Tensor, |
|
silence_mask: torch.Tensor = None, |
|
width: int = 1500, |
|
height: int = 200, |
|
output: str = None, |
|
): |
|
no_silence = silence_mask is None or not silence_mask.any() |
|
assert no_silence or silence_mask.shape[0] == loudness_tensor.shape[0] |
|
if loudness_tensor.shape[0] < 2: |
|
raise NotImplementedError(f'audio size, {loudness_tensor.shape[0]}, is too short to visualize') |
|
else: |
|
width = loudness_tensor.shape[0] if width == -1 else width |
|
im = torch.zeros((height, width, 3), dtype=torch.uint8) |
|
mid = round(height / 2) |
|
for i, j in enumerate(loudness_tensor.tolist()): |
|
j = round(abs(j) * mid) |
|
if j == 0 or width <= i: |
|
continue |
|
im[mid - j:mid + 1, i] = 255 |
|
im[mid + 1:mid + j + 1, i] = 255 |
|
if not no_silence: |
|
im[:, silence_mask[:width], 1:] = 0 |
|
im = im.cpu().numpy() |
|
if output and not output.endswith('.png'): |
|
output += '.png' |
|
try: |
|
from PIL import Image |
|
except ModuleNotFoundError: |
|
try: |
|
import cv2 |
|
except ModuleNotFoundError: |
|
raise ModuleNotFoundError('Failed to import "PIL" or "cv2" to visualize suppression mask. ' |
|
'Try "pip install Pillow" or "pip install opencv-python"') |
|
else: |
|
im = im[..., [2, 1, 0]] |
|
if isinstance(output, str): |
|
cv2.imwrite(output, im) |
|
else: |
|
cv2.imshow('image', im) |
|
cv2.waitKey(0) |
|
else: |
|
im = Image.fromarray(im) |
|
if isinstance(output, str): |
|
im.save(output) |
|
else: |
|
im.show(im) |
|
if output: |
|
print(f'Save: {output}') |
|
|
|
|
|
def wav2mask( |
|
audio: (torch.Tensor, np.ndarray, str, bytes), |
|
q_levels: int = 20, |
|
k_size: int = 5, |
|
sr: int = None |
|
) -> (Tuple[torch.Tensor, Tuple[np.ndarray, np.ndarray]], None): |
|
""" |
|
Generate 1D mask from waveform for suppressing timestamp tokens. |
|
""" |
|
audio = standardize_audio(audio, (sr, NONVAD_SAMPLE_RATES)) |
|
loudness_tensor = audio2loudness(audio) |
|
if loudness_tensor is None: |
|
return |
|
p = k_size // 2 if k_size else 0 |
|
if p and p < loudness_tensor.shape[-1]: |
|
assert k_size % 2, f'kernel_size must be odd but got {k_size}' |
|
mask = torch.avg_pool1d( |
|
F.pad( |
|
loudness_tensor[None], |
|
(p, p), |
|
'reflect' |
|
), |
|
kernel_size=k_size, |
|
stride=1 |
|
)[0] |
|
else: |
|
mask = loudness_tensor.clone() |
|
|
|
if q_levels: |
|
mask = mask.mul(q_levels).round() |
|
|
|
mask = mask.bool() |
|
|
|
if not mask.any(): |
|
return ~mask |
|
temp_timings = mask2timing(mask) |
|
s, e = temp_timings |
|
se_mask = (e - s) > 0.1 |
|
s = s[se_mask] |
|
e = e[se_mask] |
|
mask = ~timing2mask(s, e, loudness_tensor.shape[-1]) |
|
|
|
if not mask.any(): |
|
return |
|
|
|
return mask |
|
|
|
|
|
_model_cache = {} |
|
|
|
|
|
def get_vad_silence_func( |
|
onnx=False, |
|
verbose: (bool, None) = False |
|
): |
|
if onnx in _model_cache: |
|
model, get_ts = _model_cache[onnx] |
|
else: |
|
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad:master', |
|
model='silero_vad', |
|
verbose=verbose, |
|
onnx=onnx, |
|
trust_repo=True) |
|
get_ts = utils[0] |
|
_model_cache[onnx] = (model, get_ts) |
|
|
|
warnings.filterwarnings('ignore', message=r'operator \(\) profile_node.*', category=UserWarning) |
|
|
|
def get_speech_timestamps(wav: torch.Tensor, threshold: float = .35): |
|
return get_ts(wav, model, threshold, min_speech_duration_ms=100, min_silence_duration_ms=20) |
|
|
|
def vad_silence_timing( |
|
audio: (torch.Tensor, np.ndarray, str, bytes), |
|
speech_threshold: float = .35, |
|
sr: int = None |
|
) -> (Tuple[np.ndarray, np.ndarray], None): |
|
|
|
audio = standardize_audio(audio, (sr, VAD_SAMPLE_RATES)) |
|
|
|
total_duration = round(audio.shape[-1] / SAMPLE_RATE, 3) |
|
if not total_duration: |
|
return |
|
ori_t = torch.get_num_threads() |
|
if verbose is not None: |
|
print('Predicting silences(s) with VAD...\r', end='') |
|
torch.set_num_threads(1) |
|
speech_ts = get_speech_timestamps(audio, speech_threshold) |
|
if verbose is not None: |
|
print('Predicted silence(s) with VAD. ') |
|
torch.set_num_threads(ori_t) |
|
if len(speech_ts) == 0: |
|
return np.array([0.0]), np.array([total_duration]) |
|
silent_starts = [] |
|
silent_ends = [] |
|
for ts in speech_ts: |
|
start = round(ts['start'] / SAMPLE_RATE, 3) |
|
end = round(ts['end'] / SAMPLE_RATE, 3) |
|
if start != 0: |
|
silent_ends.append(start) |
|
if len(silent_starts) == 0: |
|
silent_starts.append(0.0) |
|
if end < total_duration: |
|
silent_starts.append(end) |
|
|
|
if len(silent_starts) == 0 and len(silent_ends) == 0: |
|
return |
|
|
|
if len(silent_starts) != 0 and (len(silent_ends) == 0 or silent_ends[-1] < silent_starts[-1]): |
|
silent_ends.append(total_duration) |
|
|
|
silent_starts = np.array(silent_starts) |
|
silent_ends = np.array(silent_ends) |
|
|
|
return silent_starts, silent_ends |
|
|
|
return vad_silence_timing |
|
|
|
|
|
def visualize_suppression( |
|
audio: Union[torch.Tensor, np.ndarray, str, bytes], |
|
output: str = None, |
|
q_levels: int = 20, |
|
k_size: int = 5, |
|
vad_threshold: float = 0.35, |
|
vad: bool = False, |
|
max_width: int = 1500, |
|
height: int = 200 |
|
): |
|
""" |
|
Visualize regions on the waveform of ``audio`` detected as silent. |
|
|
|
Regions on the waveform colored red are detected as silent. |
|
|
|
Parameters |
|
---------- |
|
audio : str or numpy.ndarray or torch.Tensor or bytes |
|
Path/URL to the audio file, the audio waveform, or bytes of audio file. |
|
If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz. |
|
output : str, default None, meaning image will be shown directly via Pillow or opencv-python |
|
Path to save visualization. |
|
q_levels : int, default 20 |
|
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``. |
|
Acts as a threshold to marking sound as silent. |
|
Fewer levels will increase the threshold of volume at which to mark a sound as silent. |
|
k_size : int, default 5 |
|
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``. |
|
Recommend 5 or 3; higher sizes will reduce detection of silence. |
|
vad : bool, default False |
|
Whether to use Silero VAD to generate timestamp suppression mask. |
|
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad. |
|
vad_threshold : float, default 0.35 |
|
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection. |
|
max_width : int, default 1500 |
|
Maximum width of visualization to avoid overly large image from long audio. |
|
Each unit of pixel is equivalent to 1 token. Use -1 to visualize the entire audio track. |
|
height : int, default 200 |
|
Height of visualization. |
|
""" |
|
max_n_samples = None if max_width == -1 else round(max_width * N_SAMPLES_PER_TOKEN) |
|
|
|
audio = standardize_audio(audio) |
|
if max_n_samples is None: |
|
max_width = audio.shape[-1] |
|
else: |
|
audio = audio[:max_n_samples] |
|
loudness_tensor = audio2loudness(audio) |
|
width = min(max_width, loudness_tensor.shape[-1]) |
|
if loudness_tensor is None: |
|
raise NotImplementedError(f'Audio is too short and cannot visualized.') |
|
|
|
if vad: |
|
silence_timings = get_vad_silence_func()(audio, vad_threshold) |
|
silence_mask = None if silence_timings is None else timing2mask(*silence_timings, size=loudness_tensor.shape[0]) |
|
else: |
|
silence_mask = wav2mask(audio, q_levels=q_levels, k_size=k_size) |
|
|
|
visualize_mask(loudness_tensor, silence_mask, width=width, height=height, output=output) |
|
|