gpted / text_processing.py
mebubo's picture
Snapshot
230a441
raw
history blame
1.15 kB
from dataclasses import dataclass
from tokenizers import Tokenizer
@dataclass
class Word:
tokens: list[int]
text: str
logprob: float
first_token_index: int
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
words: list[Word] = []
current_word: list[int] = []
current_log_probs: list[float] = []
current_word_first_token_index: int = 0
for i, (token_id, logprob) in enumerate(token_probs):
token: str = tokenizer.decode([token_id])
if not token.startswith(chr(9601)) and token.isalpha():
current_word.append(token_id)
current_log_probs.append(logprob)
else:
if current_word:
words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
current_word = [token_id]
current_log_probs = [logprob]
current_word_first_token_index = i
if current_word:
words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
return words