SPO / examples /aflow /optimize.py
XiangJinYu's picture
add metagpt
fe5c39d verified
raw
history blame
4.69 kB
# -*- coding: utf-8 -*-
# @Date : 8/23/2024 20:00 PM
# @Author : didi
# @Desc : Entrance of AFlow.
import argparse
from typing import Dict, List
from metagpt.configs.models_config import ModelsConfig
from metagpt.ext.aflow.data.download_data import download
from metagpt.ext.aflow.scripts.optimizer import Optimizer
class ExperimentConfig:
def __init__(self, dataset: str, question_type: str, operators: List[str]):
self.dataset = dataset
self.question_type = question_type
self.operators = operators
EXPERIMENT_CONFIGS: Dict[str, ExperimentConfig] = {
"DROP": ExperimentConfig(
dataset="DROP",
question_type="qa",
operators=["Custom", "AnswerGenerate", "ScEnsemble"],
),
"HotpotQA": ExperimentConfig(
dataset="HotpotQA",
question_type="qa",
operators=["Custom", "AnswerGenerate", "ScEnsemble"],
),
"MATH": ExperimentConfig(
dataset="MATH",
question_type="math",
operators=["Custom", "ScEnsemble", "Programmer"],
),
"GSM8K": ExperimentConfig(
dataset="GSM8K",
question_type="math",
operators=["Custom", "ScEnsemble", "Programmer"],
),
"MBPP": ExperimentConfig(
dataset="MBPP",
question_type="code",
operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"],
),
"HumanEval": ExperimentConfig(
dataset="HumanEval",
question_type="code",
operators=["Custom", "CustomCodeGenerate", "ScEnsemble", "Test"],
),
}
def parse_args():
parser = argparse.ArgumentParser(description="AFlow Optimizer")
parser.add_argument(
"--dataset",
type=str,
choices=list(EXPERIMENT_CONFIGS.keys()),
required=True,
help="Dataset type",
)
parser.add_argument("--sample", type=int, default=4, help="Sample count")
parser.add_argument(
"--optimized_path",
type=str,
default="metagpt/ext/aflow/scripts/optimized",
help="Optimized result save path",
)
parser.add_argument("--initial_round", type=int, default=1, help="Initial round")
parser.add_argument("--max_rounds", type=int, default=20, help="Max iteration rounds")
parser.add_argument("--check_convergence", type=bool, default=True, help="Whether to enable early stop")
parser.add_argument("--validation_rounds", type=int, default=5, help="Validation rounds")
parser.add_argument(
"--if_first_optimize",
type=lambda x: x.lower() == "true",
default=True,
help="Whether to download dataset for the first time",
)
parser.add_argument(
"--opt_model_name",
type=str,
default="claude-3-5-sonnet-20240620",
help="Specifies the name of the model used for optimization tasks.",
)
parser.add_argument(
"--exec_model_name",
type=str,
default="gpt-4o-mini",
help="Specifies the name of the model used for execution tasks.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
config = EXPERIMENT_CONFIGS[args.dataset]
models_config = ModelsConfig.default()
opt_llm_config = models_config.get(args.opt_model_name)
if opt_llm_config is None:
raise ValueError(
f"The optimization model '{args.opt_model_name}' was not found in the 'models' section of the configuration file. "
"Please add it to the configuration file or specify a valid model using the --opt_model_name flag. "
)
exec_llm_config = models_config.get(args.exec_model_name)
if exec_llm_config is None:
raise ValueError(
f"The execution model '{args.exec_model_name}' was not found in the 'models' section of the configuration file. "
"Please add it to the configuration file or specify a valid model using the --exec_model_name flag. "
)
download(["datasets", "initial_rounds"], if_first_download=args.if_first_optimize)
optimizer = Optimizer(
dataset=config.dataset,
question_type=config.question_type,
opt_llm_config=opt_llm_config,
exec_llm_config=exec_llm_config,
check_convergence=args.check_convergence,
operators=config.operators,
optimized_path=args.optimized_path,
sample=args.sample,
initial_round=args.initial_round,
max_rounds=args.max_rounds,
validation_rounds=args.validation_rounds,
)
# Optimize workflow via setting the optimizer's mode to 'Graph'
optimizer.optimize("Graph")
# Test workflow via setting the optimizer's mode to 'Test'
# optimizer.optimize("Test")