pushing model
Browse files- .gitattributes +1 -0
- README.md +85 -0
- events.out.tfevents.1702930553.4090-171.189153.0 +3 -0
- poetry.lock +0 -0
- ppo_fix_continuous_action.cleanrl_model +0 -0
- ppo_fix_continuous_action.py +572 -0
- pyproject.toml +108 -0
- replay.mp4 +0 -0
- videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-0.mp4 +0 -0
- videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-1.mp4 +3 -0
- videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-8.mp4 +0 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-1.mp4 filter=lfs diff=lfs merge=lfs -text
         | 
    	
        README.md
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            tags:
         | 
| 3 | 
            +
            - Walker2d-v4
         | 
| 4 | 
            +
            - deep-reinforcement-learning
         | 
| 5 | 
            +
            - reinforcement-learning
         | 
| 6 | 
            +
            - custom-implementation
         | 
| 7 | 
            +
            library_name: cleanrl
         | 
| 8 | 
            +
            model-index:
         | 
| 9 | 
            +
            - name: PPO
         | 
| 10 | 
            +
              results:
         | 
| 11 | 
            +
              - task:
         | 
| 12 | 
            +
                  type: reinforcement-learning
         | 
| 13 | 
            +
                  name: reinforcement-learning
         | 
| 14 | 
            +
                dataset:
         | 
| 15 | 
            +
                  name: Walker2d-v4
         | 
| 16 | 
            +
                  type: Walker2d-v4
         | 
| 17 | 
            +
                metrics:
         | 
| 18 | 
            +
                - type: mean_reward
         | 
| 19 | 
            +
                  value: 3415.35 +/- 1633.96
         | 
| 20 | 
            +
                  name: mean_reward
         | 
| 21 | 
            +
                  verified: false
         | 
| 22 | 
            +
            ---
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # (CleanRL) **PPO** Agent Playing **Walker2d-v4**
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            This is a trained model of a PPO agent playing Walker2d-v4.
         | 
| 27 | 
            +
            The model was trained by using [CleanRL](https://github.com/vwxyzjn/cleanrl) and the most up-to-date training code can be
         | 
| 28 | 
            +
            found [here](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_fix_continuous_action.py).
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ## Get Started
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            To use this model, please install the `cleanrl` package with the following command:
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            ```
         | 
| 35 | 
            +
            pip install "cleanrl[ppo_fix_continuous_action]"
         | 
| 36 | 
            +
            python -m cleanrl_utils.enjoy --exp-name ppo_fix_continuous_action --env-id Walker2d-v4
         | 
| 37 | 
            +
            ```
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            Please refer to the [documentation](https://docs.cleanrl.dev/get-started/zoo/) for more detail.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            ## Command to reproduce the training
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            ```bash
         | 
| 45 | 
            +
            curl -OL https://huggingface.co/sdpkjc/Walker2d-v4-ppo_fix_continuous_action-seed2/raw/main/ppo_fix_continuous_action.py
         | 
| 46 | 
            +
            curl -OL https://huggingface.co/sdpkjc/Walker2d-v4-ppo_fix_continuous_action-seed2/raw/main/pyproject.toml
         | 
| 47 | 
            +
            curl -OL https://huggingface.co/sdpkjc/Walker2d-v4-ppo_fix_continuous_action-seed2/raw/main/poetry.lock
         | 
| 48 | 
            +
            poetry install --all-extras
         | 
| 49 | 
            +
            python ppo_fix_continuous_action.py --save-model --upload-model --hf-entity sdpkjc --env-id Walker2d-v4 --seed 2 --track
         | 
| 50 | 
            +
            ```
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Hyperparameters
         | 
| 53 | 
            +
            ```python
         | 
| 54 | 
            +
            {'anneal_lr': True,
         | 
| 55 | 
            +
             'batch_size': 2048,
         | 
| 56 | 
            +
             'capture_video': False,
         | 
| 57 | 
            +
             'clip_coef': 0.2,
         | 
| 58 | 
            +
             'clip_vloss': True,
         | 
| 59 | 
            +
             'cuda': True,
         | 
| 60 | 
            +
             'ent_coef': 0.0,
         | 
| 61 | 
            +
             'env_id': 'Walker2d-v4',
         | 
| 62 | 
            +
             'exp_name': 'ppo_fix_continuous_action',
         | 
| 63 | 
            +
             'gae_lambda': 0.95,
         | 
| 64 | 
            +
             'gamma': 0.99,
         | 
| 65 | 
            +
             'hf_entity': 'sdpkjc',
         | 
| 66 | 
            +
             'learning_rate': 0.0003,
         | 
| 67 | 
            +
             'max_grad_norm': 0.5,
         | 
| 68 | 
            +
             'minibatch_size': 64,
         | 
| 69 | 
            +
             'norm_adv': True,
         | 
| 70 | 
            +
             'num_envs': 1,
         | 
| 71 | 
            +
             'num_minibatches': 32,
         | 
| 72 | 
            +
             'num_steps': 2048,
         | 
| 73 | 
            +
             'save_model': True,
         | 
| 74 | 
            +
             'seed': 2,
         | 
| 75 | 
            +
             'target_kl': None,
         | 
| 76 | 
            +
             'torch_deterministic': True,
         | 
| 77 | 
            +
             'total_timesteps': 1000000,
         | 
| 78 | 
            +
             'track': True,
         | 
| 79 | 
            +
             'update_epochs': 10,
         | 
| 80 | 
            +
             'upload_model': True,
         | 
| 81 | 
            +
             'vf_coef': 0.5,
         | 
| 82 | 
            +
             'wandb_entity': None,
         | 
| 83 | 
            +
             'wandb_project_name': 'cleanRL'}
         | 
| 84 | 
            +
            ```
         | 
| 85 | 
            +
                
         | 
    	
        events.out.tfevents.1702930553.4090-171.189153.0
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:727fc9f06920fbcbba566bbb5ee96e4916effce2d3d335f636518141f3c193e1
         | 
| 3 | 
            +
            size 552953
         | 
    	
        poetry.lock
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        ppo_fix_continuous_action.cleanrl_model
    ADDED
    
    | Binary file (50.2 kB). View file | 
|  | 
    	
        ppo_fix_continuous_action.py
    ADDED
    
    | @@ -0,0 +1,572 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_continuous_actionpy
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import copy
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
            +
            from distutils.util import strtobool
         | 
| 8 | 
            +
            from typing import Callable
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            import gymnasium as gym
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
            import torch
         | 
| 13 | 
            +
            import torch.nn as nn
         | 
| 14 | 
            +
            import torch.optim as optim
         | 
| 15 | 
            +
            from torch.distributions.normal import Normal
         | 
| 16 | 
            +
            from torch.utils.tensorboard import SummaryWriter
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def parse_args():
         | 
| 20 | 
            +
                # fmt: off
         | 
| 21 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 22 | 
            +
                parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"),
         | 
| 23 | 
            +
                    help="the name of this experiment")
         | 
| 24 | 
            +
                parser.add_argument("--seed", type=int, default=1,
         | 
| 25 | 
            +
                    help="seed of the experiment")
         | 
| 26 | 
            +
                parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 27 | 
            +
                    help="if toggled, `torch.backends.cudnn.deterministic=False`")
         | 
| 28 | 
            +
                parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 29 | 
            +
                    help="if toggled, cuda will be enabled by default")
         | 
| 30 | 
            +
                parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 31 | 
            +
                    help="if toggled, this experiment will be tracked with Weights and Biases")
         | 
| 32 | 
            +
                parser.add_argument("--wandb-project-name", type=str, default="cleanRL",
         | 
| 33 | 
            +
                    help="the wandb's project name")
         | 
| 34 | 
            +
                parser.add_argument("--wandb-entity", type=str, default=None,
         | 
| 35 | 
            +
                    help="the entity (team) of wandb's project")
         | 
| 36 | 
            +
                parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 37 | 
            +
                    help="whether to capture videos of the agent performances (check out `videos` folder)")
         | 
| 38 | 
            +
                parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 39 | 
            +
                    help="whether to save model into the `runs/{run_name}` folder")
         | 
| 40 | 
            +
                parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
         | 
| 41 | 
            +
                    help="whether to upload the saved model to huggingface")
         | 
| 42 | 
            +
                parser.add_argument("--hf-entity", type=str, default="",
         | 
| 43 | 
            +
                    help="the user or org name of the model repository from the Hugging Face Hub")
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                # Algorithm specific arguments
         | 
| 46 | 
            +
                parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
         | 
| 47 | 
            +
                    help="the id of the environment")
         | 
| 48 | 
            +
                parser.add_argument("--total-timesteps", type=int, default=1000000,
         | 
| 49 | 
            +
                    help="total timesteps of the experiments")
         | 
| 50 | 
            +
                parser.add_argument("--learning-rate", type=float, default=3e-4,
         | 
| 51 | 
            +
                    help="the learning rate of the optimizer")
         | 
| 52 | 
            +
                parser.add_argument("--num-envs", type=int, default=1,
         | 
| 53 | 
            +
                    help="the number of parallel game environments")
         | 
| 54 | 
            +
                parser.add_argument("--num-steps", type=int, default=2048,
         | 
| 55 | 
            +
                    help="the number of steps to run in each environment per policy rollout")
         | 
| 56 | 
            +
                parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 57 | 
            +
                    help="Toggle learning rate annealing for policy and value networks")
         | 
| 58 | 
            +
                parser.add_argument("--gamma", type=float, default=0.99,
         | 
| 59 | 
            +
                    help="the discount factor gamma")
         | 
| 60 | 
            +
                parser.add_argument("--gae-lambda", type=float, default=0.95,
         | 
| 61 | 
            +
                    help="the lambda for the general advantage estimation")
         | 
| 62 | 
            +
                parser.add_argument("--num-minibatches", type=int, default=32,
         | 
| 63 | 
            +
                    help="the number of mini-batches")
         | 
| 64 | 
            +
                parser.add_argument("--update-epochs", type=int, default=10,
         | 
| 65 | 
            +
                    help="the K epochs to update the policy")
         | 
| 66 | 
            +
                parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 67 | 
            +
                    help="Toggles advantages normalization")
         | 
| 68 | 
            +
                parser.add_argument("--clip-coef", type=float, default=0.2,
         | 
| 69 | 
            +
                    help="the surrogate clipping coefficient")
         | 
| 70 | 
            +
                parser.add_argument("--clip-vloss", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
         | 
| 71 | 
            +
                    help="Toggles whether or not to use a clipped loss for the value function, as per the paper.")
         | 
| 72 | 
            +
                parser.add_argument("--ent-coef", type=float, default=0.0,
         | 
| 73 | 
            +
                    help="coefficient of the entropy")
         | 
| 74 | 
            +
                parser.add_argument("--vf-coef", type=float, default=0.5,
         | 
| 75 | 
            +
                    help="coefficient of the value function")
         | 
| 76 | 
            +
                parser.add_argument("--max-grad-norm", type=float, default=0.5,
         | 
| 77 | 
            +
                    help="the maximum norm for the gradient clipping")
         | 
| 78 | 
            +
                parser.add_argument("--target-kl", type=float, default=None,
         | 
| 79 | 
            +
                    help="the target KL divergence threshold")
         | 
| 80 | 
            +
                args = parser.parse_args()
         | 
| 81 | 
            +
                args.batch_size = int(args.num_envs * args.num_steps)
         | 
| 82 | 
            +
                args.minibatch_size = int(args.batch_size // args.num_minibatches)
         | 
| 83 | 
            +
                # fmt: on
         | 
| 84 | 
            +
                return args
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            # https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/normalize.py
         | 
| 88 | 
            +
            class RunningMeanStd(nn.Module):
         | 
| 89 | 
            +
                def __init__(self, epsilon=1e-4, shape=()):
         | 
| 90 | 
            +
                    super().__init__()
         | 
| 91 | 
            +
                    self.register_buffer("mean", torch.zeros(shape, dtype=torch.float64))
         | 
| 92 | 
            +
                    self.register_buffer("var", torch.ones(shape, dtype=torch.float64))
         | 
| 93 | 
            +
                    self.register_buffer("count", torch.tensor(epsilon, dtype=torch.float64))
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def update(self, x):
         | 
| 96 | 
            +
                    x = torch.as_tensor(x, dtype=torch.float64).to(self.mean.device)
         | 
| 97 | 
            +
                    batch_mean = torch.mean(x, dim=0).to(self.mean.device)
         | 
| 98 | 
            +
                    batch_var = torch.var(x, dim=0, unbiased=False).to(self.mean.device)
         | 
| 99 | 
            +
                    batch_count = x.shape[0]
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    self.mean, self.var, self.count = update_mean_var_count_from_moments(
         | 
| 102 | 
            +
                        self.mean, self.var, self.count, batch_mean, batch_var, batch_count
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
             | 
| 105 | 
            +
             | 
| 106 | 
            +
            def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
         | 
| 107 | 
            +
                delta = batch_mean - mean
         | 
| 108 | 
            +
                tot_count = count + batch_count
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                new_mean = mean + delta * batch_count / tot_count
         | 
| 111 | 
            +
                m_a = var * count
         | 
| 112 | 
            +
                m_b = batch_var * batch_count
         | 
| 113 | 
            +
                M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
         | 
| 114 | 
            +
                new_var = M2 / tot_count
         | 
| 115 | 
            +
                new_count = tot_count
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                return new_mean, new_var, new_count
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
            +
            class NormalizeObservation(gym.Wrapper, gym.utils.RecordConstructorArgs):
         | 
| 121 | 
            +
                def __init__(self, env: gym.Env, epsilon: float = 1e-8):
         | 
| 122 | 
            +
                    gym.utils.RecordConstructorArgs.__init__(self, epsilon=epsilon)
         | 
| 123 | 
            +
                    gym.Wrapper.__init__(self, env)
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    try:
         | 
| 126 | 
            +
                        self.num_envs = self.get_wrapper_attr("num_envs")
         | 
| 127 | 
            +
                        self.is_vector_env = self.get_wrapper_attr("is_vector_env")
         | 
| 128 | 
            +
                    except AttributeError:
         | 
| 129 | 
            +
                        self.num_envs = 1
         | 
| 130 | 
            +
                        self.is_vector_env = False
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    if self.is_vector_env:
         | 
| 133 | 
            +
                        self.obs_rms = RunningMeanStd(shape=self.single_observation_space.shape)
         | 
| 134 | 
            +
                    else:
         | 
| 135 | 
            +
                        self.obs_rms = RunningMeanStd(shape=self.observation_space.shape)
         | 
| 136 | 
            +
                    self.epsilon = epsilon
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    self.enable = True
         | 
| 139 | 
            +
                    self.freeze = False
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def step(self, action):
         | 
| 142 | 
            +
                    obs, rews, terminateds, truncateds, infos = self.env.step(action)
         | 
| 143 | 
            +
                    if self.is_vector_env:
         | 
| 144 | 
            +
                        obs = self.normalize(obs)
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        obs = self.normalize(np.array([obs]))[0]
         | 
| 147 | 
            +
                    return obs, rews, terminateds, truncateds, infos
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def reset(self, **kwargs):
         | 
| 150 | 
            +
                    obs, info = self.env.reset(**kwargs)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    if self.is_vector_env:
         | 
| 153 | 
            +
                        return self.normalize(obs), info
         | 
| 154 | 
            +
                    else:
         | 
| 155 | 
            +
                        return self.normalize(np.array([obs]))[0], info
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def normalize(self, obs):
         | 
| 158 | 
            +
                    if not self.freeze:
         | 
| 159 | 
            +
                        self.obs_rms.update(obs)
         | 
| 160 | 
            +
                    if self.enable:
         | 
| 161 | 
            +
                        return (obs - self.obs_rms.mean.cpu().numpy()) / np.sqrt(self.obs_rms.var.cpu().numpy() + self.epsilon)
         | 
| 162 | 
            +
                    return obs
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            class NormalizeReward(gym.core.Wrapper, gym.utils.RecordConstructorArgs):
         | 
| 166 | 
            +
                def __init__(
         | 
| 167 | 
            +
                    self,
         | 
| 168 | 
            +
                    env: gym.Env,
         | 
| 169 | 
            +
                    gamma: float = 0.99,
         | 
| 170 | 
            +
                    epsilon: float = 1e-8,
         | 
| 171 | 
            +
                ):
         | 
| 172 | 
            +
                    gym.utils.RecordConstructorArgs.__init__(self, gamma=gamma, epsilon=epsilon)
         | 
| 173 | 
            +
                    gym.Wrapper.__init__(self, env)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    try:
         | 
| 176 | 
            +
                        self.num_envs = self.get_wrapper_attr("num_envs")
         | 
| 177 | 
            +
                        self.is_vector_env = self.get_wrapper_attr("is_vector_env")
         | 
| 178 | 
            +
                    except AttributeError:
         | 
| 179 | 
            +
                        self.num_envs = 1
         | 
| 180 | 
            +
                        self.is_vector_env = False
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    self.return_rms = RunningMeanStd(shape=())
         | 
| 183 | 
            +
                    self.returns = np.zeros(self.num_envs)
         | 
| 184 | 
            +
                    self.gamma = gamma
         | 
| 185 | 
            +
                    self.epsilon = epsilon
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    self.enable = True
         | 
| 188 | 
            +
                    self.freeze = False
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def step(self, action):
         | 
| 191 | 
            +
                    obs, rews, terminateds, truncateds, infos = self.env.step(action)
         | 
| 192 | 
            +
                    if not self.is_vector_env:
         | 
| 193 | 
            +
                        rews = np.array([rews])
         | 
| 194 | 
            +
                    self.returns = self.returns * self.gamma * (1 - terminateds) + rews
         | 
| 195 | 
            +
                    rews = self.normalize(rews)
         | 
| 196 | 
            +
                    if not self.is_vector_env:
         | 
| 197 | 
            +
                        rews = rews[0]
         | 
| 198 | 
            +
                    return obs, rews, terminateds, truncateds, infos
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                def reset(self, **kwargs):
         | 
| 201 | 
            +
                    self.returns = np.zeros(self.num_envs)
         | 
| 202 | 
            +
                    return self.env.reset(**kwargs)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                def normalize(self, rews):
         | 
| 205 | 
            +
                    if not self.freeze:
         | 
| 206 | 
            +
                        self.return_rms.update(self.returns)
         | 
| 207 | 
            +
                    if self.enable:
         | 
| 208 | 
            +
                        return rews / np.sqrt(self.return_rms.var.cpu().numpy() + self.epsilon)
         | 
| 209 | 
            +
                    return rews
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def get_returns(self):
         | 
| 212 | 
            +
                    return self.returns
         | 
| 213 | 
            +
             | 
| 214 | 
            +
             | 
| 215 | 
            +
            def evaluate(
         | 
| 216 | 
            +
                model_path: str,
         | 
| 217 | 
            +
                make_env: Callable,
         | 
| 218 | 
            +
                env_id: str,
         | 
| 219 | 
            +
                eval_episodes: int,
         | 
| 220 | 
            +
                run_name: str,
         | 
| 221 | 
            +
                Model: torch.nn.Module,
         | 
| 222 | 
            +
                device: torch.device = torch.device("cpu"),
         | 
| 223 | 
            +
                capture_video: bool = True,
         | 
| 224 | 
            +
            ):
         | 
| 225 | 
            +
                envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name)])
         | 
| 226 | 
            +
                agent = Model(envs).to(device)
         | 
| 227 | 
            +
                agent.load_state_dict(torch.load(model_path, map_location=device))
         | 
| 228 | 
            +
                agent.eval()
         | 
| 229 | 
            +
                envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, capture_video, run_name, agent.obs_rms)])
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                obs, _ = envs.reset()
         | 
| 232 | 
            +
                episodic_returns = []
         | 
| 233 | 
            +
                while len(episodic_returns) < eval_episodes:
         | 
| 234 | 
            +
                    actions, _, _, _ = agent.get_action_and_value(torch.Tensor(obs).to(device))
         | 
| 235 | 
            +
                    next_obs, _, _, _, infos = envs.step(actions.cpu().numpy())
         | 
| 236 | 
            +
                    if "final_info" in infos:
         | 
| 237 | 
            +
                        for info in infos["final_info"]:
         | 
| 238 | 
            +
                            if "episode" not in info:
         | 
| 239 | 
            +
                                continue
         | 
| 240 | 
            +
                            print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
         | 
| 241 | 
            +
                            episodic_returns += [info["episode"]["r"]]
         | 
| 242 | 
            +
                    obs = next_obs
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                return episodic_returns
         | 
| 245 | 
            +
             | 
| 246 | 
            +
             | 
| 247 | 
            +
            def make_env(env_id, idx, capture_video, run_name, gamma):
         | 
| 248 | 
            +
                def thunk():
         | 
| 249 | 
            +
                    if capture_video:
         | 
| 250 | 
            +
                        env = gym.make(env_id, render_mode="rgb_array")
         | 
| 251 | 
            +
                    else:
         | 
| 252 | 
            +
                        env = gym.make(env_id)
         | 
| 253 | 
            +
                    env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
         | 
| 254 | 
            +
                    env = gym.wrappers.RecordEpisodeStatistics(env)
         | 
| 255 | 
            +
                    if capture_video:
         | 
| 256 | 
            +
                        if idx == 0:
         | 
| 257 | 
            +
                            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
         | 
| 258 | 
            +
                    env = gym.wrappers.ClipAction(env)
         | 
| 259 | 
            +
                    env = NormalizeObservation(env)
         | 
| 260 | 
            +
                    env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
         | 
| 261 | 
            +
                    env = NormalizeReward(env, gamma=gamma)
         | 
| 262 | 
            +
                    env = gym.wrappers.TransformReward(env, lambda reward: np.clip(reward, -10, 10))
         | 
| 263 | 
            +
                    return env
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                return thunk
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            def make_eval_env(env_id, idx, capture_video, run_name, obs_rms=None):
         | 
| 269 | 
            +
                def thunk():
         | 
| 270 | 
            +
                    if capture_video:
         | 
| 271 | 
            +
                        env = gym.make(env_id, render_mode="rgb_array")
         | 
| 272 | 
            +
                    else:
         | 
| 273 | 
            +
                        env = gym.make(env_id)
         | 
| 274 | 
            +
                    env = gym.wrappers.FlattenObservation(env)  # deal with dm_control's Dict observation space
         | 
| 275 | 
            +
                    env = gym.wrappers.RecordEpisodeStatistics(env)
         | 
| 276 | 
            +
                    if capture_video:
         | 
| 277 | 
            +
                        if idx == 0:
         | 
| 278 | 
            +
                            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
         | 
| 279 | 
            +
                    env = gym.wrappers.ClipAction(env)
         | 
| 280 | 
            +
                    env = NormalizeObservation(env)
         | 
| 281 | 
            +
                    if obs_rms is not None:
         | 
| 282 | 
            +
                        env.obs_rms = copy.deepcopy(obs_rms)
         | 
| 283 | 
            +
                    env.freeze = True
         | 
| 284 | 
            +
                    env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10))
         | 
| 285 | 
            +
                    return env
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                return thunk
         | 
| 288 | 
            +
             | 
| 289 | 
            +
             | 
| 290 | 
            +
            def get_rms(env):
         | 
| 291 | 
            +
                obs_rms, return_rms = None, None
         | 
| 292 | 
            +
                env_point = env
         | 
| 293 | 
            +
                while hasattr(env_point, "env"):
         | 
| 294 | 
            +
                    if isinstance(env_point, NormalizeObservation):
         | 
| 295 | 
            +
                        obs_rms = copy.deepcopy(env_point.obs_rms)
         | 
| 296 | 
            +
                        break
         | 
| 297 | 
            +
                    env_point = env_point.env
         | 
| 298 | 
            +
                else:
         | 
| 299 | 
            +
                    raise RuntimeError("can't find NormalizeObservation")
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                env_point = env
         | 
| 302 | 
            +
                while hasattr(env_point, "env"):
         | 
| 303 | 
            +
                    if isinstance(env_point, NormalizeReward):
         | 
| 304 | 
            +
                        return_rms = copy.deepcopy(env_point.return_rms)
         | 
| 305 | 
            +
                        break
         | 
| 306 | 
            +
                    env_point = env_point.env
         | 
| 307 | 
            +
                else:
         | 
| 308 | 
            +
                    raise RuntimeError("can't find NormalizeReward")
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                return obs_rms, return_rms
         | 
| 311 | 
            +
             | 
| 312 | 
            +
             | 
| 313 | 
            +
            def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
         | 
| 314 | 
            +
                torch.nn.init.orthogonal_(layer.weight, std)
         | 
| 315 | 
            +
                torch.nn.init.constant_(layer.bias, bias_const)
         | 
| 316 | 
            +
                return layer
         | 
| 317 | 
            +
             | 
| 318 | 
            +
             | 
| 319 | 
            +
            class Agent(nn.Module):
         | 
| 320 | 
            +
                def __init__(self, envs):
         | 
| 321 | 
            +
                    super().__init__()
         | 
| 322 | 
            +
                    self.critic = nn.Sequential(
         | 
| 323 | 
            +
                        layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
         | 
| 324 | 
            +
                        nn.Tanh(),
         | 
| 325 | 
            +
                        layer_init(nn.Linear(64, 64)),
         | 
| 326 | 
            +
                        nn.Tanh(),
         | 
| 327 | 
            +
                        layer_init(nn.Linear(64, 1), std=1.0),
         | 
| 328 | 
            +
                    )
         | 
| 329 | 
            +
                    self.actor_mean = nn.Sequential(
         | 
| 330 | 
            +
                        layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
         | 
| 331 | 
            +
                        nn.Tanh(),
         | 
| 332 | 
            +
                        layer_init(nn.Linear(64, 64)),
         | 
| 333 | 
            +
                        nn.Tanh(),
         | 
| 334 | 
            +
                        layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
         | 
| 335 | 
            +
                    )
         | 
| 336 | 
            +
                    self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))
         | 
| 337 | 
            +
                    self.obs_rms = RunningMeanStd(shape=envs.single_observation_space.shape)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                def get_value(self, x):
         | 
| 340 | 
            +
                    return self.critic(x)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                def get_action_and_value(self, x, action=None):
         | 
| 343 | 
            +
                    action_mean = self.actor_mean(x)
         | 
| 344 | 
            +
                    action_logstd = self.actor_logstd.expand_as(action_mean)
         | 
| 345 | 
            +
                    action_std = torch.exp(action_logstd)
         | 
| 346 | 
            +
                    probs = Normal(action_mean, action_std)
         | 
| 347 | 
            +
                    if action is None:
         | 
| 348 | 
            +
                        action = probs.sample()
         | 
| 349 | 
            +
                    return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
             | 
| 352 | 
            +
            if __name__ == "__main__":
         | 
| 353 | 
            +
                args = parse_args()
         | 
| 354 | 
            +
                run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
         | 
| 355 | 
            +
                if args.track:
         | 
| 356 | 
            +
                    import wandb
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    wandb.init(
         | 
| 359 | 
            +
                        project=args.wandb_project_name,
         | 
| 360 | 
            +
                        entity=args.wandb_entity,
         | 
| 361 | 
            +
                        sync_tensorboard=True,
         | 
| 362 | 
            +
                        config=vars(args),
         | 
| 363 | 
            +
                        name=run_name,
         | 
| 364 | 
            +
                        monitor_gym=True,
         | 
| 365 | 
            +
                        save_code=True,
         | 
| 366 | 
            +
                    )
         | 
| 367 | 
            +
                writer = SummaryWriter(f"runs/{run_name}")
         | 
| 368 | 
            +
                writer.add_text(
         | 
| 369 | 
            +
                    "hyperparameters",
         | 
| 370 | 
            +
                    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
         | 
| 371 | 
            +
                )
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                # TRY NOT TO MODIFY: seeding
         | 
| 374 | 
            +
                random.seed(args.seed)
         | 
| 375 | 
            +
                np.random.seed(args.seed)
         | 
| 376 | 
            +
                torch.manual_seed(args.seed)
         | 
| 377 | 
            +
                torch.backends.cudnn.deterministic = args.torch_deterministic
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                # env setup
         | 
| 382 | 
            +
                envs = gym.vector.SyncVectorEnv(
         | 
| 383 | 
            +
                    [make_env(args.env_id, i, args.capture_video, run_name, args.gamma) for i in range(args.num_envs)]
         | 
| 384 | 
            +
                )
         | 
| 385 | 
            +
                assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                agent = Agent(envs).to(device)
         | 
| 388 | 
            +
                optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                # ALGO Logic: Storage setup
         | 
| 391 | 
            +
                obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
         | 
| 392 | 
            +
                actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
         | 
| 393 | 
            +
                logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
         | 
| 394 | 
            +
                rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
         | 
| 395 | 
            +
                dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
         | 
| 396 | 
            +
                values = torch.zeros((args.num_steps, args.num_envs)).to(device)
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                # TRY NOT TO MODIFY: start the game
         | 
| 399 | 
            +
                global_step = 0
         | 
| 400 | 
            +
                start_time = time.time()
         | 
| 401 | 
            +
                next_obs, _ = envs.reset(seed=args.seed)
         | 
| 402 | 
            +
                next_obs = torch.Tensor(next_obs).to(device)
         | 
| 403 | 
            +
                next_done = torch.zeros(args.num_envs).to(device)
         | 
| 404 | 
            +
                num_updates = args.total_timesteps // args.batch_size
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                for update in range(1, num_updates + 1):
         | 
| 407 | 
            +
                    # Annealing the rate if instructed to do so.
         | 
| 408 | 
            +
                    if args.anneal_lr:
         | 
| 409 | 
            +
                        frac = 1.0 - (update - 1.0) / num_updates
         | 
| 410 | 
            +
                        lrnow = frac * args.learning_rate
         | 
| 411 | 
            +
                        optimizer.param_groups[0]["lr"] = lrnow
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    for step in range(0, args.num_steps):
         | 
| 414 | 
            +
                        global_step += 1 * args.num_envs
         | 
| 415 | 
            +
                        obs[step] = next_obs
         | 
| 416 | 
            +
                        dones[step] = next_done
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                        # ALGO LOGIC: action logic
         | 
| 419 | 
            +
                        with torch.no_grad():
         | 
| 420 | 
            +
                            action, logprob, _, value = agent.get_action_and_value(next_obs)
         | 
| 421 | 
            +
                            values[step] = value.flatten()
         | 
| 422 | 
            +
                        actions[step] = action
         | 
| 423 | 
            +
                        logprobs[step] = logprob
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                        # TRY NOT TO MODIFY: execute the game and log data.
         | 
| 426 | 
            +
                        next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy())
         | 
| 427 | 
            +
                        done = np.logical_or(terminations, truncations)
         | 
| 428 | 
            +
                        rewards[step] = torch.tensor(reward).to(device).view(-1)
         | 
| 429 | 
            +
                        next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                        # https://github.com/DLR-RM/stable-baselines3/pull/658
         | 
| 432 | 
            +
                        for idx, trunc in enumerate(truncations):
         | 
| 433 | 
            +
                            if trunc:
         | 
| 434 | 
            +
                                real_next_obs = infos["final_observation"][idx]
         | 
| 435 | 
            +
                                with torch.no_grad():
         | 
| 436 | 
            +
                                    terminal_value = agent.get_value(torch.Tensor(real_next_obs).to(device)).reshape(1, -1)[0][0]
         | 
| 437 | 
            +
                                rewards[step][idx] += args.gamma * terminal_value
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                        # Only print when at least 1 env is done
         | 
| 440 | 
            +
                        if "final_info" not in infos:
         | 
| 441 | 
            +
                            continue
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                        for info in infos["final_info"]:
         | 
| 444 | 
            +
                            # Skip the envs that are not done
         | 
| 445 | 
            +
                            if info is None:
         | 
| 446 | 
            +
                                continue
         | 
| 447 | 
            +
                            print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
         | 
| 448 | 
            +
                            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
         | 
| 449 | 
            +
                            writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    # bootstrap value if not done
         | 
| 452 | 
            +
                    with torch.no_grad():
         | 
| 453 | 
            +
                        next_value = agent.get_value(next_obs).reshape(1, -1)
         | 
| 454 | 
            +
                        advantages = torch.zeros_like(rewards).to(device)
         | 
| 455 | 
            +
                        lastgaelam = 0
         | 
| 456 | 
            +
                        for t in reversed(range(args.num_steps)):
         | 
| 457 | 
            +
                            if t == args.num_steps - 1:
         | 
| 458 | 
            +
                                nextnonterminal = 1.0 - next_done
         | 
| 459 | 
            +
                                nextvalues = next_value
         | 
| 460 | 
            +
                            else:
         | 
| 461 | 
            +
                                nextnonterminal = 1.0 - dones[t + 1]
         | 
| 462 | 
            +
                                nextvalues = values[t + 1]
         | 
| 463 | 
            +
                            delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
         | 
| 464 | 
            +
                            advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
         | 
| 465 | 
            +
                        returns = advantages + values
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                    # flatten the batch
         | 
| 468 | 
            +
                    b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
         | 
| 469 | 
            +
                    b_logprobs = logprobs.reshape(-1)
         | 
| 470 | 
            +
                    b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
         | 
| 471 | 
            +
                    b_advantages = advantages.reshape(-1)
         | 
| 472 | 
            +
                    b_returns = returns.reshape(-1)
         | 
| 473 | 
            +
                    b_values = values.reshape(-1)
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                    # Optimizing the policy and value network
         | 
| 476 | 
            +
                    b_inds = np.arange(args.batch_size)
         | 
| 477 | 
            +
                    clipfracs = []
         | 
| 478 | 
            +
                    for epoch in range(args.update_epochs):
         | 
| 479 | 
            +
                        np.random.shuffle(b_inds)
         | 
| 480 | 
            +
                        for start in range(0, args.batch_size, args.minibatch_size):
         | 
| 481 | 
            +
                            end = start + args.minibatch_size
         | 
| 482 | 
            +
                            mb_inds = b_inds[start:end]
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                            _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions[mb_inds])
         | 
| 485 | 
            +
                            logratio = newlogprob - b_logprobs[mb_inds]
         | 
| 486 | 
            +
                            ratio = logratio.exp()
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                            with torch.no_grad():
         | 
| 489 | 
            +
                                # calculate approx_kl http://joschu.net/blog/kl-approx.html
         | 
| 490 | 
            +
                                old_approx_kl = (-logratio).mean()
         | 
| 491 | 
            +
                                approx_kl = ((ratio - 1) - logratio).mean()
         | 
| 492 | 
            +
                                clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                            mb_advantages = b_advantages[mb_inds]
         | 
| 495 | 
            +
                            if args.norm_adv:
         | 
| 496 | 
            +
                                mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                            # Policy loss
         | 
| 499 | 
            +
                            pg_loss1 = -mb_advantages * ratio
         | 
| 500 | 
            +
                            pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
         | 
| 501 | 
            +
                            pg_loss = torch.max(pg_loss1, pg_loss2).mean()
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                            # Value loss
         | 
| 504 | 
            +
                            newvalue = newvalue.view(-1)
         | 
| 505 | 
            +
                            if args.clip_vloss:
         | 
| 506 | 
            +
                                v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
         | 
| 507 | 
            +
                                v_clipped = b_values[mb_inds] + torch.clamp(
         | 
| 508 | 
            +
                                    newvalue - b_values[mb_inds],
         | 
| 509 | 
            +
                                    -args.clip_coef,
         | 
| 510 | 
            +
                                    args.clip_coef,
         | 
| 511 | 
            +
                                )
         | 
| 512 | 
            +
                                v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
         | 
| 513 | 
            +
                                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
         | 
| 514 | 
            +
                                v_loss = 0.5 * v_loss_max.mean()
         | 
| 515 | 
            +
                            else:
         | 
| 516 | 
            +
                                v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                            entropy_loss = entropy.mean()
         | 
| 519 | 
            +
                            loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                            optimizer.zero_grad()
         | 
| 522 | 
            +
                            loss.backward()
         | 
| 523 | 
            +
                            nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
         | 
| 524 | 
            +
                            optimizer.step()
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                        if args.target_kl is not None:
         | 
| 527 | 
            +
                            if approx_kl > args.target_kl:
         | 
| 528 | 
            +
                                break
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
         | 
| 531 | 
            +
                    var_y = np.var(y_true)
         | 
| 532 | 
            +
                    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    # TRY NOT TO MODIFY: record rewards for plotting purposes
         | 
| 535 | 
            +
                    writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
         | 
| 536 | 
            +
                    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
         | 
| 537 | 
            +
                    writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
         | 
| 538 | 
            +
                    writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
         | 
| 539 | 
            +
                    writer.add_scalar("losses/old_approx_kl", old_approx_kl.item(), global_step)
         | 
| 540 | 
            +
                    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
         | 
| 541 | 
            +
                    writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
         | 
| 542 | 
            +
                    writer.add_scalar("losses/explained_variance", explained_var, global_step)
         | 
| 543 | 
            +
                    print("SPS:", int(global_step / (time.time() - start_time)))
         | 
| 544 | 
            +
                    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
         | 
| 545 | 
            +
             | 
| 546 | 
            +
                if args.save_model:
         | 
| 547 | 
            +
                    agent.obs_rms = copy.deepcopy(get_rms(envs.envs[0])[0])
         | 
| 548 | 
            +
                    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
         | 
| 549 | 
            +
                    torch.save(agent.state_dict(), model_path)
         | 
| 550 | 
            +
                    print(f"model saved to {model_path}")
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    episodic_returns = evaluate(
         | 
| 553 | 
            +
                        model_path,
         | 
| 554 | 
            +
                        make_eval_env,
         | 
| 555 | 
            +
                        args.env_id,
         | 
| 556 | 
            +
                        eval_episodes=10,
         | 
| 557 | 
            +
                        run_name=f"{run_name}-eval",
         | 
| 558 | 
            +
                        Model=Agent,
         | 
| 559 | 
            +
                        device=device,
         | 
| 560 | 
            +
                    )
         | 
| 561 | 
            +
                    for idx, episodic_return in enumerate(episodic_returns):
         | 
| 562 | 
            +
                        writer.add_scalar("eval/episodic_return", episodic_return, idx)
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                    if args.upload_model:
         | 
| 565 | 
            +
                        from cleanrl_utils.huggingface import push_to_hub
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                        repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
         | 
| 568 | 
            +
                        repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
         | 
| 569 | 
            +
                        push_to_hub(args, episodic_returns, repo_id, "PPO", f"runs/{run_name}", f"videos/{run_name}-eval")
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                envs.close()
         | 
| 572 | 
            +
                writer.close()
         | 
    	
        pyproject.toml
    ADDED
    
    | @@ -0,0 +1,108 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            [tool.poetry]
         | 
| 2 | 
            +
            name = "cleanrl"
         | 
| 3 | 
            +
            version = "1.1.0"
         | 
| 4 | 
            +
            description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
         | 
| 5 | 
            +
            authors = ["Costa Huang <[email protected]>"]
         | 
| 6 | 
            +
            packages = [
         | 
| 7 | 
            +
                { include = "cleanrl" },
         | 
| 8 | 
            +
                { include = "cleanrl_utils" },
         | 
| 9 | 
            +
            ]
         | 
| 10 | 
            +
            keywords = ["reinforcement", "machine", "learning", "research"]
         | 
| 11 | 
            +
            license="MIT"
         | 
| 12 | 
            +
            readme = "README.md"
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            [tool.poetry.dependencies]
         | 
| 15 | 
            +
            python = ">=3.7.1,<3.11"
         | 
| 16 | 
            +
            tensorboard = "^2.10.0"
         | 
| 17 | 
            +
            wandb = "^0.13.11"
         | 
| 18 | 
            +
            gym = "0.23.1"
         | 
| 19 | 
            +
            torch = ">=1.12.1"
         | 
| 20 | 
            +
            stable-baselines3 = "1.2.0"
         | 
| 21 | 
            +
            gymnasium = ">=0.28.1"
         | 
| 22 | 
            +
            moviepy = "^1.0.3"
         | 
| 23 | 
            +
            pygame = "2.1.0"
         | 
| 24 | 
            +
            huggingface-hub = "^0.11.1"
         | 
| 25 | 
            +
            rich = "<12.0"
         | 
| 26 | 
            +
            tenacity = "^8.2.2"
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            ale-py = {version = "0.7.4", optional = true}
         | 
| 29 | 
            +
            AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2", optional = true}
         | 
| 30 | 
            +
            opencv-python = {version = "^4.6.0.66", optional = true}
         | 
| 31 | 
            +
            procgen = {version = "^0.10.7", optional = true}
         | 
| 32 | 
            +
            pytest = {version = "^7.1.3", optional = true}
         | 
| 33 | 
            +
            mujoco = {version = "<=2.3.3", optional = true}
         | 
| 34 | 
            +
            imageio = {version = "^2.14.1", optional = true}
         | 
| 35 | 
            +
            free-mujoco-py = {version = "^2.1.6", optional = true}
         | 
| 36 | 
            +
            mkdocs-material = {version = "^8.4.3", optional = true}
         | 
| 37 | 
            +
            markdown-include = {version = "^0.7.0", optional = true}
         | 
| 38 | 
            +
            openrlbenchmark = {version = "^0.1.1b4", optional = true}
         | 
| 39 | 
            +
            jax = {version = "^0.3.17", optional = true}
         | 
| 40 | 
            +
            jaxlib = {version = "^0.3.15", optional = true}
         | 
| 41 | 
            +
            flax = {version = "^0.6.0", optional = true}
         | 
| 42 | 
            +
            optuna = {version = "^3.0.1", optional = true}
         | 
| 43 | 
            +
            optuna-dashboard = {version = "^0.7.2", optional = true}
         | 
| 44 | 
            +
            envpool = {version = "^0.6.4", optional = true}
         | 
| 45 | 
            +
            PettingZoo = {version = "1.18.1", optional = true}
         | 
| 46 | 
            +
            SuperSuit = {version = "3.4.0", optional = true}
         | 
| 47 | 
            +
            multi-agent-ale-py = {version = "0.1.11", optional = true}
         | 
| 48 | 
            +
            boto3 = {version = "^1.24.70", optional = true}
         | 
| 49 | 
            +
            awscli = {version = "^1.25.71", optional = true}
         | 
| 50 | 
            +
            shimmy = {version = ">=1.0.0", extras = ["dm-control"], optional = true}
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            [tool.poetry.group.dev.dependencies]
         | 
| 53 | 
            +
            pre-commit = "^2.20.0"
         | 
| 54 | 
            +
             | 
| 55 | 
            +
             | 
| 56 | 
            +
            [tool.poetry.group.isaacgym]
         | 
| 57 | 
            +
            optional = true
         | 
| 58 | 
            +
            [tool.poetry.group.isaacgym.dependencies]
         | 
| 59 | 
            +
            isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry", python = ">=3.7.1,<3.10"}
         | 
| 60 | 
            +
            isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            [build-system]
         | 
| 64 | 
            +
            requires = ["poetry-core"]
         | 
| 65 | 
            +
            build-backend = "poetry.core.masonry.api"
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            [tool.poetry.extras]
         | 
| 68 | 
            +
            atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 69 | 
            +
            procgen = ["procgen"]
         | 
| 70 | 
            +
            plot = ["pandas", "seaborn"]
         | 
| 71 | 
            +
            pytest = ["pytest"]
         | 
| 72 | 
            +
            mujoco = ["mujoco", "imageio"]
         | 
| 73 | 
            +
            mujoco_py = ["free-mujoco-py"]
         | 
| 74 | 
            +
            jax = ["jax", "jaxlib", "flax"]
         | 
| 75 | 
            +
            docs = ["mkdocs-material", "markdown-include", "openrlbenchmark"]
         | 
| 76 | 
            +
            envpool = ["envpool"]
         | 
| 77 | 
            +
            optuna = ["optuna", "optuna-dashboard"]
         | 
| 78 | 
            +
            pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
         | 
| 79 | 
            +
            cloud = ["boto3", "awscli"]
         | 
| 80 | 
            +
            dm_control = ["shimmy", "mujoco"]
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            # dependencies for algorithm variant (useful when you want to run a specific algorithm)
         | 
| 83 | 
            +
            dqn = []
         | 
| 84 | 
            +
            dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 85 | 
            +
            dqn_jax = ["jax", "jaxlib", "flax"]
         | 
| 86 | 
            +
            dqn_atari_jax = [
         | 
| 87 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 88 | 
            +
                "jax", "jaxlib", "flax" # jax
         | 
| 89 | 
            +
            ]
         | 
| 90 | 
            +
            c51 = []
         | 
| 91 | 
            +
            c51_atari = ["ale-py", "AutoROM", "opencv-python"]
         | 
| 92 | 
            +
            c51_jax = ["jax", "jaxlib", "flax"]
         | 
| 93 | 
            +
            c51_atari_jax = [
         | 
| 94 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 95 | 
            +
                "jax", "jaxlib", "flax" # jax
         | 
| 96 | 
            +
            ]
         | 
| 97 | 
            +
            ppo_atari_envpool_xla_jax_scan = [
         | 
| 98 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 99 | 
            +
                "jax", "jaxlib", "flax", # jax
         | 
| 100 | 
            +
                "envpool", # envpool
         | 
| 101 | 
            +
            ]
         | 
| 102 | 
            +
            qdagger_dqn_atari_impalacnn = [
         | 
| 103 | 
            +
                "ale-py", "AutoROM", "opencv-python"
         | 
| 104 | 
            +
            ]
         | 
| 105 | 
            +
            qdagger_dqn_atari_jax_impalacnn = [
         | 
| 106 | 
            +
                "ale-py", "AutoROM", "opencv-python", # atari
         | 
| 107 | 
            +
                "jax", "jaxlib", "flax", # jax
         | 
| 108 | 
            +
            ]
         | 
    	
        replay.mp4
    ADDED
    
    | Binary file (303 kB). View file | 
|  | 
    	
        videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-0.mp4
    ADDED
    
    | Binary file (491 kB). View file | 
|  | 
    	
        videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-1.mp4
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:f58cf3c084fa6da76a6557aacaa4914c09d47fb734a79866d018c72c64cce783
         | 
| 3 | 
            +
            size 1379514
         | 
    	
        videos/Walker2d-v4__ppo_fix_continuous_action__2__1702930540-eval/rl-video-episode-8.mp4
    ADDED
    
    | Binary file (303 kB). View file | 
|  | 
