Fix
Browse files- expand_llm.py +2 -2
- run.py +1 -1
expand_llm.py
CHANGED
@@ -11,10 +11,10 @@ class ExpanderOneBatchLLM:
|
|
11 |
tokenizer: Tokenizer
|
12 |
|
13 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
14 |
-
inputs = prepare_inputs([s.
|
15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
results = []
|
17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
-
expansions = [Expansion(token=token, cost=
|
19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
return ExpansionOneResultBatch(items=results)
|
|
|
11 |
tokenizer: Tokenizer
|
12 |
|
13 |
def expand(self, batch: Batch) -> ExpansionOneResultBatch:
|
14 |
+
inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
|
15 |
next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
|
16 |
results = []
|
17 |
for s, next_tokens in zip(batch.items, next_tokens):
|
18 |
+
expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens if cost + s.get_remaining_budget() >= 0]
|
19 |
results.append(ExpansionOneResult(series=s, expansions=expansions))
|
20 |
return ExpansionOneResultBatch(items=results)
|
run.py
CHANGED
@@ -29,7 +29,7 @@ expander = ExpanderOneBatchLLM(model, tokenizer)
|
|
29 |
#%%
|
30 |
series = []
|
31 |
for i, x in enumerate(contexts):
|
32 |
-
series.append(Series(id=i, tokens=x, budget=
|
33 |
|
34 |
#%%
|
35 |
batch = Batch(items=series)
|
|
|
29 |
#%%
|
30 |
series = []
|
31 |
for i, x in enumerate(contexts):
|
32 |
+
series.append(Series(id=i, tokens=x, budget=5.0))
|
33 |
|
34 |
#%%
|
35 |
batch = Batch(items=series)
|