mebubo commited on
Commit
2fb63bf
·
1 Parent(s): 8021de7

Working generic expand

Browse files
Files changed (3) hide show
  1. completions.py +44 -7
  2. expand.py +84 -0
  3. expand_test.py +161 -0
completions.py CHANGED
@@ -1,5 +1,6 @@
1
  #%%
2
  from dataclasses import dataclass
 
3
  import time
4
  import torch
5
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
@@ -34,7 +35,7 @@ def split_into_words(token_probs: list[tuple[int, float]], tokenizer: Tokenizer)
34
 
35
  def merge_tokens(a: Tok, b: Tok) -> Tok | None:
36
  if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
37
- return Tok(b.index, a.ids + b.ids, a.str + b.str, a.logprob * b.logprob)
38
  return None
39
 
40
  converted = [Tok(i, [token_id], tokenizer.decode([token_id]), logprob)
@@ -94,6 +95,20 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
94
  )
95
  return outputs
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
98
  all_new_words = []
99
  for i in range(num_inputs):
@@ -117,30 +132,24 @@ def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
117
  return model, tokenizer, device
118
 
119
  def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, device: torch.device) -> list[ApiWord]:
120
- #%%
121
  inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
122
 
123
- #%%
124
  token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
125
 
126
- #%%
127
  words = split_into_words(token_probs, tokenizer)
128
  log_prob_threshold = -5.0
129
  low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
130
 
131
- #%%
132
  contexts = [word.context for _, word in low_prob_words]
133
  inputs = prepare_inputs(contexts, tokenizer, device)
134
  input_ids = inputs["input_ids"]
135
 
136
- #%%
137
  num_samples = 10
138
  start_time = time.time()
139
  outputs = generate_outputs(model, inputs, num_samples)
140
  end_time = time.time()
141
  print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
142
 
143
- #%%
144
  replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
145
 
146
  low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
@@ -152,3 +161,31 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
152
  else:
153
  result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
154
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #%%
2
  from dataclasses import dataclass
3
+ import math
4
  import time
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, BatchEncoding
 
35
 
36
  def merge_tokens(a: Tok, b: Tok) -> Tok | None:
37
  if is_beginning_of_word(a.str) and is_continuation_of_word(b.str):
38
+ return Tok(a.index, a.ids + b.ids, a.str + b.str, a.logprob + b.logprob)
39
  return None
40
 
41
  converted = [Tok(i, [token_id], tokenizer.decode([token_id]), logprob)
 
95
  )
96
  return outputs
97
 
98
+ def find_next_tokens(model: PreTrainedModel, inputs: BatchEncoding, tokenizer: Tokenizer, min_p: float) -> list[list[tuple[int, str, float]]]:
99
+ input_ids = inputs["input_ids"]
100
+ attention_mask = inputs["attention_mask"]
101
+ with torch.no_grad():
102
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
103
+ logits: torch.Tensor = outputs.logits[:, -1, :]
104
+ log_probs: torch.Tensor = torch.log_softmax(logits, dim=-1)
105
+ # for every batch item, find all tokens with log prob greater than min_p, and return their ids and log probs
106
+ result = []
107
+ print(f"{log_probs.shape=}")
108
+ for probs in log_probs:
109
+ result.append([(i, tokenizer.convert_ids_to_tokens([i])[0], p) for i, p in enumerate(probs) if p > min_p])
110
+ return result
111
+
112
  def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
113
  all_new_words = []
114
  for i in range(num_inputs):
 
132
  return model, tokenizer, device
133
 
134
  def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, device: torch.device) -> list[ApiWord]:
 
135
  inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
136
 
 
137
  token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
138
 
 
139
  words = split_into_words(token_probs, tokenizer)
140
  log_prob_threshold = -5.0
141
  low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
142
 
 
143
  contexts = [word.context for _, word in low_prob_words]
144
  inputs = prepare_inputs(contexts, tokenizer, device)
145
  input_ids = inputs["input_ids"]
146
 
 
147
  num_samples = 10
148
  start_time = time.time()
149
  outputs = generate_outputs(model, inputs, num_samples)
150
  end_time = time.time()
151
  print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
152
 
 
153
  replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
154
 
155
  low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
 
161
  else:
162
  result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
163
  return result
164
+
165
+ # %%
166
+ model, tokenizer, device = load_model()
167
+
168
+ #%%
169
+ input_text = "The quick brown fox jumpz over"
170
+ inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
171
+
172
+ #%%
173
+ token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
174
+
175
+ #%%
176
+ words = split_into_words(token_probs, tokenizer)
177
+ log_prob_threshold = -5.0
178
+ low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
179
+
180
+ #%%
181
+ contexts = [word.context for _, word in low_prob_words]
182
+ inputs = prepare_inputs(contexts, tokenizer, device)
183
+ input_ids = inputs["input_ids"]
184
+
185
+ #%%
186
+ next_tokens = find_next_tokens(model, inputs, tokenizer, min_p=-5)
187
+
188
+ #%%
189
+ next_tokens
190
+
191
+ # %%
expand.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Protocol
4
+
5
+ # import torch
6
+ # from transformers import PreTrainedModel
7
+ # from completions import find_next_tokens, Tokenizer
8
+
9
+ @dataclass
10
+ class Series:
11
+ id: int
12
+ tokens: list[int]
13
+ budget: float
14
+
15
+ @dataclass
16
+ class Batch:
17
+ items: list[Series]
18
+
19
+ @dataclass
20
+ class ExpansionOne:
21
+ token: int
22
+ cost: float
23
+
24
+ @dataclass
25
+ class ExpansionOneResult:
26
+ series: Series
27
+ expansions: list[ExpansionOne]
28
+
29
+ @dataclass
30
+ class ExpansionOneResultBatch:
31
+ items: list[ExpansionOneResult]
32
+
33
+ # A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
34
+ class ExpanderOneBatch(Protocol):
35
+ def expand(self, batch: Batch) -> ExpansionOneResultBatch: ...
36
+
37
+ @dataclass
38
+ class ExpansionResult:
39
+ series: Series
40
+ expansions: list[list[int]]
41
+
42
+ @dataclass
43
+ class ExpansionResultBatch:
44
+ items: list[ExpansionResult]
45
+
46
+ def compute_new_series(result: ExpansionOneResult) -> list[Series]:
47
+ results = []
48
+ for expansion in result.expansions:
49
+ results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget - expansion.cost))
50
+ return results
51
+
52
+ def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
53
+ # check that ids in original_series are unique
54
+ assert len(original_series) == len({s.id for s in original_series})
55
+ # group original series by id
56
+ original_series_by_id = {s.id: s for s in original_series}
57
+ # group expanded series by id
58
+ expanded_series_by_id: dict[int, list[list[int]]] = defaultdict(list)
59
+ for s in expanded_series:
60
+ expanded_series_by_id[s.id].append(s.tokens)
61
+ results = []
62
+ for id, s in original_series_by_id.items():
63
+ expansions = expanded_series_by_id[id]
64
+ # subtract the original series from each expansion
65
+ l = len(s.tokens)
66
+ trimmed_expansions = [e[l:] for e in expansions if len(e) > l]
67
+ expansion_result = ExpansionResult(series=s, expansions=trimmed_expansions)
68
+ results.append(expansion_result)
69
+ return ExpansionResultBatch(items=results)
70
+
71
+ # A compound operation that we can implement generically, relying on an ExpanderOneBatch
72
+ def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
73
+ completed_series: list[Series] = []
74
+ current_batch = batch
75
+ while len(current_batch.items) > 0:
76
+ current_batch_items = []
77
+ expanded = expander.expand(current_batch)
78
+ for item in expanded.items:
79
+ if len(item.expansions) == 0:
80
+ completed_series.append(item.series)
81
+ else:
82
+ current_batch_items.extend(compute_new_series(item))
83
+ current_batch = Batch(items=current_batch_items)
84
+ return compute_expansions(batch.items, completed_series)
expand_test.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from expand import Series, ExpanderOneBatch, ExpansionOne, Batch, ExpansionOneResult, ExpansionOneResultBatch, ExpansionResult, ExpansionResultBatch, expand
3
+
4
+ possible_sequences = [
5
+ [1, 21, 31, 41],
6
+ [1, 21, 31, 42],
7
+ [1, 21, 32, 41, 51],
8
+ [1, 22, 33, 41],
9
+ [1, 22, 34, 41],
10
+ ]
11
+
12
+ def expand_series(series: Series) -> list[ExpansionOne]:
13
+ l = len(series.tokens)
14
+ items = [s[l] for s in possible_sequences if s[:l] == series.tokens and len(s) > l]
15
+ candidates = [ExpansionOne(token=l, cost=1.0) for l in dict.fromkeys(items)]
16
+ return [c for c in candidates if c.cost <= series.budget]
17
+
18
+ class HardcodedExpanderOneBatch(ExpanderOneBatch):
19
+ def expand(self, batch: Batch) -> ExpansionOneResultBatch:
20
+ result = []
21
+ for s in batch.items:
22
+ expansions = expand_series(s)
23
+ result.append(ExpansionOneResult(series=s, expansions=expansions))
24
+ return ExpansionOneResultBatch(items=result)
25
+
26
+ expander = HardcodedExpanderOneBatch()
27
+
28
+ def test_expander_zero_budget():
29
+ s = Series(id=0, tokens=[1], budget=0.0)
30
+ expanded = expander.expand(Batch(items=[s]))
31
+ expected = ExpansionOneResultBatch(
32
+ items=[ExpansionOneResult(series=s, expansions=[])]
33
+ )
34
+ assert expected == expanded
35
+
36
+ def test_expander_budget_one():
37
+ s = Series(id=0, tokens=[1], budget=1.0)
38
+ expanded = expander.expand(Batch(items=[s]))
39
+ expected = ExpansionOneResultBatch(
40
+ items=[ExpansionOneResult(series=s, expansions=[
41
+ ExpansionOne(token=21, cost=1.0),
42
+ ExpansionOne(token=22, cost=1.0),
43
+ ])]
44
+ )
45
+ assert expected == expanded
46
+
47
+ def test_expander_budget_two():
48
+ s = Series(id=0, tokens=[1], budget=2.0)
49
+ expanded = expander.expand(Batch(items=[s]))
50
+ expected = ExpansionOneResultBatch(
51
+ items=[ExpansionOneResult(series=s, expansions=[
52
+ ExpansionOne(token=21, cost=1.0),
53
+ ExpansionOne(token=22, cost=1.0),
54
+ ])]
55
+ )
56
+ assert expected == expanded
57
+
58
+ def test_expander_budget_one_no_expansion():
59
+ s = Series(id=0, tokens=[1, 20], budget=1.0)
60
+ expanded = expander.expand(Batch(items=[s]))
61
+ expected = ExpansionOneResultBatch(
62
+ items=[ExpansionOneResult(series=s, expansions=[])]
63
+ )
64
+ assert expected == expanded
65
+
66
+ def test_expander_budget_one_two_tokens():
67
+ s = Series(id=0, tokens=[1, 22], budget=1.0)
68
+ expanded = expander.expand(Batch(items=[s]))
69
+ expected = ExpansionOneResultBatch(
70
+ items=[ExpansionOneResult(series=s, expansions=[
71
+ ExpansionOne(token=33, cost=1.0),
72
+ ExpansionOne(token=34, cost=1.0),
73
+ ])]
74
+ )
75
+ assert expected == expanded
76
+
77
+ def test_expander_budget_one_two_tokens_two_series():
78
+ s1 = Series(id=0, tokens=[1, 21, 31], budget=1.0)
79
+ s2 = Series(id=1, tokens=[1, 22], budget=1.0)
80
+ expanded = expander.expand(Batch(items=[s1, s2]))
81
+ expected = ExpansionOneResultBatch(
82
+ items=[
83
+ ExpansionOneResult(series=s1, expansions=[
84
+ ExpansionOne(token=41, cost=1.0),
85
+ ExpansionOne(token=42, cost=1.0),
86
+ ]),
87
+ ExpansionOneResult(series=s2, expansions=[
88
+ ExpansionOne(token=33, cost=1.0),
89
+ ExpansionOne(token=34, cost=1.0),
90
+ ])
91
+ ]
92
+ )
93
+ assert expected == expanded
94
+
95
+ def test_expand_01():
96
+ batch = Batch(items=[
97
+ Series(id=0, tokens=[1, 21], budget=1.0),
98
+ Series(id=1, tokens=[1, 22], budget=1.0),
99
+ ])
100
+ expanded = expand(batch, expander)
101
+ assert expanded == ExpansionResultBatch(items=[
102
+ ExpansionResult(
103
+ series=Series(id=0, tokens=[1, 21], budget=1.0),
104
+ expansions=[
105
+ [31],
106
+ [32],
107
+ ]
108
+ ),
109
+ ExpansionResult(
110
+ series=Series(id=1, tokens=[1, 22], budget=1.0),
111
+ expansions=[
112
+ [33],
113
+ [34],
114
+ ]
115
+ ),
116
+ ])
117
+
118
+ def test_expand_02():
119
+ batch = Batch(items=[
120
+ Series(id=0, tokens=[1, 21], budget=2.0),
121
+ Series(id=1, tokens=[1, 22], budget=1.0),
122
+ ])
123
+ expanded = expand(batch, expander)
124
+ assert expanded == ExpansionResultBatch(items=[
125
+ ExpansionResult(
126
+ series=Series(id=0, tokens=[1, 21], budget=2.0),
127
+ expansions=[
128
+ [31, 41],
129
+ [31, 42],
130
+ [32, 41],
131
+ ]
132
+ ),
133
+ ExpansionResult(
134
+ series=Series(id=1, tokens=[1, 22], budget=1.0),
135
+ expansions=[
136
+ [33],
137
+ [34],
138
+ ]
139
+ ),
140
+ ])
141
+
142
+ def test_expand_03():
143
+ batch = Batch(items=[
144
+ Series(id=0, tokens=[1, 21], budget=3.0),
145
+ Series(id=1, tokens=[1, 22], budget=0.0),
146
+ ])
147
+ expanded = expand(batch, expander)
148
+ assert expanded == ExpansionResultBatch(items=[
149
+ ExpansionResult(
150
+ series=Series(id=0, tokens=[1, 21], budget=3.0),
151
+ expansions=[
152
+ [31, 41],
153
+ [31, 42],
154
+ [32, 41, 51],
155
+ ]
156
+ ),
157
+ ExpansionResult(
158
+ series=Series(id=1, tokens=[1, 22], budget=0.0),
159
+ expansions=[],
160
+ ),
161
+ ])