|
import warnings |
|
import re |
|
import torch |
|
import numpy as np |
|
from typing import Union, List, Tuple, Optional, Callable |
|
from dataclasses import dataclass |
|
from copy import deepcopy |
|
from itertools import chain |
|
|
|
from .stabilization import suppress_silence, get_vad_silence_func, mask2timing, wav2mask |
|
from .text_output import * |
|
from .utils import str_to_valid_type, format_timestamp, UnsortedException |
|
|
|
|
|
__all__ = ['WhisperResult', 'Segment'] |
|
|
|
|
|
def _combine_attr(obj: object, other_obj: object, attr: str): |
|
if (val := getattr(obj, attr)) is not None: |
|
other_val = getattr(other_obj, attr) |
|
if isinstance(val, list): |
|
if other_val is None: |
|
setattr(obj, attr, None) |
|
else: |
|
val.extend(other_val) |
|
else: |
|
new_val = None if other_val is None else ((val + other_val) / 2) |
|
setattr(obj, attr, new_val) |
|
|
|
|
|
def _increment_attr(obj: object, attr: str, val: Union[int, float]): |
|
if (curr_val := getattr(obj, attr, None)) is not None: |
|
setattr(obj, attr, curr_val + val) |
|
|
|
|
|
@dataclass |
|
class WordTiming: |
|
word: str |
|
start: float |
|
end: float |
|
probability: float = None |
|
tokens: List[int] = None |
|
left_locked: bool = False |
|
right_locked: bool = False |
|
segment_id: Optional[int] = None |
|
id: Optional[int] = None |
|
|
|
def __len__(self): |
|
return len(self.word) |
|
|
|
def __add__(self, other: 'WordTiming'): |
|
self_copy = deepcopy(self) |
|
|
|
self_copy.start = min(self_copy.start, other.start) |
|
self_copy.end = max(other.end, self_copy.end) |
|
self_copy.word += other.word |
|
self_copy.left_locked = self_copy.left_locked or other.left_locked |
|
self_copy.right_locked = self_copy.right_locked or other.right_locked |
|
_combine_attr(self_copy, other, 'probability') |
|
_combine_attr(self_copy, other, 'tokens') |
|
|
|
return self_copy |
|
|
|
def __deepcopy__(self, memo=None): |
|
return self.copy() |
|
|
|
def copy(self): |
|
return WordTiming( |
|
word=self.word, |
|
start=self.start, |
|
end=self.end, |
|
probability=self.probability, |
|
tokens=None if self.tokens is None else self.tokens.copy(), |
|
left_locked=self.left_locked, |
|
right_locked=self.right_locked, |
|
segment_id=self.segment_id, |
|
id=self.id |
|
) |
|
|
|
@property |
|
def duration(self): |
|
return round(self.end - self.start, 3) |
|
|
|
def round_all_timestamps(self): |
|
self.start = round(self.start, 3) |
|
self.end = round(self.end, 3) |
|
|
|
def offset_time(self, offset_seconds: float): |
|
self.start = round(self.start + offset_seconds, 3) |
|
self.end = round(self.end + offset_seconds, 3) |
|
|
|
def to_dict(self): |
|
dict_ = deepcopy(self).__dict__ |
|
dict_.pop('left_locked') |
|
dict_.pop('right_locked') |
|
return dict_ |
|
|
|
def lock_left(self): |
|
self.left_locked = True |
|
|
|
def lock_right(self): |
|
self.right_locked = True |
|
|
|
def lock_both(self): |
|
self.lock_left() |
|
self.lock_right() |
|
|
|
def unlock_both(self): |
|
self.left_locked = False |
|
self.right_locked = False |
|
|
|
def suppress_silence(self, |
|
silent_starts: np.ndarray, |
|
silent_ends: np.ndarray, |
|
min_word_dur: float = 0.1, |
|
nonspeech_error: float = 0.3, |
|
keep_end: Optional[bool] = True): |
|
suppress_silence(self, silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end) |
|
return self |
|
|
|
def rescale_time(self, scale_factor: float): |
|
self.start = round(self.start * scale_factor, 3) |
|
self.end = round(self.end * scale_factor, 3) |
|
|
|
def clamp_max(self, max_dur: float, clip_start: bool = False, verbose: bool = False): |
|
if self.duration > max_dur: |
|
if clip_start: |
|
new_start = round(self.end - max_dur, 3) |
|
if verbose: |
|
print(f'Start: {self.start} -> {new_start}\nEnd: {self.end}\nText:"{self.word}"\n') |
|
self.start = new_start |
|
|
|
else: |
|
new_end = round(self.start + max_dur, 3) |
|
if verbose: |
|
print(f'Start: {self.start}\nEnd: {self.end} -> {new_end}\nText:"{self.word}"\n') |
|
self.end = new_end |
|
|
|
def set_segment(self, segment: 'Segment'): |
|
self._segment = segment |
|
|
|
def get_segment(self) -> Union['Segment', None]: |
|
""" |
|
Return instance of :class:`stable_whisper.result.Segment` that this instance is a part of. |
|
""" |
|
return getattr(self, '_segment', None) |
|
|
|
|
|
def _words_by_lock(words: List[WordTiming], only_text: bool = False, include_single: bool = False): |
|
""" |
|
Return a nested list of words such that each sublist contains words that are locked together. |
|
""" |
|
all_words = [] |
|
for word in words: |
|
if len(all_words) == 0 or not (all_words[-1][-1].right_locked or word.left_locked): |
|
all_words.append([word]) |
|
else: |
|
all_words[-1].append(word) |
|
if only_text: |
|
all_words = list(map(lambda ws: list(map(lambda w: w.word, ws)), all_words)) |
|
if not include_single: |
|
all_words = [ws for ws in all_words if len(ws) > 1] |
|
return all_words |
|
|
|
|
|
@dataclass |
|
class Segment: |
|
start: float |
|
end: float |
|
text: str |
|
seek: float = None |
|
tokens: List[int] = None |
|
temperature: float = None |
|
avg_logprob: float = None |
|
compression_ratio: float = None |
|
no_speech_prob: float = None |
|
words: Union[List[WordTiming], List[dict]] = None |
|
ori_has_words: bool = None |
|
id: int = None |
|
|
|
def __getitem__(self, index: int) -> WordTiming: |
|
if self.words is None: |
|
raise ValueError('segment contains no words') |
|
return self.words[index] |
|
|
|
def __delitem__(self, index: int): |
|
if self.words is None: |
|
raise ValueError('segment contains no words') |
|
del self.words[index] |
|
self.reassign_ids() |
|
self.update_seg_with_words() |
|
|
|
def __deepcopy__(self, memo=None): |
|
return self.copy() |
|
|
|
def copy(self, new_words: Optional[List[WordTiming]] = None): |
|
if new_words is None: |
|
words = None if self.words is None else [w.copy() for w in self.words] |
|
else: |
|
words = [w.copy() for w in new_words] |
|
|
|
new_seg = Segment( |
|
start=self.start, |
|
end=self.end, |
|
text=self.text, |
|
seek=self.seek, |
|
tokens=self.tokens, |
|
temperature=self.temperature, |
|
avg_logprob=self.avg_logprob, |
|
compression_ratio=self.compression_ratio, |
|
no_speech_prob=self.no_speech_prob, |
|
words=words, |
|
id=self.id |
|
) |
|
new_seg.update_seg_with_words() |
|
return new_seg |
|
|
|
def to_display_str(self, only_segment: bool = False): |
|
line = f'[{format_timestamp(self.start)} --> {format_timestamp(self.end)}] "{self.text}"' |
|
if self.has_words and not only_segment: |
|
line += '\n' + '\n'.join( |
|
f"-[{format_timestamp(w.start)}] -> [{format_timestamp(w.end)}] \"{w.word}\"" for w in self.words |
|
) + '\n' |
|
return line |
|
|
|
@property |
|
def has_words(self): |
|
return bool(self.words) |
|
|
|
@property |
|
def duration(self): |
|
return self.end - self.start |
|
|
|
def word_count(self): |
|
if self.has_words: |
|
return len(self.words) |
|
return -1 |
|
|
|
def char_count(self): |
|
if self.has_words: |
|
return sum(len(w) for w in self.words) |
|
return len(self.text) |
|
|
|
def __post_init__(self): |
|
if self.has_words: |
|
self.words: List[WordTiming] = \ |
|
[WordTiming(**word) if isinstance(word, dict) else word for word in self.words] |
|
for w in self.words: |
|
w.set_segment(self) |
|
if self.ori_has_words is None: |
|
self.ori_has_words = self.has_words |
|
self.round_all_timestamps() |
|
|
|
def __add__(self, other: 'Segment'): |
|
self_copy = deepcopy(self) |
|
|
|
self_copy.start = min(self_copy.start, other.start) |
|
self_copy.end = max(other.end, self_copy.end) |
|
self_copy.text += other.text |
|
|
|
_combine_attr(self_copy, other, 'tokens') |
|
_combine_attr(self_copy, other, 'temperature') |
|
_combine_attr(self_copy, other, 'avg_logprob') |
|
_combine_attr(self_copy, other, 'compression_ratio') |
|
_combine_attr(self_copy, other, 'no_speech_prob') |
|
if self_copy.has_words: |
|
if other.has_words: |
|
self_copy.words.extend(other.words) |
|
else: |
|
self_copy.words = None |
|
|
|
return self_copy |
|
|
|
def _word_operations(self, operation: str, *args, **kwargs): |
|
if self.has_words: |
|
for w in self.words: |
|
getattr(w, operation)(*args, **kwargs) |
|
|
|
def round_all_timestamps(self): |
|
self.start = round(self.start, 3) |
|
self.end = round(self.end, 3) |
|
if self.has_words: |
|
for word in self.words: |
|
word.round_all_timestamps() |
|
|
|
def offset_time(self, offset_seconds: float): |
|
self.start = round(self.start + offset_seconds, 3) |
|
self.end = round(self.end + offset_seconds, 3) |
|
_increment_attr(self, 'seek', offset_seconds) |
|
self._word_operations('offset_time', offset_seconds) |
|
|
|
def add_words(self, index0: int, index1: int, inplace: bool = False): |
|
if self.has_words: |
|
new_word = self.words[index0] + self.words[index1] |
|
if inplace: |
|
i0, i1 = sorted([index0, index1]) |
|
self.words[i0] = new_word |
|
del self.words[i1] |
|
return new_word |
|
|
|
def rescale_time(self, scale_factor: float): |
|
self.start = round(self.start * scale_factor, 3) |
|
self.end = round(self.end * scale_factor, 3) |
|
if self.seek is not None: |
|
self.seek = round(self.seek * scale_factor, 3) |
|
self._word_operations('rescale_time', scale_factor) |
|
self.update_seg_with_words() |
|
|
|
def apply_min_dur(self, min_dur: float, inplace: bool = False): |
|
""" |
|
Merge any word with adjacent word if its duration is less than ``min_dur``. |
|
""" |
|
segment = self if inplace else deepcopy(self) |
|
if not self.has_words: |
|
return segment |
|
max_i = len(segment.words) - 1 |
|
if max_i == 0: |
|
return segment |
|
for i in reversed(range(len(segment.words))): |
|
if max_i == 0: |
|
break |
|
if segment.words[i].duration < min_dur: |
|
if i == max_i: |
|
segment.add_words(i-1, i, inplace=True) |
|
elif i == 0: |
|
segment.add_words(i, i+1, inplace=True) |
|
else: |
|
if segment.words[i+1].duration < segment.words[i-1].duration: |
|
segment.add_words(i-1, i, inplace=True) |
|
else: |
|
segment.add_words(i, i+1, inplace=True) |
|
max_i -= 1 |
|
return segment |
|
|
|
def _to_reverse_text( |
|
self, |
|
prepend_punctuations: str = None, |
|
append_punctuations: str = None |
|
): |
|
""" |
|
Return a copy with words reversed order per segment. |
|
""" |
|
if prepend_punctuations is None: |
|
prepend_punctuations = "\"'“¿([{-" |
|
if prepend_punctuations and ' ' not in prepend_punctuations: |
|
prepend_punctuations += ' ' |
|
if append_punctuations is None: |
|
append_punctuations = "\"'.。,,!!??::”)]}、" |
|
self_copy = deepcopy(self) |
|
has_prepend = bool(prepend_punctuations) |
|
has_append = bool(append_punctuations) |
|
if has_prepend or has_append: |
|
word_objs = ( |
|
self_copy.words |
|
if self_copy.has_words else |
|
[WordTiming(w, 0, 1, 0) for w in self_copy.text.split(' ')] |
|
) |
|
for word in word_objs: |
|
new_append = '' |
|
if has_prepend: |
|
for _ in range(len(word)): |
|
char = word.word[0] |
|
if char in prepend_punctuations: |
|
new_append += char |
|
word.word = word.word[1:] |
|
else: |
|
break |
|
new_prepend = '' |
|
if has_append: |
|
for _ in range(len(word)): |
|
char = word.word[-1] |
|
if char in append_punctuations: |
|
new_prepend += char |
|
word.word = word.word[:-1] |
|
else: |
|
break |
|
word.word = f'{new_prepend}{word.word}{new_append[::-1]}' |
|
self_copy.text = ''.join(w.word for w in reversed(word_objs)) |
|
|
|
return self_copy |
|
|
|
def to_dict(self, reverse_text: Union[bool, tuple] = False): |
|
if reverse_text: |
|
seg_dict = ( |
|
(self._to_reverse_text(*reverse_text) |
|
if isinstance(reverse_text, tuple) else |
|
self._to_reverse_text()).__dict__ |
|
) |
|
else: |
|
seg_dict = deepcopy(self).__dict__ |
|
seg_dict.pop('ori_has_words') |
|
if self.has_words: |
|
seg_dict['words'] = [w.to_dict() for w in seg_dict['words']] |
|
elif self.ori_has_words: |
|
seg_dict['words'] = [] |
|
else: |
|
seg_dict.pop('words') |
|
if self.id is None: |
|
seg_dict.pop('id') |
|
if reverse_text: |
|
seg_dict['reversed_text'] = True |
|
return seg_dict |
|
|
|
def words_by_lock(self, only_text: bool = True, include_single: bool = False): |
|
return _words_by_lock(self.words, only_text=only_text, include_single=include_single) |
|
|
|
@property |
|
def left_locked(self): |
|
if self.has_words: |
|
return self.words[0].left_locked |
|
return False |
|
|
|
@property |
|
def right_locked(self): |
|
if self.has_words: |
|
return self.words[-1].right_locked |
|
return False |
|
|
|
def lock_left(self): |
|
if self.has_words: |
|
self.words[0].lock_left() |
|
|
|
def lock_right(self): |
|
if self.has_words: |
|
self.words[-1].lock_right() |
|
|
|
def lock_both(self): |
|
self.lock_left() |
|
self.lock_right() |
|
|
|
def unlock_all_words(self): |
|
self._word_operations('unlock_both') |
|
|
|
def reassign_ids(self): |
|
if self.has_words: |
|
for i, w in enumerate(self.words): |
|
w.segment_id = self.id |
|
w.id = i |
|
|
|
def update_seg_with_words(self): |
|
if self.has_words: |
|
self.start = self.words[0].start |
|
self.end = self.words[-1].end |
|
self.text = ''.join(w.word for w in self.words) |
|
self.tokens = ( |
|
None |
|
if any(w.tokens is None for w in self.words) else |
|
[t for w in self.words for t in w.tokens] |
|
) |
|
for w in self.words: |
|
w.set_segment(self) |
|
|
|
def suppress_silence(self, |
|
silent_starts: np.ndarray, |
|
silent_ends: np.ndarray, |
|
min_word_dur: float = 0.1, |
|
word_level: bool = True, |
|
nonspeech_error: float = 0.3, |
|
use_word_position: bool = True): |
|
if self.has_words: |
|
words = self.words if word_level or len(self.words) == 1 else [self.words[0], self.words[-1]] |
|
for i, w in enumerate(words, 1): |
|
if use_word_position: |
|
keep_end = True if i == 1 else (False if i == len(words) else None) |
|
else: |
|
keep_end = None |
|
w.suppress_silence(silent_starts, silent_ends, min_word_dur, nonspeech_error, keep_end) |
|
self.update_seg_with_words() |
|
else: |
|
suppress_silence(self, |
|
silent_starts, |
|
silent_ends, |
|
min_word_dur, |
|
nonspeech_error) |
|
|
|
return self |
|
|
|
def get_locked_indices(self): |
|
locked_indices = [i |
|
for i, (left, right) in enumerate(zip(self.words[1:], self.words[:-1])) |
|
if left.left_locked or right.right_locked] |
|
return locked_indices |
|
|
|
def get_gaps(self, as_ndarray=False): |
|
if self.has_words: |
|
s_ts = np.array([w.start for w in self.words]) |
|
e_ts = np.array([w.end for w in self.words]) |
|
gap = s_ts[1:] - e_ts[:-1] |
|
return gap if as_ndarray else gap.tolist() |
|
return [] |
|
|
|
def get_gap_indices(self, max_gap: float = 0.1): |
|
if not self.has_words or len(self.words) < 2: |
|
return [] |
|
if max_gap is None: |
|
max_gap = 0 |
|
indices = (self.get_gaps(True) > max_gap).nonzero()[0].tolist() |
|
return sorted(set(indices) - set(self.get_locked_indices())) |
|
|
|
def get_punctuation_indices(self, punctuation: Union[List[str], List[Tuple[str, str]], str]): |
|
if not self.has_words or len(self.words) < 2: |
|
return [] |
|
if isinstance(punctuation, str): |
|
punctuation = [punctuation] |
|
indices = [] |
|
for p in punctuation: |
|
if isinstance(p, str): |
|
for i, s in enumerate(self.words[:-1]): |
|
if s.word.endswith(p): |
|
indices.append(i) |
|
elif i != 0 and s.word.startswith(p): |
|
indices.append(i-1) |
|
else: |
|
ending, beginning = p |
|
indices.extend([i for i, (w0, w1) in enumerate(zip(self.words[:-1], self.words[1:])) |
|
if w0.word.endswith(ending) and w1.word.startswith(beginning)]) |
|
|
|
return sorted(set(indices) - set(self.get_locked_indices())) |
|
|
|
def get_length_indices(self, max_chars: int = None, max_words: int = None, even_split: bool = True, |
|
include_lock: bool = False): |
|
|
|
if not self.has_words or (max_chars is None and max_words is None): |
|
return [] |
|
assert max_chars != 0 and max_words != 0, \ |
|
f'max_chars and max_words must be greater 0, but got {max_chars} and {max_words}' |
|
if len(self.words) < 2: |
|
return [] |
|
indices = [] |
|
if even_split: |
|
char_count = -1 if max_chars is None else sum(map(len, self.words)) |
|
word_count = -1 if max_words is None else len(self.words) |
|
exceed_chars = max_chars is not None and char_count > max_chars |
|
exceed_words = max_words is not None and word_count > max_words |
|
if exceed_chars: |
|
splits = np.ceil(char_count / max_chars) |
|
chars_per_split = char_count / splits |
|
cum_char_count = np.cumsum([len(w.word) for w in self.words[:-1]]) |
|
indices = [ |
|
(np.abs(cum_char_count-(i*chars_per_split))).argmin() |
|
for i in range(1, int(splits)) |
|
] |
|
if max_words is not None: |
|
exceed_words = any(j-i+1 > max_words for i, j in zip([0]+indices, indices+[len(self.words)])) |
|
|
|
if exceed_words: |
|
splits = np.ceil(word_count / max_words) |
|
words_per_split = word_count / splits |
|
cum_word_count = np.array(range(1, len(self.words)+1)) |
|
indices = [ |
|
np.abs(cum_word_count-(i*words_per_split)).argmin() |
|
for i in range(1, int(splits)) |
|
] |
|
|
|
else: |
|
curr_words = 0 |
|
curr_chars = 0 |
|
locked_indices = [] |
|
if include_lock: |
|
locked_indices = self.get_locked_indices() |
|
for i, word in enumerate(self.words): |
|
curr_words += 1 |
|
curr_chars += len(word) |
|
if i != 0: |
|
if ( |
|
max_chars is not None and curr_chars > max_chars |
|
or |
|
max_words is not None and curr_words > max_words |
|
) and i-1 not in locked_indices: |
|
indices.append(i-1) |
|
curr_words = 1 |
|
curr_chars = len(word) |
|
return indices |
|
|
|
def get_duration_indices(self, max_dur: float, even_split: bool = True, include_lock: bool = False): |
|
if not self.has_words or (total_duration := np.sum([w.duration for w in self.words])) <= max_dur: |
|
return [] |
|
if even_split: |
|
splits = np.ceil(total_duration / max_dur) |
|
dur_per_split = total_duration / splits |
|
cum_dur = np.cumsum([w.duration for w in self.words[:-1]]) |
|
indices = [ |
|
(np.abs(cum_dur - (i * dur_per_split))).argmin() |
|
for i in range(1, int(splits)) |
|
] |
|
else: |
|
indices = [] |
|
curr_total_dur = 0.0 |
|
locked_indices = self.get_locked_indices() if include_lock else [] |
|
for i, word in enumerate(self.words): |
|
curr_total_dur += word.duration |
|
if i != 0: |
|
if curr_total_dur > max_dur and i - 1 not in locked_indices: |
|
indices.append(i - 1) |
|
curr_total_dur = word.duration |
|
return indices |
|
|
|
def split(self, indices: List[int]): |
|
if len(indices) == 0: |
|
return [] |
|
if indices[-1] != len(self.words) - 1: |
|
indices.append(len(self.words) - 1) |
|
seg_copies = [] |
|
prev_i = 0 |
|
for i in indices: |
|
i += 1 |
|
c = deepcopy(self) |
|
c.words = c.words[prev_i:i] |
|
c.update_seg_with_words() |
|
seg_copies.append(c) |
|
prev_i = i |
|
return seg_copies |
|
|
|
def set_result(self, result: 'WhisperResult'): |
|
self._result = result |
|
|
|
def get_result(self) -> Union['WhisperResult', None]: |
|
""" |
|
Return outer instance of :class:`stable_whisper.result.WhisperResult` that ``self`` is a part of. |
|
""" |
|
return getattr(self, '_result', None) |
|
|
|
|
|
class WhisperResult: |
|
|
|
def __init__( |
|
self, |
|
result: Union[str, dict, list], |
|
force_order: bool = False, |
|
check_sorted: Union[bool, str] = True, |
|
show_unsorted: bool = True |
|
): |
|
result, self.path = self._standardize_result(result) |
|
self.ori_dict = result.get('ori_dict') or result |
|
self.language = self.ori_dict.get('language') |
|
self._regroup_history = result.get('regroup_history', '') |
|
self._nonspeech_sections = result.get('nonspeech_sections', []) |
|
segments = deepcopy(result.get('segments', self.ori_dict.get('segments'))) |
|
self.segments: List[Segment] = [Segment(**s) for s in segments] if segments else [] |
|
self._forced_order = force_order |
|
if self._forced_order: |
|
self.force_order() |
|
self.raise_for_unsorted(check_sorted, show_unsorted) |
|
self.remove_no_word_segments(any(seg.has_words for seg in self.segments)) |
|
self.update_all_segs_with_words() |
|
|
|
def __getitem__(self, index: int) -> Segment: |
|
return self.segments[index] |
|
|
|
def __delitem__(self, index: int): |
|
del self.segments[index] |
|
self.reassign_ids(True) |
|
|
|
@staticmethod |
|
def _standardize_result(result: Union[str, dict, list]): |
|
path = None |
|
if isinstance(result, str): |
|
path = result |
|
result = load_result(path) |
|
if isinstance(result, list): |
|
if isinstance(result[0], list): |
|
if not isinstance(result[0][0], dict): |
|
raise NotImplementedError(f'Got list of list of {type(result[0])} but expects list of list of dict') |
|
result = dict( |
|
segments=[ |
|
dict( |
|
start=words[0]['start'], |
|
end=words[-1]['end'], |
|
text=''.join(w['word'] for w in words), |
|
words=words |
|
) |
|
for words in result |
|
] |
|
) |
|
|
|
elif isinstance(result[0], dict): |
|
result = dict(segments=result) |
|
else: |
|
raise NotImplementedError(f'Got list of {type(result[0])} but expects list of list/dict') |
|
return result, path |
|
|
|
def force_order(self): |
|
prev_ts_end = 0 |
|
timestamps = self.all_words_or_segments() |
|
for i, ts in enumerate(timestamps, 1): |
|
if ts.start < prev_ts_end: |
|
ts.start = prev_ts_end |
|
if ts.start > ts.end: |
|
if prev_ts_end > ts.end: |
|
warnings.warn('Multiple consecutive timestamps are out of order. Some parts will have no duration.') |
|
ts.start = ts.end |
|
for j in range(i-2, -1, -1): |
|
if timestamps[j].end > ts.end: |
|
timestamps[j].end = ts.end |
|
if timestamps[j].start > ts.end: |
|
timestamps[j].start = ts.end |
|
else: |
|
if ts.start != prev_ts_end: |
|
ts.start = prev_ts_end |
|
else: |
|
ts.end = ts.start if i == len(timestamps) else timestamps[i].start |
|
prev_ts_end = ts.end |
|
if self.has_words: |
|
self.update_all_segs_with_words() |
|
|
|
def raise_for_unsorted(self, check_sorted: Union[bool, str] = True, show_unsorted: bool = True): |
|
if check_sorted is False: |
|
return |
|
all_parts = self.all_words_or_segments() |
|
has_words = self.has_words |
|
timestamps = np.array(list(chain.from_iterable((p.start, p.end) for p in all_parts))) |
|
if len(timestamps) > 1 and (unsorted_mask := timestamps[:-1] > timestamps[1:]).any(): |
|
if show_unsorted: |
|
def get_part_info(idx): |
|
curr_part = all_parts[idx] |
|
seg_id = curr_part.segment_id if has_words else curr_part.id |
|
word_id_str = f'Word ID: {curr_part.id}\n' if has_words else '' |
|
return ( |
|
f'Segment ID: {seg_id}\n{word_id_str}' |
|
f'Start: {curr_part.start}\nEnd: {curr_part.end}\n' |
|
f'Text: "{curr_part.word if has_words else curr_part.text}"' |
|
), curr_part.start, curr_part.end |
|
|
|
for i, unsorted in enumerate(unsorted_mask, 2): |
|
if unsorted: |
|
word_id = i//2-1 |
|
part_info, start, end = get_part_info(word_id) |
|
if i % 2 == 1: |
|
next_info, next_start, _ = get_part_info(word_id+1) |
|
part_info += f'\nConflict: end ({end}) > next start ({next_start})\n{next_info}' |
|
else: |
|
part_info += f'\nConflict: start ({start}) > end ({end})' |
|
print(part_info, end='\n\n') |
|
|
|
data = self.to_dict() |
|
if check_sorted is True: |
|
raise UnsortedException(data=data) |
|
warnings.warn('Timestamps are not in ascending order. ' |
|
'If data is produced by Stable-ts, please submit an issue with the saved data.') |
|
save_as_json(data, check_sorted) |
|
|
|
def update_all_segs_with_words(self): |
|
for seg in self.segments: |
|
seg.update_seg_with_words() |
|
seg.set_result(self) |
|
|
|
def update_nonspeech_sections(self, silent_starts, silent_ends): |
|
self._nonspeech_sections = [dict(start=s, end=e) for s, e in zip(silent_starts, silent_ends)] |
|
|
|
def add_segments(self, index0: int, index1: int, inplace: bool = False, lock: bool = False): |
|
new_seg = self.segments[index0] + self.segments[index1] |
|
new_seg.update_seg_with_words() |
|
if lock and self.segments[index0].has_words: |
|
lock_idx = len(self.segments[index0].words) |
|
new_seg.words[lock_idx - 1].lock_right() |
|
if lock_idx < len(new_seg.words): |
|
new_seg.words[lock_idx].lock_left() |
|
if inplace: |
|
i0, i1 = sorted([index0, index1]) |
|
self.segments[i0] = new_seg |
|
del self.segments[i1] |
|
return new_seg |
|
|
|
def rescale_time(self, scale_factor: float): |
|
for s in self.segments: |
|
s.rescale_time(scale_factor) |
|
|
|
def apply_min_dur(self, min_dur: float, inplace: bool = False): |
|
""" |
|
Merge any word/segment with adjacent word/segment if its duration is less than ``min_dur``. |
|
""" |
|
result = self if inplace else deepcopy(self) |
|
max_i = len(result.segments) - 1 |
|
if max_i == 0: |
|
return result |
|
for i in reversed(range(len(result.segments))): |
|
if max_i == 0: |
|
break |
|
if result.segments[i].duration < min_dur: |
|
if i == max_i: |
|
result.add_segments(i-1, i, inplace=True) |
|
elif i == 0: |
|
result.add_segments(i, i+1, inplace=True) |
|
else: |
|
if result.segments[i+1].duration < result.segments[i-1].duration: |
|
result.add_segments(i-1, i, inplace=True) |
|
else: |
|
result.add_segments(i, i+1, inplace=True) |
|
max_i -= 1 |
|
result.reassign_ids() |
|
for s in result.segments: |
|
s.apply_min_dur(min_dur, inplace=True) |
|
return result |
|
|
|
def offset_time(self, offset_seconds: float): |
|
for s in self.segments: |
|
s.offset_time(offset_seconds) |
|
|
|
def suppress_silence( |
|
self, |
|
silent_starts: np.ndarray, |
|
silent_ends: np.ndarray, |
|
min_word_dur: float = 0.1, |
|
word_level: bool = True, |
|
nonspeech_error: float = 0.3, |
|
use_word_position: bool = True |
|
) -> "WhisperResult": |
|
""" |
|
Move any start/end timestamps in silence parts of audio to the boundaries of the silence. |
|
|
|
Parameters |
|
---------- |
|
silent_starts : numpy.ndarray |
|
An array starting timestamps of silent sections of audio. |
|
silent_ends : numpy.ndarray |
|
An array ending timestamps of silent sections of audio. |
|
min_word_dur : float, default 0.1 |
|
Shortest duration each word is allowed to reach for adjustments. |
|
word_level : bool, default False |
|
Whether to settings to word level timestamps. |
|
nonspeech_error : float, default 0.3 |
|
Relative error of non-speech sections that appear in between a word for adjustments. |
|
use_word_position : bool, default True |
|
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if |
|
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
for s in self.segments: |
|
s.suppress_silence( |
|
silent_starts, |
|
silent_ends, |
|
min_word_dur, |
|
word_level=word_level, |
|
nonspeech_error=nonspeech_error, |
|
use_word_position=use_word_position |
|
) |
|
|
|
return self |
|
|
|
def adjust_by_silence( |
|
self, |
|
audio: Union[torch.Tensor, np.ndarray, str, bytes], |
|
vad: bool = False, |
|
*, |
|
verbose: (bool, None) = False, |
|
sample_rate: int = None, |
|
vad_onnx: bool = False, |
|
vad_threshold: float = 0.35, |
|
q_levels: int = 20, |
|
k_size: int = 5, |
|
min_word_dur: float = 0.1, |
|
word_level: bool = True, |
|
nonspeech_error: float = 0.3, |
|
use_word_position: bool = True |
|
|
|
) -> "WhisperResult": |
|
""" |
|
Adjust timestamps base detected speech gaps. |
|
|
|
This is method combines :meth:`stable_whisper.result.WhisperResult.suppress_silence` with silence detection. |
|
|
|
Parameters |
|
---------- |
|
audio : str or numpy.ndarray or torch.Tensor or bytes |
|
Path/URL to the audio file, the audio waveform, or bytes of audio file. |
|
vad : bool, default False |
|
Whether to use Silero VAD to generate timestamp suppression mask. |
|
Silero VAD requires PyTorch 1.12.0+. Official repo, https://github.com/snakers4/silero-vad. |
|
verbose : bool or None, default False |
|
If ``False``, mute messages about hitting local caches. Note that the message about first download cannot be |
|
muted. Only applies if ``vad = True``. |
|
sample_rate : int, default None, meaning ``whisper.audio.SAMPLE_RATE``, 16kHZ |
|
The sample rate of ``audio``. |
|
vad_onnx : bool, default False |
|
Whether to use ONNX for Silero VAD. |
|
vad_threshold : float, default 0.35 |
|
Threshold for detecting speech with Silero VAD. Low threshold reduces false positives for silence detection. |
|
q_levels : int, default 20 |
|
Quantization levels for generating timestamp suppression mask; ignored if ``vad = true``. |
|
Acts as a threshold to marking sound as silent. |
|
Fewer levels will increase the threshold of volume at which to mark a sound as silent. |
|
k_size : int, default 5 |
|
Kernel size for avg-pooling waveform to generate timestamp suppression mask; ignored if ``vad = true``. |
|
Recommend 5 or 3; higher sizes will reduce detection of silence. |
|
min_word_dur : float, default 0.1 |
|
Shortest duration each word is allowed to reach from adjustments. |
|
word_level : bool, default False |
|
Whether to settings to word level timestamps. |
|
nonspeech_error : float, default 0.3 |
|
Relative error of non-speech sections that appear in between a word for adjustments. |
|
use_word_position : bool, default True |
|
Whether to use position of the word in its segment to determine whether to keep end or start timestamps if |
|
adjustments are required. If it is the first word, keep end. Else if it is the last word, keep the start. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
|
|
Notes |
|
----- |
|
This operation is already performed by :func:`stable_whisper.whisper_word_level.transcribe_stable` / |
|
:func:`stable_whisper.whisper_word_level.transcribe_minimal`/ |
|
:func:`stable_whisper.non_whisper.transcribe_any` / :func:`stable_whisper.alignment.align` |
|
if ``suppress_silence = True``. |
|
""" |
|
if vad: |
|
silent_timings = get_vad_silence_func( |
|
onnx=vad_onnx, |
|
verbose=verbose |
|
)(audio, speech_threshold=vad_threshold, sr=sample_rate) |
|
else: |
|
silent_timings = mask2timing( |
|
wav2mask(audio, q_levels=q_levels, k_size=k_size, sr=sample_rate) |
|
) |
|
if silent_timings is None: |
|
return self |
|
self.suppress_silence( |
|
*silent_timings, |
|
min_word_dur=min_word_dur, |
|
word_level=word_level, |
|
nonspeech_error=nonspeech_error, |
|
use_word_position=use_word_position |
|
) |
|
self.update_nonspeech_sections(*silent_timings) |
|
return self |
|
|
|
def adjust_by_result( |
|
self, |
|
other_result: "WhisperResult", |
|
min_word_dur: float = 0.1, |
|
verbose: bool = False |
|
): |
|
""" |
|
Minimize the duration of words using timestamps of another result. |
|
|
|
Parameters |
|
---------- |
|
other_result : "WhisperResult" |
|
Timing data of the same words in a WhisperResult instance. |
|
min_word_dur : float, default 0.1 |
|
Prevent changes to timestamps if the resultant word duration is less than ``min_word_dur``. |
|
verbose : bool, default False |
|
Whether to print out the timestamp changes. |
|
""" |
|
if not (self.has_words and other_result.has_words): |
|
raise NotImplementedError('This operation can only be performed on results with word timestamps') |
|
assert [w.word for w in self.all_words()] == [w.word for w in other_result.all_words()], \ |
|
'The words in [other_result] do not match the current words.' |
|
for word, other_word in zip(self.all_words(), other_result.all_words()): |
|
if word.end > other_word.start: |
|
new_start = max(word.start, other_word.start) |
|
new_end = min(word.end, other_word.end) |
|
if new_end - new_start >= min_word_dur: |
|
line = '' |
|
if word.start != new_start: |
|
if verbose: |
|
line += f'[Start:{word.start:.3f}->{new_start:.3f}] ' |
|
word.start = new_start |
|
if word.end != new_end: |
|
if verbose: |
|
line += f'[End:{word.end:.3f}->{new_end:.3f}] ' |
|
word.end = new_end |
|
if line: |
|
print(f'{line}"{word.word}"') |
|
self.update_all_segs_with_words() |
|
|
|
def reassign_ids(self, only_segments: bool = False): |
|
for i, s in enumerate(self.segments): |
|
s.id = i |
|
if not only_segments: |
|
s.reassign_ids() |
|
|
|
def remove_no_word_segments(self, ignore_ori=False): |
|
for i in reversed(range(len(self.segments))): |
|
if (ignore_ori or self.segments[i].ori_has_words) and not self.segments[i].has_words: |
|
del self.segments[i] |
|
self.reassign_ids() |
|
|
|
def get_locked_indices(self): |
|
locked_indices = [i |
|
for i, (left, right) in enumerate(zip(self.segments[1:], self.segments[:-1])) |
|
if left.left_locked or right.right_locked] |
|
return locked_indices |
|
|
|
def get_gaps(self, as_ndarray=False): |
|
s_ts = np.array([s.start for s in self.segments]) |
|
e_ts = np.array([s.end for s in self.segments]) |
|
gap = s_ts[1:] - e_ts[:-1] |
|
return gap if as_ndarray else gap.tolist() |
|
|
|
def get_gap_indices(self, min_gap: float = 0.1): |
|
if len(self.segments) < 2: |
|
return [] |
|
if min_gap is None: |
|
min_gap = 0 |
|
indices = (self.get_gaps(True) <= min_gap).nonzero()[0].tolist() |
|
return sorted(set(indices) - set(self.get_locked_indices())) |
|
|
|
def get_punctuation_indices(self, punctuation: Union[List[str], List[Tuple[str, str]], str]): |
|
if len(self.segments) < 2: |
|
return [] |
|
if isinstance(punctuation, str): |
|
punctuation = [punctuation] |
|
indices = [] |
|
for p in punctuation: |
|
if isinstance(p, str): |
|
for i, s in enumerate(self.segments[:-1]): |
|
if s.text.endswith(p): |
|
indices.append(i) |
|
elif i != 0 and s.text.startswith(p): |
|
indices.append(i-1) |
|
else: |
|
ending, beginning = p |
|
indices.extend([i for i, (s0, s1) in enumerate(zip(self.segments[:-1], self.segments[1:])) |
|
if s0.text.endswith(ending) and s1.text.startswith(beginning)]) |
|
|
|
return sorted(set(indices) - set(self.get_locked_indices())) |
|
|
|
def all_words(self): |
|
return list(chain.from_iterable(s.words for s in self.segments)) |
|
|
|
def all_words_or_segments(self): |
|
return self.all_words() if self.has_words else self.segments |
|
|
|
def all_words_by_lock(self, only_text: bool = True, by_segment: bool = False, include_single: bool = False): |
|
if by_segment: |
|
return [ |
|
segment.words_by_lock(only_text=only_text, include_single=include_single) |
|
for segment in self.segments |
|
] |
|
return _words_by_lock(self.all_words(), only_text=only_text, include_single=include_single) |
|
|
|
def all_tokens(self): |
|
return list(chain.from_iterable(s.tokens for s in self.all_words())) |
|
|
|
def to_dict(self): |
|
return dict(text=self.text, |
|
segments=self.segments_to_dicts(), |
|
language=self.language, |
|
ori_dict=self.ori_dict, |
|
regroup_history=self._regroup_history, |
|
nonspeech_sections=self._nonspeech_sections) |
|
|
|
def segments_to_dicts(self, reverse_text: Union[bool, tuple] = False): |
|
return [s.to_dict(reverse_text=reverse_text) for s in self.segments] |
|
|
|
def _split_segments(self, get_indices, args: list = None, *, lock: bool = False, newline: bool = False): |
|
if args is None: |
|
args = [] |
|
no_words = False |
|
for i in reversed(range(0, len(self.segments))): |
|
no_words = no_words or not self.segments[i].has_words |
|
indices = sorted(set(get_indices(self.segments[i], *args))) |
|
if not indices: |
|
continue |
|
if newline: |
|
if indices[-1] == len(self.segments[i].words) - 1: |
|
del indices[-1] |
|
if not indices: |
|
continue |
|
|
|
for word_idx in indices: |
|
if self.segments[i].words[word_idx].word.endswith('\n'): |
|
continue |
|
self.segments[i].words[word_idx].word += '\n' |
|
if lock: |
|
self.segments[i].words[word_idx].lock_right() |
|
if word_idx + 1 < len(self.segments[i].words): |
|
self.segments[i].words[word_idx+1].lock_left() |
|
self.segments[i].update_seg_with_words() |
|
else: |
|
new_segments = self.segments[i].split(indices) |
|
if lock: |
|
for s in new_segments: |
|
if s == new_segments[0]: |
|
s.lock_right() |
|
elif s == new_segments[-1]: |
|
s.lock_left() |
|
else: |
|
s.lock_both() |
|
del self.segments[i] |
|
for s in reversed(new_segments): |
|
self.segments.insert(i, s) |
|
if no_words: |
|
warnings.warn('Found segment(s) without word timings. These segment(s) cannot be split.') |
|
self.remove_no_word_segments() |
|
|
|
def _merge_segments(self, indices: List[int], |
|
*, max_words: int = None, max_chars: int = None, is_sum_max: bool = False, lock: bool = False): |
|
if len(indices) == 0: |
|
return |
|
for i in reversed(indices): |
|
seg = self.segments[i] |
|
if ( |
|
( |
|
max_words and |
|
seg.has_words and |
|
( |
|
(seg.word_count() + self.segments[i + 1].word_count() > max_words) |
|
if is_sum_max else |
|
(seg.word_count() > max_words and self.segments[i + 1].word_count() > max_words) |
|
) |
|
) or |
|
( |
|
max_chars and |
|
( |
|
(seg.char_count() + self.segments[i + 1].char_count() > max_chars) |
|
if is_sum_max else |
|
(seg.char_count() > max_chars and self.segments[i + 1].char_count() > max_chars) |
|
) |
|
) |
|
): |
|
continue |
|
self.add_segments(i, i + 1, inplace=True, lock=lock) |
|
self.remove_no_word_segments() |
|
|
|
def get_content_by_time( |
|
self, |
|
time: Union[float, Tuple[float, float], dict], |
|
within: bool = False, |
|
segment_level: bool = False |
|
) -> Union[List[WordTiming], List[Segment]]: |
|
""" |
|
Return content in the ``time`` range. |
|
|
|
Parameters |
|
---------- |
|
time : float or tuple of (float, float) or dict |
|
Range of time to find content. For tuple of two floats, first value is the start time and second value is |
|
the end time. For a single float value, it is treated as both the start and end time. |
|
within : bool, default False |
|
Whether to only find content fully overlaps with ``time`` range. |
|
segment_level : bool, default False |
|
Whether to look only on the segment level and return instances of :class:`stable_whisper.result.Segment` |
|
instead of :class:`stable_whisper.result.WordTiming`. |
|
|
|
Returns |
|
------- |
|
list of stable_whisper.result.WordTiming or list of stable_whisper.result.Segment |
|
List of contents in the ``time`` range. The contents are instances of |
|
:class:`stable_whisper.result.Segment` if ``segment_level = True`` else |
|
:class:`stable_whisper.result.WordTiming`. |
|
""" |
|
if not segment_level and not self.has_words: |
|
raise ValueError('Missing word timestamps in result. Use ``segment_level=True`` instead.') |
|
contents = self.segments if segment_level else self.all_words() |
|
if isinstance(time, (float, int)): |
|
time = [time, time] |
|
elif isinstance(time, dict): |
|
time = [time['start'], time['end']] |
|
start, end = time |
|
|
|
if within: |
|
def is_in_range(c): |
|
return start <= c.start and end >= c.end |
|
else: |
|
def is_in_range(c): |
|
return start <= c.end and end >= c.start |
|
|
|
return [c for c in contents if is_in_range(c)] |
|
|
|
def split_by_gap( |
|
self, |
|
max_gap: float = 0.1, |
|
lock: bool = False, |
|
newline: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Split (in-place) any segment where the gap between two of its words is greater than ``max_gap``. |
|
|
|
Parameters |
|
---------- |
|
max_gap : float, default 0.1 |
|
Maximum second(s) allowed between two words if the same segment. |
|
lock : bool, default False |
|
Whether to prevent future splits/merges from altering changes made by this method. |
|
newline: bool, default False |
|
Whether to insert line break at the split points instead of splitting into separate segments. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
self._split_segments(lambda x: x.get_gap_indices(max_gap), lock=lock, newline=newline) |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
self._regroup_history += f'sg={max_gap}+{int(lock)}+{int(newline)}' |
|
return self |
|
|
|
def merge_by_gap( |
|
self, |
|
min_gap: float = 0.1, |
|
max_words: int = None, |
|
max_chars: int = None, |
|
is_sum_max: bool = False, |
|
lock: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Merge (in-place) any pair of adjacent segments if the gap between them <= ``min_gap``. |
|
|
|
Parameters |
|
---------- |
|
min_gap : float, default 0.1 |
|
Minimum second(s) allow between two segment. |
|
max_words : int, optional |
|
Maximum number of words allowed in each segment. |
|
max_chars : int, optional |
|
Maximum number of characters allowed in each segment. |
|
is_sum_max : bool, default False |
|
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments |
|
to be merged. |
|
lock : bool, default False |
|
Whether to prevent future splits/merges from altering changes made by this method. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
indices = self.get_gap_indices(min_gap) |
|
self._merge_segments(indices, |
|
max_words=max_words, max_chars=max_chars, is_sum_max=is_sum_max, lock=lock) |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
self._regroup_history += f'mg={min_gap}+{max_words or ""}+{max_chars or ""}+{int(is_sum_max)}+{int(lock)}' |
|
return self |
|
|
|
def split_by_punctuation( |
|
self, |
|
punctuation: Union[List[str], List[Tuple[str, str]], str], |
|
lock: bool = False, |
|
newline: bool = False, |
|
min_words: Optional[int] = None, |
|
min_chars: Optional[int] = None, |
|
min_dur: Optional[int] = None |
|
) -> "WhisperResult": |
|
""" |
|
Split (in-place) segments at words that start/end with ``punctuation``. |
|
|
|
Parameters |
|
---------- |
|
punctuation : list of str of list of tuple of (str, str) or str |
|
Punctuation(s) to split segments by. |
|
lock : bool, default False |
|
Whether to prevent future splits/merges from altering changes made by this method. |
|
newline : bool, default False |
|
Whether to insert line break at the split points instead of splitting into separate segments. |
|
min_words : int, optional |
|
Split segments with words >= ``min_words``. |
|
min_chars : int, optional |
|
Split segments with characters >= ``min_chars``. |
|
min_dur : int, optional |
|
split segments with duration (in seconds) >= ``min_dur``. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
def _over_max(x: Segment): |
|
return ( |
|
(min_words and len(x.words) >= min_words) or |
|
(min_chars and x.char_count() >= min_chars) or |
|
(min_dur and x.duration >= min_dur) |
|
) |
|
|
|
indices = set(s.id for s in self.segments if _over_max(s)) if any((min_words, min_chars, min_dur)) else None |
|
|
|
def _get_indices(x: Segment): |
|
return x.get_punctuation_indices(punctuation) if indices is None or x.id in indices else [] |
|
|
|
self._split_segments(_get_indices, lock=lock, newline=newline) |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
punct_str = '/'.join(p if isinstance(p, str) else '*'.join(p) for p in punctuation) |
|
self._regroup_history += f'sp={punct_str}+{int(lock)}+{int(newline)}' |
|
self._regroup_history += f'+{min_words or ""}+{min_chars or ""}+{min_dur or ""}'.rstrip('+') |
|
return self |
|
|
|
def merge_by_punctuation( |
|
self, |
|
punctuation: Union[List[str], List[Tuple[str, str]], str], |
|
max_words: int = None, |
|
max_chars: int = None, |
|
is_sum_max: bool = False, |
|
lock: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Merge (in-place) any two segments that has specific punctuations inbetween. |
|
|
|
Parameters |
|
---------- |
|
punctuation : list of str of list of tuple of (str, str) or str |
|
Punctuation(s) to merge segments by. |
|
max_words : int, optional |
|
Maximum number of words allowed in each segment. |
|
max_chars : int, optional |
|
Maximum number of characters allowed in each segment. |
|
is_sum_max : bool, default False |
|
Whether ``max_words`` and ``max_chars`` is applied to the merged segment instead of the individual segments |
|
to be merged. |
|
lock : bool, default False |
|
Whether to prevent future splits/merges from altering changes made by this method. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
indices = self.get_punctuation_indices(punctuation) |
|
self._merge_segments(indices, |
|
max_words=max_words, max_chars=max_chars, is_sum_max=is_sum_max, lock=lock) |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
punct_str = '/'.join(p if isinstance(p, str) else '*'.join(p) for p in punctuation) |
|
self._regroup_history += f'mp={punct_str}+{max_words or ""}+{max_chars or ""}+{int(is_sum_max)}+{int(lock)}' |
|
return self |
|
|
|
def merge_all_segments(self) -> "WhisperResult": |
|
""" |
|
Merge all segments into one segment. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if not self.segments: |
|
return self |
|
if self.has_words: |
|
self.segments[0].words = self.all_words() |
|
else: |
|
self.segments[0].text += ''.join(s.text for s in self.segments[1:]) |
|
if all(s.tokens is not None for s in self.segments): |
|
self.segments[0].tokens += list(chain.from_iterable(s.tokens for s in self.segments[1:])) |
|
self.segments[0].end = self.segments[-1].end |
|
self.segments = [self.segments[0]] |
|
self.reassign_ids() |
|
self.update_all_segs_with_words() |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
self._regroup_history += 'ms' |
|
return self |
|
|
|
def split_by_length( |
|
self, |
|
max_chars: int = None, |
|
max_words: int = None, |
|
even_split: bool = True, |
|
force_len: bool = False, |
|
lock: bool = False, |
|
include_lock: bool = False, |
|
newline: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Split (in-place) any segment that exceeds ``max_chars`` or ``max_words`` into smaller segments. |
|
|
|
Parameters |
|
---------- |
|
max_chars : int, optional |
|
Maximum number of characters allowed in each segment. |
|
max_words : int, optional |
|
Maximum number of words allowed in each segment. |
|
even_split : bool, default True |
|
Whether to evenly split a segment in length if it exceeds ``max_chars`` or ``max_words``. |
|
force_len : bool, default False |
|
Whether to force a constant length for each segment except the last segment. |
|
This will ignore all previous non-locked segment boundaries. |
|
lock : bool, default False |
|
Whether to prevent future splits/merges from altering changes made by this method. |
|
include_lock: bool, default False |
|
Whether to include previous lock before splitting based on max_words, if ``even_split = False``. |
|
Splitting will be done after the first non-locked word > ``max_chars`` / ``max_words``. |
|
newline: bool, default False |
|
Whether to insert line break at the split points instead of splitting into separate segments. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
|
|
Notes |
|
----- |
|
If ``even_split = True``, segments can still exceed ``max_chars`` and locked words will be ignored to avoid |
|
uneven splitting. |
|
""" |
|
if force_len: |
|
self.merge_all_segments() |
|
self._split_segments( |
|
lambda x: x.get_length_indices( |
|
max_chars=max_chars, |
|
max_words=max_words, |
|
even_split=even_split, |
|
include_lock=include_lock |
|
), |
|
lock=lock, |
|
newline=newline |
|
) |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
self._regroup_history += (f'sl={max_chars or ""}+{max_words or ""}+{int(even_split)}+{int(force_len)}' |
|
f'+{int(lock)}+{int(include_lock)}+{int(newline)}') |
|
return self |
|
|
|
def split_by_duration( |
|
self, |
|
max_dur: float, |
|
even_split: bool = True, |
|
force_len: bool = False, |
|
lock: bool = False, |
|
include_lock: bool = False, |
|
newline: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Split (in-place) any segment that exceeds ``max_dur`` into smaller segments. |
|
|
|
Parameters |
|
---------- |
|
max_dur : float |
|
Maximum duration (in seconds) per segment. |
|
even_split : bool, default True |
|
Whether to evenly split a segment in length if it exceeds ``max_dur``. |
|
force_len : bool, default False |
|
Whether to force a constant length for each segment except the last segment. |
|
This will ignore all previous non-locked segment boundaries. |
|
lock : bool, default False |
|
Whether to prevent future splits/merges from altering changes made by this method. |
|
include_lock: bool, default False |
|
Whether to include previous lock before splitting based on max_words, if ``even_split = False``. |
|
Splitting will be done after the first non-locked word > ``max_dur``. |
|
newline: bool, default False |
|
Whether to insert line break at the split points instead of splitting into separate segments. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
|
|
Notes |
|
----- |
|
If ``even_split = True``, segments can still exceed ``max_dur`` and locked words will be ignored to avoid |
|
uneven splitting. |
|
""" |
|
if force_len: |
|
self.merge_all_segments() |
|
self._split_segments( |
|
lambda x: x.get_duration_indices( |
|
max_dur=max_dur, |
|
even_split=even_split, |
|
include_lock=include_lock |
|
), |
|
lock=lock, |
|
newline=newline |
|
) |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
self._regroup_history += (f'sd={max_dur}+{int(even_split)}+{int(force_len)}' |
|
f'+{int(lock)}+{int(include_lock)}+{int(newline)}') |
|
return self |
|
|
|
def clamp_max( |
|
self, |
|
medium_factor: float = 2.5, |
|
max_dur: float = None, |
|
clip_start: Optional[bool] = None, |
|
verbose: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Clamp all word durations above certain value. |
|
|
|
This is most effective when applied before and after other regroup operations. |
|
|
|
Parameters |
|
---------- |
|
medium_factor : float, default 2.5 |
|
Clamp durations above (``medium_factor`` * medium duration) per segment. |
|
If ``medium_factor = None/0`` or segment has less than 3 words, it will be ignored and use only ``max_dur``. |
|
max_dur : float, optional |
|
Clamp durations above ``max_dur``. |
|
clip_start : bool or None, default None |
|
Whether to clamp the start of a word. If ``None``, clamp the start of first word and end of last word per |
|
segment. |
|
verbose : bool, default False |
|
Whether to print out the timestamp changes. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if not (medium_factor or max_dur): |
|
raise ValueError('At least one of following arguments requires non-zero value: medium_factor; max_dur') |
|
|
|
if not self.has_words: |
|
warnings.warn('Cannot clamp due to missing/no word-timestamps') |
|
return self |
|
|
|
for seg in self.segments: |
|
curr_max_dur = None |
|
if medium_factor and len(seg.words) > 2: |
|
durations = np.array([word.duration for word in seg.words]) |
|
durations.sort() |
|
curr_max_dur = medium_factor * durations[len(durations)//2 + 1] |
|
|
|
if max_dur and (not curr_max_dur or curr_max_dur > max_dur): |
|
curr_max_dur = max_dur |
|
|
|
if not curr_max_dur: |
|
continue |
|
|
|
if clip_start is None: |
|
seg.words[0].clamp_max(curr_max_dur, clip_start=True, verbose=verbose) |
|
seg.words[-1].clamp_max(curr_max_dur, clip_start=False, verbose=verbose) |
|
else: |
|
for i, word in enumerate(seg.words): |
|
word.clamp_max(curr_max_dur, clip_start=clip_start, verbose=verbose) |
|
|
|
seg.update_seg_with_words() |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
self._regroup_history += f'cm={medium_factor}+{max_dur or ""}+{clip_start or ""}+{int(verbose)}' |
|
return self |
|
|
|
def lock( |
|
self, |
|
startswith: Union[str, List[str]] = None, |
|
endswith: Union[str, List[str]] = None, |
|
right: bool = True, |
|
left: bool = False, |
|
case_sensitive: bool = False, |
|
strip: bool = True |
|
) -> "WhisperResult": |
|
""" |
|
Lock words/segments with matching prefix/suffix to prevent splitting/merging. |
|
|
|
Parameters |
|
---------- |
|
startswith: str or list of str |
|
Prefixes to lock. |
|
endswith: str or list of str |
|
Suffixes to lock. |
|
right : bool, default True |
|
Whether prevent splits/merges with the next word/segment. |
|
left : bool, default False |
|
Whether prevent splits/merges with the previous word/segment. |
|
case_sensitive : bool, default False |
|
Whether to match the case of the prefixes/suffixes with the words/segments. |
|
strip : bool, default True |
|
Whether to ignore spaces before and after both words/segments and prefixes/suffixes. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
assert startswith or endswith, 'Must specify [startswith] or/and [endswith].' |
|
startswith = [] if startswith is None else ([startswith] if isinstance(startswith, str) else startswith) |
|
endswith = [] if endswith is None else ([endswith] if isinstance(endswith, str) else endswith) |
|
if not case_sensitive: |
|
startswith = [t.lower() for t in startswith] |
|
endswith = [t.lower() for t in endswith] |
|
if strip: |
|
startswith = [t.strip() for t in startswith] |
|
endswith = [t.strip() for t in endswith] |
|
for part in self.all_words_or_segments(): |
|
text = part.word if hasattr(part, 'word') else part.text |
|
if not case_sensitive: |
|
text = text.lower() |
|
if strip: |
|
text = text.strip() |
|
for prefix in startswith: |
|
if text.startswith(prefix): |
|
if right: |
|
part.lock_right() |
|
if left: |
|
part.lock_left() |
|
for suffix in endswith: |
|
if text.endswith(suffix): |
|
if right: |
|
part.lock_right() |
|
if left: |
|
part.lock_left() |
|
if self._regroup_history: |
|
self._regroup_history += '_' |
|
startswith_str = (startswith if isinstance(startswith, str) else '/'.join(startswith)) if startswith else "" |
|
endswith_str = (endswith if isinstance(endswith, str) else '/'.join(endswith)) if endswith else "" |
|
self._regroup_history += (f'l={startswith_str}+{endswith_str}' |
|
f'+{int(right)}+{int(left)}+{int(case_sensitive)}+{int(strip)}') |
|
return self |
|
|
|
def remove_word( |
|
self, |
|
word: Union[WordTiming, Tuple[int, int]], |
|
reassign_ids: bool = True, |
|
verbose: bool = True |
|
) -> 'WhisperResult': |
|
""" |
|
Remove a word. |
|
|
|
Parameters |
|
---------- |
|
word : WordTiming or tuple of (int, int) |
|
Instance of :class:`stable_whisper.result.WordTiming` or tuple of (segment index, word index). |
|
reassign_ids : bool, default True |
|
Whether to reassign segment and word ids (indices) after removing ``word``. |
|
verbose : bool, default True |
|
Whether to print detail of the removed word. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if isinstance(word, WordTiming): |
|
if self[word.segment_id][word.id] is not word: |
|
self.reassign_ids() |
|
if self[word.segment_id][word.id] is not word: |
|
raise ValueError('word not in result') |
|
seg_id, word_id = word.segment_id, word.id |
|
else: |
|
seg_id, word_id = word |
|
if verbose: |
|
print(f'Removed: {self[seg_id][word_id].to_dict()}') |
|
del self.segments[seg_id].words[word_id] |
|
if not reassign_ids: |
|
return self |
|
if self[seg_id].has_words: |
|
self[seg_id].reassign_ids() |
|
else: |
|
self.remove_no_word_segments() |
|
return self |
|
|
|
def remove_segment( |
|
self, |
|
segment: Union[Segment, int], |
|
reassign_ids: bool = True, |
|
verbose: bool = True |
|
) -> 'WhisperResult': |
|
""" |
|
Remove a segment. |
|
|
|
Parameters |
|
---------- |
|
segment : Segment or int |
|
Instance :class:`stable_whisper.result.Segment` or segment index. |
|
reassign_ids : bool, default True |
|
Whether to reassign segment IDs (indices) after removing ``segment``. |
|
verbose : bool, default True |
|
Whether to print detail of the removed word. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if isinstance(segment, Segment): |
|
if self[segment.id] is not segment: |
|
self.reassign_ids() |
|
if self[segment.id] is not segment: |
|
raise ValueError('segment not in result') |
|
segment = segment.id |
|
if verbose: |
|
print(f'Removed: [id:{self[segment].id}] {self[segment].to_display_str(True)}') |
|
del self.segments[segment] |
|
if not reassign_ids: |
|
return self |
|
self.reassign_ids(True) |
|
return self |
|
|
|
def remove_repetition( |
|
self, |
|
max_words: int = 1, |
|
case_sensitive: bool = False, |
|
strip: bool = True, |
|
ignore_punctuations: str = "\"',.?!", |
|
extend_duration: bool = True, |
|
verbose: bool = True |
|
) -> 'WhisperResult': |
|
""" |
|
Remove words that repeat consecutively. |
|
|
|
Parameters |
|
---------- |
|
max_words : int |
|
Maximum number of words to look for consecutively. |
|
case_sensitive : bool, default False |
|
Whether the case of words need to match to be considered as repetition. |
|
strip : bool, default True |
|
Whether to ignore spaces before and after each word. |
|
ignore_punctuations : bool, default '"',.?!' |
|
Ending punctuations to ignore. |
|
extend_duration: bool, default True |
|
Whether to extend the duration of the previous word to cover the duration of the repetition. |
|
verbose: bool, default True |
|
Whether to print detail of the removed repetitions. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if not self.has_words: |
|
return self |
|
|
|
for count in range(1, max_words + 1): |
|
all_words = self.all_words() |
|
if len(all_words) < 2: |
|
return self |
|
all_words_str = [w.word for w in all_words] |
|
if strip: |
|
all_words_str = [w.strip() for w in all_words_str] |
|
if ignore_punctuations: |
|
ptn = f'[{ignore_punctuations}]+$' |
|
all_words_str = [re.sub(ptn, '', w) for w in all_words_str] |
|
if not case_sensitive: |
|
all_words_str = [w.lower() for w in all_words_str] |
|
next_i = None |
|
changes = [] |
|
for i in reversed(range(count*2, len(all_words_str)+1)): |
|
if next_i is not None: |
|
if next_i != i: |
|
continue |
|
else: |
|
next_i = None |
|
s = i - count |
|
if all_words_str[s - count:s] != all_words_str[s:i]: |
|
continue |
|
next_i = s |
|
if extend_duration: |
|
all_words[s-1].end = all_words[i-1].end |
|
temp_changes = [] |
|
for j in reversed(range(s, i)): |
|
if verbose: |
|
temp_changes.append(f'- {all_words[j].to_dict()}') |
|
self.remove_word(all_words[j], False, verbose=False) |
|
if temp_changes: |
|
changes.append( |
|
f'Remove: [{format_timestamp(all_words[s].start)} -> {format_timestamp(all_words[i-1].end)}] ' |
|
+ ''.join(_w.word for _w in all_words[s:i]) + '\n' |
|
+ '\n'.join(reversed(temp_changes)) + '\n' |
|
) |
|
for i0, i1 in zip(range(s - count, s), range(s, i)): |
|
if len(all_words[i0].word) < len(all_words[i1].word): |
|
all_words[i1].start = all_words[i0].start |
|
all_words[i1].end = all_words[i0].end |
|
_sid, _wid = all_words[i0].segment_id, all_words[i0].id |
|
self.segments[_sid].words[_wid] = all_words[i1] |
|
|
|
if changes: |
|
print('\n'.join(reversed(changes))) |
|
|
|
self.remove_no_word_segments() |
|
self.update_all_segs_with_words() |
|
|
|
return self |
|
|
|
def remove_words_by_str( |
|
self, |
|
words: Union[str, List[str], None], |
|
case_sensitive: bool = False, |
|
strip: bool = True, |
|
ignore_punctuations: str = "\"',.?!", |
|
min_prob: float = None, |
|
filters: Callable = None, |
|
verbose: bool = True |
|
) -> 'WhisperResult': |
|
""" |
|
Remove words that match ``words``. |
|
|
|
Parameters |
|
---------- |
|
words : str or list of str or None |
|
A word or list of words to remove.``None`` for all words to be passed into ``filters``. |
|
case_sensitive : bool, default False |
|
Whether the case of words need to match to be considered as repetition. |
|
strip : bool, default True |
|
Whether to ignore spaces before and after each word. |
|
ignore_punctuations : bool, default '"',.?!' |
|
Ending punctuations to ignore. |
|
min_prob : float, optional |
|
Acts as the first filter the for the words that match ``words``. Words with probability < ``min_prob`` will |
|
be removed if ``filters`` is ``None``, else pass the words into ``filters``. Words without probability will |
|
be treated as having probability < ``min_prob``. |
|
filters : Callable, optional |
|
A function that takes an instance of :class:`stable_whisper.result.WordTiming` as its only argument. |
|
This function is custom filter for the words that match ``words`` and were not caught by ``min_prob``. |
|
verbose: |
|
Whether to print detail of the removed words. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if not self.has_words: |
|
return self |
|
if isinstance(words, str): |
|
words = [words] |
|
all_words = self.all_words() |
|
all_words_str = [w.word for w in all_words] |
|
if strip: |
|
all_words_str = [w.strip() for w in all_words_str] |
|
words = [w.strip() for w in words] |
|
if ignore_punctuations: |
|
ptn = f'[{ignore_punctuations}]+$' |
|
all_words_str = [re.sub(ptn, '', w) for w in all_words_str] |
|
words = [re.sub(ptn, '', w) for w in words] |
|
if not case_sensitive: |
|
all_words_str = [w.lower() for w in all_words_str] |
|
words = [w.lower() for w in words] |
|
|
|
changes = [] |
|
for i, w in reversed(list(enumerate(all_words_str))): |
|
if not (words is None or any(w == _w for _w in words)): |
|
continue |
|
if ( |
|
(min_prob is None or all_words[i].probability is None or min_prob > all_words[i].probability) and |
|
(filters is None or filters(all_words[i])) |
|
): |
|
if verbose: |
|
changes.append(f'Removed: {all_words[i].to_dict()}') |
|
self.remove_word(all_words[i], False, verbose=False) |
|
if changes: |
|
print('\n'.join(reversed(changes))) |
|
self.remove_no_word_segments() |
|
self.update_all_segs_with_words() |
|
|
|
return self |
|
|
|
def fill_in_gaps( |
|
self, |
|
other_result: Union['WhisperResult', str], |
|
min_gap: float = 0.1, |
|
case_sensitive: bool = False, |
|
strip: bool = True, |
|
ignore_punctuations: str = "\"',.?!", |
|
verbose: bool = True |
|
) -> 'WhisperResult': |
|
""" |
|
Fill in segment gaps larger than ``min_gap`` with content from ``other_result`` at the times of gaps. |
|
|
|
Parameters |
|
---------- |
|
other_result : WhisperResult or str |
|
Another transcription result as an instance of :class:`stable_whisper.result.WhisperResult` or path to the |
|
JSON of the result. |
|
min_gap : float, default 0.1 |
|
The minimum seconds of a gap between segments that must be exceeded to be filled in. |
|
case_sensitive : bool, default False |
|
Whether to consider the case of the first and last word of the gap to determine overlapping words to remove |
|
before filling in. |
|
strip : bool, default True |
|
Whether to ignore spaces before and after the first and last word of the gap to determine overlapping words |
|
to remove before filling in. |
|
ignore_punctuations : bool, default '"',.?!' |
|
Ending punctuations to ignore in the first and last word of the gap to determine overlapping words to |
|
remove before filling in. |
|
verbose: |
|
Whether to print detail of the filled content. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
""" |
|
if len(self.segments) < 2: |
|
return self |
|
if isinstance(other_result, str): |
|
other_result = WhisperResult(other_result) |
|
|
|
if strip: |
|
def strip_space(w): |
|
return w.strip() |
|
else: |
|
def strip_space(w): |
|
return w |
|
|
|
if ignore_punctuations: |
|
ptn = f'[{ignore_punctuations}]+$' |
|
|
|
def strip_punctuations(w): |
|
return re.sub(ptn, '', strip_space(w)) |
|
else: |
|
strip_punctuations = strip_space |
|
|
|
if case_sensitive: |
|
strip = strip_punctuations |
|
else: |
|
def strip(w): |
|
return strip_punctuations(w).lower() |
|
|
|
seg_pairs = list(enumerate(zip(self.segments[:-1], self.segments[1:]))) |
|
seg_pairs.insert(0, (-1, (None, self.segments[0]))) |
|
seg_pairs.append((seg_pairs[-1][0]+1, (self.segments[-1], None))) |
|
|
|
changes = [] |
|
for i, (seg0, seg1) in reversed(seg_pairs): |
|
first_word = None if seg0 is None else seg0.words[-1] |
|
last_word = None if seg1 is None else seg1.words[0] |
|
start = (other_result[0].start if first_word is None else first_word.end) |
|
end = other_result[-1].end if last_word is None else last_word.start |
|
if end - start <= min_gap: |
|
continue |
|
gap_words = other_result.get_content_by_time((start, end)) |
|
if first_word is not None and gap_words and strip(first_word.word) == strip(gap_words[0].word): |
|
first_word.end = gap_words[0].end |
|
gap_words = gap_words[1:] |
|
if last_word is not None and gap_words and strip(last_word.word) == strip(gap_words[-1].word): |
|
last_word.start = gap_words[-1].start |
|
gap_words = gap_words[:-1] |
|
if not gap_words: |
|
continue |
|
if last_word is not None and last_word.start < gap_words[-1].end: |
|
last_word.start = gap_words[-1].end |
|
new_segments = [other_result[gap_words[0].segment_id].copy([])] |
|
for j, new_word in enumerate(gap_words): |
|
new_word = deepcopy(new_word) |
|
if j == 0 and first_word is not None and first_word.end > gap_words[0].start: |
|
new_word.start = first_word.end |
|
if new_segments[-1].id != new_word.segment_id: |
|
new_segments.append(other_result[new_word.segment_id].copy([])) |
|
new_segments[-1].words.append(new_word) |
|
if verbose: |
|
changes.append('\n'.join('Added: ' + s.to_display_str(True) for s in new_segments)) |
|
self.segments = self.segments[:i+1] + new_segments + self.segments[i+1:] |
|
if changes: |
|
print('\n'.join(reversed(changes))) |
|
self.reassign_ids() |
|
self.update_all_segs_with_words() |
|
|
|
return self |
|
|
|
def regroup( |
|
self, |
|
regroup_algo: Union[str, bool] = None, |
|
verbose: bool = False, |
|
only_show: bool = False |
|
) -> "WhisperResult": |
|
""" |
|
Regroup (in-place) words into segments. |
|
|
|
Parameters |
|
---------- |
|
regroup_algo: str or bool, default 'da' |
|
String representation of a custom regrouping algorithm or ``True`` use to the default algorithm 'da'. |
|
verbose : bool, default False |
|
Whether to show all the methods and arguments parsed from ``regroup_algo``. |
|
only_show : bool, default False |
|
Whether to show the all methods and arguments parsed from ``regroup_algo`` without running the methods |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResult |
|
The current instance after the changes. |
|
|
|
Notes |
|
----- |
|
Syntax for string representation of custom regrouping algorithm. |
|
Method keys: |
|
sg: split_by_gap |
|
sp: split_by_punctuation |
|
sl: split_by_length |
|
sd: split_by_duration |
|
mg: merge_by_gap |
|
mp: merge_by_punctuation |
|
ms: merge_all_segment |
|
cm: clamp_max |
|
l: lock |
|
us: unlock_all_segments |
|
da: default algorithm (cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?) |
|
rw: remove_word |
|
rs: remove_segment |
|
rp: remove_repetition |
|
rws: remove_words_by_str |
|
fg: fill_in_gaps |
|
Metacharacters: |
|
= separates a method key and its arguments (not used if no argument) |
|
_ separates method keys (after arguments if there are any) |
|
+ separates arguments for a method key |
|
/ separates an argument into list of strings |
|
* separates an item in list of strings into a nested list of strings |
|
Notes: |
|
-arguments are parsed positionally |
|
-if no argument is provided, the default ones will be used |
|
-use 1 or 0 to represent True or False |
|
Example 1: |
|
merge_by_gap(.2, 10, lock=True) |
|
mg=.2+10+++1 |
|
Note: [lock] is the 5th argument hence the 2 missing arguments inbetween the three + before 1 |
|
Example 2: |
|
split_by_punctuation([('.', ' '), '。', '?', '?'], True) |
|
sp=.* /。/?/?+1 |
|
Example 3: |
|
merge_all_segments().split_by_gap(.5).merge_by_gap(.15, 3) |
|
ms_sg=.5_mg=.15+3 |
|
""" |
|
if regroup_algo is False: |
|
return self |
|
if regroup_algo is None or regroup_algo is True: |
|
regroup_algo = 'da' |
|
|
|
for method, kwargs, msg in self.parse_regroup_algo(regroup_algo, include_str=verbose or only_show): |
|
if msg: |
|
print(msg) |
|
if not only_show: |
|
method(**kwargs) |
|
|
|
return self |
|
|
|
def parse_regroup_algo(self, regroup_algo: str, include_str: bool = True) -> List[Tuple[Callable, dict, str]]: |
|
methods = dict( |
|
sg=self.split_by_gap, |
|
sp=self.split_by_punctuation, |
|
sl=self.split_by_length, |
|
sd=self.split_by_duration, |
|
mg=self.merge_by_gap, |
|
mp=self.merge_by_punctuation, |
|
ms=self.merge_all_segments, |
|
cm=self.clamp_max, |
|
us=self.unlock_all_segments, |
|
l=self.lock, |
|
rw=self.remove_word, |
|
rs=self.remove_segment, |
|
rp=self.remove_repetition, |
|
rws=self.remove_words_by_str, |
|
fg=self.fill_in_gaps, |
|
) |
|
if not regroup_algo: |
|
return [] |
|
|
|
calls = regroup_algo.split('_') |
|
if 'da' in calls: |
|
default_calls = 'cm_sp=.* /。/?/?/,* /,_sg=.5_mg=.3+3_sp=.* /。/?/?'.split('_') |
|
calls = chain.from_iterable(default_calls if method == 'da' else [method] for method in calls) |
|
operations = [] |
|
for method in calls: |
|
method, args = method.split('=', maxsplit=1) if '=' in method else (method, '') |
|
if method not in methods: |
|
raise NotImplementedError(f'{method} is not one of the available methods: {tuple(methods.keys())}') |
|
args = [] if len(args) == 0 else list(map(str_to_valid_type, args.split('+'))) |
|
kwargs = {k: v for k, v in zip(methods[method].__code__.co_varnames[1:], args) if v is not None} |
|
if include_str: |
|
kwargs_str = ', '.join(f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' for k, v in kwargs.items()) |
|
op_str = f'{methods[method].__name__}({kwargs_str})' |
|
else: |
|
op_str = None |
|
operations.append((methods[method], kwargs, op_str)) |
|
|
|
return operations |
|
|
|
def find(self, pattern: str, word_level=True, flags=None) -> "WhisperResultMatches": |
|
""" |
|
Find segments/words and timestamps with regular expression. |
|
|
|
Parameters |
|
---------- |
|
pattern : str |
|
RegEx pattern to search for. |
|
word_level : bool, default True |
|
Whether to search at word-level. |
|
flags : optional |
|
RegEx flags. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResultMatches |
|
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``. |
|
""" |
|
return WhisperResultMatches(self).find(pattern, word_level=word_level, flags=flags) |
|
|
|
@property |
|
def text(self): |
|
return ''.join(s.text for s in self.segments) |
|
|
|
@property |
|
def regroup_history(self): |
|
|
|
return self._regroup_history |
|
|
|
@property |
|
def nonspeech_sections(self): |
|
return self._nonspeech_sections |
|
|
|
def show_regroup_history(self): |
|
""" |
|
Print details of all regrouping operations that been performed on data. |
|
""" |
|
if not self._regroup_history: |
|
print('Result has no history.') |
|
for *_, msg in self.parse_regroup_algo(self._regroup_history): |
|
print(f'.{msg}') |
|
|
|
def __len__(self): |
|
return len(self.segments) |
|
|
|
def unlock_all_segments(self): |
|
for s in self.segments: |
|
s.unlock_all_words() |
|
return self |
|
|
|
def reset(self): |
|
""" |
|
Restore all values to that at initialization. |
|
""" |
|
self.language = self.ori_dict.get('language') |
|
self._regroup_history = '' |
|
segments = self.ori_dict.get('segments') |
|
self.segments: List[Segment] = [Segment(**s) for s in segments] if segments else [] |
|
if self._forced_order: |
|
self.force_order() |
|
self.remove_no_word_segments(any(seg.has_words for seg in self.segments)) |
|
self.update_all_segs_with_words() |
|
|
|
@property |
|
def has_words(self): |
|
return all(seg.has_words for seg in self.segments) |
|
|
|
to_srt_vtt = result_to_srt_vtt |
|
to_ass = result_to_ass |
|
to_tsv = result_to_tsv |
|
to_txt = result_to_txt |
|
save_as_json = save_as_json |
|
|
|
|
|
class SegmentMatch: |
|
|
|
def __init__( |
|
self, |
|
segments: Union[List[Segment], Segment], |
|
_word_indices: List[List[int]] = None, |
|
_text_match: str = None |
|
): |
|
self.segments = [segments] if isinstance(segments, Segment) else segments |
|
self.word_indices = [] if _word_indices is None else _word_indices |
|
self.words = [self.segments[i].words[j] for i, indices in enumerate(self.word_indices) for j in indices] |
|
if len(self.words) != 0: |
|
self.text = ''.join( |
|
self.segments[i].words[j].word |
|
for i, indices in enumerate(self.word_indices) |
|
for j in indices |
|
) |
|
else: |
|
self.text = ''.join(seg.text for seg in self.segments) |
|
self.text_match = _text_match |
|
|
|
@property |
|
def start(self): |
|
return ( |
|
self.words[0].start |
|
if len(self.words) != 0 else |
|
(self.segments[0].start if len(self.segments) != 0 else None) |
|
) |
|
|
|
@property |
|
def end(self): |
|
return ( |
|
self.words[-1].end |
|
if len(self.words) != 0 else |
|
(self.segments[-1].end if len(self.segments) != 0 else None) |
|
) |
|
|
|
def __len__(self): |
|
return len(self.segments) |
|
|
|
def __repr__(self): |
|
return self.__dict__.__repr__() |
|
|
|
def __str__(self): |
|
return self.__dict__.__str__() |
|
|
|
|
|
class WhisperResultMatches: |
|
""" |
|
RegEx matches for WhisperResults. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
matches: Union[List[SegmentMatch], WhisperResult], |
|
_segment_indices: List[List[int]] = None |
|
): |
|
if isinstance(matches, WhisperResult): |
|
self.matches = list(map(SegmentMatch, matches.segments)) |
|
self._segment_indices = [[i] for i in range(len(matches.segments))] |
|
else: |
|
self.matches = matches |
|
assert _segment_indices is not None |
|
assert len(self.matches) == len(_segment_indices) |
|
assert all(len(match.segments) == len(_segment_indices[i]) for i, match in enumerate(self.matches)) |
|
self._segment_indices = _segment_indices |
|
|
|
@property |
|
def segment_indices(self): |
|
return self._segment_indices |
|
|
|
def _curr_seg_groups(self) -> List[List[Tuple[int, Segment]]]: |
|
seg_groups, curr_segs = [], [] |
|
curr_max = -1 |
|
for seg_indices, match in zip(self._segment_indices, self.matches): |
|
for i, seg in zip(sorted(seg_indices), match.segments): |
|
if i > curr_max: |
|
curr_segs.append((i, seg)) |
|
if i - 1 != curr_max: |
|
seg_groups.append(curr_segs) |
|
curr_segs = [] |
|
curr_max = i |
|
|
|
if curr_segs: |
|
seg_groups.append(curr_segs) |
|
return seg_groups |
|
|
|
def find(self, pattern: str, word_level=True, flags=None) -> "WhisperResultMatches": |
|
""" |
|
Find segments/words and timestamps with regular expression. |
|
|
|
Parameters |
|
---------- |
|
pattern : str |
|
RegEx pattern to search for. |
|
word_level : bool, default True |
|
Whether to search at word-level. |
|
flags : optional |
|
RegEx flags. |
|
|
|
Returns |
|
------- |
|
stable_whisper.result.WhisperResultMatches |
|
An instance of :class:`stable_whisper.result.WhisperResultMatches` with word/segment that match ``pattern``. |
|
""" |
|
|
|
seg_groups = self._curr_seg_groups() |
|
matches: List[SegmentMatch] = [] |
|
match_seg_indices: List[List[int]] = [] |
|
if word_level: |
|
if not all(all(seg.has_words for seg in match.segments) for match in self.matches): |
|
warnings.warn('Cannot perform word-level search with segment(s) missing word timestamps.') |
|
word_level = False |
|
|
|
for segs in seg_groups: |
|
if word_level: |
|
idxs = list(chain.from_iterable( |
|
[(i, j)]*len(word.word) for (i, seg) in segs for j, word in enumerate(seg.words) |
|
)) |
|
text = ''.join(word.word for (_, seg) in segs for word in seg.words) |
|
else: |
|
idxs = list(chain.from_iterable([(i, None)]*len(seg.text) for (i, seg) in segs)) |
|
text = ''.join(seg.text for (_, seg) in segs) |
|
assert len(idxs) == len(text) |
|
for curr_match in re.finditer(pattern, text, flags=flags or 0): |
|
start, end = curr_match.span() |
|
curr_idxs = idxs[start: end] |
|
curr_seg_idxs = sorted(set(i[0] for i in curr_idxs)) |
|
if word_level: |
|
curr_word_idxs = [ |
|
sorted(set(j for i, j in curr_idxs if i == seg_idx)) |
|
for seg_idx in curr_seg_idxs |
|
] |
|
else: |
|
curr_word_idxs = None |
|
matches.append(SegmentMatch( |
|
segments=[s for i, s in segs if i in curr_seg_idxs], |
|
_word_indices=curr_word_idxs, |
|
_text_match=curr_match.group() |
|
)) |
|
match_seg_indices.append(curr_seg_idxs) |
|
return WhisperResultMatches(matches, match_seg_indices) |
|
|
|
def __len__(self): |
|
return len(self.matches) |
|
|
|
def __bool__(self): |
|
return self.__len__() != 0 |
|
|
|
def __getitem__(self, idx): |
|
return self.matches[idx] |
|
|