mebubo commited on
Commit
0235f77
·
1 Parent(s): 8f6137e

Use the flagging threshold to filter out uninteresting tokens

Browse files
Files changed (2) hide show
  1. completions.py +1 -1
  2. expand_llm.py +3 -3
completions.py CHANGED
@@ -92,7 +92,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
92
 
93
  contexts = [word.context for _, word in low_prob_words]
94
 
95
- expander = LLMBatchExpander(model, tokenizer)
96
 
97
  #%%
98
  series = []
 
92
 
93
  contexts = [word.context for _, word in low_prob_words]
94
 
95
+ expander = LLMBatchExpander(model, tokenizer, threshold=log_prob_threshold)
96
 
97
  #%%
98
  series = []
expand_llm.py CHANGED
@@ -6,7 +6,7 @@ import time
6
 
7
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
8
 
9
- def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer) -> list[list[tuple[int, float]]]:
10
  input_ids = inputs["input_ids"]
11
  attention_mask = inputs["attention_mask"]
12
  print("Running inference")
@@ -21,7 +21,6 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
21
  start_time = time.time()
22
  result = []
23
  print(f"Resulting tensor: {log_probs.shape}")
24
- threshold = -10.0
25
  for probs in log_probs:
26
  # Filter out low probability tokens for efficiency
27
  above_threshold = torch.where(probs > threshold)
@@ -39,10 +38,11 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
39
  class LLMBatchExpander(BatchExpander):
40
  model: PreTrainedModel
41
  tokenizer: Tokenizer
 
42
 
43
  def expand(self, batch: Batch) -> BatchCandidates:
44
  inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
45
- next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
46
  start_time = time.time()
47
  results = []
48
  print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
 
6
 
7
  type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
8
 
9
+ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, threshold: float) -> list[list[tuple[int, float]]]:
10
  input_ids = inputs["input_ids"]
11
  attention_mask = inputs["attention_mask"]
12
  print("Running inference")
 
21
  start_time = time.time()
22
  result = []
23
  print(f"Resulting tensor: {log_probs.shape}")
 
24
  for probs in log_probs:
25
  # Filter out low probability tokens for efficiency
26
  above_threshold = torch.where(probs > threshold)
 
38
  class LLMBatchExpander(BatchExpander):
39
  model: PreTrainedModel
40
  tokenizer: Tokenizer
41
+ threshold: float
42
 
43
  def expand(self, batch: Batch) -> BatchCandidates:
44
  inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
45
+ next_tokens = find_next_tokens(self.model, inputs, self.threshold)
46
  start_time = time.time()
47
  results = []
48
  print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")