Spaces:
Running
on
Zero
Running
on
Zero
import pdb | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .utils.layer import BasicBlock | |
from einops import rearrange | |
import pickle | |
from .timm_transformer.transformer import Block as mytimmBlock | |
class MDM(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
njoints=768 | |
nfeats=1 | |
latent_dim=512 | |
ff_size=1024 | |
num_layers=8 | |
num_heads=4 | |
dropout=0.1 | |
ablation=None | |
activation="gelu" | |
legacy=False | |
data_rep='rot6d' | |
dataset='amass' | |
audio_feat_dim = 64 | |
emb_trans_dec=False | |
audio_rep='' | |
n_seed=8 | |
cond_mode='' | |
kargs={} | |
if args.vqvae_type == 'rvqvae': | |
njoints = 1536 | |
elif args.vqvae_type == 'novqvae': | |
njoints = 312 | |
self.args= args | |
self.legacy = legacy | |
self.njoints = njoints | |
self.nfeats = nfeats | |
self.data_rep = data_rep | |
self.latent_dim = latent_dim | |
self.ff_size = ff_size | |
self.num_layers = num_layers | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.ablation = ablation | |
self.activation = activation | |
self.action_emb = kargs.get('action_emb', None) | |
self.input_feats = self.njoints * self.nfeats | |
self.cond_mask_prob = kargs.get('cond_mask_prob', 0.3) | |
self.use_motionclip = args.use_motionclip | |
if args.audio_rep == 'onset+amplitude': | |
self.WavEncoder = WavEncoder(args.audio_f,audio_in=2) | |
self.audio_feat_dim = args.audio_f | |
self.text_encoder_body = nn.Linear(300, args.audio_f) | |
with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f: | |
self.lang_model = pickle.load(f) | |
pre_trained_embedding = self.lang_model.word_embedding_weights | |
self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre) | |
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout) | |
self.emb_trans_dec = emb_trans_dec | |
self.cond_mode = cond_mode | |
self.num_head = 8 | |
self.mytimmblocks = nn.ModuleList([ | |
mytimmBlock(dim=self.latent_dim,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程 | |
for _ in range(self.num_layers)]) | |
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder) | |
self.n_seed = n_seed | |
self.style_dim = 64 | |
self.embed_style = nn.Linear(6, self.style_dim) | |
self.embed_text = nn.Linear(self.input_feats*4, self.latent_dim) | |
self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints, | |
self.nfeats) | |
self.rel_pos = SinusoidalEmbeddings(self.latent_dim // self.num_head) | |
self.input_process = InputProcess(self.data_rep, self.input_feats , self.latent_dim) | |
self.input_process2 = nn.Linear(self.latent_dim * 2 + self.audio_feat_dim, self.latent_dim) | |
if self.use_motionclip: | |
self.input_process3 = nn.Linear(self.latent_dim + 512, self.latent_dim) | |
self.mix_audio_text = nn.Linear(args.audio_f+args.word_f,256) | |
def mask_cond(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_mask_prob > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond | |
return cond * (1. - mask) | |
else: | |
return cond | |
def mask_cond_audio(self, cond, force_mask=False): | |
bs, d = cond.shape | |
if force_mask: | |
return torch.zeros_like(cond) | |
elif self.training and self.cond_mask_prob_audio > 0.: | |
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob_audio).view(bs, 1) # 1-> use null_cond, 0-> use real cond | |
return cond * (1. - mask) | |
else: | |
return cond | |
def forward(self, x, timesteps, y=None,uncond_info=False): | |
""" | |
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper | |
timesteps: [batch_size] (int) | |
seed: [batch_size, njoints, nfeats] | |
""" | |
_,_,_,noise_length = x.shape | |
y = y.copy() | |
bs, njoints, nfeats, nframes = x.shape # 300 ,1141, 1, 88 | |
emb_t = self.embed_timestep(timesteps) # [1, bs, d], (1, 2, 256) | |
force_mask = y.get('uncond', False) # False | |
#force_mask=uncond_info | |
if self.n_seed != 0: | |
embed_text = self.embed_text(y['seed'].reshape(bs, -1)) # (bs, 256-64) | |
emb_seed = embed_text | |
audio_feat = self.WavEncoder(y['audio']).permute(1, 0, 2) | |
text_feat = self.text_pre_encoder_body(y['word']) | |
text_feat = self.text_encoder_body(text_feat).permute(1, 0, 2) | |
at_feat = torch.cat([audio_feat,text_feat],dim=2) | |
at_feat = self.mix_audio_text(at_feat) | |
at_feat = F.avg_pool1d(at_feat.permute(1,2,0), self.args.vqvae_squeeze_scale).permute(2,0,1) | |
# This part is test for timm transformer blocks | |
x = x.reshape(bs, njoints * nfeats, 1, nframes) # [300, 1141, 1, 88] -> [300, 1141, 1, 88] | |
# self-attention | |
x_ = self.input_process(x) # [300, 1141, 1, 88] -> [88, 300, 256] | |
# local-cross-attention | |
xseq = torch.cat((x_, at_feat), axis=2) # [88, 300, 256], [88, 300, 64] -> [88, 300, 320] | |
# all frames | |
embed_style_2 = (emb_seed + emb_t).repeat(nframes, 1, 1) # [300, 256] ,[1, 300, 256] -> [88, 300, 256] | |
xseq = torch.cat((embed_style_2, xseq), axis=2) # -> [88, 300, 576] | |
xseq = self.input_process2(xseq) #[88, 300, 576] -> [88, 300, 256] | |
if self.use_motionclip: | |
xseq = torch.cat((xseq, self.mask_cond(y['style_feature'],force_mask).unsqueeze(0).repeat(nframes, 1, 1)), axis = 2) | |
xseq = self.input_process3(xseq) | |
# 下面10行都是位置编码,感觉加了会好一点点,不知道是不是错觉 | |
xseq = xseq.permute(1, 0, 2) # [88, 300, 256] -> [300, 88, 256] | |
xseq = xseq.view(bs, nframes, self.num_head, -1) # [300, 88, 256] -> [300, 88, 8, 32] | |
xseq = xseq.permute(0, 2, 1, 3) # [300, 88, 8, 32] -> [300, 8, 88, 32] | |
xseq = xseq.reshape(bs * self.num_head, nframes, -1) # [300, 8, 88, 32] -> [2400, 88, 32] | |
pos_emb = self.rel_pos(xseq) # (88, 32) | |
xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) # [2400, 88, 32] | |
xseq_rpe = xseq.reshape(bs, self.num_head, nframes, -1) # [300, 8, 88, 32] | |
xseq = xseq_rpe.permute(0, 2, 1, 3) # [300, 8, 88, 32] -> [300, 88, 8, 32] | |
xseq = xseq.view(bs, nframes, -1) # [300, 88, 8, 32] -> [300, 88, 256] | |
for block in self.mytimmblocks: | |
xseq = block(xseq) | |
xseq = xseq.permute(1, 0, 2) # [300, 88, 256] -> [88 ,300, 256] | |
output = xseq | |
output = self.output_process(output) # [88, 300, 256] -> [300, 1141, 1, 88] | |
return output[...,:noise_length] | |
def apply_rotary(x, sinusoidal_pos): | |
sin, cos = sinusoidal_pos | |
x1, x2 = x[..., 0::2], x[..., 1::2] | |
# 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以) | |
# torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) | |
# 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。 | |
return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1) | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) # (5000, 128) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0).transpose(0, 1) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
# not used in the final model | |
x = x + self.pe[:x.shape[0], :] | |
return self.dropout(x) | |
class TimestepEmbedder(nn.Module): | |
def __init__(self, latent_dim, sequence_pos_encoder): | |
super().__init__() | |
self.latent_dim = latent_dim | |
self.sequence_pos_encoder = sequence_pos_encoder | |
time_embed_dim = self.latent_dim | |
self.time_embed = nn.Sequential( | |
nn.Linear(self.latent_dim, time_embed_dim), | |
nn.SiLU(), | |
nn.Linear(time_embed_dim, time_embed_dim), | |
) | |
def forward(self, timesteps): | |
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2) | |
class InputProcess(nn.Module): | |
def __init__(self, data_rep, input_feats, latent_dim): | |
super().__init__() | |
self.data_rep = data_rep | |
self.input_feats = input_feats | |
self.latent_dim = latent_dim | |
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
if self.data_rep == 'rot_vel': | |
self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
def forward(self, x): | |
bs, njoints, nfeats, nframes = x.shape | |
x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats) | |
if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: | |
x = self.poseEmbedding(x) # [seqlen, bs, d] | |
return x | |
elif self.data_rep == 'rot_vel': | |
first_pose = x[[0]] # [1, bs, 150] | |
first_pose = self.poseEmbedding(first_pose) # [1, bs, d] | |
vel = x[1:] # [seqlen-1, bs, 150] | |
vel = self.velEmbedding(vel) # [seqlen-1, bs, d] | |
return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d] | |
else: | |
raise ValueError | |
class OutputProcess(nn.Module): | |
def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats): | |
super().__init__() | |
self.data_rep = data_rep | |
self.input_feats = input_feats | |
self.latent_dim = latent_dim | |
self.njoints = njoints | |
self.nfeats = nfeats | |
self.poseFinal = nn.Linear(self.latent_dim, self.input_feats) | |
if self.data_rep == 'rot_vel': | |
self.velFinal = nn.Linear(self.latent_dim, self.input_feats) | |
def forward(self, output): | |
nframes, bs, d = output.shape | |
if self.data_rep in ['rot6d', 'xyz', 'hml_vec']: | |
output = self.poseFinal(output) # [88, 300, 256] -> [88, 300, 1141] | |
elif self.data_rep == 'rot_vel': | |
first_pose = output[[0]] # [1, bs, d] | |
first_pose = self.poseFinal(first_pose) # [1, bs, 150] | |
vel = output[1:] # [seqlen-1, bs, d] | |
vel = self.velFinal(vel) # [seqlen-1, bs, 150] | |
output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150] | |
else: | |
raise ValueError | |
output = output.reshape(nframes, bs, self.njoints, self.nfeats) | |
output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes] | |
return output | |
class WavEncoder(nn.Module): | |
def __init__(self, out_dim, audio_in=1): | |
super().__init__() | |
self.out_dim = out_dim | |
self.feat_extractor = nn.Sequential( | |
BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1700, downsample=True), | |
BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True), | |
BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ), | |
BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True), | |
BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7), | |
BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True), | |
) | |
def forward(self, wav_data): | |
if wav_data.dim() == 2: | |
wav_data = wav_data.unsqueeze(1) | |
else: | |
wav_data = wav_data.transpose(1, 2) | |
out = self.feat_extractor(wav_data) | |
return out.transpose(1, 2) | |
class SinusoidalEmbeddings(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer('inv_freq', inv_freq) | |
def forward(self, x): | |
n = x.shape[-2] | |
t = torch.arange(n, device = x.device).type_as(self.inv_freq) | |
freqs = torch.einsum('i , j -> i j', t, self.inv_freq) | |
return torch.cat((freqs, freqs), dim=-1) | |
def rotate_half(x): | |
x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2) | |
x1, x2 = x.unbind(dim = -2) | |
return torch.cat((-x2, x1), dim = -1) | |
def apply_rotary_pos_emb(q, k, freqs): | |
q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k)) | |
return q, k | |
if __name__ == '__main__': | |
''' | |
cd ./main/model | |
python mdm.py | |
''' | |
n_frames = 240 | |
n_seed = 8 | |
model = MDM(modeltype='', njoints=1140, nfeats=1, cond_mode = 'cross_local_attention5_style1', action_emb='tensor', audio_rep='mfcc', | |
arch='mytrans_enc', latent_dim=256, n_seed=n_seed, cond_mask_prob=0.1) | |
x = torch.randn(2, 1140, 1, 88) | |
t = torch.tensor([12, 85]) | |
model_kwargs_ = {'y': {}} | |
model_kwargs_['y']['mask'] = (torch.zeros([1, 1, 1, n_frames]) < 1) # [..., n_seed:] | |
model_kwargs_['y']['audio'] = torch.randn(2, 88, 13).permute(1, 0, 2) # [n_seed:, ...] | |
model_kwargs_['y']['style'] = torch.randn(2, 6) | |
model_kwargs_['y']['mask_local'] = torch.ones(2, 88).bool() | |
model_kwargs_['y']['seed'] = x[..., 0:n_seed] | |
y = model(x, t, model_kwargs_['y']) | |
print(y.shape) | |