Spaces:
Running
Running
File size: 6,279 Bytes
504b2e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
from datetime import datetime
import importlib
import os
import pathlib
import queue
import random
import threading
import time
import tqdm
from agent.game_agent import game_agent
from utils.game_utils import seed_everything
import wandb
from provider.ProviderFactory import ProviderFactory
from utils.calculate_log import calculate_statistics, extract_scores
from utils.config import Config
from utils.encoding_utils import encode_data_to_base64_path
from utils.file_utils import assemble_project_path, get_all_files, img_to_gif, run_path_construct
from utils.json_utils import parse_semi_formatted_text
from utils.lmm_utils import assemble_prompt
from utils.planner_utils import _extract_keys_from_template
import pickle
config = Config()
class GamePipelineRunner():
def __init__(self,
args):
# TODO add more mesages in run path. e.g. guide/instruction/..
run_path = run_path_construct(
args.output_dir,
config.env_short_name,
args.levelConfig.split('/')[-1].split(".")[0],
args.llmProviderConfig.split('/')[-1].split(".")[0],
datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")
)
self.record_path = args.output_dir
self.output_dir = assemble_project_path(run_path)
pathlib.Path(self.output_dir).mkdir(parents=True, exist_ok=True)
self.env_name = config.env_name
self.game_module = config.game_module
self.game_class = config.game_class
self.llm_provider_config_path = args.llmProviderConfig
self.llm_provider = None
self.generation_config_path = args.generationConfig if hasattr(args, 'generationConfig') else None
self.agent = None
self.reset()
def reset(self):
if self.llm_provider:
self.llm_provider.reset()
else:
self.llm_provider = ProviderFactory.getProvider(self.llm_provider_config_path, self.generation_config_path)
if self.agent is None:
self.agent = game_agent(self.llm_provider)
else:
self.agent.reset_provider(self.llm_provider)
def input_listener(self, event):
# self.reset()
count = 0
while not event.is_set() and not self.game.over: # 添加运行状态检查
if not self.game.new_action_event.is_set():
# 生成动作并通知主线程
game_info = self.game.get_game_info()
self.agent.update_game_info(game_info)
if not self.llm_provider:
action = random.choice(self.game.valid_actions)
time.sleep(0.5)
else:
action = self.agent.execute_action()
self.game.current_action = action
self.game.new_action_event.set() # 设置事件表示有新动作
else:
count += 1
time.sleep(0.02)
def run(self):
game_module = importlib.import_module(self.game_module)
game_class = getattr(game_module, self.game_class)
self.game = game_class(self.output_dir)
self.game.run(self.input_listener)
scores = self.game.get_score()
token_usage = self.llm_provider.get_tokens_usage()
# merge scores and token_usage
scores.update(token_usage)
return scores, self.game.game_frames
def pipeline_shutdown(self):
if config.save_response:
self.agent.save_response_record()
self.agent = None
print("GamePipelineRunner shutdown.")
def entry(args, run_name=""):
config.output_dir = args.output_dir
config.save_response = args.save_response
config.extra_config = vars(args)
print("Extra Config:", config.extra_config)
pipelineRunner = GamePipelineRunner(args)
pickle_record_path = os.path.join(pipelineRunner.record_path, config.env_short_name, f"record_{run_name}.pickle")
directory = os.path.dirname(pickle_record_path)
if not os.path.exists(directory):
os.makedirs(directory)
if os.path.exists(pickle_record_path):
scores = pickle.load(open(pickle_record_path, "rb"))
else:
scores = []
pickle.dump(scores, open(pickle_record_path, "wb"))
with tqdm.tqdm(total=args.test_rounds, initial=len(scores), desc=run_name) as pbar:
for test_round in range(len(scores), args.test_rounds):
pipelineRunner.reset()
seed_everything(test_round)
config.pbar = pbar
now_avg_score = 0 if len(scores) == 0 else sum([score["score"] for score in scores])/len(scores)
now_avg_score = round(now_avg_score, 2)
config.now_avg_score = now_avg_score
score_dict, game_frames = pipelineRunner.run()
pbar.update()
scores.append(score_dict)
pipelineRunner.pipeline_shutdown()
game_frames_gif = os.path.join(pipelineRunner.output_dir, datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f") + ".gif")
img_to_gif(game_frames, game_frames_gif)
log_dict = {"round" : test_round}
# 按dict 统计分数
for key in score_dict:
log_dict.update({
run_name + "_" + key: score_dict[key],
run_name + "_avg_" + key: sum([score[key] for score in scores])/len(scores)
})
log_dict.update({"image": wandb.Image(game_frames_gif)})
wandb.log(log_dict)
pickle.dump(scores, open(pickle_record_path, "wb"))
now_avg_score = 0 if len(scores) == 0 else sum([score["score"] for score in scores])/len(scores)
now_avg_score = round(now_avg_score, 2)
config.now_avg_score = now_avg_score
pbar.set_postfix(
avg_score=config.now_avg_score
)
return scores |