aideml / aide /utils /config.py
Dixing Xu
Init...
2fc6b05 unverified
"""configuration and setup utils"""
from dataclasses import dataclass
from pathlib import Path
from typing import Hashable, cast
import coolname
import rich
from omegaconf import OmegaConf
from rich.syntax import Syntax
import shutup
from rich.logging import RichHandler
import logging
from . import tree_export
from . import copytree, preproc_data, serialize
shutup.mute_warnings()
logging.basicConfig(
level="WARNING", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]
)
logger = logging.getLogger("aide")
logger.setLevel(logging.WARNING)
""" these dataclasses are just for type hinting, the actual config is in config.yaml """
@dataclass
class StageConfig:
model: str
temp: float
@dataclass
class SearchConfig:
max_debug_depth: int
debug_prob: float
num_drafts: int
@dataclass
class AgentConfig:
steps: int
k_fold_validation: int
expose_prediction: bool
data_preview: bool
code: StageConfig
feedback: StageConfig
search: SearchConfig
@dataclass
class ExecConfig:
timeout: int
agent_file_name: str
format_tb_ipython: bool
@dataclass
class Config(Hashable):
data_dir: Path
desc_file: Path | None
goal: str | None
eval: str | None
log_dir: Path
workspace_dir: Path
preprocess_data: bool
copy_data: bool
exp_name: str
exec: ExecConfig
generate_report: bool
report: StageConfig
agent: AgentConfig
def _get_next_logindex(dir: Path) -> int:
"""Get the next available index for a log directory."""
max_index = -1
for p in dir.iterdir():
try:
if current_index := int(p.name.split("-")[0]) > max_index:
max_index = current_index
except ValueError:
pass
return max_index + 1
def _load_cfg(
path: Path = Path(__file__).parent / "config.yaml", use_cli_args=True
) -> Config:
cfg = OmegaConf.load(path)
if use_cli_args:
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli())
return cfg
def load_cfg(path: Path = Path(__file__).parent / "config.yaml") -> Config:
"""Load config from .yaml file and CLI args, and set up logging directory."""
return prep_cfg(_load_cfg(path))
def prep_cfg(cfg: Config):
if cfg.data_dir is None:
raise ValueError("`data_dir` must be provided.")
if cfg.desc_file is None and cfg.goal is None:
raise ValueError(
"You must provide either a description of the task goal (`goal=...`) or a path to a plaintext file containing the description (`desc_file=...`)."
)
if cfg.data_dir.startswith("example_tasks/"):
cfg.data_dir = Path(__file__).parent.parent / cfg.data_dir
cfg.data_dir = Path(cfg.data_dir).resolve()
if cfg.desc_file is not None:
cfg.desc_file = Path(cfg.desc_file).resolve()
top_log_dir = Path(cfg.log_dir).resolve()
top_log_dir.mkdir(parents=True, exist_ok=True)
top_workspace_dir = Path(cfg.workspace_dir).resolve()
top_workspace_dir.mkdir(parents=True, exist_ok=True)
# generate experiment name and prefix with consecutive index
ind = max(_get_next_logindex(top_log_dir), _get_next_logindex(top_workspace_dir))
cfg.exp_name = cfg.exp_name or coolname.generate_slug(3)
cfg.exp_name = f"{ind}-{cfg.exp_name}"
cfg.log_dir = (top_log_dir / cfg.exp_name).resolve()
cfg.workspace_dir = (top_workspace_dir / cfg.exp_name).resolve()
# validate the config
cfg_schema: Config = OmegaConf.structured(Config)
cfg = OmegaConf.merge(cfg_schema, cfg)
return cast(Config, cfg)
def print_cfg(cfg: Config) -> None:
rich.print(Syntax(OmegaConf.to_yaml(cfg), "yaml", theme="paraiso-dark"))
def load_task_desc(cfg: Config):
"""Load task description from markdown file or config str."""
# either load the task description from a file
if cfg.desc_file is not None:
if not (cfg.goal is None and cfg.eval is None):
logger.warning(
"Ignoring goal and eval args because task description file is provided."
)
with open(cfg.desc_file) as f:
return f.read()
# or generate it from the goal and eval args
if cfg.goal is None:
raise ValueError(
"`goal` (and optionally `eval`) must be provided if a task description file is not provided."
)
task_desc = {"Task goal": cfg.goal}
if cfg.eval is not None:
task_desc["Task evaluation"] = cfg.eval
return task_desc
def prep_agent_workspace(cfg: Config):
"""Setup the agent's workspace and preprocess data if necessary."""
(cfg.workspace_dir / "input").mkdir(parents=True, exist_ok=True)
(cfg.workspace_dir / "working").mkdir(parents=True, exist_ok=True)
copytree(cfg.data_dir, cfg.workspace_dir / "input", use_symlinks=not cfg.copy_data)
if cfg.preprocess_data:
preproc_data(cfg.workspace_dir / "input")
def save_run(cfg: Config, journal):
cfg.log_dir.mkdir(parents=True, exist_ok=True)
# save journal
serialize.dump_json(journal, cfg.log_dir / "journal.json")
# save config
OmegaConf.save(config=cfg, f=cfg.log_dir / "config.yaml")
# create the tree + code visualization
tree_export.generate(cfg, journal, cfg.log_dir / "tree_plot.html")
# save the best found solution
best_node = journal.get_best_node(only_good=False)
with open(cfg.log_dir / "best_solution.py", "w") as f:
f.write(best_node.code)