Spaces:
Sleeping
Sleeping
| """ | |
| Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc. | |
| MCMC is adapted from https://github.com/google-research/ibc. | |
| """ | |
| from typing import Callable, Tuple | |
| from functools import wraps | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from abc import ABC, abstractmethod | |
| from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY | |
| from ding.torch_utils import unsqueeze_repeat | |
| from ding.model.wrapper import IModelWrapper | |
| from ding.model.common import RegressionHead | |
| def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict): | |
| """ | |
| Overview: | |
| Create stochastic optimizer. | |
| Arguments: | |
| - device (:obj:`str`): Device. | |
| - stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config. | |
| """ | |
| return STOCHASTIC_OPTIMIZER_REGISTRY.build( | |
| stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config | |
| ) | |
| def no_ebm_grad(): | |
| """Wrapper that disables energy based model gradients""" | |
| def ebm_disable_grad_wrapper(func: Callable): | |
| def wrapper(*args, **kwargs): | |
| ebm = args[-1] | |
| assert isinstance(ebm, (IModelWrapper, nn.Module)),\ | |
| 'Make sure ebm is the last positional arguments.' | |
| ebm.requires_grad_(False) | |
| result = func(*args, **kwargs) | |
| ebm.requires_grad_(True) | |
| return result | |
| return wrapper | |
| return ebm_disable_grad_wrapper | |
| class StochasticOptimizer(ABC): | |
| """ | |
| Overview: | |
| Base class for stochastic optimizers. | |
| Interface: | |
| ``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer`` | |
| """ | |
| def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Drawing action samples from the uniform random distribution \ | |
| and tiling observations to the same shape as action samples. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observation. | |
| - num_samples (:obj:`int`): The number of negative samples. | |
| Returns: | |
| - tiled_obs (:obj:`torch.Tensor`): Observations tiled. | |
| - action (:obj:`torch.Tensor`): Action sampled. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - num_samples (:obj:`int`): :math:`N`. | |
| - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> opt = StochasticOptimizer() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> tiled_obs, action = opt._sample(obs, 8) | |
| """ | |
| size = (obs.shape[0], num_samples, self.action_bounds.shape[1]) | |
| low, high = self.action_bounds[0, :], self.action_bounds[1, :] | |
| action_samples = low + (high - low) * torch.rand(size).to(self.device) | |
| tiled_obs = unsqueeze_repeat(obs, num_samples, 1) | |
| return tiled_obs, action_samples | |
| def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module): | |
| """ | |
| Overview: | |
| Return one action for each batch with highest probability (lowest energy). | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observation. | |
| - action_samples (:obj:`torch.Tensor`): Action from uniform distributions. | |
| Returns: | |
| - best_action_samples (:obj:`torch.Tensor`): Best action. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> action_samples = torch.randn(2, 8, 5) | |
| >>> ebm = EBM(4, 5) | |
| >>> opt = StochasticOptimizer() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm) | |
| """ | |
| # (B, N) | |
| energies = ebm.forward(obs, action_samples) | |
| probs = F.softmax(-1.0 * energies, dim=-1) | |
| # (B, ) | |
| best_idxs = probs.argmax(dim=-1) | |
| return action_samples[torch.arange(action_samples.size(0)), best_idxs] | |
| def set_action_bounds(self, action_bounds: np.ndarray): | |
| """ | |
| Overview: | |
| Set action bounds calculated from the dataset statistics. | |
| Arguments: | |
| - action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \ | |
| where action_bounds[0] is lower bound and action_bounds[1] is upper bound. | |
| Returns: | |
| - action_bounds (:obj:`torch.Tensor`): Action bounds. | |
| Shapes: | |
| - action_bounds (:obj:`np.ndarray`): :math:`(2, A)`. | |
| - action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`. | |
| Examples: | |
| >>> opt = StochasticOptimizer() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| """ | |
| self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device) | |
| def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Create tiled observations and sample counter-negatives for InfoNCE loss. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - tiled_obs (:obj:`torch.Tensor`): Tiled observations. | |
| - action (:obj:`torch.Tensor`): Actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| .. note:: In the case of derivative-free optimization, this function will simply call _sample. | |
| """ | |
| raise NotImplementedError | |
| def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Optimize for the best action conditioned on the current observation. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - best_action_samples (:obj:`torch.Tensor`): Best actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
| """ | |
| raise NotImplementedError | |
| class DFO(StochasticOptimizer): | |
| """ | |
| Overview: | |
| Derivative-Free Optimizer in paper Implicit Behavioral Cloning. | |
| https://arxiv.org/abs/2109.00137 | |
| Interface: | |
| ``init``, ``sample``, ``infer`` | |
| """ | |
| def __init__( | |
| self, | |
| noise_scale: float = 0.33, | |
| noise_shrink: float = 0.5, | |
| iters: int = 3, | |
| train_samples: int = 8, | |
| inference_samples: int = 16384, | |
| device: str = 'cpu', | |
| ): | |
| """ | |
| Overview: | |
| Initialize the Derivative-Free Optimizer | |
| Arguments: | |
| - noise_scale (:obj:`float`): Initial noise scale. | |
| - noise_shrink (:obj:`float`): Noise scale shrink rate. | |
| - iters (:obj:`int`): Number of iterations. | |
| - train_samples (:obj:`int`): Number of samples for training. | |
| - inference_samples (:obj:`int`): Number of samples for inference. | |
| - device (:obj:`str`): Device. | |
| """ | |
| self.action_bounds = None | |
| self.noise_scale = noise_scale | |
| self.noise_shrink = noise_shrink | |
| self.iters = iters | |
| self.train_samples = train_samples | |
| self.inference_samples = inference_samples | |
| self.device = device | |
| def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Drawing action samples from the uniform random distribution \ | |
| and tiling observations to the same shape as action samples. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - tiled_obs (:obj:`torch.Tensor`): Tiled observation. | |
| - action_samples (:obj:`torch.Tensor`): Action samples. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> ebm = EBM(4, 5) | |
| >>> opt = DFO() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> tiled_obs, action_samples = opt.sample(obs, ebm) | |
| """ | |
| return self._sample(obs, self.train_samples) | |
| def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Optimize for the best action conditioned on the current observation. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - best_action_samples (:obj:`torch.Tensor`): Actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> ebm = EBM(4, 5) | |
| >>> opt = DFO() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> best_action_samples = opt.infer(obs, ebm) | |
| """ | |
| noise_scale = self.noise_scale | |
| # (B, N, O), (B, N, A) | |
| obs, action_samples = self._sample(obs, self.inference_samples) | |
| for i in range(self.iters): | |
| # (B, N) | |
| energies = ebm.forward(obs, action_samples) | |
| probs = F.softmax(-1.0 * energies, dim=-1) | |
| # Resample with replacement. | |
| idxs = torch.multinomial(probs, self.inference_samples, replacement=True) | |
| action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] | |
| # Add noise and clip to target bounds. | |
| action_samples = action_samples + torch.randn_like(action_samples) * noise_scale | |
| action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :]) | |
| noise_scale *= self.noise_shrink | |
| # Return target with highest probability. | |
| return self._get_best_action_sample(obs, action_samples, ebm) | |
| class AutoRegressiveDFO(DFO): | |
| """ | |
| Overview: | |
| AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning. | |
| https://arxiv.org/abs/2109.00137 | |
| Interface: | |
| ``__init__``, ``infer`` | |
| """ | |
| def __init__( | |
| self, | |
| noise_scale: float = 0.33, | |
| noise_shrink: float = 0.5, | |
| iters: int = 3, | |
| train_samples: int = 8, | |
| inference_samples: int = 4096, | |
| device: str = 'cpu', | |
| ): | |
| """ | |
| Overview: | |
| Initialize the AutoRegressive Derivative-Free Optimizer | |
| Arguments: | |
| - noise_scale (:obj:`float`): Initial noise scale. | |
| - noise_shrink (:obj:`float`): Noise scale shrink rate. | |
| - iters (:obj:`int`): Number of iterations. | |
| - train_samples (:obj:`int`): Number of samples for training. | |
| - inference_samples (:obj:`int`): Number of samples for inference. | |
| - device (:obj:`str`): Device. | |
| """ | |
| super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device) | |
| def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Optimize for the best action conditioned on the current observation. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - best_action_samples (:obj:`torch.Tensor`): Actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> ebm = EBM(4, 5) | |
| >>> opt = AutoRegressiveDFO() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> best_action_samples = opt.infer(obs, ebm) | |
| """ | |
| noise_scale = self.noise_scale | |
| # (B, N, O), (B, N, A) | |
| obs, action_samples = self._sample(obs, self.inference_samples) | |
| for i in range(self.iters): | |
| # j: action_dim index | |
| for j in range(action_samples.shape[-1]): | |
| # (B, N) | |
| energies = ebm.forward(obs, action_samples)[..., j] | |
| probs = F.softmax(-1.0 * energies, dim=-1) | |
| # Resample with replacement. | |
| idxs = torch.multinomial(probs, self.inference_samples, replacement=True) | |
| action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] | |
| # Add noise and clip to target bounds. | |
| action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale | |
| action_samples[..., j] = action_samples[..., j].clamp( | |
| min=self.action_bounds[0, j], max=self.action_bounds[1, j] | |
| ) | |
| noise_scale *= self.noise_shrink | |
| # (B, N) | |
| energies = ebm.forward(obs, action_samples)[..., -1] | |
| probs = F.softmax(-1.0 * energies, dim=-1) | |
| # (B, ) | |
| best_idxs = probs.argmax(dim=-1) | |
| return action_samples[torch.arange(action_samples.size(0)), best_idxs] | |
| class MCMC(StochasticOptimizer): | |
| """ | |
| Overview: | |
| MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning. | |
| https://arxiv.org/abs/2109.00137 | |
| Interface: | |
| ``__init__``, ``sample``, ``infer``, ``grad_penalty`` | |
| """ | |
| class BaseScheduler(ABC): | |
| """ | |
| Overview: | |
| Base class for learning rate scheduler. | |
| Interface: | |
| ``get_rate`` | |
| """ | |
| def get_rate(self, index): | |
| """ | |
| Overview: | |
| Abstract method for getting learning rate. | |
| """ | |
| raise NotImplementedError | |
| class ExponentialScheduler: | |
| """ | |
| Overview: | |
| Exponential learning rate schedule for Langevin sampler. | |
| Interface: | |
| ``__init__``, ``get_rate`` | |
| """ | |
| def __init__(self, init, decay): | |
| """ | |
| Overview: | |
| Initialize the ExponentialScheduler. | |
| Arguments: | |
| - init (:obj:`float`): Initial learning rate. | |
| - decay (:obj:`float`): Decay rate. | |
| """ | |
| self._decay = decay | |
| self._latest_lr = init | |
| def get_rate(self, index): | |
| """ | |
| Overview: | |
| Get learning rate. Assumes calling sequentially. | |
| Arguments: | |
| - index (:obj:`int`): Current iteration. | |
| """ | |
| del index | |
| lr = self._latest_lr | |
| self._latest_lr *= self._decay | |
| return lr | |
| class PolynomialScheduler: | |
| """ | |
| Overview: | |
| Polynomial learning rate schedule for Langevin sampler. | |
| Interface: | |
| ``__init__``, ``get_rate`` | |
| """ | |
| def __init__(self, init, final, power, num_steps): | |
| """ | |
| Overview: | |
| Initialize the PolynomialScheduler. | |
| Arguments: | |
| - init (:obj:`float`): Initial learning rate. | |
| - final (:obj:`float`): Final learning rate. | |
| - power (:obj:`float`): Power of polynomial. | |
| - num_steps (:obj:`int`): Number of steps. | |
| """ | |
| self._init = init | |
| self._final = final | |
| self._power = power | |
| self._num_steps = num_steps | |
| def get_rate(self, index): | |
| """ | |
| Overview: | |
| Get learning rate for index. | |
| Arguments: | |
| - index (:obj:`int`): Current iteration. | |
| """ | |
| if index == -1: | |
| return self._init | |
| return ( | |
| (self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power)) | |
| ) + self._final | |
| def __init__( | |
| self, | |
| iters: int = 100, | |
| use_langevin_negative_samples: bool = True, | |
| train_samples: int = 8, | |
| inference_samples: int = 512, | |
| stepsize_scheduler: dict = dict( | |
| init=0.5, | |
| final=1e-5, | |
| power=2.0, | |
| # num_steps, | |
| ), | |
| optimize_again: bool = True, | |
| again_stepsize_scheduler: dict = dict( | |
| init=1e-5, | |
| final=1e-5, | |
| power=2.0, | |
| # num_steps, | |
| ), | |
| device: str = 'cpu', | |
| # langevin_step | |
| noise_scale: float = 0.5, | |
| grad_clip=None, | |
| delta_action_clip: float = 0.5, | |
| add_grad_penalty: bool = True, | |
| grad_norm_type: str = 'inf', | |
| grad_margin: float = 1.0, | |
| grad_loss_weight: float = 1.0, | |
| **kwargs, | |
| ): | |
| """ | |
| Overview: | |
| Initialize the MCMC. | |
| Arguments: | |
| - iters (:obj:`int`): Number of iterations. | |
| - use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler. | |
| - train_samples (:obj:`int`): Number of samples for training. | |
| - inference_samples (:obj:`int`): Number of samples for inference. | |
| - stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler. | |
| - optimize_again (:obj:`bool`): Whether to run a second optimization. | |
| - again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization. | |
| - device (:obj:`str`): Device. | |
| - noise_scale (:obj:`float`): Initial noise scale. | |
| - grad_clip (:obj:`float`): Gradient clip. | |
| - delta_action_clip (:obj:`float`): Action clip. | |
| - add_grad_penalty (:obj:`bool`): Whether to add gradient penalty. | |
| - grad_norm_type (:obj:`str`): Gradient norm type. | |
| - grad_margin (:obj:`float`): Gradient margin. | |
| - grad_loss_weight (:obj:`float`): Gradient loss weight. | |
| """ | |
| self.iters = iters | |
| self.use_langevin_negative_samples = use_langevin_negative_samples | |
| self.train_samples = train_samples | |
| self.inference_samples = inference_samples | |
| self.stepsize_scheduler = stepsize_scheduler | |
| self.optimize_again = optimize_again | |
| self.again_stepsize_scheduler = again_stepsize_scheduler | |
| self.device = device | |
| self.noise_scale = noise_scale | |
| self.grad_clip = grad_clip | |
| self.delta_action_clip = delta_action_clip | |
| self.add_grad_penalty = add_grad_penalty | |
| self.grad_norm_type = grad_norm_type | |
| self.grad_margin = grad_margin | |
| self.grad_loss_weight = grad_loss_weight | |
| def _gradient_wrt_act( | |
| obs: torch.Tensor, | |
| action: torch.Tensor, | |
| ebm: nn.Module, | |
| create_graph: bool = False, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate gradient w.r.t action. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - action (:obj:`torch.Tensor`): Actions. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| - create_graph (:obj:`bool`): Whether to create graph. | |
| Returns: | |
| - grad (:obj:`torch.Tensor`): Gradient w.r.t action. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - grad (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| """ | |
| action.requires_grad_(True) | |
| energy = ebm.forward(obs, action).sum() | |
| # `create_graph` set to `True` when second order derivative | |
| # is needed i.e, d(de/da)/d_param | |
| grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0] | |
| action.requires_grad_(False) | |
| return grad | |
| def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Calculate gradient penalty. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - action (:obj:`torch.Tensor`): Actions. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): Gradient penalty. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`. | |
| - action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`. | |
| - loss (:obj:`torch.Tensor`): :math:`(B, )`. | |
| """ | |
| if not self.add_grad_penalty: | |
| return 0. | |
| # (B, N+1, A), this gradient is differentiable w.r.t model parameters | |
| de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True) | |
| def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor: | |
| # de_deact: B, N+1, A | |
| # return: B, N+1 | |
| grad_norm_type_to_ord = { | |
| '1': 1, | |
| '2': 2, | |
| 'inf': float('inf'), | |
| } | |
| ord = grad_norm_type_to_ord[grad_norm_type] | |
| return torch.linalg.norm(de_dact, ord, dim=-1) | |
| # (B, N+1) | |
| grad_norms = compute_grad_norm(self.grad_norm_type, de_dact) | |
| grad_norms = grad_norms - self.grad_margin | |
| grad_norms = grad_norms.clamp(min=0., max=1e10) | |
| grad_norms = grad_norms.pow(2) | |
| grad_loss = grad_norms.mean() | |
| return grad_loss * self.grad_loss_weight | |
| # can not use @torch.no_grad() during the inference | |
| # because we need to calculate gradient w.r.t inputs as MCMC updates. | |
| def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Run one langevin MCMC step. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - action (:obj:`torch.Tensor`): Actions. | |
| - stepsize (:obj:`float`): Step size. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - action (:obj:`torch.Tensor`): Actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| - stepsize (:obj:`float`): :math:`(B, )`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| """ | |
| l_lambda = 1.0 | |
| de_dact = MCMC._gradient_wrt_act(obs, action, ebm) | |
| if self.grad_clip: | |
| de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip) | |
| gradient_scale = 0.5 | |
| de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale) | |
| delta_action = stepsize * de_dact | |
| delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0]) | |
| delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip) | |
| action = action - delta_action | |
| action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1]) | |
| return action | |
| def _langevin_action_given_obs( | |
| self, | |
| obs: torch.Tensor, | |
| action: torch.Tensor, | |
| ebm: nn.Module, | |
| scheduler: BaseScheduler = None | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Run langevin MCMC for `self.iters` steps. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - action (:obj:`torch.Tensor`): Actions. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| - scheduler (:obj:`BaseScheduler`): Learning rate scheduler. | |
| Returns: | |
| - action (:obj:`torch.Tensor`): Actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| """ | |
| if not scheduler: | |
| self.stepsize_scheduler['num_steps'] = self.iters | |
| scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler) | |
| stepsize = scheduler.get_rate(-1) | |
| for i in range(self.iters): | |
| action = self._langevin_step(obs, action, stepsize, ebm) | |
| stepsize = scheduler.get_rate(i) | |
| return action | |
| def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Create tiled observations and sample counter-negatives for InfoNCE loss. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - tiled_obs (:obj:`torch.Tensor`): Tiled observations. | |
| - action_samples (:obj:`torch.Tensor`): Action samples. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
| - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> ebm = EBM(4, 5) | |
| >>> opt = MCMC() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> tiled_obs, action_samples = opt.sample(obs, ebm) | |
| """ | |
| obs, uniform_action_samples = self._sample(obs, self.train_samples) | |
| if not self.use_langevin_negative_samples: | |
| return obs, uniform_action_samples | |
| langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm) | |
| return obs, langevin_action_samples | |
| def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Optimize for the best action conditioned on the current observation. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observations. | |
| - ebm (:obj:`torch.nn.Module`): Energy based model. | |
| Returns: | |
| - best_action_samples (:obj:`torch.Tensor`): Actions. | |
| Shapes: | |
| - obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
| - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
| - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
| Examples: | |
| >>> obs = torch.randn(2, 4) | |
| >>> ebm = EBM(4, 5) | |
| >>> opt = MCMC() | |
| >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
| >>> best_action_samples = opt.infer(obs, ebm) | |
| """ | |
| # (B, N, O), (B, N, A) | |
| obs, uniform_action_samples = self._sample(obs, self.inference_samples) | |
| action_samples = self._langevin_action_given_obs( | |
| obs, | |
| uniform_action_samples, | |
| ebm, | |
| ) | |
| # Run a second optimization, a trick for more precise inference | |
| if self.optimize_again: | |
| self.again_stepsize_scheduler['num_steps'] = self.iters | |
| action_samples = self._langevin_action_given_obs( | |
| obs, | |
| action_samples, | |
| ebm, | |
| scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler), | |
| ) | |
| # action_samples: B, N, A | |
| return self._get_best_action_sample(obs, action_samples, ebm) | |
| class EBM(nn.Module): | |
| """ | |
| Overview: | |
| Energy based model. | |
| Interface: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| obs_shape: int, | |
| action_shape: int, | |
| hidden_size: int = 512, | |
| hidden_layer_num: int = 4, | |
| **kwargs, | |
| ): | |
| """ | |
| Overview: | |
| Initialize the EBM. | |
| Arguments: | |
| - obs_shape (:obj:`int`): Observation shape. | |
| - action_shape (:obj:`int`): Action shape. | |
| - hidden_size (:obj:`int`): Hidden size. | |
| - hidden_layer_num (:obj:`int`): Number of hidden layers. | |
| """ | |
| super().__init__() | |
| input_size = obs_shape + action_shape | |
| self.net = nn.Sequential( | |
| nn.Linear(input_size, hidden_size), nn.ReLU(), | |
| RegressionHead( | |
| hidden_size, | |
| 1, | |
| hidden_layer_num, | |
| final_tanh=False, | |
| ) | |
| ) | |
| def forward(self, obs, action): | |
| """ | |
| Overview: | |
| Forward computation graph of EBM. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). | |
| - action (:obj:`torch.Tensor`): Action of shape (B, N, A). | |
| Returns: | |
| - pred (:obj:`torch.Tensor`): Energy of shape (B, N). | |
| Examples: | |
| >>> obs = torch.randn(2, 3, 4) | |
| >>> action = torch.randn(2, 3, 5) | |
| >>> ebm = EBM(4, 5) | |
| >>> pred = ebm(obs, action) | |
| """ | |
| x = torch.cat([obs, action], -1) | |
| x = self.net(x) | |
| return x['pred'] | |
| class AutoregressiveEBM(nn.Module): | |
| """ | |
| Overview: | |
| Autoregressive energy based model. | |
| Interface: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__( | |
| self, | |
| obs_shape: int, | |
| action_shape: int, | |
| hidden_size: int = 512, | |
| hidden_layer_num: int = 4, | |
| ): | |
| """ | |
| Overview: | |
| Initialize the AutoregressiveEBM. | |
| Arguments: | |
| - obs_shape (:obj:`int`): Observation shape. | |
| - action_shape (:obj:`int`): Action shape. | |
| - hidden_size (:obj:`int`): Hidden size. | |
| - hidden_layer_num (:obj:`int`): Number of hidden layers. | |
| """ | |
| super().__init__() | |
| self.ebm_list = nn.ModuleList() | |
| for i in range(action_shape): | |
| self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num)) | |
| def forward(self, obs, action): | |
| """ | |
| Overview: | |
| Forward computation graph of AutoregressiveEBM. | |
| Arguments: | |
| - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). | |
| - action (:obj:`torch.Tensor`): Action of shape (B, N, A). | |
| Returns: | |
| - pred (:obj:`torch.Tensor`): Energy of shape (B, N, A). | |
| Examples: | |
| >>> obs = torch.randn(2, 3, 4) | |
| >>> action = torch.randn(2, 3, 5) | |
| >>> arebm = AutoregressiveEBM(4, 5) | |
| >>> pred = arebm(obs, action) | |
| """ | |
| output_list = [] | |
| for i, ebm in enumerate(self.ebm_list): | |
| output_list.append(ebm(obs, action[..., :i + 1])) | |
| return torch.stack(output_list, axis=-1) | |