File size: 1,151 Bytes
6735ae4
230a441
6735ae4
 
 
230a441
6735ae4
 
 
 
230a441
 
 
 
 
6735ae4
230a441
 
6735ae4
230a441
6735ae4
 
 
230a441
 
6735ae4
 
 
 
230a441
6735ae4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from dataclasses import dataclass
from tokenizers import Tokenizer

@dataclass
class Word:
    tokens: list[int]
    text: str
    logprob: float
    first_token_index: int

def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
    words: list[Word] = []
    current_word: list[int] = []
    current_log_probs: list[float] = []
    current_word_first_token_index: int = 0

    for i, (token_id, logprob) in enumerate(token_probs):
        token: str = tokenizer.decode([token_id])
        if not token.startswith(chr(9601)) and token.isalpha():
            current_word.append(token_id)
            current_log_probs.append(logprob)
        else:
            if current_word:
                words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
            current_word = [token_id]
            current_log_probs = [logprob]
            current_word_first_token_index = i

    if current_word:
        words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))

    return words