File size: 7,065 Bytes
2fc6b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
"""
The journal is the core datastructure in AIDE that contains:
- the generated code samples
- information how code samples relate to each other (the tree structure)
- code execution results
- evaluation information such as metrics
...
"""
import time
import uuid
from dataclasses import dataclass, field
from typing import Literal, Optional
from dataclasses_json import DataClassJsonMixin
from .interpreter import ExecutionResult
from .utils.metric import MetricValue
from .utils.response import trim_long_string
@dataclass(eq=False)
class Node(DataClassJsonMixin):
"""A single node in the solution tree. Contains code, execution results, and evaluation information."""
# ---- code & plan ----
code: str
plan: str = field(default=None, kw_only=True) # type: ignore
# ---- general attrs ----
step: int = field(default=None, kw_only=True) # type: ignore
id: str = field(default_factory=lambda: uuid.uuid4().hex, kw_only=True)
ctime: float = field(default_factory=lambda: time.time(), kw_only=True)
parent: Optional["Node"] = field(default=None, kw_only=True)
children: set["Node"] = field(default_factory=set, kw_only=True)
# ---- execution info ----
_term_out: list[str] = field(default=None, kw_only=True) # type: ignore
exec_time: float = field(default=None, kw_only=True) # type: ignore
exc_type: str | None = field(default=None, kw_only=True)
exc_info: dict | None = field(default=None, kw_only=True)
exc_stack: list[tuple] | None = field(default=None, kw_only=True)
# ---- evaluation ----
# post-execution result analysis (findings/feedback)
analysis: str = field(default=None, kw_only=True) # type: ignore
metric: MetricValue = field(default=None, kw_only=True) # type: ignore
# whether the agent decided that the code is buggy
# -> always True if exc_type is not None or no valid metric
is_buggy: bool = field(default=None, kw_only=True) # type: ignore
def __post_init__(self) -> None:
if self.parent is not None:
self.parent.children.add(self)
@property
def stage_name(self) -> Literal["draft", "debug", "improve"]:
"""
Return the stage of the node:
- "stage" if the node is an initial solution draft
- "debug" if the node is the result of a debugging step
- "improve" if the node is the result of an improvement step
"""
if self.parent is None:
return "draft"
return "debug" if self.parent.is_buggy else "improve"
def absorb_exec_result(self, exec_result: ExecutionResult):
"""Absorb the result of executing the code from this node."""
self._term_out = exec_result.term_out
self.exec_time = exec_result.exec_time
self.exc_type = exec_result.exc_type
self.exc_info = exec_result.exc_info
self.exc_stack = exec_result.exc_stack
@property
def term_out(self) -> str:
"""Get the terminal output of the code execution (after truncating it)."""
return trim_long_string("".join(self._term_out))
@property
def is_leaf(self) -> bool:
"""Check if the node is a leaf node in the solution tree."""
return not self.children
def __eq__(self, other):
return isinstance(other, Node) and self.id == other.id
def __hash__(self):
return hash(self.id)
@property
def debug_depth(self) -> int:
"""
Length of the current debug path
- 0 if the node is not a debug node (parent is not buggy)
- 1 if the parent is buggy but the skip parent isn't
- n if there were n consecutive debugging steps
"""
if self.stage_name != "debug":
return 0
return self.parent.debug_depth + 1 # type: ignore
@dataclass
class InteractiveSession(DataClassJsonMixin):
"""
A collection of nodes for an interaction session
(when the agent interacts with a Jupyter notebook-like interface).
"""
nodes: list[Node] = field(default_factory=list)
completed: bool = False
def append(self, node: Node) -> None:
node.step = len(self.nodes)
self.nodes.append(node)
def generate_nb_trace(self, include_prompt, comment_headers=True) -> str:
"""Generate a trace of the interactive session in IPython format."""
trace = []
header_prefix = "## " if comment_headers else ""
for n in self.nodes:
trace.append(f"\n{header_prefix}In [{n.step+1}]:\n")
trace.append(n.code)
trace.append(f"\n{header_prefix}Out [{n.step+1}]:\n")
trace.append(n.term_out)
if include_prompt and self.nodes:
trace.append(f"\n{header_prefix}In [{self.nodes[-1].step+2}]:\n")
return "\n".join(trace).strip()
@dataclass
class Journal(DataClassJsonMixin):
"""A collection of nodes representing the solution tree."""
nodes: list[Node] = field(default_factory=list)
# eda: InteractiveSession = field(default_factory=lambda: InteractiveSession())
def __getitem__(self, idx: int) -> Node:
return self.nodes[idx]
def __len__(self) -> int:
"""Return the number of nodes in the journal."""
return len(self.nodes)
def append(self, node: Node) -> None:
"""Append a new node to the journal."""
node.step = len(self.nodes)
self.nodes.append(node)
@property
def draft_nodes(self) -> list[Node]:
"""Return a list of nodes representing intial coding drafts"""
return [n for n in self.nodes if n.parent is None]
@property
def buggy_nodes(self) -> list[Node]:
"""Return a list of nodes that are considered buggy by the agent."""
return [n for n in self.nodes if n.is_buggy]
@property
def good_nodes(self) -> list[Node]:
"""Return a list of nodes that are not considered buggy by the agent."""
return [n for n in self.nodes if not n.is_buggy]
def get_metric_history(self) -> list[MetricValue]:
"""Return a list of all metric values in the journal."""
return [n.metric for n in self.nodes]
def get_best_node(self, only_good=True) -> None | Node:
"""Return the best solution found so far (node with the highest validation metric)."""
if only_good:
nodes = self.good_nodes
if not nodes:
return None
else:
nodes = self.nodes
return max(nodes, key=lambda n: n.metric)
def generate_summary(self, include_code: bool = False) -> str:
"""Generate a summary of the journal for the agent."""
summary = []
for n in self.good_nodes:
summary_part = f"Design: {n.plan}\n"
if include_code:
summary_part += f"Code: {n.code}\n"
summary_part += f"Results: {n.analysis}\n"
summary_part += f"Validation Metric: {n.metric.value}\n"
summary.append(summary_part)
return "\n-------------------------------\n".join(summary)
|