gpted / expand.py
mebubo's picture
Working generic expand
2fb63bf
raw
history blame
2.9 kB
from collections import defaultdict
from dataclasses import dataclass
from typing import Protocol
# import torch
# from transformers import PreTrainedModel
# from completions import find_next_tokens, Tokenizer
@dataclass
class Series:
id: int
tokens: list[int]
budget: float
@dataclass
class Batch:
items: list[Series]
@dataclass
class ExpansionOne:
token: int
cost: float
@dataclass
class ExpansionOneResult:
series: Series
expansions: list[ExpansionOne]
@dataclass
class ExpansionOneResultBatch:
items: list[ExpansionOneResult]
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class ExpanderOneBatch(Protocol):
def expand(self, batch: Batch) -> ExpansionOneResultBatch: ...
@dataclass
class ExpansionResult:
series: Series
expansions: list[list[int]]
@dataclass
class ExpansionResultBatch:
items: list[ExpansionResult]
def compute_new_series(result: ExpansionOneResult) -> list[Series]:
results = []
for expansion in result.expansions:
results.append(Series(id=result.series.id, tokens=result.series.tokens + [expansion.token], budget=result.series.budget - expansion.cost))
return results
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> ExpansionResultBatch:
# check that ids in original_series are unique
assert len(original_series) == len({s.id for s in original_series})
# group original series by id
original_series_by_id = {s.id: s for s in original_series}
# group expanded series by id
expanded_series_by_id: dict[int, list[list[int]]] = defaultdict(list)
for s in expanded_series:
expanded_series_by_id[s.id].append(s.tokens)
results = []
for id, s in original_series_by_id.items():
expansions = expanded_series_by_id[id]
# subtract the original series from each expansion
l = len(s.tokens)
trimmed_expansions = [e[l:] for e in expansions if len(e) > l]
expansion_result = ExpansionResult(series=s, expansions=trimmed_expansions)
results.append(expansion_result)
return ExpansionResultBatch(items=results)
# A compound operation that we can implement generically, relying on an ExpanderOneBatch
def expand(batch: Batch, expander: ExpanderOneBatch) -> ExpansionResultBatch:
completed_series: list[Series] = []
current_batch = batch
while len(current_batch.items) > 0:
current_batch_items = []
expanded = expander.expand(current_batch)
for item in expanded.items:
if len(item.expansions) == 0:
completed_series.append(item.series)
else:
current_batch_items.extend(compute_new_series(item))
current_batch = Batch(items=current_batch_items)
return compute_expansions(batch.items, completed_series)