gpted / expand.py
mebubo's picture
println profiling
8f33f3e
from collections import defaultdict
from dataclasses import dataclass, field
import time
from typing import Callable, Protocol, Self
@dataclass
class Expansion:
token: int
cost: float
@dataclass
class Series:
id: int
tokens: list[int]
budget: float
expansions: list[Expansion] = field(default_factory=list)
def get_all_tokens(self) -> list[int]:
return self.tokens + [e.token for e in self.expansions]
def get_remaining_budget(self) -> float:
return self.budget + sum(e.cost for e in self.expansions)
@dataclass
class Batch:
items: list[Series]
@dataclass
class TokenCandidates:
series: Series
expansions: list[Expansion]
@dataclass
class BatchCandidates:
items: list[TokenCandidates]
# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class BatchExpander(Protocol):
def expand(self, batch: Batch) -> BatchCandidates: ...
@dataclass
class CompletedSequence:
series: Series
expansions: list[list[Expansion]]
@dataclass
class CompletedBatch:
items: list[CompletedSequence]
def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
new_series_batch = []
for expansion in result.expansions:
if not stopping_criterion(result.series, expansion):
new_series = Series(
id=result.series.id,
tokens=result.series.tokens,
expansions=result.series.expansions + [expansion],
budget=result.series.budget
)
new_series_batch.append(new_series)
completed_series = [result.series] if len(new_series_batch) == 0 else []
return new_series_batch, completed_series
def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
# check that ids in original_series are unique
assert len(original_series) == len({s.id for s in original_series})
# group original series by id
original_series_by_id = {s.id: s for s in original_series}
# group expanded series by id
expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
for s in expanded_series:
if len(s.expansions) != 0:
expanded_series_by_id[s.id].append(s.expansions)
results = []
for id, s in original_series_by_id.items():
expansions = expanded_series_by_id[id]
expansion_result = CompletedSequence(series=s, expansions=expansions)
results.append(expansion_result)
return CompletedBatch(items=results)
def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
return series.get_remaining_budget() + expansion.cost < 0
# A compound operation that we can implement generically, relying on a BatchExpander
def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
completed_series: list[Series] = []
current_batch = batch
while len(current_batch.items) > 0:
# print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
print(f"Expanding {len(current_batch.items)} series")
current_batch_items = []
start_time = time.time()
expanded = expander.expand(current_batch)
print(f"Expanded, took {time.time() - start_time} seconds")
print("Computing new batch")
start_time = time.time()
for item in expanded.items:
if len(item.expansions) == 0:
completed_series.append(item.series)
else:
new_series, completed = compute_new_series(item, completion_criterion)
completed_series.extend(completed)
current_batch_items.extend(new_series)
current_batch = Batch(items=current_batch_items)
print(f"Computed, took {time.time() - start_time} seconds")
return compute_expansions(batch.items, completed_series)