aideml / aide /journal.py
Dixing Xu
Init...
2fc6b05 unverified
"""
The journal is the core datastructure in AIDE that contains:
- the generated code samples
- information how code samples relate to each other (the tree structure)
- code execution results
- evaluation information such as metrics
...
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Literal, Optional
from dataclasses_json import DataClassJsonMixin
from .interpreter import ExecutionResult
from .utils.metric import MetricValue
from .utils.response import trim_long_string
@dataclass(eq=False)
class Node(DataClassJsonMixin):
"""A single node in the solution tree. Contains code, execution results, and evaluation information."""
# ---- code & plan ----
code: str
plan: str = field(default=None, kw_only=True) # type: ignore
# ---- general attrs ----
step: int = field(default=None, kw_only=True) # type: ignore
id: str = field(default_factory=lambda: uuid.uuid4().hex, kw_only=True)
ctime: float = field(default_factory=lambda: time.time(), kw_only=True)
parent: Optional["Node"] = field(default=None, kw_only=True)
children: set["Node"] = field(default_factory=set, kw_only=True)
# ---- execution info ----
_term_out: list[str] = field(default=None, kw_only=True) # type: ignore
exec_time: float = field(default=None, kw_only=True) # type: ignore
exc_type: str | None = field(default=None, kw_only=True)
exc_info: dict | None = field(default=None, kw_only=True)
exc_stack: list[tuple] | None = field(default=None, kw_only=True)
# ---- evaluation ----
# post-execution result analysis (findings/feedback)
analysis: str = field(default=None, kw_only=True) # type: ignore
metric: MetricValue = field(default=None, kw_only=True) # type: ignore
# whether the agent decided that the code is buggy
# -> always True if exc_type is not None or no valid metric
is_buggy: bool = field(default=None, kw_only=True) # type: ignore
def __post_init__(self) -> None:
if self.parent is not None:
self.parent.children.add(self)
@property
def stage_name(self) -> Literal["draft", "debug", "improve"]:
"""
Return the stage of the node:
- "stage" if the node is an initial solution draft
- "debug" if the node is the result of a debugging step
- "improve" if the node is the result of an improvement step
"""
if self.parent is None:
return "draft"
return "debug" if self.parent.is_buggy else "improve"
def absorb_exec_result(self, exec_result: ExecutionResult):
"""Absorb the result of executing the code from this node."""
self._term_out = exec_result.term_out
self.exec_time = exec_result.exec_time
self.exc_type = exec_result.exc_type
self.exc_info = exec_result.exc_info
self.exc_stack = exec_result.exc_stack
@property
def term_out(self) -> str:
"""Get the terminal output of the code execution (after truncating it)."""
return trim_long_string("".join(self._term_out))
@property
def is_leaf(self) -> bool:
"""Check if the node is a leaf node in the solution tree."""
return not self.children
def __eq__(self, other):
return isinstance(other, Node) and self.id == other.id
def __hash__(self):
return hash(self.id)
@property
def debug_depth(self) -> int:
"""
Length of the current debug path
- 0 if the node is not a debug node (parent is not buggy)
- 1 if the parent is buggy but the skip parent isn't
- n if there were n consecutive debugging steps
"""
if self.stage_name != "debug":
return 0
return self.parent.debug_depth + 1 # type: ignore
@dataclass
class InteractiveSession(DataClassJsonMixin):
"""
A collection of nodes for an interaction session
(when the agent interacts with a Jupyter notebook-like interface).
"""
nodes: list[Node] = field(default_factory=list)
completed: bool = False
def append(self, node: Node) -> None:
node.step = len(self.nodes)
self.nodes.append(node)
def generate_nb_trace(self, include_prompt, comment_headers=True) -> str:
"""Generate a trace of the interactive session in IPython format."""
trace = []
header_prefix = "## " if comment_headers else ""
for n in self.nodes:
trace.append(f"\n{header_prefix}In [{n.step+1}]:\n")
trace.append(n.code)
trace.append(f"\n{header_prefix}Out [{n.step+1}]:\n")
trace.append(n.term_out)
if include_prompt and self.nodes:
trace.append(f"\n{header_prefix}In [{self.nodes[-1].step+2}]:\n")
return "\n".join(trace).strip()
@dataclass
class Journal(DataClassJsonMixin):
"""A collection of nodes representing the solution tree."""
nodes: list[Node] = field(default_factory=list)
# eda: InteractiveSession = field(default_factory=lambda: InteractiveSession())
def __getitem__(self, idx: int) -> Node:
return self.nodes[idx]
def __len__(self) -> int:
"""Return the number of nodes in the journal."""
return len(self.nodes)
def append(self, node: Node) -> None:
"""Append a new node to the journal."""
node.step = len(self.nodes)
self.nodes.append(node)
@property
def draft_nodes(self) -> list[Node]:
"""Return a list of nodes representing intial coding drafts"""
return [n for n in self.nodes if n.parent is None]
@property
def buggy_nodes(self) -> list[Node]:
"""Return a list of nodes that are considered buggy by the agent."""
return [n for n in self.nodes if n.is_buggy]
@property
def good_nodes(self) -> list[Node]:
"""Return a list of nodes that are not considered buggy by the agent."""
return [n for n in self.nodes if not n.is_buggy]
def get_metric_history(self) -> list[MetricValue]:
"""Return a list of all metric values in the journal."""
return [n.metric for n in self.nodes]
def get_best_node(self, only_good=True) -> None | Node:
"""Return the best solution found so far (node with the highest validation metric)."""
if only_good:
nodes = self.good_nodes
if not nodes:
return None
else:
nodes = self.nodes
return max(nodes, key=lambda n: n.metric)
def generate_summary(self, include_code: bool = False) -> str:
"""Generate a summary of the journal for the agent."""
summary = []
for n in self.good_nodes:
summary_part = f"Design: {n.plan}\n"
if include_code:
summary_part += f"Code: {n.code}\n"
summary_part += f"Results: {n.analysis}\n"
summary_part += f"Validation Metric: {n.metric.value}\n"
summary.append(summary_part)
return "\n-------------------------------\n".join(summary)