Rolando
Set it up
8718761
raw
history blame
91.2 kB
import warnings
import re
import torch
import numpy as np
from typing import Union, List, Tuple, Optional, Callable
from dataclasses import dataclass
from copy import deepcopy
from itertools import chain
from .stabilization import suppress_silence, get_vad_silence_func, mask2timing, wav2mask
from .text_output import *
from .utils import str_to_valid_type, format_timestamp, UnsortedException
__all__ = ['WhisperResult', 'Segment']
def _combine_attr(obj: object, other_obj: object, attr: str):
if (val := getattr(obj, attr)) is not None:
other_val = getattr(other_obj, attr)
if isinstance(val, list):
if other_val is None:
setattr(obj, attr, None)
else:
val.extend(other_val)
else:
new_val = None if other_val is None else ((val + other_val) / 2)
setattr(obj, attr, new_val)
def _increment_attr(obj: object, attr: str, val: Union[int, float]):
if (curr_val := getattr(obj, attr, None)) is not None:
setattr(obj, attr, curr_val + val)
@dataclass
class WordTiming:
word: str
start: float
end: float
probability: float = None
tokens: List[int] = None
left_locked: bool = False
right_locked: bool = False
segment_id: Optional[int] = None
id: Optional[int] = None
def __len__(self):
return len(self.word)
def __add__(self, other: 'WordTiming'):
self_copy = deepcopy(self)
self_copy.start = min(self_copy.start, other.start)
self_copy.end = max(other.end, self_copy.end)
self_copy.word += other.word
self_copy.left_locked = self_copy.left_locked or other.left_locked
self_copy.right_locked = self_copy.right_locked or other.right_locked
_combine_attr(self_copy, other, 'probability')
_combine_attr(self_copy, other, 'tokens')
return self_copy
def __deepcopy__(self, memo=None):
return self.copy()
def copy(self):
return WordTiming(
word=self.word,
start=self.start,
end=self.end,
probability=self.probability,
tokens=None if self.tokens is None else self.tokens.copy(),
left_locked=self.left_locked,
right_locked=self.right_locked,
segment_id=self.segment_id,
id=self.id
)
@property
def duration(self):
return round(self.end - self.start, 3)
def round_all_timestamps(self):
self.start = round(self.start, 3)
self.end = round(self.end, 3)
def offset_time(self, offset_seconds: float):
self.start = round(self.start + offset_seconds, 3)
self.end = round(self.end + offset_seconds, 3)
def to_dict(self):
dict_ = deepcopy(self).__dict__
dict_.pop('left_locked')
dict_.pop('right_locked')
return dict_
def lock_left(self):
self.left_locked = True
def lock_right(self):
self.right_locked = True
def lock_both(self):
self.lock_left()
self.lock_right()
def unlock_both(self):
self.left_locked = False
self.right_locked = False
def suppress_silence(self,
silent_starts: np.ndarray,
silent_ends: np.ndarray,
min_word_dur: float = 0.1,
nonspeech_error: float = 0.3,
keep_end: Optional[bool] = True):
suppress_silence(self, silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end)
return self
def rescale_time(self, scale_factor: float):
self.start = round(self.start * scale_factor, 3)
self.end = round(self.end * scale_factor, 3)
def clamp_max(self, max_dur: float, clip_start: bool = False, verbose: bool = False):
if self.duration > max_dur:
if clip_start:
new_start = round(self.end - max_dur, 3)
if verbose:
print(f'Start: {self.start} -> {new_start}\nEnd: {self.end}\nText:"{self.word}"\n')
self.start = new_start
else:
new_end = round(self.start + max_dur, 3)
if verbose:
print(f'Start: {self.start}\nEnd: {self.end} -> {new_end}\nText:"{self.word}"\n')
self.end = new_end
def set_segment(self, segment: 'Segment'):
self._segment = segment
def get_segment(self) -> Union['Segment', None]:
"""
Return instance of :class:`stable_whisper.result.Segment` that this instance is a part of.
"""
return getattr(self, '_segment', None)
def _words_by_lock(words: List[WordTiming], only_text: bool = False, include_single: bool = False):
"""
Return a nested list of words such that each sublist contains words that are locked together.
"""
all_words = []
for word in words:
if len(all_words) == 0 or not (all_words[-1][-1].right_locked or word.left_locked):
all_words.append([word])
else:
all_words[-1].append(word)
if only_text:
all_words = list(map(lambda ws: list(map(lambda w: w.word, ws)), all_words))
if not include_single:
all_words = [ws for ws in all_words if len(ws) > 1]
return all_words
@dataclass
class Segment:
start: float
end: float
text: str
seek: float = None
tokens: List[int] = None
temperature: float = None
avg_logprob: float = None
compression_ratio: float = None
no_speech_prob: float = None
words: Union[List[WordTiming], List[dict]] = None
ori_has_words: bool = None
id: int = None
def __getitem__(self, index: int) -> WordTiming:
if self.words is None:
raise ValueError('segment contains no words')
return self.words[index]
def __delitem__(self, index: int):
if self.words is None:
raise ValueError('segment contains no words')
del self.words[index]
self.reassign_ids()
self.update_seg_with_words()
def __deepcopy__(self, memo=None):
return self.copy()
def copy(self, new_words: Optional[List[WordTiming]] = None):
if new_words is None:
words = None if self.words is None else [w.copy() for w in self.words]
else:
words = [w.copy() for w in new_words]
new_seg = Segment(
start=self.start,
end=self.end,
text=self.text,
seek=self.seek,
tokens=self.tokens,
temperature=self.temperature,
avg_logprob=self.avg_logprob,
compression_ratio=self.compression_ratio,
no_speech_prob=self.no_speech_prob,
words=words,
id=self.id
)
new_seg.update_seg_with_words()
return new_seg
def to_display_str(self, only_segment: bool = False):
line = f'[{format_timestamp(self.start)} --> {format_timestamp(self.end)}] "{self.text}"'
if self.has_words and not only_segment:
line += '\n' + '\n'.join(
f"-[{format_timestamp(w.start)}] -> [{format_timestamp(w.end)}] \"{w.word}\"" for w in self.words
) + '\n'
return line
@property
def has_words(self):
return bool(self.words)
@property
def duration(self):
return self.end - self.start
def word_count(self):
if self.has_words:
return len(self.words)
return -1
def char_count(self):
if self.has_words:
return sum(len(w) for w in self.words)
return len(self.text)
def __post_init__(self):
if self.has_words:
self.words: List[WordTiming] = \
[WordTiming(**word) if isinstance(word, dict) else word for word in self.words]
for w in self.words:
w.set_segment(self)
if self.ori_has_words is None:
self.ori_has_words = self.has_words
self.round_all_timestamps()
def __add__(self, other: 'Segment'):
self_copy = deepcopy(self)
self_copy.start = min(self_copy.start, other.start)
self_copy.end = max(other.end, self_copy.end)
self_copy.text += other.text
_combine_attr(self_copy, other, 'tokens')
_combine_attr(self_copy, other, 'temperature')
_combine_attr(self_copy, other, 'avg_logprob')
_combine_attr(self_copy, other, 'compression_ratio')
_combine_attr(self_copy, other, 'no_speech_prob')
if self_copy.has_words:
if other.has_words:
self_copy.words.extend(other.words)
else:
self_copy.words = None
return self_copy
def _word_operations(self, operation: str, *args, **kwargs):
if self.has_words:
for w in self.words:
getattr(w, operation)(*args, **kwargs)
def round_all_timestamps(self):
self.start = round(self.start, 3)
self.end = round(self.end, 3)
if self.has_words:
for word in self.words:
word.round_all_timestamps()
def offset_time(self, offset_seconds: float):
self.start = round(self.start + offset_seconds, 3)
self.end = round(self.end + offset_seconds, 3)
_increment_attr(self, 'seek', offset_seconds)
self._word_operations('offset_time', offset_seconds)
def add_words(self, index0: int, index1: int, inplace: bool = False):
if self.has_words:
new_word = self.words[index0] + self.words[index1]
if inplace:
i0, i1 = sorted([index0, index1])
self.words[i0] = new_word
del self.words[i1]
return new_word
def rescale_time(self, scale_factor: float):
self.start = round(self.start * scale_factor, 3)
self.end = round(self.end * scale_factor, 3)
if self.seek is not None:
self.seek = round(self.seek * scale_factor, 3)
self._word_operations('rescale_time', scale_factor)
self.update_seg_with_words()
def apply_min_dur(self, min_dur: float, inplace: bool = False):
"""
Merge any word with adjacent word if its duration is less than ``min_dur``.
"""
segment = self if inplace else deepcopy(self)
if not self.has_words:
return segment
max_i = len(segment.words) - 1
if max_i == 0:
return segment
for i in reversed(range(len(segment.words))):
if max_i == 0:
break
if segment.words[i].duration < min_dur:
if i == max_i:
segment.add_words(i-1, i, inplace=True)
elif i == 0:
segment.add_words(i, i+1, inplace=True)
else:
if segment.words[i+1].duration < segment.words[i-1].duration:
segment.add_words(i-1, i, inplace=True)
else:
segment.add_words(i, i+1, inplace=True)
max_i -= 1
return segment
def _to_reverse_text(
self,
prepend_punctuations: str = None,
append_punctuations: str = None
):
"""
Return a copy with words reversed order per segment.
"""
if prepend_punctuations is None:
prepend_punctuations = "\"'“¿([{-"
if prepend_punctuations and ' ' not in prepend_punctuations:
prepend_punctuations += ' '
if append_punctuations is None:
append_punctuations = "\"'.。,,!!??::”)]}、"
self_copy = deepcopy(self)
has_prepend = bool(prepend_punctuations)
has_append = bool(append_punctuations)
if has_prepend or has_append:
word_objs = (
self_copy.words
if self_copy.has_words else
[WordTiming(w, 0, 1, 0) for w in self_copy.text.split(' ')]
)
for word in word_objs:
new_append = ''
if has_prepend:
for _ in range(len(word)):
char = word.word[0]
if char in prepend_punctuations:
new_append += char
word.word = word.word[1:]
else:
break
new_prepend = ''
if has_append:
for _ in range(len(word)):
char = word.word[-1]
if char in append_punctuations:
new_prepend += char
word.word = word.word[:-1]
else:
break
word.word = f'{new_prepend}{word.word}{new_append[::-1]}'
self_copy.text = ''.join(w.word for w in reversed(word_objs))
return self_copy
def to_dict(self, reverse_text: Union[bool, tuple] = False):
if reverse_text:
seg_dict = (
(self._to_reverse_text(*reverse_text)
if isinstance(reverse_text, tuple) else
self._to_reverse_text()).__dict__
)
else:
seg_dict = deepcopy(self).__dict__
seg_dict.pop('ori_has_words')
if self.has_words:
seg_dict['words'] = [w.to_dict() for w in seg_dict['words']]
elif self.ori_has_words:
seg_dict['words'] = []
else:
seg_dict.pop('words')
if self.id is None:
seg_dict.pop('id')
if reverse_text:
seg_dict['reversed_text'] = True
return seg_dict
def words_by_lock(self, only_text: bool = True, include_single: bool = False):
return _words_by_lock(self.words, only_text=only_text, include_single=include_single)
@property
def left_locked(self):
if self.has_words:
return self.words[0].left_locked
return False
@property
def right_locked(self):
if self.has_words:
return self.words[-1].right_locked
return False
def lock_left(self):
if self.has_words:
self.words[0].lock_left()
def lock_right(self):
if self.has_words:
self.words[-1].lock_right()
def lock_both(self):
self.lock_left()
self.lock_right()
def unlock_all_words(self):
self._word_operations('unlock_both')
def reassign_ids(self):
if self.has_words:
for i, w in enumerate(self.words):
w.segment_id = self.id
w.id = i
def update_seg_with_words(self):
if self.has_words:
self.start = self.words[0].start
self.end = self.words[-1].end
self.text = ''.join(w.word for w in self.words)
self.tokens = (
None
if any(w.tokens is None for w in self.words) else
[t for w in self.words for t in w.tokens]
)
for w in self.words:
w.set_segment(self)
def suppress_silence(self,
silent_starts: np.ndarray,
silent_ends: np.ndarray,
min_word_dur: float = 0.1,
word_level: bool = True,
nonspeech_error: float = 0.3,
use_word_position: bool = True):
if self.has_words:
words = self.words if word_level or len(self.words) == 1 else [self.words[0], self.words[-1]]
for i, w in enumerate(words, 1):
if use_word_position:
keep_end = True if i == 1 else (False if i == len(words) else None)
else:
keep_end = None
w.suppress_silence(silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end)
self.update_seg_with_words()
else:
suppress_silence(self,
silent_starts,
silent_ends,
min_word_dur,
nonspeech_error)
return self
def get_locked_indices(self):
locked_indices = [i
for i, (left, right) in enumerate(zip(self.words[1:], self.words[:-1]))
if left.left_locked or right.right_locked]
return locked_indices
def get_gaps(self, as_ndarray=False):
if self.has_words:
s_ts = np.array([w.start for w in self.words])
e_ts = np.array([w.end for w in self.words])
gap = s_ts[1:] - e_ts[:-1]
return gap if as_ndarray else gap.tolist()
return []
def get_gap_indices(self, max_gap: float = 0.1): # for splitting
if not self.has_words or len(self.words) < 2:
return []
if max_gap is None:
max_gap = 0
indices = (self.get_gaps(True) > max_gap).nonzero()[0].tolist()
return sorted(set(indices) - set(self.get_locked_indices()))
def get_punctuation_indices(self, punctuation: Union[List[str], List[Tuple[str, str]], str]): # for splitting
if not self.has_words or len(self.words) < 2:
return []
if isinstance(punctuation, str):
punctuation = [punctuation]
indices = []
for p in punctuation:
if isinstance(p, str):
for i, s in enumerate(self.words[:-1]):
if s.word.endswith(p):
indices.append(i)
elif i != 0 and s.word.startswith(p):
indices.append(i-1)
else:
ending, beginning = p
indices.extend([i for i, (w0, w1) in enumerate(zip(self.words[:-1], self.words[1:]))
if w0.word.endswith(ending) and w1.word.startswith(beginning)])
return sorted(set(indices) - set(self.get_locked_indices()))
def get_length_indices(self, max_chars: int = None, max_words: int = None, even_split: bool = True,
include_lock: bool = False):
# for splitting
if not self.has_words or (max_chars is None and max_words is None):
return []
assert max_chars != 0 and max_words != 0, \
f'max_chars and max_words must be greater 0, but got {max_chars} and {max_words}'
if len(self.words) < 2:
return []
indices = []
if even_split:
char_count = -1 if max_chars is None else sum(map(len, self.words))
word_count = -1 if max_words is None else len(self.words)
exceed_chars = max_chars is not None and char_count > max_chars
exceed_words = max_words is not None and word_count > max_words
if exceed_chars:
splits = np.ceil(char_count / max_chars)
chars_per_split = char_count / splits
cum_char_count = np.cumsum([len(w.word) for w in self.words[:-1]])
indices = [
(np.abs(cum_char_count-(i*chars_per_split))).argmin()
for i in range(1, int(splits))
]
if max_words is not None:
exceed_words = any(j-i+1 > max_words for i, j in zip([0]+indices, indices+[len(self.words)]))
if exceed_words:
splits = np.ceil(word_count / max_words)
words_per_split = word_count / splits
cum_word_count = np.array(range(1, len(self.words)+1))
indices = [
np.abs(cum_word_count-(i*words_per_split)).argmin()
for i in range(1, int(splits))
]
else:
curr_words = 0
curr_chars = 0
locked_indices = []
if include_lock:
locked_indices = self.get_locked_indices()
for i, word in enumerate(self.words):
curr_words += 1
curr_chars += len(word)
if i != 0:
if (
max_chars is not None and curr_chars > max_chars
or
max_words is not None and curr_words > max_words
) and i-1 not in locked_indices:
indices.append(i-1)
curr_words = 1
curr_chars = len(word)
return indices
def get_duration_indices(self, max_dur: float, even_split: bool = True, include_lock: bool = False):
if not self.has_words or (total_duration := np.sum([w.duration for w in self.words])) <= max_dur:
return []
if even_split:
splits = np.ceil(total_duration / max_dur)
dur_per_split = total_duration / splits
cum_dur = np.cumsum([w.duration for w in self.words[:-1]])
indices = [
(np.abs(cum_dur - (i * dur_per_split))).argmin()
for i in range(1, int(splits))
]
else:
indices = []
curr_total_dur = 0.0
locked_indices = self.get_locked_indices() if include_lock else []
for i, word in enumerate(self.words):
curr_total_dur += word.duration
if i != 0:
if curr_total_dur > max_dur and i - 1 not in locked_indices:
indices.append(i - 1)
curr_total_dur = word.duration
return indices
def split(self, indices: List[int]):
if len(indices) == 0:
return []
if indices[-1] != len(self.words) - 1:
indices.append(len(self.words) - 1)
seg_copies = []
prev_i = 0
for i in indices:
i += 1
c = deepcopy(self)
c.words = c.words[prev_i:i]
c.update_seg_with_words()
seg_copies.append(c)
prev_i = i
return seg_copies
def set_result(self, result: 'WhisperResult'):
self._result = result
def get_result(self) -> Union['WhisperResult', None]:
"""
Return outer instance of :class:`stable_whisper.result.WhisperResult` that ``self`` is a part of.
"""
return getattr(self, '_result', None)
class WhisperResult:
def __init__(
self,
result: Union[str, dict, list],
force_order: bool = False,
check_sorted: Union[bool, str] = True,
show_unsorted: bool = True
):
result, self.path = self._standardize_result(result)
self.ori_dict = result.get('ori_dict') or result
self.language = self.ori_dict.get('language')
self._regroup_history = result.get('regroup_history', '')
self._nonspeech_sections = result.get('nonspeech_sections', [])
segments = deepcopy(result.get('segments', self.ori_dict.get('segments')))
self.segments: List[Segment] = [Segment(**s) for s in segments] if segments else []
self._forced_order = force_order
if self._forced_order:
self.force_order()
self.raise_for_unsorted(check_sorted, show_unsorted)
self.remove_no_word_segments(any(seg.has_words for seg in self.segments))
self.update_all_segs_with_words()
def __getitem__(self, index: int) -> Segment:
return self.segments[index]
def __delitem__(self, index: int):
del self.segments[index]
self.reassign_ids(True)
@staticmethod
def _standardize_result(result: Union[str, dict, list]):
path = None
if isinstance(result, str):
path = result
result = load_result(path)
if isinstance(result, list):
if isinstance(result[0], list):
if not isinstance(result[0][0], dict):
raise NotImplementedError(f'Got list of list of {type(result[0])} but expects list of list of dict')
result = dict(
segments=[
dict(
start=words[0]['start'],
end=words[-1]['end'],
text=''.join(w['word'] for w in words),
words=words
)
for words in result
]
)
elif isinstance(result[0], dict):
result = dict(segments=result)
else:
raise NotImplementedError(f'Got list of {type(result[0])} but expects list of list/dict')
return result, path
def force_order(self):
prev_ts_end = 0
timestamps = self.all_words_or_segments()
for i, ts in enumerate(timestamps, 1):
if ts.start < prev_ts_end:
ts.start = prev_ts_end
if ts.start > ts.end:
if prev_ts_end > ts.end:
warnings.warn('Multiple consecutive timestamps are out of order. Some parts will have no duration.')
ts.start = ts.end
for j in range(i-2, -1, -1):
if timestamps[j].end > ts.end:
timestamps[j].end = ts.end
if timestamps[j].start > ts.end:
timestamps[j].start = ts.end
else:
if ts.start != prev_ts_end:
ts.start = prev_ts_end
else:
ts.end = ts.start if i == len(timestamps) else timestamps[i].start
prev_ts_end = ts.end
if self.has_words:
self.update_all_segs_with_words()
def raise_for_unsorted(self, check_sorted: Union[bool, str] = True, show_unsorted: bool = True):
if check_sorted is False:
return
all_parts = self.all_words_or_segments()
has_words = self.has_words
timestamps = np.array(list(chain.from_iterable((p.start, p.end) for p in all_parts)))
if len(timestamps) > 1 and (unsorted_mask := timestamps[:-1] > timestamps[1:]).any():
if show_unsorted:
def get_part_info(idx):
curr_part = all_parts[idx]
seg_id = curr_part.segment_id if has_words else curr_part.id
word_id_str = f'Word ID: {curr_part.id}\n' if has_words else ''
return (
f'Segment ID: {seg_id}\n{word_id_str}'
f'Start: {curr_part.start}\nEnd: {curr_part.end}\n'
f'Text: "{curr_part.word if has_words else curr_part.text}"'
), curr_part.start, curr_part.end
for i, unsorted in enumerate(unsorted_mask, 2):
if unsorted:
word_id = i//2-1
part_info, start, end = get_part_info(word_id)
if i % 2 == 1:
next_info, next_start, _ = get_part_info(word_id+1)
part_info += f'\nConflict: end ({end}) > next start ({next_start})\n{next_info}'
else:
part_info += f'\nConflict: start ({start}) > end ({end})'
print(part_info, end='\n\n')
data = self.to_dict()
if check_sorted is True:
raise UnsortedException(data=data)
warnings.warn('Timestamps are not in ascending order. '
'If data is produced by Stable-ts, please submit an issue with the saved data.')
save_as_json(data, check_sorted)
def update_all_segs_with_words(self):
for seg in self.segments:
seg.update_seg_with_words()
seg.set_result(self)
def update_nonspeech_sections(self, silent_starts, silent_ends):
self._nonspeech_sections = [dict(start=s, end=e) for s, e in zip(silent_starts, silent_ends)]
def add_segments(self, index0: int, index1: int, inplace: bool = False, lock: bool = False):
new_seg = self.segments[index0] + self.segments[index1]
new_seg.update_seg_with_words()
if lock and self.segments[index0].has_words:
lock_idx = len(self.segments[index0].words)
new_seg.words[lock_idx - 1].lock_right()
if lock_idx < len(new_seg.words):
new_seg.words[lock_idx].lock_left()
if inplace:
i0, i1 = sorted([index0, index1])
self.segments[i0] = new_seg
del self.segments[i1]
return new_seg
def rescale_time(self, scale_factor: float):
for s in self.segments:
s.rescale_time(scale_factor)
def apply_min_dur(self, min_dur: float, inplace: bool = False):
"""
Merge any word/segment with adjacent word/segment if its duration is less than ``min_dur``.
"""
result = self if inplace else deepcopy(self)
max_i = len(result.segments) - 1
if max_i == 0:
return result
for i in reversed(range(len(result.segments))):
if max_i == 0:
break
if result.segments[i].duration < min_dur:
if i == max_i:
result.add_segments(i-1, i, inplace=True)
elif i == 0:
result.add_segments(i, i+1, inplace=True)
else:
if result.segments[i+1].duration < result.segments[i-1].duration:
result.add_segments(i-1, i, inplace=True)
else:
result.add_segments(i, i+1, inplace=True)
max_i -= 1
result.reassign_ids()
for s in result.segments:
s.apply_min_dur(min_dur, inplace=True)
return result
def offset_time(self, offset_seconds: float):
for s in self.segments:
s.offset_time(offset_seconds)
def suppress_silence(
self,
silent_starts: np.ndarray,
silent_ends: np.ndarray,
min_word_dur: float = 0.1,
word_level: bool = True,
nonspeech_error: float = 0.3,
use_word_position: bool = True
) -> "WhisperResult":
"""
Move any start/end timestamps in silence parts of audio to the boundaries of the silence.
Parameters
----------
silent_starts : numpy.ndarray
An array starting timestamps of silent sections of audio.
silent_ends : numpy.ndarray
An array ending timestamps of silent sections of audio.
min_word_dur : float, default 0.1
Shortest duration each word is allowed to reach for adjustments.
word_level : bool, default False
Whether to settings to word level timestamps.
nonspeech_error : float, default 0.3
Relative error of non-speech sections that appear in between a word for adjustments.
use_word_position : bool, default True
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
for s in self.segments:
s.suppress_silence(
silent_starts,
silent_ends,
min_word_dur,
word_level=word_level,
nonspeech_error=nonspeech_error,
use_word_position=use_word_position
)
return self
def adjust_by_silence(
self,
audio: Union[torch.Tensor, np.ndarray, str, bytes],
vad: bool = False,
*,
verbose: (bool, None) = False,
sample_rate: int = None,
vad_onnx: bool = False,
vad_threshold: float = 0.35,
q_levels: int = 20,
k_size: int = 5,
min_word_dur: float = 0.1,
word_level: bool = True,
nonspeech_error: float = 0.3,
use_word_position: bool = True
) -> "WhisperResult":
"""
Adjust timestamps base detected speech gaps.
This is method combines :meth:`stable_whisper.result.WhisperResult.suppress_silence` with silence detection.
Parameters
----------
audio : str or numpy.ndarray or torch.Tensor or bytes
Path/URL to the audio file, the audio waveform, or bytes of audio file.
vad : bool, default False
Whether to use Silero VAD to generate timestamp suppression mask.
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad.
verbose : bool or None, default False
If ``False``, mute messages about hitting local caches. Note that the message about first download cannot be
muted. Only applies if ``vad = True``.
sample_rate : int, default None, meaning ``whisper.audio.SAMPLE_RATE``, 16kHZ
The sample rate of ``audio``.
vad_onnx : bool, default False
Whether to use ONNX for Silero VAD.
vad_threshold : float, default 0.35
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection.
q_levels : int, default 20
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``.
Acts as a threshold to marking sound as silent.
Fewer levels will increase the threshold of volume at which to mark a sound as silent.
k_size : int, default 5
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``.
Recommend 5 or 3; higher sizes will reduce detection of silence.
min_word_dur : float, default 0.1
Shortest duration each word is allowed to reach from adjustments.
word_level : bool, default False
Whether to settings to word level timestamps.
nonspeech_error : float, default 0.3
Relative error of non-speech sections that appear in between a word for adjustments.
use_word_position : bool, default True
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
Notes
-----
This operation is already performed by :func:`stable_whisper.whisper_word_level.transcribe_stable` /
:func:`stable_whisper.whisper_word_level.transcribe_minimal`/
:func:`stable_whisper.non_whisper.transcribe_any` / :func:`stable_whisper.alignment.align`
if ``suppress_silence = True``.
"""
if vad:
silent_timings = get_vad_silence_func(
onnx=vad_onnx,
verbose=verbose
)(audio, speech_threshold=vad_threshold, sr=sample_rate)
else:
silent_timings = mask2timing(
wav2mask(audio, q_levels=q_levels, k_size=k_size, sr=sample_rate)
)
if silent_timings is None:
return self
self.suppress_silence(
*silent_timings,
min_word_dur=min_word_dur,
word_level=word_level,
nonspeech_error=nonspeech_error,
use_word_position=use_word_position
)
self.update_nonspeech_sections(*silent_timings)
return self
def adjust_by_result(
self,
other_result: "WhisperResult",
min_word_dur: float = 0.1,
verbose: bool = False
):
"""
Minimize the duration of words using timestamps of another result.
Parameters
----------
other_result : "WhisperResult"
Timing data of the same words in a WhisperResult instance.
min_word_dur : float, default 0.1
Prevent changes to timestamps if the resultant word duration is less than ``min_word_dur``.
verbose : bool, default False
Whether to print out the timestamp changes.
"""
if not (self.has_words and other_result.has_words):
raise NotImplementedError('This operation can only be performed on results with word timestamps')
assert [w.word for w in self.all_words()] == [w.word for w in other_result.all_words()], \
'The words in [other_result] do not match the current words.'
for word, other_word in zip(self.all_words(), other_result.all_words()):
if word.end > other_word.start:
new_start = max(word.start, other_word.start)
new_end = min(word.end, other_word.end)
if new_end - new_start >= min_word_dur:
line = ''
if word.start != new_start:
if verbose:
line += f'[Start:{word.start:.3f}->{new_start:.3f}] '
word.start = new_start
if word.end != new_end:
if verbose:
line += f'[End:{word.end:.3f}->{new_end:.3f}] '
word.end = new_end
if line:
print(f'{line}"{word.word}"')
self.update_all_segs_with_words()
def reassign_ids(self, only_segments: bool = False):
for i, s in enumerate(self.segments):
s.id = i
if not only_segments:
s.reassign_ids()
def remove_no_word_segments(self, ignore_ori=False):
for i in reversed(range(len(self.segments))):
if (ignore_ori or self.segments[i].ori_has_words) and not self.segments[i].has_words:
del self.segments[i]
self.reassign_ids()
def get_locked_indices(self):
locked_indices = [i
for i, (left, right) in enumerate(zip(self.segments[1:], self.segments[:-1]))
if left.left_locked or right.right_locked]
return locked_indices
def get_gaps(self, as_ndarray=False):
s_ts = np.array([s.start for s in self.segments])
e_ts = np.array([s.end for s in self.segments])
gap = s_ts[1:] - e_ts[:-1]
return gap if as_ndarray else gap.tolist()
def get_gap_indices(self, min_gap: float = 0.1): # for merging
if len(self.segments) < 2:
return []
if min_gap is None:
min_gap = 0
indices = (self.get_gaps(True) <= min_gap).nonzero()[0].tolist()
return sorted(set(indices) - set(self.get_locked_indices()))
def get_punctuation_indices(self, punctuation: Union[List[str], List[Tuple[str, str]], str]): # for merging
if len(self.segments) < 2:
return []
if isinstance(punctuation, str):
punctuation = [punctuation]
indices = []
for p in punctuation:
if isinstance(p, str):
for i, s in enumerate(self.segments[:-1]):
if s.text.endswith(p):
indices.append(i)
elif i != 0 and s.text.startswith(p):
indices.append(i-1)
else:
ending, beginning = p
indices.extend([i for i, (s0, s1) in enumerate(zip(self.segments[:-1], self.segments[1:]))
if s0.text.endswith(ending) and s1.text.startswith(beginning)])
return sorted(set(indices) - set(self.get_locked_indices()))
def all_words(self):
return list(chain.from_iterable(s.words for s in self.segments))
def all_words_or_segments(self):
return self.all_words() if self.has_words else self.segments
def all_words_by_lock(self, only_text: bool = True, by_segment: bool = False, include_single: bool = False):
if by_segment:
return [
segment.words_by_lock(only_text=only_text, include_single=include_single)
for segment in self.segments
]
return _words_by_lock(self.all_words(), only_text=only_text, include_single=include_single)
def all_tokens(self):
return list(chain.from_iterable(s.tokens for s in self.all_words()))
def to_dict(self):
return dict(text=self.text,
segments=self.segments_to_dicts(),
language=self.language,
ori_dict=self.ori_dict,
regroup_history=self._regroup_history,
nonspeech_sections=self._nonspeech_sections)
def segments_to_dicts(self, reverse_text: Union[bool, tuple] = False):
return [s.to_dict(reverse_text=reverse_text) for s in self.segments]
def _split_segments(self, get_indices, args: list = None, *, lock: bool = False, newline: bool = False):
if args is None:
args = []
no_words = False
for i in reversed(range(0, len(self.segments))):
no_words = no_words or not self.segments[i].has_words
indices = sorted(set(get_indices(self.segments[i], *args)))
if not indices:
continue
if newline:
if indices[-1] == len(self.segments[i].words) - 1:
del indices[-1]
if not indices:
continue
for word_idx in indices:
if self.segments[i].words[word_idx].word.endswith('\n'):
continue
self.segments[i].words[word_idx].word += '\n'
if lock:
self.segments[i].words[word_idx].lock_right()
if word_idx + 1 < len(self.segments[i].words):
self.segments[i].words[word_idx+1].lock_left()
self.segments[i].update_seg_with_words()
else:
new_segments = self.segments[i].split(indices)
if lock:
for s in new_segments:
if s == new_segments[0]:
s.lock_right()
elif s == new_segments[-1]:
s.lock_left()
else:
s.lock_both()
del self.segments[i]
for s in reversed(new_segments):
self.segments.insert(i, s)
if no_words:
warnings.warn('Found segment(s) without word timings. These segment(s) cannot be split.')
self.remove_no_word_segments()
def _merge_segments(self, indices: List[int],
*, max_words: int = None, max_chars: int = None, is_sum_max: bool = False, lock: bool = False):
if len(indices) == 0:
return
for i in reversed(indices):
seg = self.segments[i]
if (
(
max_words and
seg.has_words and
(
(seg.word_count() + self.segments[i + 1].word_count() > max_words)
if is_sum_max else
(seg.word_count() > max_words and self.segments[i + 1].word_count() > max_words)
)
) or
(
max_chars and
(
(seg.char_count() + self.segments[i + 1].char_count() > max_chars)
if is_sum_max else
(seg.char_count() > max_chars and self.segments[i + 1].char_count() > max_chars)
)
)
):
continue
self.add_segments(i, i + 1, inplace=True, lock=lock)
self.remove_no_word_segments()
def get_content_by_time(
self,
time: Union[float, Tuple[float, float], dict],
within: bool = False,
segment_level: bool = False
) -> Union[List[WordTiming], List[Segment]]:
"""
Return content in the ``time`` range.
Parameters
----------
time : float or tuple of (float, float) or dict
Range of time to find content. For tuple of two floats, first value is the start time and second value is
the end time. For a single float value, it is treated as both the start and end time.
within : bool, default False
Whether to only find content fully overlaps with ``time`` range.
segment_level : bool, default False
Whether to look only on the segment level and return instances of :class:`stable_whisper.result.Segment`
instead of :class:`stable_whisper.result.WordTiming`.
Returns
-------
list of stable_whisper.result.WordTiming or list of stable_whisper.result.Segment
List of contents in the ``time`` range. The contents are instances of
:class:`stable_whisper.result.Segment` if ``segment_level = True`` else
:class:`stable_whisper.result.WordTiming`.
"""
if not segment_level and not self.has_words:
raise ValueError('Missing word timestamps in result. Use ``segment_level=True`` instead.')
contents = self.segments if segment_level else self.all_words()
if isinstance(time, (float, int)):
time = [time, time]
elif isinstance(time, dict):
time = [time['start'], time['end']]
start, end = time
if within:
def is_in_range(c):
return start <= c.start and end >= c.end
else:
def is_in_range(c):
return start <= c.end and end >= c.start
return [c for c in contents if is_in_range(c)]
def split_by_gap(
self,
max_gap: float = 0.1,
lock: bool = False,
newline: bool = False
) -> "WhisperResult":
"""
Split (in-place) any segment where the gap between two of its words is greater than ``max_gap``.
Parameters
----------
max_gap : float, default 0.1
Maximum second(s) allowed between two words if the same segment.
lock : bool, default False
Whether to prevent future splits/merges from altering changes made by this method.
newline: bool, default False
Whether to insert line break at the split points instead of splitting into separate segments.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
self._split_segments(lambda x: x.get_gap_indices(max_gap), lock=lock, newline=newline)
if self._regroup_history:
self._regroup_history += '_'
self._regroup_history += f'sg={max_gap}+{int(lock)}+{int(newline)}'
return self
def merge_by_gap(
self,
min_gap: float = 0.1,
max_words: int = None,
max_chars: int = None,
is_sum_max: bool = False,
lock: bool = False
) -> "WhisperResult":
"""
Merge (in-place) any pair of adjacent segments if the gap between them <= ``min_gap``.
Parameters
----------
min_gap : float, default 0.1
Minimum second(s) allow between two segment.
max_words : int, optional
Maximum number of words allowed in each segment.
max_chars : int, optional
Maximum number of characters allowed in each segment.
is_sum_max : bool, default False
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments
to be merged.
lock : bool, default False
Whether to prevent future splits/merges from altering changes made by this method.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
indices = self.get_gap_indices(min_gap)
self._merge_segments(indices,
max_words=max_words, max_chars=max_chars, is_sum_max=is_sum_max, lock=lock)
if self._regroup_history:
self._regroup_history += '_'
self._regroup_history += f'mg={min_gap}+{max_words or ""}+{max_chars or ""}+{int(is_sum_max)}+{int(lock)}'
return self
def split_by_punctuation(
self,
punctuation: Union[List[str], List[Tuple[str, str]], str],
lock: bool = False,
newline: bool = False,
min_words: Optional[int] = None,
min_chars: Optional[int] = None,
min_dur: Optional[int] = None
) -> "WhisperResult":
"""
Split (in-place) segments at words that start/end with ``punctuation``.
Parameters
----------
punctuation : list of str of list of tuple of (str, str) or str
Punctuation(s) to split segments by.
lock : bool, default False
Whether to prevent future splits/merges from altering changes made by this method.
newline : bool, default False
Whether to insert line break at the split points instead of splitting into separate segments.
min_words : int, optional
Split segments with words >= ``min_words``.
min_chars : int, optional
Split segments with characters >= ``min_chars``.
min_dur : int, optional
split segments with duration (in seconds) >= ``min_dur``.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
def _over_max(x: Segment):
return (
(min_words and len(x.words) >= min_words) or
(min_chars and x.char_count() >= min_chars) or
(min_dur and x.duration >= min_dur)
)
indices = set(s.id for s in self.segments if _over_max(s)) if any((min_words, min_chars, min_dur)) else None
def _get_indices(x: Segment):
return x.get_punctuation_indices(punctuation) if indices is None or x.id in indices else []
self._split_segments(_get_indices, lock=lock, newline=newline)
if self._regroup_history:
self._regroup_history += '_'
punct_str = '/'.join(p if isinstance(p, str) else '*'.join(p) for p in punctuation)
self._regroup_history += f'sp={punct_str}+{int(lock)}+{int(newline)}'
self._regroup_history += f'+{min_words or ""}+{min_chars or ""}+{min_dur or ""}'.rstrip('+')
return self
def merge_by_punctuation(
self,
punctuation: Union[List[str], List[Tuple[str, str]], str],
max_words: int = None,
max_chars: int = None,
is_sum_max: bool = False,
lock: bool = False
) -> "WhisperResult":
"""
Merge (in-place) any two segments that has specific punctuations inbetween.
Parameters
----------
punctuation : list of str of list of tuple of (str, str) or str
Punctuation(s) to merge segments by.
max_words : int, optional
Maximum number of words allowed in each segment.
max_chars : int, optional
Maximum number of characters allowed in each segment.
is_sum_max : bool, default False
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments
to be merged.
lock : bool, default False
Whether to prevent future splits/merges from altering changes made by this method.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
indices = self.get_punctuation_indices(punctuation)
self._merge_segments(indices,
max_words=max_words, max_chars=max_chars, is_sum_max=is_sum_max, lock=lock)
if self._regroup_history:
self._regroup_history += '_'
punct_str = '/'.join(p if isinstance(p, str) else '*'.join(p) for p in punctuation)
self._regroup_history += f'mp={punct_str}+{max_words or ""}+{max_chars or ""}+{int(is_sum_max)}+{int(lock)}'
return self
def merge_all_segments(self) -> "WhisperResult":
"""
Merge all segments into one segment.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if not self.segments:
return self
if self.has_words:
self.segments[0].words = self.all_words()
else:
self.segments[0].text += ''.join(s.text for s in self.segments[1:])
if all(s.tokens is not None for s in self.segments):
self.segments[0].tokens += list(chain.from_iterable(s.tokens for s in self.segments[1:]))
self.segments[0].end = self.segments[-1].end
self.segments = [self.segments[0]]
self.reassign_ids()
self.update_all_segs_with_words()
if self._regroup_history:
self._regroup_history += '_'
self._regroup_history += 'ms'
return self
def split_by_length(
self,
max_chars: int = None,
max_words: int = None,
even_split: bool = True,
force_len: bool = False,
lock: bool = False,
include_lock: bool = False,
newline: bool = False
) -> "WhisperResult":
"""
Split (in-place) any segment that exceeds ``max_chars`` or ``max_words`` into smaller segments.
Parameters
----------
max_chars : int, optional
Maximum number of characters allowed in each segment.
max_words : int, optional
Maximum number of words allowed in each segment.
even_split : bool, default True
Whether to evenly split a segment in length if it exceeds ``max_chars`` or ``max_words``.
force_len : bool, default False
Whether to force a constant length for each segment except the last segment.
This will ignore all previous non-locked segment boundaries.
lock : bool, default False
Whether to prevent future splits/merges from altering changes made by this method.
include_lock: bool, default False
Whether to include previous lock before splitting based on max_words, if ``even_split = False``.
Splitting will be done after the first non-locked word > ``max_chars`` / ``max_words``.
newline: bool, default False
Whether to insert line break at the split points instead of splitting into separate segments.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
Notes
-----
If ``even_split = True``, segments can still exceed ``max_chars`` and locked words will be ignored to avoid
uneven splitting.
"""
if force_len:
self.merge_all_segments()
self._split_segments(
lambda x: x.get_length_indices(
max_chars=max_chars,
max_words=max_words,
even_split=even_split,
include_lock=include_lock
),
lock=lock,
newline=newline
)
if self._regroup_history:
self._regroup_history += '_'
self._regroup_history += (f'sl={max_chars or ""}+{max_words or ""}+{int(even_split)}+{int(force_len)}'
f'+{int(lock)}+{int(include_lock)}+{int(newline)}')
return self
def split_by_duration(
self,
max_dur: float,
even_split: bool = True,
force_len: bool = False,
lock: bool = False,
include_lock: bool = False,
newline: bool = False
) -> "WhisperResult":
"""
Split (in-place) any segment that exceeds ``max_dur`` into smaller segments.
Parameters
----------
max_dur : float
Maximum duration (in seconds) per segment.
even_split : bool, default True
Whether to evenly split a segment in length if it exceeds ``max_dur``.
force_len : bool, default False
Whether to force a constant length for each segment except the last segment.
This will ignore all previous non-locked segment boundaries.
lock : bool, default False
Whether to prevent future splits/merges from altering changes made by this method.
include_lock: bool, default False
Whether to include previous lock before splitting based on max_words, if ``even_split = False``.
Splitting will be done after the first non-locked word > ``max_dur``.
newline: bool, default False
Whether to insert line break at the split points instead of splitting into separate segments.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
Notes
-----
If ``even_split = True``, segments can still exceed ``max_dur`` and locked words will be ignored to avoid
uneven splitting.
"""
if force_len:
self.merge_all_segments()
self._split_segments(
lambda x: x.get_duration_indices(
max_dur=max_dur,
even_split=even_split,
include_lock=include_lock
),
lock=lock,
newline=newline
)
if self._regroup_history:
self._regroup_history += '_'
self._regroup_history += (f'sd={max_dur}+{int(even_split)}+{int(force_len)}'
f'+{int(lock)}+{int(include_lock)}+{int(newline)}')
return self
def clamp_max(
self,
medium_factor: float = 2.5,
max_dur: float = None,
clip_start: Optional[bool] = None,
verbose: bool = False
) -> "WhisperResult":
"""
Clamp all word durations above certain value.
This is most effective when applied before and after other regroup operations.
Parameters
----------
medium_factor : float, default 2.5
Clamp durations above (``medium_factor`` * medium duration) per segment.
If ``medium_factor = None/0`` or segment has less than 3 words, it will be ignored and use only ``max_dur``.
max_dur : float, optional
Clamp durations above ``max_dur``.
clip_start : bool or None, default None
Whether to clamp the start of a word. If ``None``, clamp the start of first word and end of last word per
segment.
verbose : bool, default False
Whether to print out the timestamp changes.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if not (medium_factor or max_dur):
raise ValueError('At least one of following arguments requires non-zero value: medium_factor; max_dur')
if not self.has_words:
warnings.warn('Cannot clamp due to missing/no word-timestamps')
return self
for seg in self.segments:
curr_max_dur = None
if medium_factor and len(seg.words) > 2:
durations = np.array([word.duration for word in seg.words])
durations.sort()
curr_max_dur = medium_factor * durations[len(durations)//2 + 1]
if max_dur and (not curr_max_dur or curr_max_dur > max_dur):
curr_max_dur = max_dur
if not curr_max_dur:
continue
if clip_start is None:
seg.words[0].clamp_max(curr_max_dur, clip_start=True, verbose=verbose)
seg.words[-1].clamp_max(curr_max_dur, clip_start=False, verbose=verbose)
else:
for i, word in enumerate(seg.words):
word.clamp_max(curr_max_dur, clip_start=clip_start, verbose=verbose)
seg.update_seg_with_words()
if self._regroup_history:
self._regroup_history += '_'
self._regroup_history += f'cm={medium_factor}+{max_dur or ""}+{clip_start or ""}+{int(verbose)}'
return self
def lock(
self,
startswith: Union[str, List[str]] = None,
endswith: Union[str, List[str]] = None,
right: bool = True,
left: bool = False,
case_sensitive: bool = False,
strip: bool = True
) -> "WhisperResult":
"""
Lock words/segments with matching prefix/suffix to prevent splitting/merging.
Parameters
----------
startswith: str or list of str
Prefixes to lock.
endswith: str or list of str
Suffixes to lock.
right : bool, default True
Whether prevent splits/merges with the next word/segment.
left : bool, default False
Whether prevent splits/merges with the previous word/segment.
case_sensitive : bool, default False
Whether to match the case of the prefixes/suffixes with the words/segments.
strip : bool, default True
Whether to ignore spaces before and after both words/segments and prefixes/suffixes.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
assert startswith or endswith, 'Must specify [startswith] or/and [endswith].'
startswith = [] if startswith is None else ([startswith] if isinstance(startswith, str) else startswith)
endswith = [] if endswith is None else ([endswith] if isinstance(endswith, str) else endswith)
if not case_sensitive:
startswith = [t.lower() for t in startswith]
endswith = [t.lower() for t in endswith]
if strip:
startswith = [t.strip() for t in startswith]
endswith = [t.strip() for t in endswith]
for part in self.all_words_or_segments():
text = part.word if hasattr(part, 'word') else part.text
if not case_sensitive:
text = text.lower()
if strip:
text = text.strip()
for prefix in startswith:
if text.startswith(prefix):
if right:
part.lock_right()
if left:
part.lock_left()
for suffix in endswith:
if text.endswith(suffix):
if right:
part.lock_right()
if left:
part.lock_left()
if self._regroup_history:
self._regroup_history += '_'
startswith_str = (startswith if isinstance(startswith, str) else '/'.join(startswith)) if startswith else ""
endswith_str = (endswith if isinstance(endswith, str) else '/'.join(endswith)) if endswith else ""
self._regroup_history += (f'l={startswith_str}+{endswith_str}'
f'+{int(right)}+{int(left)}+{int(case_sensitive)}+{int(strip)}')
return self
def remove_word(
self,
word: Union[WordTiming, Tuple[int, int]],
reassign_ids: bool = True,
verbose: bool = True
) -> 'WhisperResult':
"""
Remove a word.
Parameters
----------
word : WordTiming or tuple of (int, int)
Instance of :class:`stable_whisper.result.WordTiming` or tuple of (segment index, word index).
reassign_ids : bool, default True
Whether to reassign segment and word ids (indices) after removing ``word``.
verbose : bool, default True
Whether to print detail of the removed word.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if isinstance(word, WordTiming):
if self[word.segment_id][word.id] is not word:
self.reassign_ids()
if self[word.segment_id][word.id] is not word:
raise ValueError('word not in result')
seg_id, word_id = word.segment_id, word.id
else:
seg_id, word_id = word
if verbose:
print(f'Removed: {self[seg_id][word_id].to_dict()}')
del self.segments[seg_id].words[word_id]
if not reassign_ids:
return self
if self[seg_id].has_words:
self[seg_id].reassign_ids()
else:
self.remove_no_word_segments()
return self
def remove_segment(
self,
segment: Union[Segment, int],
reassign_ids: bool = True,
verbose: bool = True
) -> 'WhisperResult':
"""
Remove a segment.
Parameters
----------
segment : Segment or int
Instance :class:`stable_whisper.result.Segment` or segment index.
reassign_ids : bool, default True
Whether to reassign segment IDs (indices) after removing ``segment``.
verbose : bool, default True
Whether to print detail of the removed word.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if isinstance(segment, Segment):
if self[segment.id] is not segment:
self.reassign_ids()
if self[segment.id] is not segment:
raise ValueError('segment not in result')
segment = segment.id
if verbose:
print(f'Removed: [id:{self[segment].id}] {self[segment].to_display_str(True)}')
del self.segments[segment]
if not reassign_ids:
return self
self.reassign_ids(True)
return self
def remove_repetition(
self,
max_words: int = 1,
case_sensitive: bool = False,
strip: bool = True,
ignore_punctuations: str = "\"',.?!",
extend_duration: bool = True,
verbose: bool = True
) -> 'WhisperResult':
"""
Remove words that repeat consecutively.
Parameters
----------
max_words : int
Maximum number of words to look for consecutively.
case_sensitive : bool, default False
Whether the case of words need to match to be considered as repetition.
strip : bool, default True
Whether to ignore spaces before and after each word.
ignore_punctuations : bool, default '"',.?!'
Ending punctuations to ignore.
extend_duration: bool, default True
Whether to extend the duration of the previous word to cover the duration of the repetition.
verbose: bool, default True
Whether to print detail of the removed repetitions.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if not self.has_words:
return self
for count in range(1, max_words + 1):
all_words = self.all_words()
if len(all_words) < 2:
return self
all_words_str = [w.word for w in all_words]
if strip:
all_words_str = [w.strip() for w in all_words_str]
if ignore_punctuations:
ptn = f'[{ignore_punctuations}]+$'
all_words_str = [re.sub(ptn, '', w) for w in all_words_str]
if not case_sensitive:
all_words_str = [w.lower() for w in all_words_str]
next_i = None
changes = []
for i in reversed(range(count*2, len(all_words_str)+1)):
if next_i is not None:
if next_i != i:
continue
else:
next_i = None
s = i - count
if all_words_str[s - count:s] != all_words_str[s:i]:
continue
next_i = s
if extend_duration:
all_words[s-1].end = all_words[i-1].end
temp_changes = []
for j in reversed(range(s, i)):
if verbose:
temp_changes.append(f'- {all_words[j].to_dict()}')
self.remove_word(all_words[j], False, verbose=False)
if temp_changes:
changes.append(
f'Remove: [{format_timestamp(all_words[s].start)} -> {format_timestamp(all_words[i-1].end)}] '
+ ''.join(_w.word for _w in all_words[s:i]) + '\n'
+ '\n'.join(reversed(temp_changes)) + '\n'
)
for i0, i1 in zip(range(s - count, s), range(s, i)):
if len(all_words[i0].word) < len(all_words[i1].word):
all_words[i1].start = all_words[i0].start
all_words[i1].end = all_words[i0].end
_sid, _wid = all_words[i0].segment_id, all_words[i0].id
self.segments[_sid].words[_wid] = all_words[i1]
if changes:
print('\n'.join(reversed(changes)))
self.remove_no_word_segments()
self.update_all_segs_with_words()
return self
def remove_words_by_str(
self,
words: Union[str, List[str], None],
case_sensitive: bool = False,
strip: bool = True,
ignore_punctuations: str = "\"',.?!",
min_prob: float = None,
filters: Callable = None,
verbose: bool = True
) -> 'WhisperResult':
"""
Remove words that match ``words``.
Parameters
----------
words : str or list of str or None
A word or list of words to remove.``None`` for all words to be passed into ``filters``.
case_sensitive : bool, default False
Whether the case of words need to match to be considered as repetition.
strip : bool, default True
Whether to ignore spaces before and after each word.
ignore_punctuations : bool, default '"',.?!'
Ending punctuations to ignore.
min_prob : float, optional
Acts as the first filter the for the words that match ``words``. Words with probability < ``min_prob`` will
be removed if ``filters`` is ``None``, else pass the words into ``filters``. Words without probability will
be treated as having probability < ``min_prob``.
filters : Callable, optional
A function that takes an instance of :class:`stable_whisper.result.WordTiming` as its only argument.
This function is custom filter for the words that match ``words`` and were not caught by ``min_prob``.
verbose:
Whether to print detail of the removed words.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if not self.has_words:
return self
if isinstance(words, str):
words = [words]
all_words = self.all_words()
all_words_str = [w.word for w in all_words]
if strip:
all_words_str = [w.strip() for w in all_words_str]
words = [w.strip() for w in words]
if ignore_punctuations:
ptn = f'[{ignore_punctuations}]+$'
all_words_str = [re.sub(ptn, '', w) for w in all_words_str]
words = [re.sub(ptn, '', w) for w in words]
if not case_sensitive:
all_words_str = [w.lower() for w in all_words_str]
words = [w.lower() for w in words]
changes = []
for i, w in reversed(list(enumerate(all_words_str))):
if not (words is None or any(w == _w for _w in words)):
continue
if (
(min_prob is None or all_words[i].probability is None or min_prob > all_words[i].probability) and
(filters is None or filters(all_words[i]))
):
if verbose:
changes.append(f'Removed: {all_words[i].to_dict()}')
self.remove_word(all_words[i], False, verbose=False)
if changes:
print('\n'.join(reversed(changes)))
self.remove_no_word_segments()
self.update_all_segs_with_words()
return self
def fill_in_gaps(
self,
other_result: Union['WhisperResult', str],
min_gap: float = 0.1,
case_sensitive: bool = False,
strip: bool = True,
ignore_punctuations: str = "\"',.?!",
verbose: bool = True
) -> 'WhisperResult':
"""
Fill in segment gaps larger than ``min_gap`` with content from ``other_result`` at the times of gaps.
Parameters
----------
other_result : WhisperResult or str
Another transcription result as an instance of :class:`stable_whisper.result.WhisperResult` or path to the
JSON of the result.
min_gap : float, default 0.1
The minimum seconds of a gap between segments that must be exceeded to be filled in.
case_sensitive : bool, default False
Whether to consider the case of the first and last word of the gap to determine overlapping words to remove
before filling in.
strip : bool, default True
Whether to ignore spaces before and after the first and last word of the gap to determine overlapping words
to remove before filling in.
ignore_punctuations : bool, default '"',.?!'
Ending punctuations to ignore in the first and last word of the gap to determine overlapping words to
remove before filling in.
verbose:
Whether to print detail of the filled content.
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
"""
if len(self.segments) < 2:
return self
if isinstance(other_result, str):
other_result = WhisperResult(other_result)
if strip:
def strip_space(w):
return w.strip()
else:
def strip_space(w):
return w
if ignore_punctuations:
ptn = f'[{ignore_punctuations}]+$'
def strip_punctuations(w):
return re.sub(ptn, '', strip_space(w))
else:
strip_punctuations = strip_space
if case_sensitive:
strip = strip_punctuations
else:
def strip(w):
return strip_punctuations(w).lower()
seg_pairs = list(enumerate(zip(self.segments[:-1], self.segments[1:])))
seg_pairs.insert(0, (-1, (None, self.segments[0])))
seg_pairs.append((seg_pairs[-1][0]+1, (self.segments[-1], None)))
changes = []
for i, (seg0, seg1) in reversed(seg_pairs):
first_word = None if seg0 is None else seg0.words[-1]
last_word = None if seg1 is None else seg1.words[0]
start = (other_result[0].start if first_word is None else first_word.end)
end = other_result[-1].end if last_word is None else last_word.start
if end - start <= min_gap:
continue
gap_words = other_result.get_content_by_time((start, end))
if first_word is not None and gap_words and strip(first_word.word) == strip(gap_words[0].word):
first_word.end = gap_words[0].end
gap_words = gap_words[1:]
if last_word is not None and gap_words and strip(last_word.word) == strip(gap_words[-1].word):
last_word.start = gap_words[-1].start
gap_words = gap_words[:-1]
if not gap_words:
continue
if last_word is not None and last_word.start < gap_words[-1].end:
last_word.start = gap_words[-1].end
new_segments = [other_result[gap_words[0].segment_id].copy([])]
for j, new_word in enumerate(gap_words):
new_word = deepcopy(new_word)
if j == 0 and first_word is not None and first_word.end > gap_words[0].start:
new_word.start = first_word.end
if new_segments[-1].id != new_word.segment_id:
new_segments.append(other_result[new_word.segment_id].copy([]))
new_segments[-1].words.append(new_word)
if verbose:
changes.append('\n'.join('Added: ' + s.to_display_str(True) for s in new_segments))
self.segments = self.segments[:i+1] + new_segments + self.segments[i+1:]
if changes:
print('\n'.join(reversed(changes)))
self.reassign_ids()
self.update_all_segs_with_words()
return self
def regroup(
self,
regroup_algo: Union[str, bool] = None,
verbose: bool = False,
only_show: bool = False
) -> "WhisperResult":
"""
Regroup (in-place) words into segments.
Parameters
----------
regroup_algo: str or bool, default 'da'
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'.
verbose : bool, default False
Whether to show all the methods and arguments parsed from ``regroup_algo``.
only_show : bool, default False
Whether to show the all methods and arguments parsed from ``regroup_algo`` without running the methods
Returns
-------
stable_whisper.result.WhisperResult
The current instance after the changes.
Notes
-----
Syntax for string representation of custom regrouping algorithm.
Method keys:
sg: split_by_gap
sp: split_by_punctuation
sl: split_by_length
sd: split_by_duration
mg: merge_by_gap
mp: merge_by_punctuation
ms: merge_all_segment
cm: clamp_max
l: lock
us: unlock_all_segments
da: default algorithm (cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?)
rw: remove_word
rs: remove_segment
rp: remove_repetition
rws: remove_words_by_str
fg: fill_in_gaps
Metacharacters:
= separates a method key and its arguments (not used if no argument)
_ separates method keys (after arguments if there are any)
+ separates arguments for a method key
/ separates an argument into list of strings
* separates an item in list of strings into a nested list of strings
Notes:
-arguments are parsed positionally
-if no argument is provided, the default ones will be used
-use 1 or 0 to represent True or False
Example 1:
merge_by_gap(.2, 10, lock=True)
mg=.2+10+++1
Note: [lock] is the 5th argument hence the 2 missing arguments inbetween the three + before 1
Example 2:
split_by_punctuation([('.', ' '), '。', '?', '?'], True)
sp=.* /。/?/?+1
Example 3:
merge_all_segments().split_by_gap(.5).merge_by_gap(.15, 3)
ms_sg=.5_mg=.15+3
"""
if regroup_algo is False:
return self
if regroup_algo is None or regroup_algo is True:
regroup_algo = 'da'
for method, kwargs, msg in self.parse_regroup_algo(regroup_algo, include_str=verbose or only_show):
if msg:
print(msg)
if not only_show:
method(**kwargs)
return self
def parse_regroup_algo(self, regroup_algo: str, include_str: bool = True) -> List[Tuple[Callable, dict, str]]:
methods = dict(
sg=self.split_by_gap,
sp=self.split_by_punctuation,
sl=self.split_by_length,
sd=self.split_by_duration,
mg=self.merge_by_gap,
mp=self.merge_by_punctuation,
ms=self.merge_all_segments,
cm=self.clamp_max,
us=self.unlock_all_segments,
l=self.lock,
rw=self.remove_word,
rs=self.remove_segment,
rp=self.remove_repetition,
rws=self.remove_words_by_str,
fg=self.fill_in_gaps,
)
if not regroup_algo:
return []
calls = regroup_algo.split('_')
if 'da' in calls:
default_calls = 'cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?'.split('_')
calls = chain.from_iterable(default_calls if method == 'da' else [method] for method in calls)
operations = []
for method in calls:
method, args = method.split('=', maxsplit=1) if '=' in method else (method, '')
if method not in methods:
raise NotImplementedError(f'{method} is not one of the available methods: {tuple(methods.keys())}')
args = [] if len(args) == 0 else list(map(str_to_valid_type, args.split('+')))
kwargs = {k: v for k, v in zip(methods[method].__code__.co_varnames[1:], args) if v is not None}
if include_str:
kwargs_str = ', '.join(f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' for k, v in kwargs.items())
op_str = f'{methods[method].__name__}({kwargs_str})'
else:
op_str = None
operations.append((methods[method], kwargs, op_str))
return operations
def find(self, pattern: str, word_level=True, flags=None) -> "WhisperResultMatches":
"""
Find segments/words and timestamps with regular expression.
Parameters
----------
pattern : str
RegEx pattern to search for.
word_level : bool, default True
Whether to search at word-level.
flags : optional
RegEx flags.
Returns
-------
stable_whisper.result.WhisperResultMatches
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``.
"""
return WhisperResultMatches(self).find(pattern, word_level=word_level, flags=flags)
@property
def text(self):
return ''.join(s.text for s in self.segments)
@property
def regroup_history(self):
# same syntax as ``regroup_algo`` for :meth:``result.WhisperResult.regroup`
return self._regroup_history
@property
def nonspeech_sections(self):
return self._nonspeech_sections
def show_regroup_history(self):
"""
Print details of all regrouping operations that been performed on data.
"""
if not self._regroup_history:
print('Result has no history.')
for *_, msg in self.parse_regroup_algo(self._regroup_history):
print(f'.{msg}')
def __len__(self):
return len(self.segments)
def unlock_all_segments(self):
for s in self.segments:
s.unlock_all_words()
return self
def reset(self):
"""
Restore all values to that at initialization.
"""
self.language = self.ori_dict.get('language')
self._regroup_history = ''
segments = self.ori_dict.get('segments')
self.segments: List[Segment] = [Segment(**s) for s in segments] if segments else []
if self._forced_order:
self.force_order()
self.remove_no_word_segments(any(seg.has_words for seg in self.segments))
self.update_all_segs_with_words()
@property
def has_words(self):
return all(seg.has_words for seg in self.segments)
to_srt_vtt = result_to_srt_vtt
to_ass = result_to_ass
to_tsv = result_to_tsv
to_txt = result_to_txt
save_as_json = save_as_json
class SegmentMatch:
def __init__(
self,
segments: Union[List[Segment], Segment],
_word_indices: List[List[int]] = None,
_text_match: str = None
):
self.segments = [segments] if isinstance(segments, Segment) else segments
self.word_indices = [] if _word_indices is None else _word_indices
self.words = [self.segments[i].words[j] for i, indices in enumerate(self.word_indices) for j in indices]
if len(self.words) != 0:
self.text = ''.join(
self.segments[i].words[j].word
for i, indices in enumerate(self.word_indices)
for j in indices
)
else:
self.text = ''.join(seg.text for seg in self.segments)
self.text_match = _text_match
@property
def start(self):
return (
self.words[0].start
if len(self.words) != 0 else
(self.segments[0].start if len(self.segments) != 0 else None)
)
@property
def end(self):
return (
self.words[-1].end
if len(self.words) != 0 else
(self.segments[-1].end if len(self.segments) != 0 else None)
)
def __len__(self):
return len(self.segments)
def __repr__(self):
return self.__dict__.__repr__()
def __str__(self):
return self.__dict__.__str__()
class WhisperResultMatches:
"""
RegEx matches for WhisperResults.
"""
# Use WhisperResult.find() instead of instantiating this class directly.
def __init__(
self,
matches: Union[List[SegmentMatch], WhisperResult],
_segment_indices: List[List[int]] = None
):
if isinstance(matches, WhisperResult):
self.matches = list(map(SegmentMatch, matches.segments))
self._segment_indices = [[i] for i in range(len(matches.segments))]
else:
self.matches = matches
assert _segment_indices is not None
assert len(self.matches) == len(_segment_indices)
assert all(len(match.segments) == len(_segment_indices[i]) for i, match in enumerate(self.matches))
self._segment_indices = _segment_indices
@property
def segment_indices(self):
return self._segment_indices
def _curr_seg_groups(self) -> List[List[Tuple[int, Segment]]]:
seg_groups, curr_segs = [], []
curr_max = -1
for seg_indices, match in zip(self._segment_indices, self.matches):
for i, seg in zip(sorted(seg_indices), match.segments):
if i > curr_max:
curr_segs.append((i, seg))
if i - 1 != curr_max:
seg_groups.append(curr_segs)
curr_segs = []
curr_max = i
if curr_segs:
seg_groups.append(curr_segs)
return seg_groups
def find(self, pattern: str, word_level=True, flags=None) -> "WhisperResultMatches":
"""
Find segments/words and timestamps with regular expression.
Parameters
----------
pattern : str
RegEx pattern to search for.
word_level : bool, default True
Whether to search at word-level.
flags : optional
RegEx flags.
Returns
-------
stable_whisper.result.WhisperResultMatches
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``.
"""
seg_groups = self._curr_seg_groups()
matches: List[SegmentMatch] = []
match_seg_indices: List[List[int]] = []
if word_level:
if not all(all(seg.has_words for seg in match.segments) for match in self.matches):
warnings.warn('Cannot perform word-level search with segment(s) missing word timestamps.')
word_level = False
for segs in seg_groups:
if word_level:
idxs = list(chain.from_iterable(
[(i, j)]*len(word.word) for (i, seg) in segs for j, word in enumerate(seg.words)
))
text = ''.join(word.word for (_, seg) in segs for word in seg.words)
else:
idxs = list(chain.from_iterable([(i, None)]*len(seg.text) for (i, seg) in segs))
text = ''.join(seg.text for (_, seg) in segs)
assert len(idxs) == len(text)
for curr_match in re.finditer(pattern, text, flags=flags or 0):
start, end = curr_match.span()
curr_idxs = idxs[start: end]
curr_seg_idxs = sorted(set(i[0] for i in curr_idxs))
if word_level:
curr_word_idxs = [
sorted(set(j for i, j in curr_idxs if i == seg_idx))
for seg_idx in curr_seg_idxs
]
else:
curr_word_idxs = None
matches.append(SegmentMatch(
segments=[s for i, s in segs if i in curr_seg_idxs],
_word_indices=curr_word_idxs,
_text_match=curr_match.group()
))
match_seg_indices.append(curr_seg_idxs)
return WhisperResultMatches(matches, match_seg_indices)
def __len__(self):
return len(self.matches)
def __bool__(self):
return self.__len__() != 0
def __getitem__(self, idx):
return self.matches[idx]