| from dataclasses import dataclass | |
| class Word: | |
| tokens: list[int] | |
| text: str | |
| logprob: float | |
| first_token_index: int | |
| def split_into_words(tokens, log_probs) -> list[Word]: | |
| words = [] | |
| current_word = [] | |
| current_log_probs = [] | |
| current_word_first_token_index = 0 | |
| for i, (token, logprob) in enumerate(zip(tokens, log_probs)): | |
| if not token.startswith(chr(9601)) and token.isalpha(): | |
| current_word.append(token) | |
| current_log_probs.append(logprob) | |
| else: | |
| if current_word: | |
| words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index)) | |
| current_word = [token] | |
| current_log_probs = [logprob] | |
| current_word_first_token_index = i | |
| if current_word: | |
| words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index)) | |
| return words | |