|
"""configuration and setup utils""" |
|
|
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Hashable, cast |
|
|
|
import coolname |
|
import rich |
|
from omegaconf import OmegaConf |
|
from rich.syntax import Syntax |
|
import shutup |
|
from rich.logging import RichHandler |
|
import logging |
|
|
|
from . import tree_export |
|
from . import copytree, preproc_data, serialize |
|
|
|
shutup.mute_warnings() |
|
logging.basicConfig( |
|
level="WARNING", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()] |
|
) |
|
logger = logging.getLogger("aide") |
|
logger.setLevel(logging.WARNING) |
|
|
|
|
|
""" these dataclasses are just for type hinting, the actual config is in config.yaml """ |
|
|
|
|
|
@dataclass |
|
class StageConfig: |
|
model: str |
|
temp: float |
|
|
|
|
|
@dataclass |
|
class SearchConfig: |
|
max_debug_depth: int |
|
debug_prob: float |
|
num_drafts: int |
|
|
|
|
|
@dataclass |
|
class AgentConfig: |
|
steps: int |
|
k_fold_validation: int |
|
expose_prediction: bool |
|
data_preview: bool |
|
|
|
code: StageConfig |
|
feedback: StageConfig |
|
|
|
search: SearchConfig |
|
|
|
|
|
@dataclass |
|
class ExecConfig: |
|
timeout: int |
|
agent_file_name: str |
|
format_tb_ipython: bool |
|
|
|
|
|
@dataclass |
|
class Config(Hashable): |
|
data_dir: Path |
|
desc_file: Path | None |
|
|
|
goal: str | None |
|
eval: str | None |
|
|
|
log_dir: Path |
|
workspace_dir: Path |
|
|
|
preprocess_data: bool |
|
copy_data: bool |
|
|
|
exp_name: str |
|
|
|
exec: ExecConfig |
|
generate_report: bool |
|
report: StageConfig |
|
agent: AgentConfig |
|
|
|
|
|
def _get_next_logindex(dir: Path) -> int: |
|
"""Get the next available index for a log directory.""" |
|
max_index = -1 |
|
for p in dir.iterdir(): |
|
try: |
|
if current_index := int(p.name.split("-")[0]) > max_index: |
|
max_index = current_index |
|
except ValueError: |
|
pass |
|
return max_index + 1 |
|
|
|
|
|
def _load_cfg( |
|
path: Path = Path(__file__).parent / "config.yaml", use_cli_args=True |
|
) -> Config: |
|
cfg = OmegaConf.load(path) |
|
if use_cli_args: |
|
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli()) |
|
return cfg |
|
|
|
|
|
def load_cfg(path: Path = Path(__file__).parent / "config.yaml") -> Config: |
|
"""Load config from .yaml file and CLI args, and set up logging directory.""" |
|
return prep_cfg(_load_cfg(path)) |
|
|
|
|
|
def prep_cfg(cfg: Config): |
|
if cfg.data_dir is None: |
|
raise ValueError("`data_dir` must be provided.") |
|
|
|
if cfg.desc_file is None and cfg.goal is None: |
|
raise ValueError( |
|
"You must provide either a description of the task goal (`goal=...`) or a path to a plaintext file containing the description (`desc_file=...`)." |
|
) |
|
|
|
if cfg.data_dir.startswith("example_tasks/"): |
|
cfg.data_dir = Path(__file__).parent.parent / cfg.data_dir |
|
cfg.data_dir = Path(cfg.data_dir).resolve() |
|
|
|
if cfg.desc_file is not None: |
|
cfg.desc_file = Path(cfg.desc_file).resolve() |
|
|
|
top_log_dir = Path(cfg.log_dir).resolve() |
|
top_log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
top_workspace_dir = Path(cfg.workspace_dir).resolve() |
|
top_workspace_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
ind = max(_get_next_logindex(top_log_dir), _get_next_logindex(top_workspace_dir)) |
|
cfg.exp_name = cfg.exp_name or coolname.generate_slug(3) |
|
cfg.exp_name = f"{ind}-{cfg.exp_name}" |
|
|
|
cfg.log_dir = (top_log_dir / cfg.exp_name).resolve() |
|
cfg.workspace_dir = (top_workspace_dir / cfg.exp_name).resolve() |
|
|
|
|
|
cfg_schema: Config = OmegaConf.structured(Config) |
|
cfg = OmegaConf.merge(cfg_schema, cfg) |
|
|
|
return cast(Config, cfg) |
|
|
|
|
|
def print_cfg(cfg: Config) -> None: |
|
rich.print(Syntax(OmegaConf.to_yaml(cfg), "yaml", theme="paraiso-dark")) |
|
|
|
|
|
def load_task_desc(cfg: Config): |
|
"""Load task description from markdown file or config str.""" |
|
|
|
|
|
if cfg.desc_file is not None: |
|
if not (cfg.goal is None and cfg.eval is None): |
|
logger.warning( |
|
"Ignoring goal and eval args because task description file is provided." |
|
) |
|
|
|
with open(cfg.desc_file) as f: |
|
return f.read() |
|
|
|
|
|
if cfg.goal is None: |
|
raise ValueError( |
|
"`goal` (and optionally `eval`) must be provided if a task description file is not provided." |
|
) |
|
|
|
task_desc = {"Task goal": cfg.goal} |
|
if cfg.eval is not None: |
|
task_desc["Task evaluation"] = cfg.eval |
|
|
|
return task_desc |
|
|
|
|
|
def prep_agent_workspace(cfg: Config): |
|
"""Setup the agent's workspace and preprocess data if necessary.""" |
|
(cfg.workspace_dir / "input").mkdir(parents=True, exist_ok=True) |
|
(cfg.workspace_dir / "working").mkdir(parents=True, exist_ok=True) |
|
|
|
copytree(cfg.data_dir, cfg.workspace_dir / "input", use_symlinks=not cfg.copy_data) |
|
if cfg.preprocess_data: |
|
preproc_data(cfg.workspace_dir / "input") |
|
|
|
|
|
def save_run(cfg: Config, journal): |
|
cfg.log_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
serialize.dump_json(journal, cfg.log_dir / "journal.json") |
|
|
|
OmegaConf.save(config=cfg, f=cfg.log_dir / "config.yaml") |
|
|
|
tree_export.generate(cfg, journal, cfg.log_dir / "tree_plot.html") |
|
|
|
best_node = journal.get_best_node(only_good=False) |
|
with open(cfg.log_dir / "best_solution.py", "w") as f: |
|
f.write(best_node.code) |
|
|