from collections import defaultdict
from dataclasses import dataclass, field
import time
from typing import Callable, Protocol, Self

@dataclass
class Expansion:
    token: int
    cost: float

@dataclass
class Series:
    id: int
    tokens: list[int]
    budget: float
    expansions: list[Expansion] = field(default_factory=list)

    def get_all_tokens(self) -> list[int]:
        return self.tokens + [e.token for e in self.expansions]

    def get_remaining_budget(self) -> float:
        return self.budget + sum(e.cost for e in self.expansions)

@dataclass
class Batch:
    items: list[Series]

@dataclass
class TokenCandidates:
    series: Series
    expansions: list[Expansion]

@dataclass
class BatchCandidates:
    items: list[TokenCandidates]

# A fundamental operation that we can implement both using an LLM and using a list of hardcoded sequences, for testing
class BatchExpander(Protocol):
    def expand(self, batch: Batch) -> BatchCandidates: ...

@dataclass
class CompletedSequence:
    series: Series
    expansions: list[list[Expansion]]

@dataclass
class CompletedBatch:
    items: list[CompletedSequence]

def compute_new_series(result: TokenCandidates, stopping_criterion: Callable[[Series, Expansion], bool]) -> tuple[list[Series], list[Series]]:
    new_series_batch = []
    for expansion in result.expansions:
        if not stopping_criterion(result.series, expansion):
            new_series = Series(
                id=result.series.id,
                tokens=result.series.tokens,
                expansions=result.series.expansions + [expansion],
                budget=result.series.budget
            )
            new_series_batch.append(new_series)
    completed_series = [result.series] if len(new_series_batch) == 0 else []
    return new_series_batch, completed_series

def compute_expansions(original_series: list[Series], expanded_series: list[Series]) -> CompletedBatch:
    # check that ids in original_series are unique
    assert len(original_series) == len({s.id for s in original_series})
    # group original series by id
    original_series_by_id = {s.id: s for s in original_series}
    # group expanded series by id
    expanded_series_by_id: dict[int, list[list[Expansion]]] = defaultdict(list)
    for s in expanded_series:
        if len(s.expansions) != 0:
            expanded_series_by_id[s.id].append(s.expansions)
    results = []
    for id, s in original_series_by_id.items():
        expansions = expanded_series_by_id[id]
        expansion_result = CompletedSequence(series=s, expansions=expansions)
        results.append(expansion_result)
    return CompletedBatch(items=results)

def default_completion_criterion(series: Series, expansion: Expansion) -> bool:
    return series.get_remaining_budget() + expansion.cost < 0

# A compound operation that we can implement generically, relying on a BatchExpander
def expand(batch: Batch, expander: BatchExpander, completion_criterion: Callable[[Series, Expansion], bool] = default_completion_criterion) -> CompletedBatch:
    completed_series: list[Series] = []
    current_batch = batch
    while len(current_batch.items) > 0:
        # print(f"Expanding {len(current_batch.items)} series: {current_batch.items}")
        print(f"Expanding {len(current_batch.items)} series")
        current_batch_items = []
        start_time = time.time()
        expanded = expander.expand(current_batch)
        print(f"Expanded, took {time.time() - start_time} seconds")
        print("Computing new batch")
        start_time = time.time()
        for item in expanded.items:
            if len(item.expansions) == 0:
                completed_series.append(item.series)
            else:
                new_series, completed = compute_new_series(item, completion_criterion)
                completed_series.extend(completed)
                current_batch_items.extend(new_series)
        current_batch = Batch(items=current_batch_items)
        print(f"Computed, took {time.time() - start_time} seconds")
    return compute_expansions(batch.items, completed_series)