SPO / examples /di /InfiAgent-DABench /run_InfiAgent-DABench.py
XiangJinYu's picture
add metagpt
fe5c39d verified
raw
history blame
3.27 kB
import asyncio
import json
from DABench import DABench
from metagpt.logs import logger
from metagpt.roles.di.data_interpreter import DataInterpreter
async def get_prediction(agent, requirement):
"""Helper function to obtain a prediction from a new instance of the agent.
This function runs the agent with the provided requirement and extracts the prediction
from the result. If an error occurs during processing, it logs the error and returns None.
Args:
agent: The agent instance used to generate predictions.
requirement: The input requirement for which the prediction is to be made.
Returns:
The predicted result if successful, otherwise None.
"""
try:
# Run the agent with the given requirement and await the result
result = await agent.run(requirement)
# Parse the result to extract the prediction from the JSON response
prediction_json = json.loads(str(result).split("Current Plan")[1].split("## Current Task")[0])
prediction = prediction_json[-1]["result"] # Extract the last result from the parsed JSON
return prediction # Return the extracted prediction
except Exception as e:
# Log an error message if an exception occurs during processing
logger.info(f"Error processing requirement: {requirement}. Error: {e}")
return None # Return None in case of an error
async def evaluate_all(agent, k):
"""Evaluate all tasks in DABench using the specified baseline agent.
Tasks are divided into groups of size k and processed in parallel.
Args:
agent: The baseline agent used for making predictions.
k (int): The number of tasks to process in each group concurrently.
"""
bench = DABench() # Create an instance of DABench to access its methods and data
id_list, predictions = [], [] # Initialize lists to store IDs and predictions
tasks = [] # Initialize a list to hold the tasks
# Iterate over the answers in DABench to generate tasks
for key, value in bench.answers.items():
requirement = bench.generate_formatted_prompt(key) # Generate a formatted prompt for the current key
tasks.append(get_prediction(agent, requirement)) # Append the prediction task to the tasks list
id_list.append(key) # Append the current key to the ID list
# Process tasks in groups of size k and execute them concurrently
for i in range(0, len(tasks), k):
# Get the current group of tasks
current_group = tasks[i : i + k]
# Execute the current group of tasks in parallel
group_predictions = await asyncio.gather(*current_group)
# Filter out any None values from the predictions and extend the predictions list
predictions.extend(pred for pred in group_predictions if pred is not None)
# Evaluate the results using all valid predictions and logger.info the evaluation
logger.info(bench.eval_all(id_list, predictions))
def main(k=5):
"""Main function to run the evaluation process."""
agent = DataInterpreter() # Create an instance of the DataInterpreter agent
asyncio.run(evaluate_all(agent, k)) # Run the evaluate_all function asynchronously
if __name__ == "__main__":
main()