Spaces:
Sleeping
Sleeping
| from collections import namedtuple | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.distributions import Independent, Normal | |
| a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight']) | |
| a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss']) | |
| def a2c_error(data: namedtuple) -> namedtuple: | |
| """ | |
| Overview: | |
| Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space | |
| Arguments: | |
| - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` | |
| Returns: | |
| - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor | |
| Shapes: | |
| - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim | |
| - action (:obj:`torch.LongTensor`): :math:`(B, )` | |
| - value (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - adv (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - return (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` | |
| - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor | |
| - value_loss (:obj:`torch.FloatTensor`): :math:`()` | |
| - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` | |
| Examples: | |
| >>> data = a2c_data( | |
| >>> logit=torch.randn(2, 3), | |
| >>> action=torch.randint(0, 3, (2, )), | |
| >>> value=torch.randn(2, ), | |
| >>> adv=torch.randn(2, ), | |
| >>> return_=torch.randn(2, ), | |
| >>> weight=torch.ones(2, ), | |
| >>> ) | |
| >>> loss = a2c_error(data) | |
| """ | |
| logit, action, value, adv, return_, weight = data | |
| if weight is None: | |
| weight = torch.ones_like(value) | |
| dist = torch.distributions.categorical.Categorical(logits=logit) | |
| logp = dist.log_prob(action) | |
| entropy_loss = (dist.entropy() * weight).mean() | |
| policy_loss = -(logp * adv * weight).mean() | |
| value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() | |
| return a2c_loss(policy_loss, value_loss, entropy_loss) | |
| def a2c_error_continuous(data: namedtuple) -> namedtuple: | |
| """ | |
| Overview: | |
| Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space | |
| Arguments: | |
| - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data`` | |
| Returns: | |
| - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor | |
| Shapes: | |
| - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim | |
| - action (:obj:`torch.LongTensor`): :math:`(B, N)` | |
| - value (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - adv (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - return (:obj:`torch.FloatTensor`): :math:`(B, )` | |
| - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )` | |
| - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor | |
| - value_loss (:obj:`torch.FloatTensor`): :math:`()` | |
| - entropy_loss (:obj:`torch.FloatTensor`): :math:`()` | |
| Examples: | |
| >>> data = a2c_data( | |
| >>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)}, | |
| >>> action=torch.randn(2, 3), | |
| >>> value=torch.randn(2, ), | |
| >>> adv=torch.randn(2, ), | |
| >>> return_=torch.randn(2, ), | |
| >>> weight=torch.ones(2, ), | |
| >>> ) | |
| >>> loss = a2c_error_continuous(data) | |
| """ | |
| logit, action, value, adv, return_, weight = data | |
| if weight is None: | |
| weight = torch.ones_like(value) | |
| dist = Independent(Normal(logit['mu'], logit['sigma']), 1) | |
| logp = dist.log_prob(action) | |
| entropy_loss = (dist.entropy() * weight).mean() | |
| policy_loss = -(logp * adv * weight).mean() | |
| value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean() | |
| return a2c_loss(policy_loss, value_loss, entropy_loss) | |