|
""" |
|
The journal is the core datastructure in AIDE that contains: |
|
- the generated code samples |
|
- information how code samples relate to each other (the tree structure) |
|
- code execution results |
|
- evaluation information such as metrics |
|
... |
|
""" |
|
|
|
import time |
|
import uuid |
|
from dataclasses import dataclass, field |
|
from typing import Literal, Optional |
|
|
|
from dataclasses_json import DataClassJsonMixin |
|
from .interpreter import ExecutionResult |
|
from .utils.metric import MetricValue |
|
from .utils.response import trim_long_string |
|
|
|
|
|
@dataclass(eq=False) |
|
class Node(DataClassJsonMixin): |
|
"""A single node in the solution tree. Contains code, execution results, and evaluation information.""" |
|
|
|
|
|
code: str |
|
plan: str = field(default=None, kw_only=True) |
|
|
|
|
|
step: int = field(default=None, kw_only=True) |
|
id: str = field(default_factory=lambda: uuid.uuid4().hex, kw_only=True) |
|
ctime: float = field(default_factory=lambda: time.time(), kw_only=True) |
|
parent: Optional["Node"] = field(default=None, kw_only=True) |
|
children: set["Node"] = field(default_factory=set, kw_only=True) |
|
|
|
|
|
_term_out: list[str] = field(default=None, kw_only=True) |
|
exec_time: float = field(default=None, kw_only=True) |
|
exc_type: str | None = field(default=None, kw_only=True) |
|
exc_info: dict | None = field(default=None, kw_only=True) |
|
exc_stack: list[tuple] | None = field(default=None, kw_only=True) |
|
|
|
|
|
|
|
analysis: str = field(default=None, kw_only=True) |
|
metric: MetricValue = field(default=None, kw_only=True) |
|
|
|
|
|
is_buggy: bool = field(default=None, kw_only=True) |
|
|
|
def __post_init__(self) -> None: |
|
if self.parent is not None: |
|
self.parent.children.add(self) |
|
|
|
@property |
|
def stage_name(self) -> Literal["draft", "debug", "improve"]: |
|
""" |
|
Return the stage of the node: |
|
- "stage" if the node is an initial solution draft |
|
- "debug" if the node is the result of a debugging step |
|
- "improve" if the node is the result of an improvement step |
|
""" |
|
if self.parent is None: |
|
return "draft" |
|
return "debug" if self.parent.is_buggy else "improve" |
|
|
|
def absorb_exec_result(self, exec_result: ExecutionResult): |
|
"""Absorb the result of executing the code from this node.""" |
|
self._term_out = exec_result.term_out |
|
self.exec_time = exec_result.exec_time |
|
self.exc_type = exec_result.exc_type |
|
self.exc_info = exec_result.exc_info |
|
self.exc_stack = exec_result.exc_stack |
|
|
|
@property |
|
def term_out(self) -> str: |
|
"""Get the terminal output of the code execution (after truncating it).""" |
|
return trim_long_string("".join(self._term_out)) |
|
|
|
@property |
|
def is_leaf(self) -> bool: |
|
"""Check if the node is a leaf node in the solution tree.""" |
|
return not self.children |
|
|
|
def __eq__(self, other): |
|
return isinstance(other, Node) and self.id == other.id |
|
|
|
def __hash__(self): |
|
return hash(self.id) |
|
|
|
@property |
|
def debug_depth(self) -> int: |
|
""" |
|
Length of the current debug path |
|
- 0 if the node is not a debug node (parent is not buggy) |
|
- 1 if the parent is buggy but the skip parent isn't |
|
- n if there were n consecutive debugging steps |
|
""" |
|
if self.stage_name != "debug": |
|
return 0 |
|
return self.parent.debug_depth + 1 |
|
|
|
|
|
@dataclass |
|
class InteractiveSession(DataClassJsonMixin): |
|
""" |
|
A collection of nodes for an interaction session |
|
(when the agent interacts with a Jupyter notebook-like interface). |
|
""" |
|
|
|
nodes: list[Node] = field(default_factory=list) |
|
completed: bool = False |
|
|
|
def append(self, node: Node) -> None: |
|
node.step = len(self.nodes) |
|
self.nodes.append(node) |
|
|
|
def generate_nb_trace(self, include_prompt, comment_headers=True) -> str: |
|
"""Generate a trace of the interactive session in IPython format.""" |
|
trace = [] |
|
header_prefix = "## " if comment_headers else "" |
|
for n in self.nodes: |
|
trace.append(f"\n{header_prefix}In [{n.step+1}]:\n") |
|
trace.append(n.code) |
|
trace.append(f"\n{header_prefix}Out [{n.step+1}]:\n") |
|
trace.append(n.term_out) |
|
|
|
if include_prompt and self.nodes: |
|
trace.append(f"\n{header_prefix}In [{self.nodes[-1].step+2}]:\n") |
|
|
|
return "\n".join(trace).strip() |
|
|
|
|
|
@dataclass |
|
class Journal(DataClassJsonMixin): |
|
"""A collection of nodes representing the solution tree.""" |
|
|
|
nodes: list[Node] = field(default_factory=list) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Node: |
|
return self.nodes[idx] |
|
|
|
def __len__(self) -> int: |
|
"""Return the number of nodes in the journal.""" |
|
return len(self.nodes) |
|
|
|
def append(self, node: Node) -> None: |
|
"""Append a new node to the journal.""" |
|
node.step = len(self.nodes) |
|
self.nodes.append(node) |
|
|
|
@property |
|
def draft_nodes(self) -> list[Node]: |
|
"""Return a list of nodes representing intial coding drafts""" |
|
return [n for n in self.nodes if n.parent is None] |
|
|
|
@property |
|
def buggy_nodes(self) -> list[Node]: |
|
"""Return a list of nodes that are considered buggy by the agent.""" |
|
return [n for n in self.nodes if n.is_buggy] |
|
|
|
@property |
|
def good_nodes(self) -> list[Node]: |
|
"""Return a list of nodes that are not considered buggy by the agent.""" |
|
return [n for n in self.nodes if not n.is_buggy] |
|
|
|
def get_metric_history(self) -> list[MetricValue]: |
|
"""Return a list of all metric values in the journal.""" |
|
return [n.metric for n in self.nodes] |
|
|
|
def get_best_node(self, only_good=True) -> None | Node: |
|
"""Return the best solution found so far (node with the highest validation metric).""" |
|
if only_good: |
|
nodes = self.good_nodes |
|
if not nodes: |
|
return None |
|
else: |
|
nodes = self.nodes |
|
return max(nodes, key=lambda n: n.metric) |
|
|
|
def generate_summary(self, include_code: bool = False) -> str: |
|
"""Generate a summary of the journal for the agent.""" |
|
summary = [] |
|
for n in self.good_nodes: |
|
summary_part = f"Design: {n.plan}\n" |
|
if include_code: |
|
summary_part += f"Code: {n.code}\n" |
|
summary_part += f"Results: {n.analysis}\n" |
|
summary_part += f"Validation Metric: {n.metric.value}\n" |
|
summary.append(summary_part) |
|
return "\n-------------------------------\n".join(summary) |
|
|