Spaces:
Running
Running
| from typing import Any, Dict, List, Optional, Tuple, Type, Union | |
| import gym | |
| import numpy as np | |
| import torch as th | |
| from torch.nn import functional as F | |
| from stable_baselines3.common import logger | |
| from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm | |
| from stable_baselines3.common.preprocessing import maybe_transpose | |
| from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule | |
| from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update | |
| from stable_baselines3.dqn.policies import DQNPolicy | |
| class DQN(OffPolicyAlgorithm): | |
| """ | |
| Deep Q-Network (DQN) | |
| Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 | |
| Default hyperparameters are taken from the nature paper, | |
| except for the optimizer and learning rate that were taken from Stable Baselines defaults. | |
| :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) | |
| :param env: The environment to learn from (if registered in Gym, can be str) | |
| :param learning_rate: The learning rate, it can be a function | |
| of the current progress remaining (from 1 to 0) | |
| :param buffer_size: size of the replay buffer | |
| :param learning_starts: how many steps of the model to collect transitions for before learning starts | |
| :param batch_size: Minibatch size for each gradient update | |
| :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update | |
| :param gamma: the discount factor | |
| :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit | |
| like ``(5, "step")`` or ``(2, "episode")``. | |
| :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) | |
| Set to ``-1`` means to do as many gradient steps as steps done in the environment | |
| during the rollout. | |
| :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer | |
| at a cost of more complexity. | |
| See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 | |
| :param target_update_interval: update the target network every ``target_update_interval`` | |
| environment steps. | |
| :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced | |
| :param exploration_initial_eps: initial value of random action probability | |
| :param exploration_final_eps: final value of random action probability | |
| :param max_grad_norm: The maximum value for the gradient clipping | |
| :param tensorboard_log: the log location for tensorboard (if None, no logging) | |
| :param create_eval_env: Whether to create a second environment that will be | |
| used for evaluating the agent periodically. (Only available when passing string for the environment) | |
| :param policy_kwargs: additional arguments to be passed to the policy on creation | |
| :param verbose: the verbosity level: 0 no output, 1 info, 2 debug | |
| :param seed: Seed for the pseudo random generators | |
| :param device: Device (cpu, cuda, ...) on which the code should be run. | |
| Setting it to auto, the code will be run on the GPU if possible. | |
| :param _init_setup_model: Whether or not to build the network at the creation of the instance | |
| """ | |
| def __init__( | |
| self, | |
| policy: Union[str, Type[DQNPolicy]], | |
| env: Union[GymEnv, str], | |
| learning_rate: Union[float, Schedule] = 1e-4, | |
| buffer_size: int = 1000000, | |
| learning_starts: int = 50000, | |
| batch_size: Optional[int] = 32, | |
| tau: float = 1.0, | |
| gamma: float = 0.99, | |
| train_freq: Union[int, Tuple[int, str]] = 4, | |
| gradient_steps: int = 1, | |
| optimize_memory_usage: bool = False, | |
| target_update_interval: int = 10000, | |
| exploration_fraction: float = 0.1, | |
| exploration_initial_eps: float = 1.0, | |
| exploration_final_eps: float = 0.05, | |
| max_grad_norm: float = 10, | |
| tensorboard_log: Optional[str] = None, | |
| create_eval_env: bool = False, | |
| policy_kwargs: Optional[Dict[str, Any]] = None, | |
| verbose: int = 0, | |
| seed: Optional[int] = None, | |
| device: Union[th.device, str] = "auto", | |
| _init_setup_model: bool = True, | |
| ): | |
| super(DQN, self).__init__( | |
| policy, | |
| env, | |
| DQNPolicy, | |
| learning_rate, | |
| buffer_size, | |
| learning_starts, | |
| batch_size, | |
| tau, | |
| gamma, | |
| train_freq, | |
| gradient_steps, | |
| action_noise=None, # No action noise | |
| policy_kwargs=policy_kwargs, | |
| tensorboard_log=tensorboard_log, | |
| verbose=verbose, | |
| device=device, | |
| create_eval_env=create_eval_env, | |
| seed=seed, | |
| sde_support=False, | |
| optimize_memory_usage=optimize_memory_usage, | |
| supported_action_spaces=(gym.spaces.Discrete,), | |
| ) | |
| self.exploration_initial_eps = exploration_initial_eps | |
| self.exploration_final_eps = exploration_final_eps | |
| self.exploration_fraction = exploration_fraction | |
| self.target_update_interval = target_update_interval | |
| self.max_grad_norm = max_grad_norm | |
| # "epsilon" for the epsilon-greedy exploration | |
| self.exploration_rate = 0.0 | |
| # Linear schedule will be defined in `_setup_model()` | |
| self.exploration_schedule = None | |
| self.q_net, self.q_net_target = None, None | |
| if _init_setup_model: | |
| self._setup_model() | |
| def _setup_model(self) -> None: | |
| super(DQN, self)._setup_model() | |
| self._create_aliases() | |
| self.exploration_schedule = get_linear_fn( | |
| self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction | |
| ) | |
| def _create_aliases(self) -> None: | |
| self.q_net = self.policy.q_net | |
| self.q_net_target = self.policy.q_net_target | |
| def _on_step(self) -> None: | |
| """ | |
| Update the exploration rate and target network if needed. | |
| This method is called in ``collect_rollouts()`` after each step in the environment. | |
| """ | |
| if self.num_timesteps % self.target_update_interval == 0: | |
| polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) | |
| self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) | |
| logger.record("rollout/exploration rate", self.exploration_rate) | |
| def train(self, gradient_steps: int, batch_size: int = 100) -> None: | |
| # Update learning rate according to schedule | |
| self._update_learning_rate(self.policy.optimizer) | |
| losses = [] | |
| for _ in range(gradient_steps): | |
| # Sample replay buffer | |
| replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) | |
| with th.no_grad(): | |
| # Compute the next Q-values using the target network | |
| next_q_values = self.q_net_target(replay_data.next_observations) | |
| # Follow greedy policy: use the one with the highest value | |
| next_q_values, _ = next_q_values.max(dim=1) | |
| # Avoid potential broadcast issue | |
| next_q_values = next_q_values.reshape(-1, 1) | |
| # 1-step TD target | |
| target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values | |
| # Get current Q-values estimates | |
| current_q_values = self.q_net(replay_data.observations) | |
| # Retrieve the q-values for the actions from the replay buffer | |
| current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) | |
| # Compute Huber loss (less sensitive to outliers) | |
| loss = F.smooth_l1_loss(current_q_values, target_q_values) | |
| losses.append(loss.item()) | |
| # Optimize the policy | |
| self.policy.optimizer.zero_grad() | |
| loss.backward() | |
| # Clip gradient norm | |
| th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) | |
| self.policy.optimizer.step() | |
| # Increase update counter | |
| self._n_updates += gradient_steps | |
| logger.record("train/n_updates", self._n_updates, exclude="tensorboard") | |
| logger.record("train/loss", np.mean(losses)) | |
| def predict( | |
| self, | |
| observation: np.ndarray, | |
| state: Optional[np.ndarray] = None, | |
| mask: Optional[np.ndarray] = None, | |
| deterministic: bool = False, | |
| ) -> Tuple[np.ndarray, Optional[np.ndarray]]: | |
| """ | |
| Overrides the base_class predict function to include epsilon-greedy exploration. | |
| :param observation: the input observation | |
| :param state: The last states (can be None, used in recurrent policies) | |
| :param mask: The last masks (can be None, used in recurrent policies) | |
| :param deterministic: Whether or not to return deterministic actions. | |
| :return: the model's action and the next state | |
| (used in recurrent policies) | |
| """ | |
| if not deterministic and np.random.rand() < self.exploration_rate: | |
| if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): | |
| n_batch = observation.shape[0] | |
| action = np.array([self.action_space.sample() for _ in range(n_batch)]) | |
| else: | |
| action = np.array(self.action_space.sample()) | |
| else: | |
| action, state = self.policy.predict(observation, state, mask, deterministic) | |
| return action, state | |
| def learn( | |
| self, | |
| total_timesteps: int, | |
| callback: MaybeCallback = None, | |
| log_interval: int = 4, | |
| eval_env: Optional[GymEnv] = None, | |
| eval_freq: int = -1, | |
| n_eval_episodes: int = 5, | |
| tb_log_name: str = "DQN", | |
| eval_log_path: Optional[str] = None, | |
| reset_num_timesteps: bool = True, | |
| ) -> OffPolicyAlgorithm: | |
| return super(DQN, self).learn( | |
| total_timesteps=total_timesteps, | |
| callback=callback, | |
| log_interval=log_interval, | |
| eval_env=eval_env, | |
| eval_freq=eval_freq, | |
| n_eval_episodes=n_eval_episodes, | |
| tb_log_name=tb_log_name, | |
| eval_log_path=eval_log_path, | |
| reset_num_timesteps=reset_num_timesteps, | |
| ) | |
| def _excluded_save_params(self) -> List[str]: | |
| return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"] | |
| def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: | |
| state_dicts = ["policy", "policy.optimizer"] | |
| return state_dicts, [] | |