Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| from collections import namedtuple | |
| from ding.rl_utils.isw import compute_importance_weights | |
| def compute_q_retraces( | |
| q_values: torch.Tensor, | |
| v_pred: torch.Tensor, | |
| rewards: torch.Tensor, | |
| actions: torch.Tensor, | |
| weights: torch.Tensor, | |
| ratio: torch.Tensor, | |
| gamma: float = 0.9 | |
| ) -> torch.Tensor: | |
| """ | |
| Shapes: | |
| - q_values (:obj:`torch.Tensor`): :math:`(T + 1, B, N)`, where T is unroll_len, B is batch size, N is discrete \ | |
| action dim. | |
| - v_pred (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)` | |
| - rewards (:obj:`torch.Tensor`): :math:`(T, B)` | |
| - actions (:obj:`torch.Tensor`): :math:`(T, B)` | |
| - weights (:obj:`torch.Tensor`): :math:`(T, B)` | |
| - ratio (:obj:`torch.Tensor`): :math:`(T, B, N)` | |
| - q_retraces (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)` | |
| Examples: | |
| >>> T=2 | |
| >>> B=3 | |
| >>> N=4 | |
| >>> q_values=torch.randn(T+1, B, N) | |
| >>> v_pred=torch.randn(T+1, B, 1) | |
| >>> rewards=torch.randn(T, B) | |
| >>> actions=torch.randint(0, N, (T, B)) | |
| >>> weights=torch.ones(T, B) | |
| >>> ratio=torch.randn(T, B, N) | |
| >>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio) | |
| .. note:: | |
| q_retrace operation doesn't need to compute gradient, just executes forward computation. | |
| """ | |
| T = q_values.size()[0] - 1 | |
| rewards = rewards.unsqueeze(-1) | |
| actions = actions.unsqueeze(-1) | |
| weights = weights.unsqueeze(-1) | |
| q_retraces = torch.zeros_like(v_pred) # shape (T+1),B,1 | |
| tmp_retraces = v_pred[-1] # shape B,1 | |
| q_retraces[-1] = v_pred[-1] | |
| q_gather = torch.zeros_like(v_pred) | |
| q_gather[0:-1] = q_values[0:-1].gather(-1, actions) # shape (T+1),B,1 | |
| ratio_gather = ratio.gather(-1, actions) # shape T,B,1 | |
| for idx in reversed(range(T)): | |
| q_retraces[idx] = rewards[idx] + gamma * weights[idx] * tmp_retraces | |
| tmp_retraces = ratio_gather[idx].clamp(max=1.0) * (q_retraces[idx] - q_gather[idx]) + v_pred[idx] | |
| return q_retraces # shape (T+1),B,1 | |