Rolando
Set it up
8718761
raw
history blame
10.2 kB
import string
import torch
import numpy as np
from typing import TYPE_CHECKING, List, Callable, Optional
from itertools import chain
from whisper.audio import TOKENS_PER_SECOND, N_SAMPLES_PER_TOKEN
from whisper.timing import WordTiming, median_filter, dtw, merge_punctuations
if TYPE_CHECKING:
from whisper.tokenizer import Tokenizer
from whisper.model import Whisper
# modified version of whisper.timing.find_alignment
def find_alignment_stable(
model: "Whisper",
tokenizer: "Tokenizer",
text_tokens: List[int],
mel: torch.Tensor,
num_samples: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
ts_num: int = 0,
ts_noise: float = 0.1,
token_split=None,
audio_features: torch.Tensor = None
) -> List[WordTiming]:
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
if audio_features is None:
audio_features = model.encoder(mel.unsqueeze(0))
if ts_num:
if ts_noise is None:
ts_noise = 0.1
extra_audio_features = audio_features.repeat_interleave(ts_num, 0)
torch.manual_seed(0)
audio_features = torch.cat([audio_features,
extra_audio_features *
(1 - (torch.rand_like(extra_audio_features) * ts_noise))],
dim=0)
logits = model.decoder(tokens.unsqueeze(0).repeat_interleave(audio_features.shape[0], 0),
audio_features)
else:
logits = model.decoder(tokens.unsqueeze(0), audio_features)
logits = logits[0]
sampled_logits = logits[len(tokenizer.sot_sequence):, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames
weights = torch.cat([QKs[_l][:, _h] for _l, _h in model.alignment_heads.indices().T], dim=0)
weights = weights[:, :, : round(num_samples / N_SAMPLES_PER_TOKEN)]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence): -1]
text_indices, time_indices = dtw(-matrix)
if token_split is None:
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
else:
words, word_tokens = token_split
words.append(tokenizer.decode([tokenizer.eot]))
word_tokens.append([tokenizer.eot])
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps].clip(min=0) / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def _split_tokens(tokens: List[int], tokenizer: "Tokenizer"):
split_by_space = getattr(tokenizer, 'language_code', tokenizer.language) not in {"zh", "ja", "th", "lo", "my"}
text = tokenizer.decode_with_timestamps(tokens)
words = []
word_tokens = []
curr_tokens = []
is_append = False
for token in tokens:
curr_tokens.append(token)
curr_text = tokenizer.decode(curr_tokens)
is_whole = token >= tokenizer.eot
if not is_whole:
is_whole = text[:len(curr_text)] == curr_text
if is_whole and split_by_space:
is_append = not (curr_text.startswith(" ") or curr_text.strip() in string.punctuation)
if is_whole:
if is_append and len(words) != 0:
words[-1] += curr_text
word_tokens[-1].extend(curr_tokens)
else:
words.append(curr_text)
word_tokens.append(curr_tokens)
text = text[len(curr_text):]
curr_tokens = []
if len(curr_tokens) != 0:
words.append(curr_text if len(text) == 0 else text)
word_tokens.append(curr_tokens)
elif len(text) != 0:
words[-1] += text
return words, word_tokens
def split_word_tokens(segments: List[dict],
tokenizer: "Tokenizer",
*,
padding: (str, int) = None,
split_callback: Callable = None):
if padding is not None:
if isinstance(padding, str):
padding = tokenizer.encode(padding)
else:
padding = [padding]
tokens = []
seg_indices = []
words = []
word_tokens = []
for i, s in enumerate(segments):
temp_word_tokens = [t for t in s['tokens'] if not isinstance(t, int) or t < tokenizer.eot]
curr_words, curr_word_tokens = (
_split_tokens(temp_word_tokens, tokenizer)
if split_callback is None else
split_callback(temp_word_tokens, tokenizer)
)
assert len(curr_words) == len(curr_word_tokens), \
f'word count and token group count do not match, {len(curr_words)} and {len(curr_word_tokens)}'
if (
padding is not None and
curr_word_tokens[0][0] != padding and
(len(tokens) == 0 or tokens[-1] != padding)
):
tokens.extend(padding)
words.append(None)
word_tokens.append(padding)
seg_indices.extend([i] * len(curr_words))
tokens.extend(list(chain.from_iterable(curr_word_tokens)))
words.extend(curr_words)
word_tokens.extend(curr_word_tokens)
return tokens, (words, word_tokens), seg_indices
def pop_empty_alignment(alignment: List[WordTiming]):
return list(reversed([alignment.pop(i) for i in reversed(range(len(alignment))) if alignment[i].word is None]))
# modified version of whisper.timing.add_word_timestamps
def add_word_timestamps_stable(
*,
segments: List[dict],
model: "Whisper",
tokenizer: "Tokenizer",
mel: torch.Tensor,
num_samples: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
audio_features: torch.Tensor = None,
ts_num: int = 0,
ts_noise: float = 0.1,
min_word_dur: float = 0.1,
split_callback: Callable = None,
gap_padding: Optional[str] = ' ...',
**kwargs,
):
if len(segments) == 0:
return
if min_word_dur is None:
min_word_dur = 0
if prepend_punctuations is None:
prepend_punctuations = "\"'“¿([{-"
if append_punctuations is None:
append_punctuations = "\"'.。,,!!??::”)]}、"
def align():
for seg in segments:
seg['words'] = []
text_tokens, token_split, seg_indices = split_word_tokens(segments, tokenizer,
padding=gap_padding, split_callback=split_callback)
alignment = find_alignment_stable(model, tokenizer, text_tokens, mel, num_samples,
**kwargs,
token_split=token_split,
audio_features=audio_features,
ts_num=ts_num,
ts_noise=ts_noise)
alt_beginning_alignment = pop_empty_alignment(alignment)
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"]
assert len(alignment) == len(seg_indices)
assert (gap_padding is None or len(segments) == len(alt_beginning_alignment))
for i, timing in zip(seg_indices, alignment):
if len(timing.tokens) != 0:
start = timing.start
end = timing.end
if (
len(segments[i]['words']) == 0 and
((end - start) < min_word_dur) and
len(alt_beginning_alignment)
):
start = alt_beginning_alignment[i].start
segments[i]['words'].append(
dict(
word=timing.word,
start=round(time_offset + start, 3),
end=round(time_offset + end, 3),
probability=timing.probability,
tokens=timing.tokens
)
)
align()
if (
gap_padding is not None and
any(
(word['end'] - word['start']) < min_word_dur
for seg in segments
for word in seg['words']
)
):
gap_padding = None
align()
for segment in segments:
if len(words := segment["words"]) > 0:
# adjust the segment-level timestamps based on the word-level timestamps
segment["start"] = words[0]["start"]
segment["end"] = words[-1]["end"]