stable-ts / stable_whisper /stabilization.py
Rolando
Set it up
8718761
import warnings
from typing import List, Union, Tuple, Optional
from itertools import chain
import torch
import torch.nn.functional as F
import numpy as np
from whisper.audio import TOKENS_PER_SECOND, SAMPLE_RATE, N_SAMPLES_PER_TOKEN
NONVAD_SAMPLE_RATES = (16000,)
VAD_SAMPLE_RATES = (16000, 8000)
def is_ascending_sequence(
seq: List[Union[int, float]],
verbose=True
) -> bool:
"""
check if a sequence of numbers are in ascending order
"""
is_ascending = True
for idx, (i, j) in enumerate(zip(seq[:-1], seq[1:])):
if i > j:
is_ascending = False
if verbose:
print(f'[Index{idx}]:{i} > [Index{idx + 1}]:{j}')
else:
break
return is_ascending
def valid_ts(
ts: List[dict],
warn=True
) -> bool:
valid = is_ascending_sequence(list(chain.from_iterable([s['start'], s['end']] for s in ts)), False)
if warn and not valid:
warnings.warn(message='Found timestamp(s) jumping backwards in time. '
'Use word_timestamps=True to avoid the issue.')
return valid
def mask2timing(
silence_mask: (np.ndarray, torch.Tensor),
time_offset: float = 0.0,
) -> (Tuple[np.ndarray, np.ndarray], None):
if silence_mask is None or not silence_mask.any():
return
assert silence_mask.ndim == 1
if isinstance(silence_mask, torch.Tensor):
silences = silence_mask.cpu().numpy().copy()
elif isinstance(silence_mask, np.ndarray):
silences = silence_mask.copy()
else:
raise NotImplementedError(f'Expected torch.Tensor or numpy.ndarray, but got {type(silence_mask)}')
silences[0] = False
silences[-1] = False
silent_starts = np.logical_and(~silences[:-1], silences[1:]).nonzero()[0] / TOKENS_PER_SECOND
silent_ends = (np.logical_and(silences[:-1], ~silences[1:]).nonzero()[0] + 1) / TOKENS_PER_SECOND
if time_offset:
silent_starts += time_offset
silent_ends += time_offset
return silent_starts, silent_ends
def timing2mask(
silent_starts: np.ndarray,
silent_ends: np.ndarray,
size: int,
time_offset: float = None
) -> torch.Tensor:
assert len(silent_starts) == len(silent_ends)
ts_token_mask = torch.zeros(size, dtype=torch.bool)
if time_offset:
silent_starts = (silent_starts - time_offset).clip(min=0)
silent_ends = (silent_ends - time_offset).clip(min=0)
mask_i = (silent_starts * TOKENS_PER_SECOND).round().astype(np.int16)
mask_e = (silent_ends * TOKENS_PER_SECOND).round().astype(np.int16)
for mi, me in zip(mask_i, mask_e):
ts_token_mask[mi:me+1] = True
return ts_token_mask
def suppress_silence(
result_obj,
silent_starts: Union[np.ndarray, List[float]],
silent_ends: Union[np.ndarray, List[float]],
min_word_dur: float,
nonspeech_error: float = 0.3,
keep_end: Optional[bool] = True
):
assert len(silent_starts) == len(silent_ends)
if len(silent_starts) == 0 or (result_obj.end - result_obj.start) <= min_word_dur:
return
if isinstance(silent_starts, list):
silent_starts = np.array(silent_starts)
if isinstance(silent_ends, list):
silent_ends = np.array(silent_ends)
start_overlaps = np.all(
(silent_starts <= result_obj.start, result_obj.start < silent_ends, silent_ends <= result_obj.end),
axis=0
).nonzero()[0].tolist()
if start_overlaps:
new_start = silent_ends[start_overlaps[0]]
result_obj.start = min(new_start, round(result_obj.end - min_word_dur, 3))
if (result_obj.end - result_obj.start) <= min_word_dur:
return
end_overlaps = np.all(
(result_obj.start <= silent_starts, silent_starts < result_obj.end, result_obj.end <= silent_ends),
axis=0
).nonzero()[0].tolist()
if end_overlaps:
new_end = silent_starts[end_overlaps[0]]
result_obj.end = max(new_end, round(result_obj.start + min_word_dur, 3))
if (result_obj.end - result_obj.start) <= min_word_dur:
return
if nonspeech_error:
matches = np.logical_and(
result_obj.start <= silent_starts,
result_obj.end >= silent_ends,
).nonzero()[0].tolist()
if len(matches) == 0:
return
silence_start = np.min(silent_starts[matches])
silence_end = np.max(silent_ends[matches])
start_extra = silence_start - result_obj.start
end_extra = result_obj.end - silence_end
silent_duration = silence_end - silence_start
start_within_error = (start_extra / silent_duration) <= nonspeech_error
end_within_error = (end_extra / silent_duration) <= nonspeech_error
if keep_end is None:
keep_end = start_extra <= end_extra
within_error = start_within_error if keep_end else end_within_error
else:
within_error = start_within_error or end_within_error
if within_error:
if keep_end:
result_obj.start = min(silence_end, round(result_obj.end - min_word_dur, 3))
else:
result_obj.end = max(silence_start, round(result_obj.start + min_word_dur, 3))
def standardize_audio(
audio: Union[torch.Tensor, np.ndarray, str, bytes],
resample_sr: Tuple[Optional[int], Union[int, Tuple[int]]] = None
) -> torch.Tensor:
if isinstance(audio, (str, bytes)):
from .audio import load_audio
audio = load_audio(audio)
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
audio = audio.float()
if resample_sr:
in_sr, out_sr = resample_sr
if in_sr:
if isinstance(out_sr, int):
out_sr = [out_sr]
if in_sr not in out_sr:
from torchaudio.functional import resample
audio = resample(audio, in_sr, out_sr[0])
return audio
def audio2loudness(
audio_tensor: torch.Tensor
) -> (torch.Tensor, None):
assert audio_tensor.dim() == 1, f'waveform must be 1D, but got {audio_tensor.dim()}D'
audio_tensor = audio_tensor.abs()
k = int(audio_tensor.numel() * 0.001)
if k:
top_values, _ = torch.topk(audio_tensor, k)
threshold = top_values[-1]
else:
threshold = audio_tensor.quantile(0.999, dim=-1)
if (token_count := round(audio_tensor.shape[-1] / N_SAMPLES_PER_TOKEN)+1) > 2:
if threshold < 1e-5:
return torch.zeros(token_count, dtype=audio_tensor.dtype, device=audio_tensor.device)
audio_tensor = audio_tensor / min(1., threshold * 1.75)
audio_tensor = F.interpolate(
audio_tensor[None, None],
size=token_count,
mode='linear',
align_corners=False
)[0, 0]
return audio_tensor
def visualize_mask(
loudness_tensor: torch.Tensor,
silence_mask: torch.Tensor = None,
width: int = 1500,
height: int = 200,
output: str = None,
):
no_silence = silence_mask is None or not silence_mask.any()
assert no_silence or silence_mask.shape[0] == loudness_tensor.shape[0]
if loudness_tensor.shape[0] < 2:
raise NotImplementedError(f'audio size, {loudness_tensor.shape[0]}, is too short to visualize')
else:
width = loudness_tensor.shape[0] if width == -1 else width
im = torch.zeros((height, width, 3), dtype=torch.uint8)
mid = round(height / 2)
for i, j in enumerate(loudness_tensor.tolist()):
j = round(abs(j) * mid)
if j == 0 or width <= i:
continue
im[mid - j:mid + 1, i] = 255
im[mid + 1:mid + j + 1, i] = 255
if not no_silence:
im[:, silence_mask[:width], 1:] = 0
im = im.cpu().numpy()
if output and not output.endswith('.png'):
output += '.png'
try:
from PIL import Image
except ModuleNotFoundError:
try:
import cv2
except ModuleNotFoundError:
raise ModuleNotFoundError('Failed to import "PIL" or "cv2" to visualize suppression mask. '
'Try "pip install Pillow" or "pip install opencv-python"')
else:
im = im[..., [2, 1, 0]]
if isinstance(output, str):
cv2.imwrite(output, im)
else:
cv2.imshow('image', im)
cv2.waitKey(0)
else:
im = Image.fromarray(im)
if isinstance(output, str):
im.save(output)
else:
im.show(im)
if output:
print(f'Save: {output}')
def wav2mask(
audio: (torch.Tensor, np.ndarray, str, bytes),
q_levels: int = 20,
k_size: int = 5,
sr: int = None
) -> (Tuple[torch.Tensor, Tuple[np.ndarray, np.ndarray]], None):
"""
Generate 1D mask from waveform for suppressing timestamp tokens.
"""
audio = standardize_audio(audio, (sr, NONVAD_SAMPLE_RATES))
loudness_tensor = audio2loudness(audio)
if loudness_tensor is None:
return
p = k_size // 2 if k_size else 0
if p and p < loudness_tensor.shape[-1]:
assert k_size % 2, f'kernel_size must be odd but got {k_size}'
mask = torch.avg_pool1d(
F.pad(
loudness_tensor[None],
(p, p),
'reflect'
),
kernel_size=k_size,
stride=1
)[0]
else:
mask = loudness_tensor.clone()
if q_levels:
mask = mask.mul(q_levels).round()
mask = mask.bool()
if not mask.any(): # entirely silent
return ~mask
temp_timings = mask2timing(mask)
s, e = temp_timings
se_mask = (e - s) > 0.1
s = s[se_mask]
e = e[se_mask]
mask = ~timing2mask(s, e, loudness_tensor.shape[-1])
if not mask.any(): # no silence
return
return mask
_model_cache = {}
def get_vad_silence_func(
onnx=False,
verbose: (bool, None) = False
):
if onnx in _model_cache:
model, get_ts = _model_cache[onnx]
else:
model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad:master',
model='silero_vad',
verbose=verbose,
onnx=onnx,
trust_repo=True)
get_ts = utils[0]
_model_cache[onnx] = (model, get_ts)
warnings.filterwarnings('ignore', message=r'operator \(\) profile_node.*', category=UserWarning)
def get_speech_timestamps(wav: torch.Tensor, threshold: float = .35):
return get_ts(wav, model, threshold, min_speech_duration_ms=100, min_silence_duration_ms=20)
def vad_silence_timing(
audio: (torch.Tensor, np.ndarray, str, bytes),
speech_threshold: float = .35,
sr: int = None
) -> (Tuple[np.ndarray, np.ndarray], None):
audio = standardize_audio(audio, (sr, VAD_SAMPLE_RATES))
total_duration = round(audio.shape[-1] / SAMPLE_RATE, 3)
if not total_duration:
return
ori_t = torch.get_num_threads()
if verbose is not None:
print('Predicting silences(s) with VAD...\r', end='')
torch.set_num_threads(1) # vad was optimized for single performance
speech_ts = get_speech_timestamps(audio, speech_threshold)
if verbose is not None:
print('Predicted silence(s) with VAD. ')
torch.set_num_threads(ori_t)
if len(speech_ts) == 0: # all silent
return np.array([0.0]), np.array([total_duration])
silent_starts = []
silent_ends = []
for ts in speech_ts:
start = round(ts['start'] / SAMPLE_RATE, 3)
end = round(ts['end'] / SAMPLE_RATE, 3)
if start != 0:
silent_ends.append(start)
if len(silent_starts) == 0:
silent_starts.append(0.0)
if end < total_duration:
silent_starts.append(end)
if len(silent_starts) == 0 and len(silent_ends) == 0:
return
if len(silent_starts) != 0 and (len(silent_ends) == 0 or silent_ends[-1] < silent_starts[-1]):
silent_ends.append(total_duration)
silent_starts = np.array(silent_starts)
silent_ends = np.array(silent_ends)
return silent_starts, silent_ends
return vad_silence_timing
def visualize_suppression(
audio: Union[torch.Tensor, np.ndarray, str, bytes],
output: str = None,
q_levels: int = 20,
k_size: int = 5,
vad_threshold: float = 0.35,
vad: bool = False,
max_width: int = 1500,
height: int = 200
):
"""
Visualize regions on the waveform of ``audio`` detected as silent.
Regions on the waveform colored red are detected as silent.
Parameters
----------
audio : str or numpy.ndarray or torch.Tensor or bytes
Path/URL to the audio file, the audio waveform, or bytes of audio file.
If audio is ``numpy.ndarray`` or ``torch.Tensor``, the audio must be already at sampled to 16kHz.
output : str, default None, meaning image will be shown directly via Pillow or opencv-python
Path to save visualization.
q_levels : int, default 20
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
Acts as a threshold to marking sound as silent.
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
k_size : int, default 5
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
Recommend 5 or 3; higher sizes will reduce detection of silence.
vad : bool, default False
Whether to use Silero VAD to generate timestamp suppression mask.
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
vad_threshold : float, default 0.35
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
max_width : int, default 1500
Maximum width of visualization to avoid overly large image from long audio.
Each unit of pixel is equivalent to 1 token. Use -1 to visualize the entire audio track.
height : int, default 200
Height of visualization.
"""
max_n_samples = None if max_width == -1 else round(max_width * N_SAMPLES_PER_TOKEN)
audio = standardize_audio(audio)
if max_n_samples is None:
max_width = audio.shape[-1]
else:
audio = audio[:max_n_samples]
loudness_tensor = audio2loudness(audio)
width = min(max_width, loudness_tensor.shape[-1])
if loudness_tensor is None:
raise NotImplementedError(f'Audio is too short and cannot visualized.')
if vad:
silence_timings = get_vad_silence_func()(audio, vad_threshold)
silence_mask = None if silence_timings is None else timing2mask(*silence_timings, size=loudness_tensor.shape[0])
else:
silence_mask = wav2mask(audio, q_levels=q_levels, k_size=k_size)
visualize_mask(loudness_tensor, silence_mask, width=width, height=height, output=output)