Snapshot
Browse files- app.py +3 -3
- text_processing.py +13 -11
app.py
CHANGED
@@ -21,7 +21,7 @@ def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tup
|
|
21 |
attention_mask = cast(torch.Tensor, inputs["attention_mask"])
|
22 |
return input_ids, attention_mask
|
23 |
|
24 |
-
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[
|
25 |
with torch.no_grad():
|
26 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
27 |
# B x T x V
|
@@ -31,8 +31,8 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
|
|
31 |
# T - 1
|
32 |
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
|
33 |
# T - 1
|
34 |
-
tokens:
|
35 |
-
return list(zip(tokens, token_log_probs.tolist()))
|
36 |
|
37 |
|
38 |
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
|
|
|
21 |
attention_mask = cast(torch.Tensor, inputs["attention_mask"])
|
22 |
return input_ids, attention_mask
|
23 |
|
24 |
+
def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[int, float]]:
|
25 |
with torch.no_grad():
|
26 |
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
|
27 |
# B x T x V
|
|
|
31 |
# T - 1
|
32 |
token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
|
33 |
# T - 1
|
34 |
+
tokens: torch.Tensor = input_ids[0][1:]
|
35 |
+
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
36 |
|
37 |
|
38 |
def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
|
text_processing.py
CHANGED
@@ -1,30 +1,32 @@
|
|
1 |
from dataclasses import dataclass
|
|
|
2 |
|
3 |
@dataclass
|
4 |
class Word:
|
5 |
-
tokens: list[
|
6 |
text: str
|
7 |
logprob: float
|
8 |
first_token_index: int
|
9 |
|
10 |
-
def split_into_words(token_probs: list[tuple[
|
11 |
-
words = []
|
12 |
-
current_word = []
|
13 |
-
current_log_probs = []
|
14 |
-
current_word_first_token_index = 0
|
15 |
|
16 |
-
for i, (
|
|
|
17 |
if not token.startswith(chr(9601)) and token.isalpha():
|
18 |
-
current_word.append(
|
19 |
current_log_probs.append(logprob)
|
20 |
else:
|
21 |
if current_word:
|
22 |
-
words.append(Word(current_word,
|
23 |
-
current_word = [
|
24 |
current_log_probs = [logprob]
|
25 |
current_word_first_token_index = i
|
26 |
|
27 |
if current_word:
|
28 |
-
words.append(Word(current_word,
|
29 |
|
30 |
return words
|
|
|
1 |
from dataclasses import dataclass
|
2 |
+
from tokenizers import Tokenizer
|
3 |
|
4 |
@dataclass
|
5 |
class Word:
|
6 |
+
tokens: list[int]
|
7 |
text: str
|
8 |
logprob: float
|
9 |
first_token_index: int
|
10 |
|
11 |
+
def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
|
12 |
+
words: list[Word] = []
|
13 |
+
current_word: list[int] = []
|
14 |
+
current_log_probs: list[float] = []
|
15 |
+
current_word_first_token_index: int = 0
|
16 |
|
17 |
+
for i, (token_id, logprob) in enumerate(token_probs):
|
18 |
+
token: str = tokenizer.decode([token_id])
|
19 |
if not token.startswith(chr(9601)) and token.isalpha():
|
20 |
+
current_word.append(token_id)
|
21 |
current_log_probs.append(logprob)
|
22 |
else:
|
23 |
if current_word:
|
24 |
+
words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
|
25 |
+
current_word = [token_id]
|
26 |
current_log_probs = [logprob]
|
27 |
current_word_first_token_index = i
|
28 |
|
29 |
if current_word:
|
30 |
+
words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
|
31 |
|
32 |
return words
|