Spaces:
Sleeping
Sleeping
| import ding.config | |
| from .a2c import A2CAgent | |
| from .c51 import C51Agent | |
| from .ddpg import DDPGAgent | |
| from .dqn import DQNAgent | |
| from .pg import PGAgent | |
| from .ppof import PPOF | |
| from .ppo_offpolicy import PPOOffPolicyAgent | |
| from .sac import SACAgent | |
| from .sql import SQLAgent | |
| from .td3 import TD3Agent | |
| supported_algo = dict( | |
| A2C=A2CAgent, | |
| C51=C51Agent, | |
| DDPG=DDPGAgent, | |
| DQN=DQNAgent, | |
| PG=PGAgent, | |
| PPOF=PPOF, | |
| PPOOffPolicy=PPOOffPolicyAgent, | |
| SAC=SACAgent, | |
| SQL=SQLAgent, | |
| TD3=TD3Agent, | |
| ) | |
| supported_algo_list = list(supported_algo.keys()) | |
| def env_supported(algo: str = None) -> list: | |
| """ | |
| return list of the envs that supported by di-engine. | |
| """ | |
| if algo is not None: | |
| if algo.upper() == "A2C": | |
| return list(ding.config.example.A2C.supported_env.keys()) | |
| elif algo.upper() == "C51": | |
| return list(ding.config.example.C51.supported_env.keys()) | |
| elif algo.upper() == "DDPG": | |
| return list(ding.config.example.DDPG.supported_env.keys()) | |
| elif algo.upper() == "DQN": | |
| return list(ding.config.example.DQN.supported_env.keys()) | |
| elif algo.upper() == "PG": | |
| return list(ding.config.example.PG.supported_env.keys()) | |
| elif algo.upper() == "PPOF": | |
| return list(ding.config.example.PPOF.supported_env.keys()) | |
| elif algo.upper() == "PPOOFFPOLICY": | |
| return list(ding.config.example.PPOOffPolicy.supported_env.keys()) | |
| elif algo.upper() == "SAC": | |
| return list(ding.config.example.SAC.supported_env.keys()) | |
| elif algo.upper() == "SQL": | |
| return list(ding.config.example.SQL.supported_env.keys()) | |
| elif algo.upper() == "TD3": | |
| return list(ding.config.example.TD3.supported_env.keys()) | |
| else: | |
| raise ValueError("The algo {} is not supported by di-engine.".format(algo)) | |
| else: | |
| supported_env = set() | |
| supported_env.update(ding.config.example.A2C.supported_env.keys()) | |
| supported_env.update(ding.config.example.C51.supported_env.keys()) | |
| supported_env.update(ding.config.example.DDPG.supported_env.keys()) | |
| supported_env.update(ding.config.example.DQN.supported_env.keys()) | |
| supported_env.update(ding.config.example.PG.supported_env.keys()) | |
| supported_env.update(ding.config.example.PPOF.supported_env.keys()) | |
| supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys()) | |
| supported_env.update(ding.config.example.SAC.supported_env.keys()) | |
| supported_env.update(ding.config.example.SQL.supported_env.keys()) | |
| supported_env.update(ding.config.example.TD3.supported_env.keys()) | |
| # return the list of the envs | |
| return list(supported_env) | |
| supported_env = env_supported() | |
| def algo_supported(env_id: str = None) -> list: | |
| """ | |
| return list of the algos that supported by di-engine. | |
| """ | |
| if env_id is not None: | |
| algo = [] | |
| if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]: | |
| algo.append("A2C") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]: | |
| algo.append("C51") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]: | |
| algo.append("DDPG") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]: | |
| algo.append("DQN") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]: | |
| algo.append("PG") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]: | |
| algo.append("PPOF") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]: | |
| algo.append("PPOOffPolicy") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]: | |
| algo.append("SAC") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]: | |
| algo.append("SQL") | |
| if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]: | |
| algo.append("TD3") | |
| if len(algo) == 0: | |
| raise ValueError("The env {} is not supported by di-engine.".format(env_id)) | |
| return algo | |
| else: | |
| return supported_algo_list | |
| def is_supported(env_id: str = None, algo: str = None) -> bool: | |
| """ | |
| Check if the env-algo pair is supported by di-engine. | |
| """ | |
| if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]: | |
| if algo is not None and algo.upper() in supported_algo_list: | |
| if env_id.upper() in env_supported(algo): | |
| return True | |
| else: | |
| return False | |
| elif algo is None: | |
| return True | |
| else: | |
| return False | |
| elif env_id is None: | |
| if algo is not None and algo.upper() in supported_algo_list: | |
| return True | |
| elif algo is None: | |
| raise ValueError("Please specify the env or algo.") | |
| else: | |
| return False | |
| else: | |
| return False | |