Spaces:
Sleeping
Sleeping
| import torch | |
| import treetensor.torch as ttorch | |
| from torch.distributions import Normal, Independent | |
| class ArgmaxSampler: | |
| ''' | |
| Overview: | |
| Argmax sampler, return the index of the maximum value | |
| ''' | |
| def __call__(self, logit: torch.Tensor) -> torch.Tensor: | |
| ''' | |
| Overview: | |
| Return the index of the maximum value | |
| Arguments: | |
| - logit (:obj:`torch.Tensor`): The input tensor | |
| Returns: | |
| - action (:obj:`torch.Tensor`): The index of the maximum value | |
| ''' | |
| return logit.argmax(dim=-1) | |
| class MultinomialSampler: | |
| ''' | |
| Overview: | |
| Multinomial sampler, return the index of the sampled value | |
| ''' | |
| def __call__(self, logit: torch.Tensor) -> torch.Tensor: | |
| ''' | |
| Overview: | |
| Return the index of the sampled value | |
| Arguments: | |
| - logit (:obj:`torch.Tensor`): The input tensor | |
| Returns: | |
| - action (:obj:`torch.Tensor`): The index of the sampled value | |
| ''' | |
| dist = torch.distributions.Categorical(logits=logit) | |
| return dist.sample() | |
| class MuSampler: | |
| ''' | |
| Overview: | |
| Mu sampler, return the mu of the input tensor | |
| ''' | |
| def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: | |
| ''' | |
| Overview: | |
| Return the mu of the input tensor | |
| Arguments: | |
| - logit (:obj:`ttorch.Tensor`): The input tensor | |
| Returns: | |
| - action (:obj:`torch.Tensor`): The mu of the input tensor | |
| ''' | |
| return logit.mu | |
| class ReparameterizationSampler: | |
| ''' | |
| Overview: | |
| Reparameterization sampler, return the reparameterized value of the input tensor | |
| ''' | |
| def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: | |
| ''' | |
| Overview: | |
| Return the reparameterized value of the input tensor | |
| Arguments: | |
| - logit (:obj:`ttorch.Tensor`): The input tensor | |
| Returns: | |
| - action (:obj:`torch.Tensor`): The reparameterized value of the input tensor | |
| ''' | |
| dist = Normal(logit.mu, logit.sigma) | |
| dist = Independent(dist, 1) | |
| return dist.rsample() | |
| class HybridStochasticSampler: | |
| ''' | |
| Overview: | |
| Hybrid stochastic sampler, return the sampled action type and the reparameterized action args | |
| ''' | |
| def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: | |
| ''' | |
| Overview: | |
| Return the sampled action type and the reparameterized action args | |
| Arguments: | |
| - logit (:obj:`ttorch.Tensor`): The input tensor | |
| Returns: | |
| - action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args | |
| ''' | |
| dist = torch.distributions.Categorical(logits=logit.action_type) | |
| action_type = dist.sample() | |
| dist = Normal(logit.action_args.mu, logit.action_args.sigma) | |
| dist = Independent(dist, 1) | |
| action_args = dist.rsample() | |
| return ttorch.as_tensor({ | |
| 'action_type': action_type, | |
| 'action_args': action_args, | |
| }) | |
| class HybridDeterminsticSampler: | |
| ''' | |
| Overview: | |
| Hybrid deterministic sampler, return the argmax action type and the mu action args | |
| ''' | |
| def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: | |
| ''' | |
| Overview: | |
| Return the argmax action type and the mu action args | |
| Arguments: | |
| - logit (:obj:`ttorch.Tensor`): The input tensor | |
| Returns: | |
| - action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args | |
| ''' | |
| action_type = logit.action_type.argmax(dim=-1) | |
| action_args = logit.action_args.mu | |
| return ttorch.as_tensor({ | |
| 'action_type': action_type, | |
| 'action_args': action_args, | |
| }) | |