|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from mld.models.operator.attention import SkipTransformerEncoder, TransformerEncoderLayer |
|
from mld.models.operator.position_encoding import build_position_encoding |
|
|
|
|
|
class MldTrajEncoder(nn.Module): |
|
|
|
def __init__(self, |
|
nfeats: int, |
|
latent_dim: list = [1, 256], |
|
hidden_dim: Optional[int] = None, |
|
force_post_proj: bool = False, |
|
ff_size: int = 1024, |
|
num_layers: int = 9, |
|
num_heads: int = 4, |
|
dropout: float = 0.1, |
|
normalize_before: bool = False, |
|
norm_eps: float = 1e-5, |
|
activation: str = "gelu", |
|
norm_post: bool = True, |
|
activation_post: Optional[str] = None, |
|
position_embedding: str = "learned") -> None: |
|
super(MldTrajEncoder, self).__init__() |
|
|
|
self.latent_size = latent_dim[0] |
|
self.latent_dim = latent_dim[-1] if hidden_dim is None else hidden_dim |
|
add_post_proj = force_post_proj or (hidden_dim is not None and hidden_dim != latent_dim[-1]) |
|
self.latent_proj = nn.Linear(self.latent_dim, latent_dim[-1]) if add_post_proj else nn.Identity() |
|
|
|
self.skel_embedding = nn.Linear(nfeats * 3, self.latent_dim) |
|
|
|
self.query_pos_encoder = build_position_encoding( |
|
self.latent_dim, position_embedding=position_embedding) |
|
|
|
encoder_layer = TransformerEncoderLayer( |
|
self.latent_dim, |
|
num_heads, |
|
ff_size, |
|
dropout, |
|
activation, |
|
normalize_before, |
|
norm_eps |
|
) |
|
encoder_norm = nn.LayerNorm(self.latent_dim, eps=norm_eps) if norm_post else None |
|
self.encoder = SkipTransformerEncoder(encoder_layer, num_layers, encoder_norm, activation_post) |
|
self.global_motion_token = nn.Parameter(torch.randn(self.latent_size, self.latent_dim)) |
|
|
|
def forward(self, features: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: |
|
bs, nframes, nfeats = features.shape |
|
x = self.skel_embedding(features) |
|
x = x.permute(1, 0, 2) |
|
dist = torch.tile(self.global_motion_token[:, None, :], (1, bs, 1)) |
|
dist_masks = torch.ones((bs, dist.shape[0]), dtype=torch.bool, device=x.device) |
|
aug_mask = torch.cat((dist_masks, mask), 1) |
|
xseq = torch.cat((dist, x), 0) |
|
xseq = self.query_pos_encoder(xseq) |
|
global_token = self.encoder(xseq, src_key_padding_mask=~aug_mask)[0][:dist.shape[0]] |
|
global_token = self.latent_proj(global_token) |
|
global_token = global_token.permute(1, 0, 2) |
|
return global_token |
|
|