Use the flagging threshold to filter out uninteresting tokens
Browse files- completions.py +1 -1
- expand_llm.py +3 -3
completions.py
CHANGED
@@ -92,7 +92,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
92 |
|
93 |
contexts = [word.context for _, word in low_prob_words]
|
94 |
|
95 |
-
expander = LLMBatchExpander(model, tokenizer)
|
96 |
|
97 |
#%%
|
98 |
series = []
|
|
|
92 |
|
93 |
contexts = [word.context for _, word in low_prob_words]
|
94 |
|
95 |
+
expander = LLMBatchExpander(model, tokenizer, threshold=log_prob_threshold)
|
96 |
|
97 |
#%%
|
98 |
series = []
|
expand_llm.py
CHANGED
@@ -6,7 +6,7 @@ import time
|
|
6 |
|
7 |
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
8 |
|
9 |
-
def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding,
|
10 |
input_ids = inputs["input_ids"]
|
11 |
attention_mask = inputs["attention_mask"]
|
12 |
print("Running inference")
|
@@ -21,7 +21,6 @@ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: T
|
|
21 |
start_time = time.time()
|
22 |
result = []
|
23 |
print(f"Resulting tensor: {log_probs.shape}")
|
24 |
-
threshold = -10.0
|
25 |
for probs in log_probs:
|
26 |
# Filter out low probability tokens for efficiency
|
27 |
above_threshold = torch.where(probs > threshold)
|
@@ -39,10 +38,11 @@ def prepare_inputs(contexts: list[list[int]], tokenizer: Tokenizer, device: torc
|
|
39 |
class LLMBatchExpander(BatchExpander):
|
40 |
model: PreTrainedModel
|
41 |
tokenizer: Tokenizer
|
|
|
42 |
|
43 |
def expand(self, batch: Batch) -> BatchCandidates:
|
44 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
45 |
-
next_tokens = find_next_tokens(self.model, inputs, self.
|
46 |
start_time = time.time()
|
47 |
results = []
|
48 |
print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
|
|
|
6 |
|
7 |
type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
|
8 |
|
9 |
+
def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, threshold: float) -> list[list[tuple[int, float]]]:
|
10 |
input_ids = inputs["input_ids"]
|
11 |
attention_mask = inputs["attention_mask"]
|
12 |
print("Running inference")
|
|
|
21 |
start_time = time.time()
|
22 |
result = []
|
23 |
print(f"Resulting tensor: {log_probs.shape}")
|
|
|
24 |
for probs in log_probs:
|
25 |
# Filter out low probability tokens for efficiency
|
26 |
above_threshold = torch.where(probs > threshold)
|
|
|
38 |
class LLMBatchExpander(BatchExpander):
|
39 |
model: PreTrainedModel
|
40 |
tokenizer: Tokenizer
|
41 |
+
threshold: float
|
42 |
|
43 |
def expand(self, batch: Batch) -> BatchCandidates:
|
44 |
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
45 |
+
next_tokens = find_next_tokens(self.model, inputs, self.threshold)
|
46 |
start_time = time.time()
|
47 |
results = []
|
48 |
print(f"Batch size: {len(batch.items)}, next tokens size: {len(next_tokens)}")
|