V-MAGE-DEMO / runner.py
Fengx1n's picture
Initial DEMO
e53fda1
import argparse
import importlib
import os
import threading
import time
import pandas as pd
from runner.game_runner import entry
from utils.config import Config
import configparser
import os
config = Config()
def main(args):
lmm_config = configparser.ConfigParser()
lmm_config.read(args.llmProviderConfig)
try:
model_name = lmm_config.get('lmm', 'model_path')
model_name = model_name.replace("/", "_")
except:
model_name = lmm_config.get('lmm', 'model_name')
def get_args_parser():
parser = argparse.ArgumentParser("Cradle Agent Runner")
parser.add_argument("--llmProviderConfig", type=str, default="./config/gpt_server_config.ini", help="The path to the LLM provider config file.")
parser.add_argument("--gameEnvConfig", type=str, default="./config/env_config/env_config_race.json", help="The path to the environment config file.")
parser.add_argument("--levelConfig", type=str, default="./config/level_config/racegame/level1.json", help="The path to the level config file.")
parser.add_argument("--generationConfig", type=str, default="./config/generation_config.ini", help="The path to the swift generation config file.")
parser.add_argument("--test_rounds", type=int, default=1, help="Rounds to test the game.")
parser.add_argument("--output_dir", type=str, default="./runs", help="The path to output the results and log.")
return parser
def get_local_rank():
if "LOCAL_RANK" in os.environ:
return int(os.environ["LOCAL_RANK"])
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
return None
if __name__ == '__main__':
local_rank = get_local_rank()
if not local_rank or local_rank == 0:
parser = get_args_parser()
args = parser.parse_args()
config.load_env_config(args.gameEnvConfig)
config.load_level_config(args.levelConfig)
main(args)
else:
print(local_rank)
print("process killed.")