Working stopping on word boundary
Browse files- expand.py +20 -12
- expand_llm.py +17 -1
- expand_test.py +5 -2
- run.py +4 -1
expand.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from collections import defaultdict
|
2 |
from dataclasses import dataclass, field
|
3 |
-
from typing import Protocol, Self
|
4 |
|
5 |
@dataclass
|
6 |
class Expansion:
|
@@ -46,16 +46,19 @@ class ExpansionResult:
|
|
46 |
class ExpansionResultBatch:
|
47 |
items: list[ExpansionResult]
|
48 |
|
49 |
-
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
50 |
-
|
51 |
for expansion in result.expansions:
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
61 |
# check that ids in original_series are unique
|
@@ -74,8 +77,11 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
|
|
74 |
results.append(expansion_result)
|
75 |
return ExpansionResultBatch(items=results)
|
76 |
|
|
|
|
|
|
|
77 |
# A compound operation that we can implement generically, relying on an ExpanderOneBatch
|
78 |
-
def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
|
79 |
completed_series: list[Series] = []
|
80 |
current_batch = batch
|
81 |
while len(current_batch.items) > 0:
|
@@ -86,6 +92,8 @@ def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
|
|
86 |
if len(item.expansions) == 0:
|
87 |
completed_series.append(item.series)
|
88 |
else:
|
89 |
-
|
|
|
|
|
90 |
current_batch = Batch(items=current_batch_items)
|
91 |
return compute_expansions(batch.items, completed_series)
|
|
|
1 |
from collections import defaultdict
|
2 |
from dataclasses import dataclass, field
|
3 |
+
from typing import Callable, Protocol, Self
|
4 |
|
5 |
@dataclass
|
6 |
class Expansion:
|
|
|
46 |
class ExpansionResultBatch:
|
47 |
items: list[ExpansionResult]
|
48 |
|
49 |
+
def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
|
50 |
+
new_series_batch = []
|
51 |
for expansion in result.expansions:
|
52 |
+
if not stopping_criterion(result.series, expansion):
|
53 |
+
new_series = Series(
|
54 |
+
id=result.series.id,
|
55 |
+
tokens=result.series.tokens,
|
56 |
+
expansions=result.series.expansions + [expansion],
|
57 |
+
budget=result.series.budget
|
58 |
+
)
|
59 |
+
new_series_batch.append(new_series)
|
60 |
+
completed_series = [result.series] if len(new_series_batch) == 0 else []
|
61 |
+
return new_series_batch, completed_series
|
62 |
|
63 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
64 |
# check that ids in original_series are unique
|
|
|
77 |
results.append(expansion_result)
|
78 |
return ExpansionResultBatch(items=results)
|
79 |
|
80 |
+
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
|
81 |
+
return series.get_remaining_budget() + expansion.cost < 0
|
82 |
+
|
83 |
# A compound operation that we can implement generically, relying on an ExpanderOneBatch
|
84 |
+
def expand(batch: Batch, expander: ExpanderOneBatch, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> ExpansionResultBatch:
|
85 |
completed_series: list[Series] = []
|
86 |
current_batch = batch
|
87 |
while len(current_batch.items) > 0:
|
|
|
92 |
if len(item.expansions) == 0:
|
93 |
completed_series.append(item.series)
|
94 |
else:
|
95 |
+
new_series, completed = compute_new_series(item, completion_criterion)
|
96 |
+
completed_series.extend(completed)
|
97 |
+
current_batch_items.extend(new_series)
|
98 |
current_batch = Batch(items=current_batch_items)
|
99 |
return compute_expansions(batch.items, completed_series)
|
expand_llm.py
CHANGED
@@ -15,6 +15,22 @@ class ExpanderOneBatchLLM:
|
|
15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
results = []
|
17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
-
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens
|
19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
return ExpansionOneResultBatch(items=results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
results = []
|
17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
+
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
return ExpansionOneResultBatch(items=results)
|
21 |
+
|
22 |
+
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
23 |
+
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
24 |
+
d = default_completion_criterion(series, expansion)
|
25 |
+
if d:
|
26 |
+
return d
|
27 |
+
token_str = tokenizer.decode([expansion.token])
|
28 |
+
starts_with_space = token_str.startswith(" ")
|
29 |
+
print(f"-----{token_str}-----, {starts_with_space=}")
|
30 |
+
is_first_token = len(series.expansions) == 0
|
31 |
+
if is_first_token and not starts_with_space:
|
32 |
+
return True
|
33 |
+
if not is_first_token and starts_with_space:
|
34 |
+
return True
|
35 |
+
return False
|
36 |
+
return stopping_criterion
|
expand_test.py
CHANGED
@@ -14,7 +14,7 @@ def expand_series(series: Series) -> list[Expansion]:
|
|
14 |
l = len(all_tokens)
|
15 |
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
|
16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
17 |
-
return
|
18 |
|
19 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
20 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
@@ -30,7 +30,10 @@ def test_expander_zero_budget():
|
|
30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
31 |
expanded = expander.expand(Batch(items=[s]))
|
32 |
expected = ExpansionOneResultBatch(
|
33 |
-
items=[ExpansionOneResult(series=s, expansions=[
|
|
|
|
|
|
|
34 |
)
|
35 |
assert expected == expanded
|
36 |
|
|
|
14 |
l = len(all_tokens)
|
15 |
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
|
16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
17 |
+
return candidates
|
18 |
|
19 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
20 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
|
30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
31 |
expanded = expander.expand(Batch(items=[s]))
|
32 |
expected = ExpansionOneResultBatch(
|
33 |
+
items=[ExpansionOneResult(series=s, expansions=[
|
34 |
+
Expansion(token=21, cost=-1.0),
|
35 |
+
Expansion(token=22, cost=-1.0),
|
36 |
+
])]
|
37 |
)
|
38 |
assert expected == expanded
|
39 |
|
run.py
CHANGED
@@ -35,7 +35,10 @@ for i, x in enumerate(contexts):
|
|
35 |
batch = Batch(items=series)
|
36 |
|
37 |
#%%
|
38 |
-
|
|
|
|
|
|
|
39 |
|
40 |
# %%
|
41 |
def print_expansions(expansions: ExpansionResultBatch):
|
|
|
35 |
batch = Batch(items=series)
|
36 |
|
37 |
#%%
|
38 |
+
stopping_criterion = create_stopping_criterion_llm(tokenizer)
|
39 |
+
|
40 |
+
#%%
|
41 |
+
expanded = expand(batch, expander, stopping_criterion)
|
42 |
|
43 |
# %%
|
44 |
def print_expansions(expansions: ExpansionResultBatch):
|