V-MAGE-EVAL-DEMO / runner /game_runner.py
Fengx1n's picture
first commit
504b2e4
from datetime import datetime
import importlib
import os
import pathlib
import queue
import random
import threading
import time
import tqdm
from agent.game_agent import game_agent
from utils.game_utils import seed_everything
import wandb
from provider.ProviderFactory import ProviderFactory
from utils.calculate_log import calculate_statistics, extract_scores
from utils.config import Config
from utils.encoding_utils import encode_data_to_base64_path
from utils.file_utils import assemble_project_path, get_all_files, img_to_gif, run_path_construct
from utils.json_utils import parse_semi_formatted_text
from utils.lmm_utils import assemble_prompt
from utils.planner_utils import _extract_keys_from_template
import pickle
config = Config()
class GamePipelineRunner():
def __init__(self,
args):
# TODO add more mesages in run path. e.g. guide/instruction/..
run_path = run_path_construct(
args.output_dir,
config.env_short_name,
args.levelConfig.split('/')[-1].split(".")[0],
args.llmProviderConfig.split('/')[-1].split(".")[0],
datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
)
self.record_path = args.output_dir
self.output_dir = assemble_project_path(run_path)
pathlib.Path(self.output_dir).mkdir(parents=True, exist_ok=True)
self.env_name = config.env_name
self.game_module = config.game_module
self.game_class = config.game_class
self.llm_provider_config_path = args.llmProviderConfig
self.llm_provider = None
self.generation_config_path = args.generationConfig if hasattr(args, 'generationConfig') else None
self.agent = None
self.reset()
def reset(self):
if self.llm_provider:
self.llm_provider.reset()
else:
self.llm_provider = ProviderFactory.getProvider(self.llm_provider_config_path, self.generation_config_path)
if self.agent is None:
self.agent = game_agent(self.llm_provider)
else:
self.agent.reset_provider(self.llm_provider)
def input_listener(self, event):
# self.reset()
count = 0
while not event.is_set() and not self.game.over: # 添加运行状态检查
if not self.game.new_action_event.is_set():
# 生成动作并通知主线程
game_info = self.game.get_game_info()
self.agent.update_game_info(game_info)
if not self.llm_provider:
action = random.choice(self.game.valid_actions)
time.sleep(0.5)
else:
action = self.agent.execute_action()
self.game.current_action = action
self.game.new_action_event.set() # 设置事件表示有新动作
else:
count += 1
time.sleep(0.02)
def run(self):
game_module = importlib.import_module(self.game_module)
game_class = getattr(game_module, self.game_class)
self.game = game_class(self.output_dir)
self.game.run(self.input_listener)
scores = self.game.get_score()
token_usage = self.llm_provider.get_tokens_usage()
# merge scores and token_usage
scores.update(token_usage)
return scores, self.game.game_frames
def pipeline_shutdown(self):
if config.save_response:
self.agent.save_response_record()
self.agent = None
print("GamePipelineRunner shutdown.")
def entry(args, run_name=""):
config.output_dir = args.output_dir
config.save_response = args.save_response
config.extra_config = vars(args)
print("Extra Config:", config.extra_config)
pipelineRunner = GamePipelineRunner(args)
pickle_record_path = os.path.join(pipelineRunner.record_path, config.env_short_name, f"record_{run_name}.pickle")
directory = os.path.dirname(pickle_record_path)
if not os.path.exists(directory):
os.makedirs(directory)
if os.path.exists(pickle_record_path):
scores = pickle.load(open(pickle_record_path, "rb"))
else:
scores = []
pickle.dump(scores, open(pickle_record_path, "wb"))
with tqdm.tqdm(total=args.test_rounds, initial=len(scores), desc=run_name) as pbar:
for test_round in range(len(scores), args.test_rounds):
pipelineRunner.reset()
seed_everything(test_round)
config.pbar = pbar
now_avg_score = 0 if len(scores) == 0 else sum([score["score"] for score in scores])/len(scores)
now_avg_score = round(now_avg_score, 2)
config.now_avg_score = now_avg_score
score_dict, game_frames = pipelineRunner.run()
pbar.update()
scores.append(score_dict)
pipelineRunner.pipeline_shutdown()
game_frames_gif = os.path.join(pipelineRunner.output_dir, datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") + ".gif")
img_to_gif(game_frames, game_frames_gif)
log_dict = {"round" : test_round}
# 按dict 统计分数
for key in score_dict:
log_dict.update({
run_name + "_" + key: score_dict[key],
run_name + "_avg_" + key: sum([score[key] for score in scores])/len(scores)
})
log_dict.update({"image": wandb.Image(game_frames_gif)})
wandb.log(log_dict)
pickle.dump(scores, open(pickle_record_path, "wb"))
now_avg_score = 0 if len(scores) == 0 else sum([score["score"] for score in scores])/len(scores)
now_avg_score = round(now_avg_score, 2)
config.now_avg_score = now_avg_score
pbar.set_postfix(
avg_score=config.now_avg_score
)
return scores