File size: 2,031 Bytes
e53fda1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import argparse
import importlib
import os
import threading
import time

import pandas as pd
from runner.game_runner import entry
from utils.config import Config
import configparser
import os

config = Config()

def main(args):
  
    lmm_config = configparser.ConfigParser()
    lmm_config.read(args.llmProviderConfig)
    
    try:
      model_name = lmm_config.get('lmm', 'model_path')
      model_name = model_name.replace("/", "_")
    except:
      model_name = lmm_config.get('lmm', 'model_name')


def get_args_parser():

    parser = argparse.ArgumentParser("Cradle Agent Runner")
    parser.add_argument("--llmProviderConfig", type=str, default="./config/gpt_server_config.ini", help="The path to the LLM provider config file.")
    parser.add_argument("--gameEnvConfig", type=str, default="./config/env_config/env_config_race.json", help="The path to the environment config file.")
    parser.add_argument("--levelConfig", type=str, default="./config/level_config/racegame/level1.json", help="The path to the level config file.")
    
    parser.add_argument("--generationConfig", type=str, default="./config/generation_config.ini", help="The path to the swift generation config file.")
    
    parser.add_argument("--test_rounds", type=int, default=1, help="Rounds to test the game.")
    
    parser.add_argument("--output_dir", type=str, default="./runs", help="The path to output the results and log.")
    
    return parser


def get_local_rank():
  if "LOCAL_RANK" in os.environ:
    return int(os.environ["LOCAL_RANK"])
  if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
    return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
  return None


if __name__ == '__main__':
    local_rank =  get_local_rank()
    if not local_rank or local_rank == 0:
        parser = get_args_parser()
        args = parser.parse_args()

        config.load_env_config(args.gameEnvConfig)
        config.load_level_config(args.levelConfig)

        main(args)
    else:
        print(local_rank)
        print("process killed.")