mebubo commited on
Commit
c12f2c8
·
1 Parent(s): 6f4449d
Files changed (2) hide show
  1. expand_llm.py +2 -2
  2. run.py +1 -1
expand_llm.py CHANGED
@@ -11,10 +11,10 @@ class ExpanderOneBatchLLM:
11
  tokenizer: Tokenizer
12
 
13
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
14
- inputs = prepare_inputs([s.tokens for s in batch.items], self.tokenizer, self.model.device)
15
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
  results = []
17
  for s, next_tokens in zip(batch.items, next_tokens):
18
- expansions = [Expansion(token=token, cost=logprob) for token, logprob in next_tokens if logprob + s.get_remaining_budget() >= 0]
19
  results.append(ExpansionOneResult(series=s, expansions=expansions))
20
  return ExpansionOneResultBatch(items=results)
 
11
  tokenizer: Tokenizer
12
 
13
  def expand(self, batch: Batch) -> ExpansionOneResultBatch:
14
+ inputs = prepare_inputs([s.get_all_tokens() for s in batch.items], self.tokenizer, self.model.device)
15
  next_tokens = find_next_tokens(self.model, inputs, self.tokenizer)
16
  results = []
17
  for s, next_tokens in zip(batch.items, next_tokens):
18
+ expansions = [Expansion(token=token, cost=cost) for token, cost in next_tokens if cost + s.get_remaining_budget() >= 0]
19
  results.append(ExpansionOneResult(series=s, expansions=expansions))
20
  return ExpansionOneResultBatch(items=results)
run.py CHANGED
@@ -29,7 +29,7 @@ expander = ExpanderOneBatchLLM(model, tokenizer)
29
  #%%
30
  series = []
31
  for i, x in enumerate(contexts):
32
- series.append(Series(id=i, tokens=x, budget=7.0))
33
 
34
  #%%
35
  batch = Batch(items=series)
 
29
  #%%
30
  series = []
31
  for i, x in enumerate(contexts):
32
+ series.append(Series(id=i, tokens=x, budget=5.0))
33
 
34
  #%%
35
  batch = Batch(items=series)