SynTalker / models /denoiser.py
robinwitch's picture
update
1da48bb
import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils.layer import BasicBlock
from einops import rearrange
import pickle
from .timm_transformer.transformer import Block as mytimmBlock
class MDM(nn.Module):
def __init__(self, args):
super().__init__()
njoints=768
nfeats=1
latent_dim=512
ff_size=1024
num_layers=8
num_heads=4
dropout=0.1
ablation=None
activation="gelu"
legacy=False
data_rep='rot6d'
dataset='amass'
audio_feat_dim = 64
emb_trans_dec=False
audio_rep=''
n_seed=8
cond_mode=''
kargs={}
if args.vqvae_type == 'rvqvae':
njoints = 1536
elif args.vqvae_type == 'novqvae':
njoints = 312
self.args= args
self.legacy = legacy
self.njoints = njoints
self.nfeats = nfeats
self.data_rep = data_rep
self.latent_dim = latent_dim
self.ff_size = ff_size
self.num_layers = num_layers
self.num_heads = num_heads
self.dropout = dropout
self.ablation = ablation
self.activation = activation
self.action_emb = kargs.get('action_emb', None)
self.input_feats = self.njoints * self.nfeats
self.cond_mask_prob = kargs.get('cond_mask_prob', 0.3)
self.use_motionclip = args.use_motionclip
if args.audio_rep == 'onset+amplitude':
self.WavEncoder = WavEncoder(args.audio_f,audio_in=2)
self.audio_feat_dim = args.audio_f
self.text_encoder_body = nn.Linear(300, args.audio_f)
with open(f"{args.data_path}weights/vocab.pkl", 'rb') as f:
self.lang_model = pickle.load(f)
pre_trained_embedding = self.lang_model.word_embedding_weights
self.text_pre_encoder_body = nn.Embedding.from_pretrained(torch.FloatTensor(pre_trained_embedding),freeze=args.t_fix_pre)
self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout)
self.emb_trans_dec = emb_trans_dec
self.cond_mode = cond_mode
self.num_head = 8
self.mytimmblocks = nn.ModuleList([
mytimmBlock(dim=self.latent_dim,num_heads=self.num_heads,mlp_ratio=self.ff_size//self.latent_dim,drop_path=self.dropout) #hidden是对应于输入x的维度,attn_heads应该是12,这里写1是为了方便调试流程
for _ in range(self.num_layers)])
self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
self.n_seed = n_seed
self.style_dim = 64
self.embed_style = nn.Linear(6, self.style_dim)
self.embed_text = nn.Linear(self.input_feats*4, self.latent_dim)
self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
self.nfeats)
self.rel_pos = SinusoidalEmbeddings(self.latent_dim // self.num_head)
self.input_process = InputProcess(self.data_rep, self.input_feats , self.latent_dim)
self.input_process2 = nn.Linear(self.latent_dim * 2 + self.audio_feat_dim, self.latent_dim)
if self.use_motionclip:
self.input_process3 = nn.Linear(self.latent_dim + 512, self.latent_dim)
self.mix_audio_text = nn.Linear(args.audio_f+args.word_f,256)
def mask_cond(self, cond, force_mask=False):
bs, d = cond.shape
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_mask_prob > 0.:
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(bs, 1) # 1-> use null_cond, 0-> use real cond
return cond * (1. - mask)
else:
return cond
def mask_cond_audio(self, cond, force_mask=False):
bs, d = cond.shape
if force_mask:
return torch.zeros_like(cond)
elif self.training and self.cond_mask_prob_audio > 0.:
mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob_audio).view(bs, 1) # 1-> use null_cond, 0-> use real cond
return cond * (1. - mask)
else:
return cond
def forward(self, x, timesteps, y=None,uncond_info=False):
"""
x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
timesteps: [batch_size] (int)
seed: [batch_size, njoints, nfeats]
"""
_,_,_,noise_length = x.shape
y = y.copy()
bs, njoints, nfeats, nframes = x.shape # 300 ,1141, 1, 88
emb_t = self.embed_timestep(timesteps) # [1, bs, d], (1, 2, 256)
force_mask = y.get('uncond', False) # False
#force_mask=uncond_info
if self.n_seed != 0:
embed_text = self.embed_text(y['seed'].reshape(bs, -1)) # (bs, 256-64)
emb_seed = embed_text
audio_feat = self.WavEncoder(y['audio']).permute(1, 0, 2)
text_feat = self.text_pre_encoder_body(y['word'])
text_feat = self.text_encoder_body(text_feat).permute(1, 0, 2)
at_feat = torch.cat([audio_feat,text_feat],dim=2)
at_feat = self.mix_audio_text(at_feat)
at_feat = F.avg_pool1d(at_feat.permute(1,2,0), self.args.vqvae_squeeze_scale).permute(2,0,1)
# This part is test for timm transformer blocks
x = x.reshape(bs, njoints * nfeats, 1, nframes) # [300, 1141, 1, 88] -> [300, 1141, 1, 88]
# self-attention
x_ = self.input_process(x) # [300, 1141, 1, 88] -> [88, 300, 256]
# local-cross-attention
xseq = torch.cat((x_, at_feat), axis=2) # [88, 300, 256], [88, 300, 64] -> [88, 300, 320]
# all frames
embed_style_2 = (emb_seed + emb_t).repeat(nframes, 1, 1) # [300, 256] ,[1, 300, 256] -> [88, 300, 256]
xseq = torch.cat((embed_style_2, xseq), axis=2) # -> [88, 300, 576]
xseq = self.input_process2(xseq) #[88, 300, 576] -> [88, 300, 256]
if self.use_motionclip:
xseq = torch.cat((xseq, self.mask_cond(y['style_feature'],force_mask).unsqueeze(0).repeat(nframes, 1, 1)), axis = 2)
xseq = self.input_process3(xseq)
# 下面10行都是位置编码,感觉加了会好一点点,不知道是不是错觉
xseq = xseq.permute(1, 0, 2) # [88, 300, 256] -> [300, 88, 256]
xseq = xseq.view(bs, nframes, self.num_head, -1) # [300, 88, 256] -> [300, 88, 8, 32]
xseq = xseq.permute(0, 2, 1, 3) # [300, 88, 8, 32] -> [300, 8, 88, 32]
xseq = xseq.reshape(bs * self.num_head, nframes, -1) # [300, 8, 88, 32] -> [2400, 88, 32]
pos_emb = self.rel_pos(xseq) # (88, 32)
xseq, _ = apply_rotary_pos_emb(xseq, xseq, pos_emb) # [2400, 88, 32]
xseq_rpe = xseq.reshape(bs, self.num_head, nframes, -1) # [300, 8, 88, 32]
xseq = xseq_rpe.permute(0, 2, 1, 3) # [300, 8, 88, 32] -> [300, 88, 8, 32]
xseq = xseq.view(bs, nframes, -1) # [300, 88, 8, 32] -> [300, 88, 256]
for block in self.mytimmblocks:
xseq = block(xseq)
xseq = xseq.permute(1, 0, 2) # [300, 88, 256] -> [88 ,300, 256]
output = xseq
output = self.output_process(output) # [88, 300, 256] -> [300, 1141, 1, 88]
return output[...,:noise_length]
@staticmethod
def apply_rotary(x, sinusoidal_pos):
sin, cos = sinusoidal_pos
x1, x2 = x[..., 0::2], x[..., 1::2]
# 如果是旋转query key的话,下面这个直接cat就行,因为要进行矩阵乘法,最终会在这个维度求和。(只要保持query和key的最后一个dim的每一个位置对应上就可以)
# torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
# 如果是旋转value的话,下面这个stack后再flatten才可以,因为训练好的模型最后一个dim是两两之间交替的。
return torch.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1).flatten(-2, -1)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model) # (5000, 128)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # (5000, 1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# not used in the final model
x = x + self.pe[:x.shape[0], :]
return self.dropout(x)
class TimestepEmbedder(nn.Module):
def __init__(self, latent_dim, sequence_pos_encoder):
super().__init__()
self.latent_dim = latent_dim
self.sequence_pos_encoder = sequence_pos_encoder
time_embed_dim = self.latent_dim
self.time_embed = nn.Sequential(
nn.Linear(self.latent_dim, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim),
)
def forward(self, timesteps):
return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
class InputProcess(nn.Module):
def __init__(self, data_rep, input_feats, latent_dim):
super().__init__()
self.data_rep = data_rep
self.input_feats = input_feats
self.latent_dim = latent_dim
self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
if self.data_rep == 'rot_vel':
self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim)
def forward(self, x):
bs, njoints, nfeats, nframes = x.shape
x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats)
if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
x = self.poseEmbedding(x) # [seqlen, bs, d]
return x
elif self.data_rep == 'rot_vel':
first_pose = x[[0]] # [1, bs, 150]
first_pose = self.poseEmbedding(first_pose) # [1, bs, d]
vel = x[1:] # [seqlen-1, bs, 150]
vel = self.velEmbedding(vel) # [seqlen-1, bs, d]
return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d]
else:
raise ValueError
class OutputProcess(nn.Module):
def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats):
super().__init__()
self.data_rep = data_rep
self.input_feats = input_feats
self.latent_dim = latent_dim
self.njoints = njoints
self.nfeats = nfeats
self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
if self.data_rep == 'rot_vel':
self.velFinal = nn.Linear(self.latent_dim, self.input_feats)
def forward(self, output):
nframes, bs, d = output.shape
if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
output = self.poseFinal(output) # [88, 300, 256] -> [88, 300, 1141]
elif self.data_rep == 'rot_vel':
first_pose = output[[0]] # [1, bs, d]
first_pose = self.poseFinal(first_pose) # [1, bs, 150]
vel = output[1:] # [seqlen-1, bs, d]
vel = self.velFinal(vel) # [seqlen-1, bs, 150]
output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150]
else:
raise ValueError
output = output.reshape(nframes, bs, self.njoints, self.nfeats)
output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes]
return output
class WavEncoder(nn.Module):
def __init__(self, out_dim, audio_in=1):
super().__init__()
self.out_dim = out_dim
self.feat_extractor = nn.Sequential(
BasicBlock(audio_in, out_dim//4, 15, 5, first_dilation=1700, downsample=True),
BasicBlock(out_dim//4, out_dim//4, 15, 6, first_dilation=0, downsample=True),
BasicBlock(out_dim//4, out_dim//4, 15, 1, first_dilation=7, ),
BasicBlock(out_dim//4, out_dim//2, 15, 6, first_dilation=0, downsample=True),
BasicBlock(out_dim//2, out_dim//2, 15, 1, first_dilation=7),
BasicBlock(out_dim//2, out_dim, 15, 3, first_dilation=0,downsample=True),
)
def forward(self, wav_data):
if wav_data.dim() == 2:
wav_data = wav_data.unsqueeze(1)
else:
wav_data = wav_data.transpose(1, 2)
out = self.feat_extractor(wav_data)
return out.transpose(1, 2)
class SinusoidalEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n = x.shape[-2]
t = torch.arange(n, device = x.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x = rearrange(x, 'b ... (r d) -> b (...) r d', r = 2)
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)
def apply_rotary_pos_emb(q, k, freqs):
q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
return q, k
if __name__ == '__main__':
'''
cd ./main/model
python mdm.py
'''
n_frames = 240
n_seed = 8
model = MDM(modeltype='', njoints=1140, nfeats=1, cond_mode = 'cross_local_attention5_style1', action_emb='tensor', audio_rep='mfcc',
arch='mytrans_enc', latent_dim=256, n_seed=n_seed, cond_mask_prob=0.1)
x = torch.randn(2, 1140, 1, 88)
t = torch.tensor([12, 85])
model_kwargs_ = {'y': {}}
model_kwargs_['y']['mask'] = (torch.zeros([1, 1, 1, n_frames]) < 1) # [..., n_seed:]
model_kwargs_['y']['audio'] = torch.randn(2, 88, 13).permute(1, 0, 2) # [n_seed:, ...]
model_kwargs_['y']['style'] = torch.randn(2, 6)
model_kwargs_['y']['mask_local'] = torch.ones(2, 88).bool()
model_kwargs_['y']['seed'] = x[..., 0:n_seed]
y = model(x, t, model_kwargs_['y'])
print(y.shape)