Spaces:
Sleeping
Sleeping
| import copy | |
| import numpy as np | |
| from collections import namedtuple | |
| from typing import Union, Optional, Callable | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ding.hpc_rl import hpc_wrapper | |
| from ding.rl_utils.value_rescale import value_transform, value_inv_transform | |
| from ding.torch_utils import to_tensor | |
| q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']) | |
| def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray: | |
| assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper" | |
| disc_cumsum = np.zeros_like(x) | |
| disc_cumsum[-1] = x[-1] | |
| for t in reversed(range(x.shape[0] - 1)): | |
| disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1] | |
| return disc_cumsum | |
| def q_1step_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| 1 step td_error, support single agent case and multi agent case. | |
| Arguments: | |
| - data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): 1step td error | |
| Shapes: | |
| - data (:obj:`q_1step_td_data`): the q_1step_td_data containing\ | |
| ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`( , B)` | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| Examples: | |
| >>> action_dim = 4 | |
| >>> data = q_1step_td_data( | |
| >>> q=torch.randn(3, action_dim), | |
| >>> next_q=torch.randn(3, action_dim), | |
| >>> act=torch.randint(0, action_dim, (3,)), | |
| >>> next_act=torch.randint(0, action_dim, (3,)), | |
| >>> reward=torch.randn(3), | |
| >>> done=torch.randint(0, 2, (3,)).bool(), | |
| >>> weight=torch.ones(3), | |
| >>> ) | |
| >>> loss = q_1step_td_error(data, 0.99) | |
| """ | |
| q, next_q, act, next_act, reward, done, weight = data | |
| assert len(act.shape) == 1, act.shape | |
| assert len(reward.shape) == 1, reward.shape | |
| batch_range = torch.arange(act.shape[0]) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| q_s_a = q[batch_range, act] | |
| target_q_s_a = next_q[batch_range, next_act] | |
| target_q_s_a = gamma * (1 - done) * target_q_s_a + reward | |
| return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean() | |
| m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']) | |
| def m_q_1step_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| tau: float, | |
| alpha: float, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Munchausen td_error for DQN algorithm, support 1 step td error. | |
| Arguments: | |
| - data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - tau (:obj:`float`): Entropy factor for Munchausen DQN | |
| - alpha (:obj:`float`): Discount factor for Munchausen term | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\ | |
| ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`( , B)` | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| Examples: | |
| >>> action_dim = 4 | |
| >>> data = m_q_1step_td_data( | |
| >>> q=torch.randn(3, action_dim), | |
| >>> target_q=torch.randn(3, action_dim), | |
| >>> next_q=torch.randn(3, action_dim), | |
| >>> act=torch.randint(0, action_dim, (3,)), | |
| >>> reward=torch.randn(3), | |
| >>> done=torch.randint(0, 2, (3,)), | |
| >>> weight=torch.ones(3), | |
| >>> ) | |
| >>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01) | |
| """ | |
| q, target_q, next_q, act, reward, done, weight = data | |
| lower_bound = -1 | |
| assert len(act.shape) == 1, act.shape | |
| assert len(reward.shape) == 1, reward.shape | |
| batch_range = torch.arange(act.shape[0]) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| q_s_a = q[batch_range, act] | |
| # calculate muchausen addon | |
| # replay_log_policy | |
| target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1) | |
| logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1) | |
| log_pi = target_q - target_v_s - tau * logsum | |
| act_get = act.unsqueeze(-1) | |
| # same to the last second tau_log_pi_a | |
| munchausen_addon = log_pi.gather(1, act_get) | |
| muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1) | |
| # replay_next_log_policy | |
| target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1) | |
| logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1) | |
| tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next | |
| # do stable softmax == replay_next_policy | |
| pi_target = F.softmax((next_q - target_v_s_next) / tau) | |
| target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1) | |
| target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a | |
| td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1) | |
| # calculate action_gap and clipfrac | |
| with torch.no_grad(): | |
| top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0] | |
| action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean() | |
| clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound) | |
| clipfrac = torch.as_tensor(clipped).float() | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac | |
| q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight']) | |
| def q_v_1step_td_error( | |
| data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none') | |
| ) -> torch.Tensor: | |
| # we will use this function in discrete sac algorithm to calculate td error between q and v value. | |
| """ | |
| Overview: | |
| td_error between q and v value for SAC algorithm, support 1 step td error. | |
| Arguments: | |
| - data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\ | |
| ['q', 'v', 'act', 'reward', 'done', 'weight'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - v (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`( , B)` | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| Examples: | |
| >>> action_dim = 4 | |
| >>> data = q_v_1step_td_data( | |
| >>> q=torch.randn(3, action_dim), | |
| >>> v=torch.randn(3), | |
| >>> act=torch.randint(0, action_dim, (3,)), | |
| >>> reward=torch.randn(3), | |
| >>> done=torch.randint(0, 2, (3,)), | |
| >>> weight=torch.ones(3), | |
| >>> ) | |
| >>> loss = q_v_1step_td_error(data, 0.99) | |
| """ | |
| q, v, act, reward, done, weight = data | |
| if len(act.shape) == 1: | |
| assert len(reward.shape) == 1, reward.shape | |
| batch_range = torch.arange(act.shape[0]) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| q_s_a = q[batch_range, act] | |
| target_q_s_a = gamma * (1 - done) * v + reward | |
| else: | |
| assert len(reward.shape) == 1, reward.shape | |
| batch_range = torch.arange(act.shape[0]) | |
| actor_range = torch.arange(act.shape[1]) | |
| batch_actor_range = torch.arange(act.shape[0] * act.shape[1]) | |
| if weight is None: | |
| weight = torch.ones_like(act) | |
| temp_q = q.reshape(act.shape[0] * act.shape[1], -1) | |
| temp_act = act.reshape(act.shape[0] * act.shape[1]) | |
| q_s_a = temp_q[batch_actor_range, temp_act] | |
| q_s_a = q_s_a.reshape(act.shape[0], act.shape[1]) | |
| target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1) | |
| td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample | |
| def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))] | |
| return x.view(*size) | |
| nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done']) | |
| def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None): | |
| ''' | |
| Overview: | |
| Calculate nstep return for DQN algorithm, support single agent case and multi agent case. | |
| Arguments: | |
| - data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num | |
| - value_gamma (:obj:`torch.Tensor`): Discount factor for value | |
| Returns: | |
| - return (:obj:`torch.Tensor`): nstep return | |
| Shapes: | |
| - data (:obj:`nstep_return_data`): the nstep_return_data containing\ | |
| ['reward', 'next_value', 'done'] | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - next_value (:obj:`torch.FloatTensor`): :math:`(, B)` | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| Examples: | |
| >>> data = nstep_return_data( | |
| >>> reward=torch.randn(3, 3), | |
| >>> next_value=torch.randn(3), | |
| >>> done=torch.randint(0, 2, (3,)), | |
| >>> ) | |
| >>> loss = nstep_return(data, 0.99, 3) | |
| ''' | |
| reward, next_value, done = data | |
| assert reward.shape[0] == nstep | |
| device = reward.device | |
| if isinstance(gamma, float): | |
| reward_factor = torch.ones(nstep).to(device) | |
| for i in range(1, nstep): | |
| reward_factor[i] = gamma * reward_factor[i - 1] | |
| reward_factor = view_similar(reward_factor, reward) | |
| return_tmp = reward.mul(reward_factor).sum(0) | |
| if value_gamma is None: | |
| return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done) | |
| else: | |
| return_ = return_tmp + value_gamma * next_value * (1 - done) | |
| elif isinstance(gamma, list): | |
| # if gamma is list, for NGU policy case | |
| reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device) | |
| for i in range(1, nstep + 1): | |
| reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1] | |
| reward_factor = view_similar(reward_factor, reward) | |
| return_tmp = reward.mul(reward_factor[:nstep]).sum(0) | |
| return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done) | |
| else: | |
| raise TypeError("The type of gamma should be float or list") | |
| return return_ | |
| dist_1step_td_data = namedtuple( | |
| 'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight'] | |
| ) | |
| def dist_1step_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| v_min: float, | |
| v_max: float, | |
| n_atom: int, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| 1 step td_error for distributed q-learning based algorithm | |
| Arguments: | |
| - data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - v_min (:obj:`float`): The min value of support | |
| - v_max (:obj:`float`): The max value of support | |
| - n_atom (:obj:`int`): The num of atom | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\ | |
| ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] | |
| - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] | |
| - next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` | |
| - act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(, B)` | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| Examples: | |
| >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) | |
| >>> next_dist = torch.randn(4, 3, 51).abs() | |
| >>> act = torch.randint(0, 3, (4,)) | |
| >>> next_act = torch.randint(0, 3, (4,)) | |
| >>> reward = torch.randn(4) | |
| >>> done = torch.randint(0, 2, (4,)) | |
| >>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None) | |
| >>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51) | |
| """ | |
| dist, next_dist, act, next_act, reward, done, weight = data | |
| device = reward.device | |
| assert len(reward.shape) == 1, reward.shape | |
| support = torch.linspace(v_min, v_max, n_atom).to(device) | |
| delta_z = (v_max - v_min) / (n_atom - 1) | |
| if len(act.shape) == 1: | |
| reward = reward.unsqueeze(-1) | |
| done = done.unsqueeze(-1) | |
| batch_size = act.shape[0] | |
| batch_range = torch.arange(batch_size) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| next_dist = next_dist[batch_range, next_act].detach() | |
| else: | |
| reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) | |
| done = done.unsqueeze(-1).repeat(1, act.shape[1]) | |
| batch_size = act.shape[0] * act.shape[1] | |
| batch_range = torch.arange(act.shape[0] * act.shape[1]) | |
| action_dim = dist.shape[2] | |
| dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) | |
| reward = reward.reshape(act.shape[0] * act.shape[1], -1) | |
| done = done.reshape(act.shape[0] * act.shape[1], -1) | |
| next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) | |
| next_act = next_act.reshape(act.shape[0] * act.shape[1]) | |
| next_dist = next_dist[batch_range, next_act].detach() | |
| next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1) | |
| act = act.reshape(act.shape[0] * act.shape[1]) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| target_z = reward + (1 - done) * gamma * support | |
| target_z = target_z.clamp(min=v_min, max=v_max) | |
| b = (target_z - v_min) / delta_z | |
| l = b.floor().long() | |
| u = b.ceil().long() | |
| # Fix disappearing probability mass when l = b = u (b is int) | |
| l[(u > 0) * (l == u)] -= 1 | |
| u[(l < (n_atom - 1)) * (l == u)] += 1 | |
| proj_dist = torch.zeros_like(next_dist) | |
| offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, | |
| n_atom).long().to(device) | |
| proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)) | |
| proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)) | |
| log_p = torch.log(dist[batch_range, act]) | |
| loss = -(log_p * proj_dist * weight).sum(-1).mean() | |
| return loss | |
| dist_nstep_td_data = namedtuple( | |
| 'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight'] | |
| ) | |
| def shape_fn_dntd(args, kwargs): | |
| r""" | |
| Overview: | |
| Return dntd shape for hpc | |
| Returns: | |
| shape: [T, B, N, n_atom] | |
| """ | |
| if len(args) <= 0: | |
| tmp = [kwargs['data'].reward.shape[0]] | |
| tmp.extend(list(kwargs['data'].dist.shape)) | |
| else: | |
| tmp = [args[0].reward.shape[0]] | |
| tmp.extend(list(args[0].dist.shape)) | |
| return tmp | |
| def dist_nstep_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| v_min: float, | |
| v_max: float, | |
| n_atom: int, | |
| nstep: int = 1, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\ | |
| agent case and multi agent case. | |
| Arguments: | |
| - data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\ | |
| ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight'] | |
| - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom] | |
| - next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` | |
| - act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_act (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| Examples: | |
| >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True) | |
| >>> next_n_dist = torch.randn(4, 3, 51).abs() | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> reward = torch.randn(5, 4) | |
| >>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) | |
| >>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5) | |
| """ | |
| dist, next_n_dist, act, next_n_act, reward, done, weight = data | |
| device = reward.device | |
| reward_factor = torch.ones(nstep).to(device) | |
| for i in range(1, nstep): | |
| reward_factor[i] = gamma * reward_factor[i - 1] | |
| reward = torch.matmul(reward_factor, reward) | |
| support = torch.linspace(v_min, v_max, n_atom).to(device) | |
| delta_z = (v_max - v_min) / (n_atom - 1) | |
| if len(act.shape) == 1: | |
| reward = reward.unsqueeze(-1) | |
| done = done.unsqueeze(-1) | |
| batch_size = act.shape[0] | |
| batch_range = torch.arange(batch_size) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| elif isinstance(weight, float): | |
| weight = torch.tensor(weight) | |
| next_n_dist = next_n_dist[batch_range, next_n_act].detach() | |
| else: | |
| reward = reward.unsqueeze(-1).repeat(1, act.shape[1]) | |
| done = done.unsqueeze(-1).repeat(1, act.shape[1]) | |
| batch_size = act.shape[0] * act.shape[1] | |
| batch_range = torch.arange(act.shape[0] * act.shape[1]) | |
| action_dim = dist.shape[2] | |
| dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) | |
| reward = reward.reshape(act.shape[0] * act.shape[1], -1) | |
| done = done.reshape(act.shape[0] * act.shape[1], -1) | |
| next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1) | |
| next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1]) | |
| next_n_dist = next_n_dist[batch_range, next_n_act].detach() | |
| next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1) | |
| act = act.reshape(act.shape[0] * act.shape[1]) | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| elif isinstance(weight, float): | |
| weight = torch.tensor(weight) | |
| if value_gamma is None: | |
| target_z = reward + (1 - done) * (gamma ** nstep) * support | |
| elif isinstance(value_gamma, float): | |
| value_gamma = torch.tensor(value_gamma).unsqueeze(-1) | |
| target_z = reward + (1 - done) * value_gamma * support | |
| else: | |
| value_gamma = value_gamma.unsqueeze(-1) | |
| target_z = reward + (1 - done) * value_gamma * support | |
| target_z = target_z.clamp(min=v_min, max=v_max) | |
| b = (target_z - v_min) / delta_z | |
| l = b.floor().long() | |
| u = b.ceil().long() | |
| # Fix disappearing probability mass when l = b = u (b is int) | |
| l[(u > 0) * (l == u)] -= 1 | |
| u[(l < (n_atom - 1)) * (l == u)] += 1 | |
| proj_dist = torch.zeros_like(next_n_dist) | |
| offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size, | |
| n_atom).long().to(device) | |
| proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1)) | |
| proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1)) | |
| assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist) | |
| log_p = torch.log(dist[batch_range, act]) | |
| if len(weight.shape) == 1: | |
| weight = weight.unsqueeze(-1) | |
| td_error_per_sample = -(log_p * proj_dist).sum(-1) | |
| loss = -(log_p * proj_dist * weight).sum(-1).mean() | |
| return loss, td_error_per_sample | |
| v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight']) | |
| def v_1step_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa | |
| ) -> torch.Tensor: | |
| ''' | |
| Overview: | |
| 1 step td_error for distributed value based algorithm | |
| Arguments: | |
| - data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`v_1step_td_data`): the v_1step_td_data containing\ | |
| ['v', 'next_v', 'reward', 'done', 'weight'] | |
| - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] | |
| - next_v (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(, B)` | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| Examples: | |
| >>> v = torch.randn(5).requires_grad_(True) | |
| >>> next_v = torch.randn(5) | |
| >>> reward = torch.rand(5) | |
| >>> done = torch.zeros(5) | |
| >>> data = v_1step_td_data(v, next_v, reward, done, None) | |
| >>> loss, td_error_per_sample = v_1step_td_error(data, 0.99) | |
| ''' | |
| v, next_v, reward, done, weight = data | |
| if weight is None: | |
| weight = torch.ones_like(v) | |
| if len(v.shape) == len(reward.shape): | |
| if done is not None: | |
| target_v = gamma * (1 - done) * next_v + reward | |
| else: | |
| target_v = gamma * next_v + reward | |
| else: | |
| if done is not None: | |
| target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1) | |
| else: | |
| target_v = gamma * next_v + reward.unsqueeze(1) | |
| td_error_per_sample = criterion(v, target_v.detach()) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample | |
| v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']) | |
| def v_nstep_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| nstep: int = 1, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa | |
| ) -> torch.Tensor: | |
| r""" | |
| Overview: | |
| Multistep (n step) td_error for distributed value based algorithm | |
| Arguments: | |
| - data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\ | |
| ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'] | |
| - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ] | |
| - next_v (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| - value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\ | |
| we use value_gamma as the gamma discount value for next_v rather than gamma**n_step | |
| Examples: | |
| >>> v = torch.randn(5).requires_grad_(True) | |
| >>> next_v = torch.randn(5) | |
| >>> reward = torch.rand(5, 5) | |
| >>> done = torch.zeros(5) | |
| >>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) | |
| >>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) | |
| """ | |
| v, next_n_v, reward, done, weight, value_gamma = data | |
| if weight is None: | |
| weight = torch.ones_like(v) | |
| target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma) | |
| td_error_per_sample = criterion(v, target_v.detach()) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample | |
| q_nstep_td_data = namedtuple( | |
| 'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'] | |
| ) | |
| dqfd_nstep_td_data = namedtuple( | |
| 'dqfd_nstep_td_data', [ | |
| 'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step', | |
| 'next_n_action_one_step', 'is_expert' | |
| ] | |
| ) | |
| def shape_fn_qntd(args, kwargs): | |
| r""" | |
| Overview: | |
| Return qntd shape for hpc | |
| Returns: | |
| shape: [T, B, N] | |
| """ | |
| if len(args) <= 0: | |
| tmp = [kwargs['data'].reward.shape[0]] | |
| tmp.extend(list(kwargs['data'].q.shape)) | |
| else: | |
| tmp = [args[0].reward.shape[0]] | |
| tmp.extend(list(args[0].q.shape)) | |
| return tmp | |
| def q_nstep_td_error( | |
| data: namedtuple, | |
| gamma: Union[float, list], | |
| nstep: int = 1, | |
| cum_reward: bool = False, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none'), | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error for q-learning based algorithm | |
| Arguments: | |
| - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data | |
| - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| Examples: | |
| >>> next_q = torch.randn(4, 3) | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> nstep =3 | |
| >>> q = torch.randn(4, 3).requires_grad_(True) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
| >>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, weight = data | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| if len(action.shape) == 1: # single agent case | |
| action = action.unsqueeze(-1) | |
| elif len(action.shape) > 1: # MARL case | |
| reward = reward.unsqueeze(-1) | |
| weight = weight.unsqueeze(-1) | |
| done = done.unsqueeze(-1) | |
| if value_gamma is not None: | |
| value_gamma = value_gamma.unsqueeze(-1) | |
| q_s_a = q.gather(-1, action).squeeze(-1) | |
| target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) | |
| td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample | |
| def bdq_nstep_td_error( | |
| data: namedtuple, | |
| gamma: Union[float, list], | |
| nstep: int = 1, | |
| cum_reward: bool = False, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none'), | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \ | |
| Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946. | |
| In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \ | |
| calculation method of n-step, i.e., TD-error: | |
| Arguments: | |
| - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data | |
| - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, D)` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| Examples: | |
| >>> action_per_branch = 3 | |
| >>> next_q = torch.randn(8, 6, action_per_branch) | |
| >>> done = torch.randn(8) | |
| >>> action = torch.randint(0, action_per_branch, size=(8, 6)) | |
| >>> next_action = torch.randint(0, action_per_branch, size=(8, 6)) | |
| >>> nstep =3 | |
| >>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True) | |
| >>> reward = torch.rand(nstep, 8) | |
| >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
| >>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, weight = data | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| reward = reward.unsqueeze(-1) | |
| done = done.unsqueeze(-1) | |
| if value_gamma is not None: | |
| value_gamma = value_gamma.unsqueeze(-1) | |
| q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1) | |
| target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1) | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) | |
| td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) | |
| td_error_per_sample = td_error_per_sample.mean(-1) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample | |
| def shape_fn_qntd_rescale(args, kwargs): | |
| r""" | |
| Overview: | |
| Return qntd_rescale shape for hpc | |
| Returns: | |
| shape: [T, B, N] | |
| """ | |
| if len(args) <= 0: | |
| tmp = [kwargs['data'].reward.shape[0]] | |
| tmp.extend(list(kwargs['data'].q.shape)) | |
| else: | |
| tmp = [args[0].reward.shape[0]] | |
| tmp.extend(list(args[0].q.shape)) | |
| return tmp | |
| def q_nstep_td_error_with_rescale( | |
| data: namedtuple, | |
| gamma: Union[float, list], | |
| nstep: int = 1, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none'), | |
| trans_fn: Callable = value_transform, | |
| inv_trans_fn: Callable = value_inv_transform, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error with value rescaling | |
| Arguments: | |
| - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\ | |
| (refer to rl_utils/value_rescale.py) | |
| - inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\ | |
| (refer to rl_utils/value_rescale.py) | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| Examples: | |
| >>> next_q = torch.randn(4, 3) | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> nstep =3 | |
| >>> q = torch.randn(4, 3).requires_grad_(True) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
| >>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, weight = data | |
| assert len(action.shape) == 1, action.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_range = torch.arange(action.shape[0]) | |
| q_s_a = q[batch_range, action] | |
| target_q_s_a = next_n_q[batch_range, next_n_action] | |
| target_q_s_a = inv_trans_fn(target_q_s_a) | |
| target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) | |
| target_q_s_a = trans_fn(target_q_s_a) | |
| td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample | |
| def dqfd_nstep_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| lambda_n_step_td: float, | |
| lambda_supervised_loss: float, | |
| margin_function: float, | |
| lambda_one_step_td: float = 1., | |
| nstep: int = 1, | |
| cum_reward: bool = False, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none'), | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd | |
| Arguments: | |
| - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): discount factor | |
| - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data | |
| - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - nstep (:obj:`int`): nstep num, default set to 10 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor | |
| - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ | |
| + supervised margin loss, 1-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ | |
| , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - is_expert (:obj:`int`) : 0 or 1 | |
| Examples: | |
| >>> next_q = torch.randn(4, 3) | |
| >>> done = torch.randn(4) | |
| >>> done_1 = torch.randn(4) | |
| >>> next_q_one_step = torch.randn(4, 3) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action_one_step = torch.randint(0, 3, size=(4, )) | |
| >>> is_expert = torch.ones((4)) | |
| >>> nstep = 3 | |
| >>> q = torch.randn(4, 3).requires_grad_(True) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = dqfd_nstep_td_data( | |
| >>> q, next_q, action, next_action, reward, done, done_1, None, | |
| >>> next_q_one_step, next_action_one_step, is_expert | |
| >>> ) | |
| >>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( | |
| >>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, | |
| >>> margin_function=0.8, nstep=nstep | |
| >>> ) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ | |
| is_expert = data # set is_expert flag(expert 1, agent 0) | |
| assert len(action.shape) == 1, action.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_range = torch.arange(action.shape[0]) | |
| q_s_a = q[batch_range, action] | |
| target_q_s_a = next_n_q[batch_range, next_n_action] | |
| target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] | |
| # calculate n-step TD-loss | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) | |
| td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) | |
| # calculate 1-step TD-loss | |
| nstep = 1 | |
| reward = reward[0].unsqueeze(0) # get the one-step reward | |
| value_gamma = None | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) | |
| else: | |
| target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) | |
| else: | |
| target_q_s_a_one_step = nstep_return( | |
| nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma | |
| ) | |
| td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) | |
| device = q_s_a.device | |
| device_cpu = torch.device('cpu') | |
| # calculate the supervised loss | |
| l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, ) | |
| l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) | |
| # along the first dimension. for the index of the action, fill the corresponding position in l with 0 | |
| JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) | |
| return ( | |
| ( | |
| ( | |
| lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + | |
| lambda_supervised_loss * JE | |
| ) * weight | |
| ).mean(), lambda_n_step_td * td_error_per_sample.abs() + | |
| lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), | |
| (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) | |
| ) | |
| def dqfd_nstep_td_error_with_rescale( | |
| data: namedtuple, | |
| gamma: float, | |
| lambda_n_step_td: float, | |
| lambda_supervised_loss: float, | |
| lambda_one_step_td: float, | |
| margin_function: float, | |
| nstep: int = 1, | |
| cum_reward: bool = False, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none'), | |
| trans_fn: Callable = value_transform, | |
| inv_trans_fn: Callable = value_inv_transform, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd | |
| Arguments: | |
| - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data | |
| - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - nstep (:obj:`int`): nstep num, default set to 10 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor | |
| - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\ | |
| + supervised margin loss, 1-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\ | |
| , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - is_expert (:obj:`int`) : 0 or 1 | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \ | |
| is_expert = data # set is_expert flag(expert 1, agent 0) | |
| assert len(action.shape) == 1, action.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_range = torch.arange(action.shape[0]) | |
| q_s_a = q[batch_range, action] | |
| target_q_s_a = next_n_q[batch_range, next_n_action] | |
| target_q_s_a = inv_trans_fn(target_q_s_a) # rescale | |
| target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step] | |
| target_q_s_a_one_step = inv_trans_fn(target_q_s_a_one_step) # rescale | |
| # calculate n-step TD-loss | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done) | |
| else: | |
| target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done) | |
| else: | |
| # to use value_gamma in n-step TD-loss | |
| target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma) | |
| target_q_s_a = trans_fn(target_q_s_a) # rescale | |
| td_error_per_sample = criterion(q_s_a, target_q_s_a.detach()) | |
| # calculate 1-step TD-loss | |
| nstep = 1 | |
| reward = reward[0].unsqueeze(0) # get the one-step reward | |
| value_gamma = None # This is very important, to use gamma in 1-step TD-loss | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step) | |
| else: | |
| target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step) | |
| else: | |
| target_q_s_a_one_step = nstep_return( | |
| nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma | |
| ) | |
| target_q_s_a_one_step = trans_fn(target_q_s_a_one_step) # rescale | |
| td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach()) | |
| device = q_s_a.device | |
| device_cpu = torch.device('cpu') | |
| # calculate the supervised loss | |
| l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, ) | |
| l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu)) | |
| # along the first dimension. for the index of the action, fill the corresponding position in l with 0 | |
| JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a) | |
| return ( | |
| ( | |
| ( | |
| lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample + | |
| lambda_supervised_loss * JE | |
| ) * weight | |
| ).mean(), lambda_n_step_td * td_error_per_sample.abs() + | |
| lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(), | |
| (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean()) | |
| ) | |
| qrdqn_nstep_td_data = namedtuple( | |
| 'qrdqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'tau', 'weight'] | |
| ) | |
| def qrdqn_nstep_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| nstep: int = 1, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error with in QRDQN | |
| Arguments: | |
| - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| Examples: | |
| >>> next_q = torch.randn(4, 3, 3) | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> nstep = 3 | |
| >>> q = torch.randn(4, 3, 3).requires_grad_(True) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None) | |
| >>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, tau, weight = data | |
| assert len(action.shape) == 1, action.shape | |
| assert len(next_n_action.shape) == 1, next_n_action.shape | |
| assert len(done.shape) == 1, done.shape | |
| assert len(q.shape) == 3, q.shape | |
| assert len(next_n_q.shape) == 3, next_n_q.shape | |
| assert len(reward.shape) == 2, reward.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_range = torch.arange(action.shape[0]) | |
| # shape: batch_size x num x 1 | |
| q_s_a = q[batch_range, action, :].unsqueeze(2) | |
| # shape: batch_size x 1 x num | |
| target_q_s_a = next_n_q[batch_range, next_n_action, :].unsqueeze(1) | |
| assert reward.shape[0] == nstep | |
| reward_factor = torch.ones(nstep).to(reward) | |
| for i in range(1, nstep): | |
| reward_factor[i] = gamma * reward_factor[i - 1] | |
| # shape: batch_size | |
| reward = torch.matmul(reward_factor, reward) | |
| # shape: batch_size x 1 x num | |
| if value_gamma is None: | |
| target_q_s_a = reward.unsqueeze(-1).unsqueeze(-1) + (gamma ** nstep | |
| ) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1) | |
| else: | |
| target_q_s_a = reward.unsqueeze(-1).unsqueeze( | |
| -1 | |
| ) + value_gamma.unsqueeze(-1).unsqueeze(-1) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1) | |
| # shape: batch_size x num x num | |
| u = F.smooth_l1_loss(target_q_s_a, q_s_a, reduction="none") | |
| # shape: batch_size | |
| loss = (u * (tau - (target_q_s_a - q_s_a).detach().le(0.).float()).abs()).sum(-1).mean(1) | |
| return (loss * weight).mean(), loss | |
| def q_nstep_sql_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| alpha: float, | |
| nstep: int = 1, | |
| cum_reward: bool = False, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| criterion: torch.nn.modules = nn.MSELoss(reduction='none'), | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error for q-learning based algorithm | |
| Arguments: | |
| - data (:obj:`q_nstep_td_data`): The input data, q_nstep_sql_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - Alpha (:obj:`float`): A parameter to weight entropy term in a policy equation | |
| - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data | |
| - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target soft_q_value | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| Examples: | |
| >>> next_q = torch.randn(4, 3) | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> nstep = 3 | |
| >>> q = torch.randn(4, 3).requires_grad_(True) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
| >>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, weight = data | |
| assert len(action.shape) == 1, action.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_range = torch.arange(action.shape[0]) | |
| q_s_a = q[batch_range, action] | |
| # target_q_s_a = next_n_q[batch_range, next_n_action] | |
| target_v = alpha * torch.logsumexp( | |
| next_n_q / alpha, 1 | |
| ) # target_v = alpha * torch.log(torch.sum(torch.exp(next_n_q / alpha), 1)) | |
| target_v[target_v == float("Inf")] = 20 | |
| target_v[target_v == float("-Inf")] = -20 | |
| # For an appropriate hyper-parameter alpha, these hardcodes can be removed. | |
| # However, algorithms may face the danger of explosion for other alphas. | |
| # The hardcodes above are to prevent this situation from happening | |
| record_target_v = copy.deepcopy(target_v) | |
| # print(target_v) | |
| if cum_reward: | |
| if value_gamma is None: | |
| target_v = reward + (gamma ** nstep) * target_v * (1 - done) | |
| else: | |
| target_v = reward + value_gamma * target_v * (1 - done) | |
| else: | |
| target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma) | |
| td_error_per_sample = criterion(q_s_a, target_v.detach()) | |
| return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v | |
| iqn_nstep_td_data = namedtuple( | |
| 'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight'] | |
| ) | |
| def iqn_nstep_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| nstep: int = 1, | |
| kappa: float = 1.0, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error with in IQN, \ | |
| referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning \ | |
| <https://arxiv.org/pdf/1806.06923.pdf> | |
| Arguments: | |
| - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - beta_function (:obj:`Callable`): The risk function | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| Examples: | |
| >>> next_q = torch.randn(3, 4, 3) | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> nstep = 3 | |
| >>> q = torch.randn(3, 4, 3).requires_grad_(True) | |
| >>> replay_quantile = torch.randn([3, 4, 1]) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) | |
| >>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, replay_quantiles, weight = data | |
| assert len(action.shape) == 1, action.shape | |
| assert len(next_n_action.shape) == 1, next_n_action.shape | |
| assert len(done.shape) == 1, done.shape | |
| assert len(q.shape) == 3, q.shape | |
| assert len(next_n_q.shape) == 3, next_n_q.shape | |
| assert len(reward.shape) == 2, reward.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_size = done.shape[0] | |
| tau = q.shape[0] | |
| tau_prime = next_n_q.shape[0] | |
| action = action.repeat([tau, 1]).unsqueeze(-1) | |
| next_n_action = next_n_action.repeat([tau_prime, 1]).unsqueeze(-1) | |
| # shape: batch_size x tau x a | |
| q_s_a = torch.gather(q, -1, action).permute([1, 0, 2]) | |
| # shape: batch_size x tau_prim x 1 | |
| target_q_s_a = torch.gather(next_n_q, -1, next_n_action).permute([1, 0, 2]) | |
| assert reward.shape[0] == nstep | |
| device = torch.device("cuda" if reward.is_cuda else "cpu") | |
| reward_factor = torch.ones(nstep).to(device) | |
| for i in range(1, nstep): | |
| reward_factor[i] = gamma * reward_factor[i - 1] | |
| reward = torch.matmul(reward_factor, reward) | |
| if value_gamma is None: | |
| target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1) | |
| else: | |
| target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done | |
| ).unsqueeze(-1) | |
| target_q_s_a = target_q_s_a.unsqueeze(-1) | |
| # shape: batch_size x tau' x tau x 1. | |
| bellman_errors = (target_q_s_a[:, :, None, :] - q_s_a[:, None, :, :]) | |
| # The huber loss (see Section 2.3 of the paper) is defined via two cases: | |
| huber_loss = torch.where( | |
| bellman_errors.abs() <= kappa, 0.5 * bellman_errors ** 2, kappa * (bellman_errors.abs() - 0.5 * kappa) | |
| ) | |
| # Reshape replay_quantiles to batch_size x num_tau_samples x 1 | |
| replay_quantiles = replay_quantiles.reshape([tau, batch_size, 1]).permute([1, 0, 2]) | |
| # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. | |
| replay_quantiles = replay_quantiles[:, None, :, :].repeat([1, tau_prime, 1, 1]) | |
| # shape: batch_size x tau_prime x tau x 1. | |
| quantile_huber_loss = (torch.abs(replay_quantiles - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa | |
| # shape: batch_size | |
| loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0] | |
| return (loss * weight).mean(), loss | |
| fqf_nstep_td_data = namedtuple( | |
| 'fqf_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'quantiles_hats', 'weight'] | |
| ) | |
| def fqf_nstep_td_error( | |
| data: namedtuple, | |
| gamma: float, | |
| nstep: int = 1, | |
| kappa: float = 1.0, | |
| value_gamma: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Multistep (1 step or n step) td_error with in FQF, \ | |
| referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ | |
| <https://arxiv.org/pdf/1911.02140.pdf> | |
| Arguments: | |
| - data (:obj:`fqf_nstep_td_data`): The input data, fqf_nstep_td_data to calculate loss | |
| - gamma (:obj:`float`): Discount factor | |
| - nstep (:obj:`int`): nstep num, default set to 1 | |
| - criterion (:obj:`torch.nn.modules`): Loss function criterion | |
| - beta_function (:obj:`Callable`): The risk function | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor | |
| Shapes: | |
| - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\ | |
| ['q', 'next_n_q', 'action', 'reward', 'done'] | |
| - q (:obj:`torch.FloatTensor`): :math:`(B, tau, N)` i.e. [batch_size, tau, action_dim] | |
| - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, tau', N)` | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep) | |
| - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep | |
| - quantiles_hats (:obj:`torch.FloatTensor`): :math:`(B, tau)` | |
| Examples: | |
| >>> next_q = torch.randn(4, 3, 3) | |
| >>> done = torch.randn(4) | |
| >>> action = torch.randint(0, 3, size=(4, )) | |
| >>> next_action = torch.randint(0, 3, size=(4, )) | |
| >>> nstep = 3 | |
| >>> q = torch.randn(4, 3, 3).requires_grad_(True) | |
| >>> quantiles_hats = torch.randn([4, 3]) | |
| >>> reward = torch.rand(nstep, 4) | |
| >>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) | |
| >>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep) | |
| """ | |
| q, next_n_q, action, next_n_action, reward, done, quantiles_hats, weight = data | |
| assert len(action.shape) == 1, action.shape | |
| assert len(next_n_action.shape) == 1, next_n_action.shape | |
| assert len(done.shape) == 1, done.shape | |
| assert len(q.shape) == 3, q.shape | |
| assert len(next_n_q.shape) == 3, next_n_q.shape | |
| assert len(reward.shape) == 2, reward.shape | |
| if weight is None: | |
| weight = torch.ones_like(action) | |
| batch_size = done.shape[0] | |
| tau = q.shape[1] | |
| tau_prime = next_n_q.shape[1] | |
| # shape: batch_size x tau x 1 | |
| q_s_a = evaluate_quantile_at_action(q, action) | |
| # shape: batch_size x tau_prime x 1 | |
| target_q_s_a = evaluate_quantile_at_action(next_n_q, next_n_action) | |
| assert reward.shape[0] == nstep | |
| reward_factor = torch.ones(nstep).to(reward.device) | |
| for i in range(1, nstep): | |
| reward_factor[i] = gamma * reward_factor[i - 1] | |
| reward = torch.matmul(reward_factor, reward) # [batch_size] | |
| if value_gamma is None: | |
| target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1) | |
| else: | |
| target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done | |
| ).unsqueeze(-1) | |
| target_q_s_a = target_q_s_a.unsqueeze(-1) | |
| # shape: batch_size x tau' x tau x 1. | |
| bellman_errors = (target_q_s_a.unsqueeze(2) - q_s_a.unsqueeze(1)) | |
| # shape: batch_size x tau' x tau x 1 | |
| huber_loss = F.smooth_l1_loss(target_q_s_a.unsqueeze(2), q_s_a.unsqueeze(1), reduction="none") | |
| # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1. | |
| quantiles_hats = quantiles_hats[:, None, :, None].repeat([1, tau_prime, 1, 1]) | |
| # shape: batch_size x tau_prime x tau x 1. | |
| quantile_huber_loss = (torch.abs(quantiles_hats - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa | |
| # shape: batch_size | |
| loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0] | |
| return (loss * weight).mean(), loss | |
| def evaluate_quantile_at_action(q_s, actions): | |
| assert q_s.shape[0] == actions.shape[0] | |
| batch_size, num_quantiles = q_s.shape[:2] | |
| # Expand actions into (batch_size, num_quantiles, 1). | |
| action_index = actions[:, None, None].expand(batch_size, num_quantiles, 1) | |
| # Calculate quantile values at specified actions. | |
| q_s_a = q_s.gather(dim=2, index=action_index) | |
| return q_s_a | |
| def fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions): | |
| """ | |
| Overview: | |
| Calculate the fraction loss in FQF, \ | |
| referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \ | |
| <https://arxiv.org/pdf/1911.02140.pdf> | |
| Arguments: | |
| - q_tau_i (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles-1, action_dim)` | |
| - q_value (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles, action_dim)` | |
| - quantiles (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles+1)` | |
| - actions (:obj:`torch.LongTensor`): :math:`(batch_size, )` | |
| Returns: | |
| - fraction_loss (:obj:`torch.Tensor`): fraction loss, 0-dim tensor | |
| """ | |
| assert q_value.requires_grad | |
| batch_size = q_value.shape[0] | |
| num_quantiles = q_value.shape[1] | |
| with torch.no_grad(): | |
| sa_quantiles = evaluate_quantile_at_action(q_tau_i, actions) | |
| assert sa_quantiles.shape == (batch_size, num_quantiles - 1, 1) | |
| q_s_a_hats = evaluate_quantile_at_action(q_value, actions) # [batch_size, num_quantiles, 1] | |
| assert q_s_a_hats.shape == (batch_size, num_quantiles, 1) | |
| assert not q_s_a_hats.requires_grad | |
| # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing. | |
| # I relax this requirements and calculate gradients of quantiles even when | |
| # F^{-1} is not non-decreasing. | |
| values_1 = sa_quantiles - q_s_a_hats[:, :-1] | |
| signs_1 = sa_quantiles > torch.cat([q_s_a_hats[:, :1], sa_quantiles[:, :-1]], dim=1) | |
| assert values_1.shape == signs_1.shape | |
| values_2 = sa_quantiles - q_s_a_hats[:, 1:] | |
| signs_2 = sa_quantiles < torch.cat([sa_quantiles[:, 1:], q_s_a_hats[:, -1:]], dim=1) | |
| assert values_2.shape == signs_2.shape | |
| gradient_of_taus = (torch.where(signs_1, values_1, -values_1) + | |
| torch.where(signs_2, values_2, -values_2)).view(batch_size, num_quantiles - 1) | |
| assert not gradient_of_taus.requires_grad | |
| assert gradient_of_taus.shape == quantiles[:, 1:-1].shape | |
| # Gradients of the network parameters and corresponding loss | |
| # are calculated using chain rule. | |
| fraction_loss = (gradient_of_taus * quantiles[:, 1:-1]).sum(dim=1).mean() | |
| return fraction_loss | |
| td_lambda_data = namedtuple('td_lambda_data', ['value', 'reward', 'weight']) | |
| def shape_fn_td_lambda(args, kwargs): | |
| r""" | |
| Overview: | |
| Return td_lambda shape for hpc | |
| Returns: | |
| shape: [T, B] | |
| """ | |
| if len(args) <= 0: | |
| tmp = kwargs['data'].reward.shape[0] | |
| else: | |
| tmp = args[0].reward.shape | |
| return tmp | |
| def td_lambda_error(data: namedtuple, gamma: float = 0.9, lambda_: float = 0.8) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Computing TD(lambda) loss given constant gamma and lambda. | |
| There is no special handling for terminal state value, | |
| if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal | |
| (*including the terminal state*, values[terminal] should also be 0) | |
| Arguments: | |
| - data (:obj:`namedtuple`): td_lambda input data with fields ['value', 'reward', 'weight'] | |
| - gamma (:obj:`float`): Constant discount factor gamma, should be in [0, 1], defaults to 0.9 | |
| - lambda (:obj:`float`): Constant lambda, should be in [0, 1], defaults to 0.8 | |
| Returns: | |
| - loss (:obj:`torch.Tensor`): Computed MSE loss, averaged over the batch | |
| Shapes: | |
| - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch,\ | |
| which is the estimation of the state value at step 0 to T | |
| - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, the returns from time step 0 to T-1 | |
| - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight | |
| - loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor | |
| Examples: | |
| >>> T, B = 8, 4 | |
| >>> value = torch.randn(T + 1, B).requires_grad_(True) | |
| >>> reward = torch.rand(T, B) | |
| >>> loss = td_lambda_error(td_lambda_data(value, reward, None)) | |
| """ | |
| value, reward, weight = data | |
| if weight is None: | |
| weight = torch.ones_like(reward) | |
| with torch.no_grad(): | |
| return_ = generalized_lambda_returns(value, reward, gamma, lambda_) | |
| # discard the value at T as it should be considered in the next slice | |
| loss = 0.5 * (F.mse_loss(return_, value[:-1], reduction='none') * weight).mean() | |
| return loss | |
| def generalized_lambda_returns( | |
| bootstrap_values: torch.Tensor, | |
| rewards: torch.Tensor, | |
| gammas: float, | |
| lambda_: float, | |
| done: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| r""" | |
| Overview: | |
| Functional equivalent to trfl.value_ops.generalized_lambda_returns | |
| https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74 | |
| Passing in a number instead of tensor to make the value constant for all samples in batch | |
| Arguments: | |
| - bootstrap_values (:obj:`torch.Tensor` or :obj:`float`): | |
| estimation of the value at step 0 to *T*, of size [T_traj+1, batchsize] | |
| - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize] | |
| - gammas (:obj:`torch.Tensor` or :obj:`float`): | |
| Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] | |
| - lambda (:obj:`torch.Tensor` or :obj:`float`): Determining the mix of bootstrapping | |
| vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize] | |
| - done (:obj:`torch.Tensor` or :obj:`float`): | |
| Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] | |
| Returns: | |
| - return (:obj:`torch.Tensor`): Computed lambda return value | |
| for each state from 0 to T-1, of size [T_traj, batchsize] | |
| """ | |
| if not isinstance(gammas, torch.Tensor): | |
| gammas = gammas * torch.ones_like(rewards) | |
| if not isinstance(lambda_, torch.Tensor): | |
| lambda_ = lambda_ * torch.ones_like(rewards) | |
| bootstrap_values_tp1 = bootstrap_values[1:, :] | |
| return multistep_forward_view(bootstrap_values_tp1, rewards, gammas, lambda_, done) | |
| def multistep_forward_view( | |
| bootstrap_values: torch.Tensor, | |
| rewards: torch.Tensor, | |
| gammas: float, | |
| lambda_: float, | |
| done: Optional[torch.Tensor] = None | |
| ) -> torch.Tensor: | |
| r""" | |
| Overview: | |
| Same as trfl.sequence_ops.multistep_forward_view | |
| Implementing (12.18) in Sutton & Barto | |
| ``` | |
| result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T] | |
| for t in 0...T-2 : | |
| result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1]) | |
| ``` | |
| Assuming the first dim of input tensors correspond to the index in batch | |
| Arguments: | |
| - bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize] | |
| - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize] | |
| - gammas (:obj:`torch.Tensor`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize] | |
| - lambda (:obj:`torch.Tensor`): Determining the mix of bootstrapping vs further accumulation of \ | |
| multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored \ | |
| and effectively set to 0, as there is no information about future rewards. | |
| - done (:obj:`torch.Tensor` or :obj:`float`): | |
| Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize] | |
| Returns: | |
| - ret (:obj:`torch.Tensor`): Computed lambda return value \ | |
| for each state from 0 to T-1, of size [T_traj, batchsize] | |
| """ | |
| result = torch.empty_like(rewards) | |
| if done is None: | |
| done = torch.zeros_like(rewards) | |
| # Forced cutoff at the last one | |
| result[-1, :] = rewards[-1, :] + (1 - done[-1, :]) * gammas[-1, :] * bootstrap_values[-1, :] | |
| discounts = gammas * lambda_ | |
| for t in reversed(range(rewards.size()[0] - 1)): | |
| result[t, :] = rewards[t, :] + (1 - done[t, :]) * \ | |
| ( | |
| discounts[t, :] * result[t + 1, :] + | |
| (gammas[t, :] - discounts[t, :]) * bootstrap_values[t, :] | |
| ) | |
| return result | |