Spaces:
Sleeping
Sleeping
| """ | |
| this extremely minimal Decision Transformer model is based on | |
| the following causal transformer (GPT) implementation: | |
| Misha Laskin's tweet: | |
| https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA | |
| and its corresponding notebook: | |
| https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing | |
| ** the above colab notebook has a bug while applying masked_fill | |
| which is fixed in the following code | |
| """ | |
| import math | |
| from typing import Union, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from ding.utils import SequenceType | |
| class MaskedCausalAttention(nn.Module): | |
| """ | |
| Overview: | |
| The implementation of masked causal attention in decision transformer. The input of this module is a sequence \ | |
| of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 \ | |
| input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: | |
| """ | |
| Overview: | |
| Initialize the MaskedCausalAttention Model according to input arguments. | |
| Arguments: | |
| - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. | |
| - max_T (:obj:`int`): The max context length of the attention, such as 6. | |
| - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. | |
| - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. | |
| """ | |
| super().__init__() | |
| self.n_heads = n_heads | |
| self.max_T = max_T | |
| self.q_net = nn.Linear(h_dim, h_dim) | |
| self.k_net = nn.Linear(h_dim, h_dim) | |
| self.v_net = nn.Linear(h_dim, h_dim) | |
| self.proj_net = nn.Linear(h_dim, h_dim) | |
| self.att_drop = nn.Dropout(drop_p) | |
| self.proj_drop = nn.Dropout(drop_p) | |
| ones = torch.ones((max_T, max_T)) | |
| mask = torch.tril(ones).view(1, 1, max_T, max_T) | |
| # register buffer makes sure mask does not get updated | |
| # during backpropagation | |
| self.register_buffer('mask', mask) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| MaskedCausalAttention forward computation graph, input a sequence tensor \ | |
| and return a tensor with the same shape. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor. | |
| Returns: | |
| - out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. | |
| Examples: | |
| >>> inputs = torch.randn(2, 4, 64) | |
| >>> model = MaskedCausalAttention(64, 5, 4, 0.1) | |
| >>> outputs = model(inputs) | |
| >>> assert outputs.shape == torch.Size([2, 4, 64]) | |
| """ | |
| B, T, C = x.shape # batch size, seq length, h_dim * n_heads | |
| N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim | |
| # rearrange q, k, v as (B, N, T, D) | |
| q = self.q_net(x).view(B, T, N, D).transpose(1, 2) | |
| k = self.k_net(x).view(B, T, N, D).transpose(1, 2) | |
| v = self.v_net(x).view(B, T, N, D).transpose(1, 2) | |
| # weights (B, N, T, T) | |
| weights = q @ k.transpose(2, 3) / math.sqrt(D) | |
| # causal mask applied to weights | |
| weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf')) | |
| # normalize weights, all -inf -> 0 after softmax | |
| normalized_weights = F.softmax(weights, dim=-1) | |
| # attention (B, N, T, D) | |
| attention = self.att_drop(normalized_weights @ v) | |
| # gather heads and project (B, N, T, D) -> (B, T, N*D) | |
| attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) | |
| out = self.proj_drop(self.proj_net(attention)) | |
| return out | |
| class Block(nn.Module): | |
| """ | |
| Overview: | |
| The implementation of a transformer block in decision transformer. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: | |
| """ | |
| Overview: | |
| Initialize the Block Model according to input arguments. | |
| Arguments: | |
| - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. | |
| - max_T (:obj:`int`): The max context length of the attention, such as 6. | |
| - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. | |
| - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. | |
| """ | |
| super().__init__() | |
| self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(h_dim, 4 * h_dim), | |
| nn.GELU(), | |
| nn.Linear(4 * h_dim, h_dim), | |
| nn.Dropout(drop_p), | |
| ) | |
| self.ln1 = nn.LayerNorm(h_dim) | |
| self.ln2 = nn.LayerNorm(h_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Overview: | |
| Forward computation graph of the decision transformer block, input a sequence tensor \ | |
| and return a tensor with the same shape. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The input tensor. | |
| Returns: | |
| - output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. | |
| Examples: | |
| >>> inputs = torch.randn(2, 4, 64) | |
| >>> model = Block(64, 5, 4, 0.1) | |
| >>> outputs = model(inputs) | |
| >>> outputs.shape == torch.Size([2, 4, 64]) | |
| """ | |
| # Attention -> LayerNorm -> MLP -> LayerNorm | |
| x = x + self.attention(x) # residual | |
| x = self.ln1(x) | |
| x = x + self.mlp(x) # residual | |
| x = self.ln2(x) | |
| # x = x + self.attention(self.ln1(x)) | |
| # x = x + self.mlp(self.ln2(x)) | |
| return x | |
| class DecisionTransformer(nn.Module): | |
| """ | |
| Overview: | |
| The implementation of decision transformer. | |
| Interfaces: | |
| ``__init__``, ``forward``, ``configure_optimizers`` | |
| """ | |
| def __init__( | |
| self, | |
| state_dim: Union[int, SequenceType], | |
| act_dim: int, | |
| n_blocks: int, | |
| h_dim: int, | |
| context_len: int, | |
| n_heads: int, | |
| drop_p: float, | |
| max_timestep: int = 4096, | |
| state_encoder: Optional[nn.Module] = None, | |
| continuous: bool = False | |
| ): | |
| """ | |
| Overview: | |
| Initialize the DecisionTransformer Model according to input arguments. | |
| Arguments: | |
| - obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128 or (4, 84, 84). | |
| - act_dim (:obj:`int`): The dimension of actions, such as 6. | |
| - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. | |
| - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. | |
| - context_len (:obj:`int`): The max context length of the attention, such as 6. | |
| - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. | |
| - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. | |
| - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. | |
| - state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \ | |
| None, the raw state will be pushed into the transformer. | |
| - continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``. | |
| """ | |
| super().__init__() | |
| self.state_dim = state_dim | |
| self.act_dim = act_dim | |
| self.h_dim = h_dim | |
| # transformer blocks | |
| input_seq_len = 3 * context_len | |
| # projection heads (project to embedding) | |
| self.embed_ln = nn.LayerNorm(h_dim) | |
| self.embed_timestep = nn.Embedding(max_timestep, h_dim) | |
| self.drop = nn.Dropout(drop_p) | |
| self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) | |
| self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) | |
| if state_encoder is None: | |
| self.state_encoder = None | |
| blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] | |
| self.embed_rtg = torch.nn.Linear(1, h_dim) | |
| self.embed_state = torch.nn.Linear(state_dim, h_dim) | |
| self.predict_rtg = torch.nn.Linear(h_dim, 1) | |
| self.predict_state = torch.nn.Linear(h_dim, state_dim) | |
| if continuous: | |
| # continuous actions | |
| self.embed_action = torch.nn.Linear(act_dim, h_dim) | |
| use_action_tanh = True # True for continuous actions | |
| else: | |
| # discrete actions | |
| self.embed_action = torch.nn.Embedding(act_dim, h_dim) | |
| use_action_tanh = False # False for discrete actions | |
| self.predict_action = nn.Sequential( | |
| *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) | |
| ) | |
| else: | |
| blocks = [Block(h_dim, input_seq_len + 1, n_heads, drop_p) for _ in range(n_blocks)] | |
| self.state_encoder = state_encoder | |
| self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh()) | |
| self.head = nn.Linear(h_dim, act_dim, bias=False) | |
| self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh()) | |
| self.transformer = nn.Sequential(*blocks) | |
| def forward( | |
| self, | |
| timesteps: torch.Tensor, | |
| states: torch.Tensor, | |
| actions: torch.Tensor, | |
| returns_to_go: torch.Tensor, | |
| tar: Optional[int] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Overview: | |
| Forward computation graph of the decision transformer, input a sequence tensor \ | |
| and return a tensor with the same shape. | |
| Arguments: | |
| - timesteps (:obj:`torch.Tensor`): The timestep for input sequence. | |
| - states (:obj:`torch.Tensor`): The sequence of states. | |
| - actions (:obj:`torch.Tensor`): The sequence of actions. | |
| - returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go. | |
| - tar (:obj:`Optional[int]`): Whether to predict action, regardless of index. | |
| Returns: | |
| - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \ | |
| they are correspondingly the predicted states, predicted actions and predicted return-to-go. | |
| Examples: | |
| >>> B, T = 4, 6 | |
| >>> state_dim = 3 | |
| >>> act_dim = 2 | |
| >>> DT_model = DecisionTransformer(\ | |
| state_dim=state_dim,\ | |
| act_dim=act_dim,\ | |
| n_blocks=3,\ | |
| h_dim=8,\ | |
| context_len=T,\ | |
| n_heads=2,\ | |
| drop_p=0.1,\ | |
| ) | |
| >>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T | |
| >>> states = torch.randn([B, T, state_dim]) # B x T x state_dim | |
| >>> actions = torch.randint(0, act_dim, [B, T, 1]) | |
| >>> action_target = torch.randint(0, act_dim, [B, T, 1]) | |
| >>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float() | |
| >>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T | |
| >>> actions = actions.squeeze(-1) | |
| >>> state_preds, action_preds, return_preds = DT_model.forward(\ | |
| timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\ | |
| ) | |
| >>> assert state_preds.shape == torch.Size([B, T, state_dim]) | |
| >>> assert return_preds.shape == torch.Size([B, T, 1]) | |
| >>> assert action_preds.shape == torch.Size([B, T, act_dim]) | |
| """ | |
| B, T = states.shape[0], states.shape[1] | |
| if self.state_encoder is None: | |
| time_embeddings = self.embed_timestep(timesteps) | |
| # time embeddings are treated similar to positional embeddings | |
| state_embeddings = self.embed_state(states) + time_embeddings | |
| action_embeddings = self.embed_action(actions) + time_embeddings | |
| returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings | |
| # stack rtg, states and actions and reshape sequence as | |
| # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) | |
| t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), | |
| dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) | |
| h = self.embed_ln(t_p) | |
| # transformer and prediction | |
| h = self.transformer(h) | |
| # get h reshaped such that its size = (B x 3 x T x h_dim) and | |
| # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t | |
| # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t | |
| # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t | |
| # that is, for each timestep (t) we have 3 output embeddings from the transformer, | |
| # each conditioned on all previous timesteps plus | |
| # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. | |
| h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) | |
| return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a | |
| state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a | |
| action_preds = self.predict_action(h[:, 1]) # predict action given r, s | |
| else: | |
| state_embeddings = self.state_encoder( | |
| states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() | |
| ) # (batch * block_size, h_dim) | |
| state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim) | |
| returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) | |
| action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) | |
| token_embeddings = torch.zeros( | |
| (B, T * 3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device | |
| ) | |
| token_embeddings[:, ::3, :] = returns_embeddings | |
| token_embeddings[:, 1::3, :] = state_embeddings | |
| token_embeddings[:, 2::3, :] = action_embeddings[:, -T + int(tar is None):, :] | |
| all_global_pos_emb = torch.repeat_interleave( | |
| self.global_pos_emb, B, dim=0 | |
| ) # batch_size, traj_length, h_dim | |
| position_embeddings = torch.gather( | |
| all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1) | |
| ) + self.pos_emb[:, :token_embeddings.shape[1], :] | |
| t_p = token_embeddings + position_embeddings | |
| h = self.drop(t_p) | |
| h = self.transformer(h) | |
| h = self.embed_ln(h) | |
| logits = self.head(h) | |
| return_preds = None | |
| state_preds = None | |
| action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings | |
| return state_preds, action_preds, return_preds | |
| def configure_optimizers( | |
| self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95) | |
| ) -> torch.optim.Optimizer: | |
| """ | |
| Overview: | |
| This function returns an optimizer given the input arguments. \ | |
| We are separating out all parameters of the model into two buckets: those that will experience \ | |
| weight decay for regularization and those that won't (biases, and layernorm/embedding weights). | |
| Arguments: | |
| - weight_decay (:obj:`float`): The weigh decay of the optimizer. | |
| - learning_rate (:obj:`float`): The learning rate of the optimizer. | |
| - betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer. | |
| Outputs: | |
| - optimizer (:obj:`torch.optim.Optimizer`): The desired optimizer. | |
| """ | |
| # separate out all parameters to those that will and won't experience regularizing weight decay | |
| decay = set() | |
| no_decay = set() | |
| # whitelist_weight_modules = (torch.nn.Linear, ) | |
| whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) | |
| blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
| for mn, m in self.named_modules(): | |
| for pn, p in m.named_parameters(): | |
| fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
| if pn.endswith('bias'): | |
| # all biases will not be decayed | |
| no_decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): | |
| # weights of whitelist modules will be weight decayed | |
| decay.add(fpn) | |
| elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
| # weights of blacklist modules will NOT be weight decayed | |
| no_decay.add(fpn) | |
| # special case the position embedding parameter in the root GPT module as not decayed | |
| no_decay.add('pos_emb') | |
| no_decay.add('global_pos_emb') | |
| # validate that we considered every parameter | |
| param_dict = {pn: p for pn, p in self.named_parameters()} | |
| inter_params = decay & no_decay | |
| union_params = decay | no_decay | |
| assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) | |
| assert len(param_dict.keys() - union_params) == 0,\ | |
| "parameters %s were not separated into either decay/no_decay set!" \ | |
| % (str(param_dict.keys() - union_params), ) | |
| # create the pytorch optimizer object | |
| optim_groups = [ | |
| { | |
| "params": [param_dict[pn] for pn in sorted(list(decay))], | |
| "weight_decay": weight_decay | |
| }, | |
| { | |
| "params": [param_dict[pn] for pn in sorted(list(no_decay))], | |
| "weight_decay": 0.0 | |
| }, | |
| ] | |
| optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) | |
| return optimizer | |