from datetime import datetime import importlib import os import pathlib import queue import random import threading import time import tqdm from agent.game_agent import game_agent from utils.game_utils import seed_everything import wandb from provider.ProviderFactory import ProviderFactory from utils.calculate_log import calculate_statistics, extract_scores from utils.config import Config from utils.encoding_utils import encode_data_to_base64_path from utils.file_utils import assemble_project_path, get_all_files, img_to_gif, run_path_construct from utils.json_utils import parse_semi_formatted_text from utils.lmm_utils import assemble_prompt from utils.planner_utils import _extract_keys_from_template import pickle config = Config() class GamePipelineRunner(): def __init__(self, args): # TODO add more mesages in run path. e.g. guide/instruction/.. run_path = run_path_construct( args.output_dir, config.env_short_name, args.levelConfig.split('/')[-1].split(".")[0], args.llmProviderConfig.split('/')[-1].split(".")[0], datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") ) self.record_path = args.output_dir self.output_dir = assemble_project_path(run_path) pathlib.Path(self.output_dir).mkdir(parents=True, exist_ok=True) self.env_name = config.env_name self.game_module = config.game_module self.game_class = config.game_class self.llm_provider_config_path = args.llmProviderConfig self.llm_provider = None self.generation_config_path = args.generationConfig if hasattr(args, 'generationConfig') else None self.agent = None self.reset() def reset(self): if self.llm_provider: self.llm_provider.reset() else: self.llm_provider = ProviderFactory.getProvider(self.llm_provider_config_path, self.generation_config_path) if self.agent is None: self.agent = game_agent(self.llm_provider) else: self.agent.reset_provider(self.llm_provider) def input_listener(self, event): # self.reset() count = 0 while not event.is_set() and not self.game.over: # 添加运行状态检查 if not self.game.new_action_event.is_set(): # 生成动作并通知主线程 game_info = self.game.get_game_info() self.agent.update_game_info(game_info) if not self.llm_provider: action = random.choice(self.game.valid_actions) time.sleep(0.5) else: action = self.agent.execute_action() self.game.current_action = action self.game.new_action_event.set() # 设置事件表示有新动作 else: count += 1 time.sleep(0.02) def run(self): game_module = importlib.import_module(self.game_module) game_class = getattr(game_module, self.game_class) self.game = game_class(self.output_dir) self.game.run(self.input_listener) scores = self.game.get_score() token_usage = self.llm_provider.get_tokens_usage() # merge scores and token_usage scores.update(token_usage) return scores, self.game.game_frames def pipeline_shutdown(self): if config.save_response: self.agent.save_response_record() self.agent = None print("GamePipelineRunner shutdown.") def entry(args, run_name=""): config.output_dir = args.output_dir config.save_response = args.save_response config.extra_config = vars(args) print("Extra Config:", config.extra_config) pipelineRunner = GamePipelineRunner(args) pickle_record_path = os.path.join(pipelineRunner.record_path, config.env_short_name, f"record_{run_name}.pickle") directory = os.path.dirname(pickle_record_path) if not os.path.exists(directory): os.makedirs(directory) if os.path.exists(pickle_record_path): scores = pickle.load(open(pickle_record_path, "rb")) else: scores = [] pickle.dump(scores, open(pickle_record_path, "wb")) with tqdm.tqdm(total=args.test_rounds, initial=len(scores), desc=run_name) as pbar: for test_round in range(len(scores), args.test_rounds): pipelineRunner.reset() seed_everything(test_round) config.pbar = pbar now_avg_score = 0 if len(scores) == 0 else sum([score["score"] for score in scores])/len(scores) now_avg_score = round(now_avg_score, 2) config.now_avg_score = now_avg_score score_dict, game_frames = pipelineRunner.run() pbar.update() scores.append(score_dict) pipelineRunner.pipeline_shutdown() game_frames_gif = os.path.join(pipelineRunner.output_dir, datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") + ".gif") img_to_gif(game_frames, game_frames_gif) log_dict = {"round" : test_round} # 按dict 统计分数 for key in score_dict: log_dict.update({ run_name + "_" + key: score_dict[key], run_name + "_avg_" + key: sum([score[key] for score in scores])/len(scores) }) log_dict.update({"image": wandb.Image(game_frames_gif)}) wandb.log(log_dict) pickle.dump(scores, open(pickle_record_path, "wb")) now_avg_score = 0 if len(scores) == 0 else sum([score["score"] for score in scores])/len(scores) now_avg_score = round(now_avg_score, 2) config.now_avg_score = now_avg_score pbar.set_postfix( avg_score=config.now_avg_score ) return scores