Working stopping on word boundary
Browse files- expand.py +20 -12
- expand_llm.py +17 -1
- expand_test.py +5 -2
- run.py +4 -1
expand.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from collections import defaultdict
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
-
from typing import Protocol, Self
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class Expansion:
|
|
@@ -46,16 +46,19 @@ class ExpansionResult:
|
|
| 46 |
class ExpansionResultBatch:
|
| 47 |
items: list[ExpansionResult]
|
| 48 |
|
| 49 |
-
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
|
| 50 |
-
|
| 51 |
for expansion in result.expansions:
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
| 61 |
# check that ids in original_series are unique
|
|
@@ -74,8 +77,11 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
|
|
| 74 |
results.append(expansion_result)
|
| 75 |
return ExpansionResultBatch(items=results)
|
| 76 |
|
|
|
|
|
|
|
|
|
|
| 77 |
# A compound operation that we can implement generically, relying on an ExpanderOneBatch
|
| 78 |
-
def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
|
| 79 |
completed_series: list[Series] = []
|
| 80 |
current_batch = batch
|
| 81 |
while len(current_batch.items) > 0:
|
|
@@ -86,6 +92,8 @@ def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
|
|
| 86 |
if len(item.expansions) == 0:
|
| 87 |
completed_series.append(item.series)
|
| 88 |
else:
|
| 89 |
-
|
|
|
|
|
|
|
| 90 |
current_batch = Batch(items=current_batch_items)
|
| 91 |
return compute_expansions(batch.items, completed_series)
|
|
|
|
| 1 |
from collections import defaultdict
|
| 2 |
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Callable, Protocol, Self
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class Expansion:
|
|
|
|
| 46 |
class ExpansionResultBatch:
|
| 47 |
items: list[ExpansionResult]
|
| 48 |
|
| 49 |
+
def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
|
| 50 |
+
new_series_batch = []
|
| 51 |
for expansion in result.expansions:
|
| 52 |
+
if not stopping_criterion(result.series, expansion):
|
| 53 |
+
new_series = Series(
|
| 54 |
+
id=result.series.id,
|
| 55 |
+
tokens=result.series.tokens,
|
| 56 |
+
expansions=result.series.expansions + [expansion],
|
| 57 |
+
budget=result.series.budget
|
| 58 |
+
)
|
| 59 |
+
new_series_batch.append(new_series)
|
| 60 |
+
completed_series = [result.series] if len(new_series_batch) == 0 else []
|
| 61 |
+
return new_series_batch, completed_series
|
| 62 |
|
| 63 |
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
|
| 64 |
# check that ids in original_series are unique
|
|
|
|
| 77 |
results.append(expansion_result)
|
| 78 |
return ExpansionResultBatch(items=results)
|
| 79 |
|
| 80 |
+
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
|
| 81 |
+
return series.get_remaining_budget() + expansion.cost < 0
|
| 82 |
+
|
| 83 |
# A compound operation that we can implement generically, relying on an ExpanderOneBatch
|
| 84 |
+
def expand(batch: Batch, expander: ExpanderOneBatch, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> ExpansionResultBatch:
|
| 85 |
completed_series: list[Series] = []
|
| 86 |
current_batch = batch
|
| 87 |
while len(current_batch.items) > 0:
|
|
|
|
| 92 |
if len(item.expansions) == 0:
|
| 93 |
completed_series.append(item.series)
|
| 94 |
else:
|
| 95 |
+
new_series, completed = compute_new_series(item, completion_criterion)
|
| 96 |
+
completed_series.extend(completed)
|
| 97 |
+
current_batch_items.extend(new_series)
|
| 98 |
current_batch = Batch(items=current_batch_items)
|
| 99 |
return compute_expansions(batch.items, completed_series)
|
expand_llm.py
CHANGED
|
@@ -15,6 +15,22 @@ class ExpanderOneBatchLLM:
|
|
| 15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
| 16 |
results = []
|
| 17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
| 18 |
-
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens
|
| 19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
| 20 |
return ExpansionOneResultBatch(items=results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
| 16 |
results = []
|
| 17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
| 18 |
+
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
|
| 19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
| 20 |
return ExpansionOneResultBatch(items=results)
|
| 21 |
+
|
| 22 |
+
def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
|
| 23 |
+
def stopping_criterion(series: Series, expansion: Expansion) -> bool:
|
| 24 |
+
d = default_completion_criterion(series, expansion)
|
| 25 |
+
if d:
|
| 26 |
+
return d
|
| 27 |
+
token_str = tokenizer.decode([expansion.token])
|
| 28 |
+
starts_with_space = token_str.startswith(" ")
|
| 29 |
+
print(f"-----{token_str}-----, {starts_with_space=}")
|
| 30 |
+
is_first_token = len(series.expansions) == 0
|
| 31 |
+
if is_first_token and not starts_with_space:
|
| 32 |
+
return True
|
| 33 |
+
if not is_first_token and starts_with_space:
|
| 34 |
+
return True
|
| 35 |
+
return False
|
| 36 |
+
return stopping_criterion
|
expand_test.py
CHANGED
|
@@ -14,7 +14,7 @@ def expand_series(series: Series) -> list[Expansion]:
|
|
| 14 |
l = len(all_tokens)
|
| 15 |
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
|
| 16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
| 17 |
-
return
|
| 18 |
|
| 19 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
| 20 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
@@ -30,7 +30,10 @@ def test_expander_zero_budget():
|
|
| 30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
| 31 |
expanded = expander.expand(Batch(items=[s]))
|
| 32 |
expected = ExpansionOneResultBatch(
|
| 33 |
-
items=[ExpansionOneResult(series=s, expansions=[
|
|
|
|
|
|
|
|
|
|
| 34 |
)
|
| 35 |
assert expected == expanded
|
| 36 |
|
|
|
|
| 14 |
l = len(all_tokens)
|
| 15 |
items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
|
| 16 |
candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
|
| 17 |
+
return candidates
|
| 18 |
|
| 19 |
class HardcodedExpanderOneBatch(ExpanderOneBatch):
|
| 20 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
|
|
|
| 30 |
s = Series(id=0, tokens=[1], budget=0.0)
|
| 31 |
expanded = expander.expand(Batch(items=[s]))
|
| 32 |
expected = ExpansionOneResultBatch(
|
| 33 |
+
items=[ExpansionOneResult(series=s, expansions=[
|
| 34 |
+
Expansion(token=21, cost=-1.0),
|
| 35 |
+
Expansion(token=22, cost=-1.0),
|
| 36 |
+
])]
|
| 37 |
)
|
| 38 |
assert expected == expanded
|
| 39 |
|
run.py
CHANGED
|
@@ -35,7 +35,10 @@ for i, x in enumerate(contexts):
|
|
| 35 |
batch = Batch(items=series)
|
| 36 |
|
| 37 |
#%%
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# %%
|
| 41 |
def print_expansions(expansions: ExpansionResultBatch):
|
|
|
|
| 35 |
batch = Batch(items=series)
|
| 36 |
|
| 37 |
#%%
|
| 38 |
+
stopping_criterion = create_stopping_criterion_llm(tokenizer)
|
| 39 |
+
|
| 40 |
+
#%%
|
| 41 |
+
expanded = expand(batch, expander, stopping_criterion)
|
| 42 |
|
| 43 |
# %%
|
| 44 |
def print_expansions(expansions: ExpansionResultBatch):
|