mebubo commited on
Commit
3d1d657
·
1 Parent(s): c12f2c8

Working stopping on word boundary

Browse files
Files changed (4) hide show
  1. expand.py +20 -12
  2. expand_llm.py +17 -1
  3. expand_test.py +5 -2
  4. run.py +4 -1
expand.py CHANGED
@@ -1,6 +1,6 @@
1
  from collections import defaultdict
2
  from dataclasses import dataclass, field
3
- from typing import Protocol, Self
4
 
5
  @dataclass
6
  class Expansion:
@@ -46,16 +46,19 @@ class ExpansionResult:
46
  class ExpansionResultBatch:
47
  items: list[ExpansionResult]
48
 
49
- def compute_new_series(result: ExpansionOneResult) -> list[Series]:
50
- results = []
51
  for expansion in result.expansions:
52
- results.append(Series(
53
- id=result.series.id,
54
- tokens=result.series.tokens,
55
- expansions=result.series.expansions + [expansion],
56
- budget=result.series.budget
57
- ))
58
- return results
 
 
 
59
 
60
  def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
61
  # check that ids in original_series are unique
@@ -74,8 +77,11 @@ def compute_expansions(original_series: list[Series], expanded_series: list[Seri
74
  results.append(expansion_result)
75
  return ExpansionResultBatch(items=results)
76
 
 
 
 
77
  # A compound operation that we can implement generically, relying on an ExpanderOneBatch
78
- def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
79
  completed_series: list[Series] = []
80
  current_batch = batch
81
  while len(current_batch.items) > 0:
@@ -86,6 +92,8 @@ def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
86
  if len(item.expansions) == 0:
87
  completed_series.append(item.series)
88
  else:
89
- current_batch_items.extend(compute_new_series(item))
 
 
90
  current_batch = Batch(items=current_batch_items)
91
  return compute_expansions(batch.items, completed_series)
 
1
  from collections import defaultdict
2
  from dataclasses import dataclass, field
3
+ from typing import Callable, Protocol, Self
4
 
5
  @dataclass
6
  class Expansion:
 
46
  class ExpansionResultBatch:
47
  items: list[ExpansionResult]
48
 
49
+ def compute_new_series(result: ExpansionOneResult, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
50
+ new_series_batch = []
51
  for expansion in result.expansions:
52
+ if not stopping_criterion(result.series, expansion):
53
+ new_series = Series(
54
+ id=result.series.id,
55
+ tokens=result.series.tokens,
56
+ expansions=result.series.expansions + [expansion],
57
+ budget=result.series.budget
58
+ )
59
+ new_series_batch.append(new_series)
60
+ completed_series = [result.series] if len(new_series_batch) == 0 else []
61
+ return new_series_batch, completed_series
62
 
63
  def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
64
  # check that ids in original_series are unique
 
77
  results.append(expansion_result)
78
  return ExpansionResultBatch(items=results)
79
 
80
+ def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
81
+ return series.get_remaining_budget() + expansion.cost < 0
82
+
83
  # A compound operation that we can implement generically, relying on an ExpanderOneBatch
84
+ def expand(batch: Batch, expander: ExpanderOneBatch, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> ExpansionResultBatch:
85
  completed_series: list[Series] = []
86
  current_batch = batch
87
  while len(current_batch.items) > 0:
 
92
  if len(item.expansions) == 0:
93
  completed_series.append(item.series)
94
  else:
95
+ new_series, completed = compute_new_series(item, completion_criterion)
96
+ completed_series.extend(completed)
97
+ current_batch_items.extend(new_series)
98
  current_batch = Batch(items=current_batch_items)
99
  return compute_expansions(batch.items, completed_series)
expand_llm.py CHANGED
@@ -15,6 +15,22 @@ class ExpanderOneBatchLLM:
15
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
  results = []
17
  for s, next_tokens in zip(batch.items, next_tokens):
18
- expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens if cost + s.get_remaining_budget() >= 0]
19
  results.append(ExpansionOneResult(series=s, expansions=expansions))
20
  return ExpansionOneResultBatch(items=results)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
  results = []
17
  for s, next_tokens in zip(batch.items, next_tokens):
18
+ expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens]
19
  results.append(ExpansionOneResult(series=s, expansions=expansions))
20
  return ExpansionOneResultBatch(items=results)
21
+
22
+ def create_stopping_criterion_llm(tokenizer: Tokenizer) -> Callable[[Series, Expansion], bool]:
23
+ def stopping_criterion(series: Series, expansion: Expansion) -> bool:
24
+ d = default_completion_criterion(series, expansion)
25
+ if d:
26
+ return d
27
+ token_str = tokenizer.decode([expansion.token])
28
+ starts_with_space = token_str.startswith(" ")
29
+ print(f"-----{token_str}-----, {starts_with_space=}")
30
+ is_first_token = len(series.expansions) == 0
31
+ if is_first_token and not starts_with_space:
32
+ return True
33
+ if not is_first_token and starts_with_space:
34
+ return True
35
+ return False
36
+ return stopping_criterion
expand_test.py CHANGED
@@ -14,7 +14,7 @@ def expand_series(series: Series) -> list[Expansion]:
14
  l = len(all_tokens)
15
  items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
16
  candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
17
- return [c for c in candidates if c.cost + series.get_remaining_budget() >= 0]
18
 
19
  class HardcodedExpanderOneBatch(ExpanderOneBatch):
20
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
@@ -30,7 +30,10 @@ def test_expander_zero_budget():
30
  s = Series(id=0, tokens=[1], budget=0.0)
31
  expanded = expander.expand(Batch(items=[s]))
32
  expected = ExpansionOneResultBatch(
33
- items=[ExpansionOneResult(series=s, expansions=[])]
 
 
 
34
  )
35
  assert expected == expanded
36
 
 
14
  l = len(all_tokens)
15
  items = [s[l] for s in possible_sequences if s[:l] == all_tokens and len(s) > l]
16
  candidates = [Expansion(token=l, cost=-1.0) for l in dict.fromkeys(items)]
17
+ return candidates
18
 
19
  class HardcodedExpanderOneBatch(ExpanderOneBatch):
20
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
 
30
  s = Series(id=0, tokens=[1], budget=0.0)
31
  expanded = expander.expand(Batch(items=[s]))
32
  expected = ExpansionOneResultBatch(
33
+ items=[ExpansionOneResult(series=s, expansions=[
34
+ Expansion(token=21, cost=-1.0),
35
+ Expansion(token=22, cost=-1.0),
36
+ ])]
37
  )
38
  assert expected == expanded
39
 
run.py CHANGED
@@ -35,7 +35,10 @@ for i, x in enumerate(contexts):
35
  batch = Batch(items=series)
36
 
37
  #%%
38
- expanded = expand(batch, expander)
 
 
 
39
 
40
  # %%
41
  def print_expansions(expansions: ExpansionResultBatch):
 
35
  batch = Batch(items=series)
36
 
37
  #%%
38
+ stopping_criterion = create_stopping_criterion_llm(tokenizer)
39
+
40
+ #%%
41
+ expanded = expand(batch, expander, stopping_criterion)
42
 
43
  # %%
44
  def print_expansions(expansions: ExpansionResultBatch):