mebubo commited on
Commit
230a441
·
1 Parent(s): 426b33e
Files changed (2) hide show
  1. app.py +3 -3
  2. text_processing.py +13 -11
app.py CHANGED
@@ -21,7 +21,7 @@ def tokenize(input_text: str, tokenizer: Tokenizer, device: torch.device) -> tup
21
  attention_mask = cast(torch.Tensor, inputs["attention_mask"])
22
  return input_ids, attention_mask
23
 
24
- def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[str, float]]:
25
  with torch.no_grad():
26
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
27
  # B x T x V
@@ -31,8 +31,8 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
31
  # T - 1
32
  token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
33
  # T - 1
34
- tokens: list[str] = tokenizer.convert_ids_to_tokens(input_ids[0])[1:]
35
- return list(zip(tokens, token_log_probs.tolist()))
36
 
37
 
38
  def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
 
21
  attention_mask = cast(torch.Tensor, inputs["attention_mask"])
22
  return input_ids, attention_mask
23
 
24
+ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> list[tuple[int, float]]:
25
  with torch.no_grad():
26
  outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
27
  # B x T x V
 
31
  # T - 1
32
  token_log_probs: torch.Tensor = log_probs[0, range(log_probs.shape[1]), input_ids[0][1:]]
33
  # T - 1
34
+ tokens: torch.Tensor = input_ids[0][1:]
35
+ return list(zip(tokens.tolist(), token_log_probs.tolist()))
36
 
37
 
38
  def generate_replacements(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prefix_tokens: list[int], device: torch.device, num_samples: int = 5) -> list[str]:
text_processing.py CHANGED
@@ -1,30 +1,32 @@
1
  from dataclasses import dataclass
 
2
 
3
  @dataclass
4
  class Word:
5
- tokens: list[str]
6
  text: str
7
  logprob: float
8
  first_token_index: int
9
 
10
- def split_into_words(token_probs: list[tuple[str, float]]) -> list[Word]:
11
- words = []
12
- current_word = []
13
- current_log_probs = []
14
- current_word_first_token_index = 0
15
 
16
- for i, (token, logprob) in enumerate(token_probs):
 
17
  if not token.startswith(chr(9601)) and token.isalpha():
18
- current_word.append(token)
19
  current_log_probs.append(logprob)
20
  else:
21
  if current_word:
22
- words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
23
- current_word = [token]
24
  current_log_probs = [logprob]
25
  current_word_first_token_index = i
26
 
27
  if current_word:
28
- words.append(Word(current_word, "".join(current_word), sum(current_log_probs), current_word_first_token_index))
29
 
30
  return words
 
1
  from dataclasses import dataclass
2
+ from tokenizers import Tokenizer
3
 
4
  @dataclass
5
  class Word:
6
+ tokens: list[int]
7
  text: str
8
  logprob: float
9
  first_token_index: int
10
 
11
+ def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer) -> list[Word]:
12
+ words: list[Word] = []
13
+ current_word: list[int] = []
14
+ current_log_probs: list[float] = []
15
+ current_word_first_token_index: int = 0
16
 
17
+ for i, (token_id, logprob) in enumerate(token_probs):
18
+ token: str = tokenizer.decode([token_id])
19
  if not token.startswith(chr(9601)) and token.isalpha():
20
+ current_word.append(token_id)
21
  current_log_probs.append(logprob)
22
  else:
23
  if current_word:
24
+ words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
25
+ current_word = [token_id]
26
  current_log_probs = [logprob]
27
  current_word_first_token_index = i
28
 
29
  if current_word:
30
+ words.append(Word(current_word, tokenizer.decode(current_word), sum(current_log_probs), current_word_first_token_index))
31
 
32
  return words