HaWoR / lib /models /components /pose_transformer.py
ThunderVVV's picture
update
5f028d6
from inspect import isfunction
from typing import Callable, Optional
import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from torch import nn
from .t_cond_mlp import (
AdaptiveLayerNorm1D,
FrequencyEmbedder,
normalization_layer,
)
# from .vit import Attention, FeedForward
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: Callable, norm: str = "layer", norm_cond_dim: int = -1):
super().__init__()
self.norm = normalization_layer(norm, dim, norm_cond_dim)
self.fn = fn
def forward(self, x: torch.Tensor, *args, **kwargs):
if isinstance(self.norm, AdaptiveLayerNorm1D):
return self.fn(self.norm(x, *args), **kwargs)
else:
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(self, dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head**-0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(dropout)
context_dim = default(context_dim, dim)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_out = (
nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
if project_out
else nn.Identity()
)
def forward(self, x, context=None):
context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
q = self.to_q(x)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), [q, k, v])
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args):
for attn, ff in self.layers:
x = attn(x, *args) + x
x = ff(x, *args) + x
return x
class TransformerCrossAttn(nn.Module):
def __init__(
self,
dim: int,
depth: int,
heads: int,
dim_head: int,
mlp_dim: int,
dropout: float = 0.0,
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
sa = Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)
ca = CrossAttention(
dim, context_dim=context_dim, heads=heads, dim_head=dim_head, dropout=dropout
)
ff = FeedForward(dim, mlp_dim, dropout=dropout)
self.layers.append(
nn.ModuleList(
[
PreNorm(dim, sa, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ca, norm=norm, norm_cond_dim=norm_cond_dim),
PreNorm(dim, ff, norm=norm, norm_cond_dim=norm_cond_dim),
]
)
)
def forward(self, x: torch.Tensor, *args, context=None, context_list=None):
if context_list is None:
context_list = [context] * len(self.layers)
if len(context_list) != len(self.layers):
raise ValueError(f"len(context_list) != len(self.layers) ({len(context_list)} != {len(self.layers)})")
for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
x = self_attn(x, *args) + x
x = cross_attn(x, *args, context=context_list[i]) + x
x = ff(x, *args) + x
return x
class DropTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[0, :, 0], self.p).bernoulli().bool()
# TODO: permutation idx for each batch using torch.argsort
if zero_mask.any():
x = x[:, ~zero_mask, :]
return x
class ZeroTokenDropout(nn.Module):
def __init__(self, p: float = 0.1):
super().__init__()
if p < 0 or p > 1:
raise ValueError(
"dropout probability has to be between 0 and 1, " "but got {}".format(p)
)
self.p = p
def forward(self, x: torch.Tensor):
# x: (batch_size, seq_len, dim)
if self.training and self.p > 0:
zero_mask = torch.full_like(x[:, :, 0], self.p).bernoulli().bool()
# Zero-out the masked tokens
x[zero_mask, :] = 0
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = "drop",
emb_dropout_loc: str = "token",
norm: str = "layer",
norm_cond_dim: int = -1,
token_pe_numfreq: int = -1,
):
super().__init__()
if token_pe_numfreq > 0:
token_dim_new = token_dim * (2 * token_pe_numfreq + 1)
self.to_token_embedding = nn.Sequential(
Rearrange("b n d -> (b n) d", n=num_tokens, d=token_dim),
FrequencyEmbedder(token_pe_numfreq, token_pe_numfreq - 1),
Rearrange("(b n) d -> b n d", n=num_tokens, d=token_dim_new),
nn.Linear(token_dim_new, dim),
)
else:
self.to_token_embedding = nn.Linear(token_dim, dim)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
else:
raise ValueError(f"Unknown emb_dropout_type: {emb_dropout_type}")
self.emb_dropout_loc = emb_dropout_loc
self.transformer = Transformer(
dim, depth, heads, dim_head, mlp_dim, dropout, norm=norm, norm_cond_dim=norm_cond_dim
)
def forward(self, inp: torch.Tensor, *args, **kwargs):
x = inp
if self.emb_dropout_loc == "input":
x = self.dropout(x)
x = self.to_token_embedding(x)
if self.emb_dropout_loc == "token":
x = self.dropout(x)
b, n, _ = x.shape
x += self.pos_embedding[:, :n]
if self.emb_dropout_loc == "token_afterpos":
x = self.dropout(x)
x = self.transformer(x, *args)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
num_tokens: int,
token_dim: int,
dim: int,
depth: int,
heads: int,
mlp_dim: int,
dim_head: int = 64,
dropout: float = 0.0,
emb_dropout: float = 0.0,
emb_dropout_type: str = 'drop',
norm: str = "layer",
norm_cond_dim: int = -1,
context_dim: Optional[int] = None,
skip_token_embedding: bool = False,
):
super().__init__()
if not skip_token_embedding:
self.to_token_embedding = nn.Linear(token_dim, dim)
else:
self.to_token_embedding = nn.Identity()
if token_dim != dim:
raise ValueError(
f"token_dim ({token_dim}) != dim ({dim}) when skip_token_embedding is True"
)
self.pos_embedding = nn.Parameter(torch.randn(1, num_tokens, dim))
if emb_dropout_type == "drop":
self.dropout = DropTokenDropout(emb_dropout)
elif emb_dropout_type == "zero":
self.dropout = ZeroTokenDropout(emb_dropout)
elif emb_dropout_type == "normal":
self.dropout = nn.Dropout(emb_dropout)
self.transformer = TransformerCrossAttn(
dim,
depth,
heads,
dim_head,
mlp_dim,
dropout,
norm=norm,
norm_cond_dim=norm_cond_dim,
context_dim=context_dim,
)
def forward(self, inp: torch.Tensor, *args, context=None, context_list=None):
x = self.to_token_embedding(inp)
b, n, _ = x.shape
x = self.dropout(x)
x += self.pos_embedding[:, :n]
x = self.transformer(x, *args, context=context, context_list=context_list)
return x