Upload 10 files
Browse files- utils/config.py +17 -0
- utils/eval_trans.py +598 -0
- utils/losses.py +30 -0
- utils/motion_process.py +59 -0
- utils/paramUtil.py +63 -0
- utils/quaternion.py +423 -0
- utils/rotation_conversions.py +532 -0
- utils/skeleton.py +199 -0
- utils/utils_model.py +66 -0
- utils/word_vectorizer.py +99 -0
utils/config.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
SMPL_DATA_PATH = "./body_models/smpl"
|
4 |
+
|
5 |
+
SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl")
|
6 |
+
SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl")
|
7 |
+
JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')
|
8 |
+
|
9 |
+
ROT_CONVENTION_TO_ROT_NUMBER = {
|
10 |
+
'legacy': 23,
|
11 |
+
'no_hands': 21,
|
12 |
+
'full_hands': 51,
|
13 |
+
'mitten_hands': 33,
|
14 |
+
}
|
15 |
+
|
16 |
+
GENDERS = ['neutral', 'male', 'female']
|
17 |
+
NUM_BETAS = 10
|
utils/eval_trans.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import clip
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from scipy import linalg
|
7 |
+
|
8 |
+
import visualization.plot_3d_global as plot_3d
|
9 |
+
from utils.motion_process import recover_from_ric
|
10 |
+
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
def tensorborad_add_video_xyz(writer, xyz, nb_iter, tag, nb_vis=4, title_batch=None, outname=None, fps=25):
|
16 |
+
# Validate that xyz has the expected dimensions
|
17 |
+
if xyz.ndimension() != 4 or xyz.shape[-1] != 3:
|
18 |
+
raise ValueError(f"Expected a 4D tensor with shape (batch_size, seq_length, n_joints, 3), but got {xyz.shape}")
|
19 |
+
|
20 |
+
# Convert tensor to NumPy array for drawing the batch
|
21 |
+
xyz_numpy = xyz.cpu().numpy()
|
22 |
+
|
23 |
+
# Generate animations using draw_to_batch
|
24 |
+
plot_xyz = plot_3d.draw_to_batch(xyz_numpy, title_batch, outname, fps=fps)
|
25 |
+
|
26 |
+
# Ensure the correct TensorBoard shape (batch, seq, channels, height, width)
|
27 |
+
plot_xyz = np.transpose(plot_xyz.numpy(), (0, 1, 4, 2, 3)) # Ensure TensorBoard compatibility
|
28 |
+
|
29 |
+
if plot_xyz.ndimension() != 5:
|
30 |
+
raise ValueError(f"Expected a 5D tensor with (batch_size, seq_length, channels, height, width), but got {plot_xyz.shape}")
|
31 |
+
|
32 |
+
# Add the video to TensorBoard
|
33 |
+
writer.add_video(tag, plot_xyz, nb_iter, fps=fps)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def evaluation_vqvae(out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) :
|
40 |
+
net.eval()
|
41 |
+
nb_sample = 0
|
42 |
+
|
43 |
+
draw_org = []
|
44 |
+
draw_pred = []
|
45 |
+
draw_text = []
|
46 |
+
|
47 |
+
|
48 |
+
motion_annotation_list = []
|
49 |
+
motion_pred_list = []
|
50 |
+
|
51 |
+
R_precision_real = 0
|
52 |
+
R_precision = 0
|
53 |
+
|
54 |
+
nb_sample = 0
|
55 |
+
matching_score_real = 0
|
56 |
+
matching_score_pred = 0
|
57 |
+
for batch in val_loader:
|
58 |
+
word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token, name = batch
|
59 |
+
|
60 |
+
motion = motion.cuda()
|
61 |
+
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
|
62 |
+
bs, seq = motion.shape[0], motion.shape[1]
|
63 |
+
|
64 |
+
num_joints = 21 if motion.shape[-1] == 251 else 22
|
65 |
+
|
66 |
+
pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
|
67 |
+
|
68 |
+
for i in range(bs):
|
69 |
+
pose = val_loader.dataset.inv_transform(motion[i:i+1, :m_length[i], :].detach().cpu().numpy())
|
70 |
+
pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
|
71 |
+
|
72 |
+
|
73 |
+
pred_pose, loss_commit, perplexity = net(motion[i:i+1, :m_length[i]])
|
74 |
+
pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
|
75 |
+
pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
|
76 |
+
|
77 |
+
if savenpy:
|
78 |
+
np.save(os.path.join(out_dir, name[i]+'_gt.npy'), pose_xyz[:, :m_length[i]].cpu().numpy())
|
79 |
+
np.save(os.path.join(out_dir, name[i]+'_pred.npy'), pred_xyz.detach().cpu().numpy())
|
80 |
+
|
81 |
+
pred_pose_eval[i:i+1,:m_length[i],:] = pred_pose
|
82 |
+
|
83 |
+
if i < min(4, bs):
|
84 |
+
draw_org.append(pose_xyz)
|
85 |
+
draw_pred.append(pred_xyz)
|
86 |
+
draw_text.append(caption[i])
|
87 |
+
|
88 |
+
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length)
|
89 |
+
|
90 |
+
motion_pred_list.append(em_pred)
|
91 |
+
motion_annotation_list.append(em)
|
92 |
+
|
93 |
+
temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
|
94 |
+
R_precision_real += temp_R
|
95 |
+
matching_score_real += temp_match
|
96 |
+
temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
|
97 |
+
R_precision += temp_R
|
98 |
+
matching_score_pred += temp_match
|
99 |
+
|
100 |
+
nb_sample += bs
|
101 |
+
|
102 |
+
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
|
103 |
+
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
|
104 |
+
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
|
105 |
+
mu, cov= calculate_activation_statistics(motion_pred_np)
|
106 |
+
|
107 |
+
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
|
108 |
+
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
|
109 |
+
|
110 |
+
R_precision_real = R_precision_real / nb_sample
|
111 |
+
R_precision = R_precision / nb_sample
|
112 |
+
|
113 |
+
matching_score_real = matching_score_real / nb_sample
|
114 |
+
matching_score_pred = matching_score_pred / nb_sample
|
115 |
+
|
116 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
117 |
+
|
118 |
+
msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
|
119 |
+
logger.info(msg)
|
120 |
+
|
121 |
+
if draw:
|
122 |
+
writer.add_scalar('./Test/FID', fid, nb_iter)
|
123 |
+
writer.add_scalar('./Test/Diversity', diversity, nb_iter)
|
124 |
+
writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
|
125 |
+
writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
|
126 |
+
writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
|
127 |
+
writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
|
128 |
+
|
129 |
+
|
130 |
+
#if nb_iter % 5000 == 0 :
|
131 |
+
#for ii in range(4):
|
132 |
+
#tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
|
133 |
+
|
134 |
+
#if nb_iter % 5000 == 0 :
|
135 |
+
#for ii in range(4):
|
136 |
+
#tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
|
137 |
+
|
138 |
+
|
139 |
+
if fid < best_fid :
|
140 |
+
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
|
141 |
+
logger.info(msg)
|
142 |
+
best_fid, best_iter = fid, nb_iter
|
143 |
+
if save:
|
144 |
+
torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
|
145 |
+
|
146 |
+
if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
|
147 |
+
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
|
148 |
+
logger.info(msg)
|
149 |
+
best_div = diversity
|
150 |
+
if save:
|
151 |
+
torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth'))
|
152 |
+
|
153 |
+
if R_precision[0] > best_top1 :
|
154 |
+
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
|
155 |
+
logger.info(msg)
|
156 |
+
best_top1 = R_precision[0]
|
157 |
+
if save:
|
158 |
+
torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_top1.pth'))
|
159 |
+
|
160 |
+
if R_precision[1] > best_top2 :
|
161 |
+
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
|
162 |
+
logger.info(msg)
|
163 |
+
best_top2 = R_precision[1]
|
164 |
+
|
165 |
+
if R_precision[2] > best_top3 :
|
166 |
+
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
|
167 |
+
logger.info(msg)
|
168 |
+
best_top3 = R_precision[2]
|
169 |
+
|
170 |
+
if matching_score_pred < best_matching :
|
171 |
+
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
|
172 |
+
logger.info(msg)
|
173 |
+
best_matching = matching_score_pred
|
174 |
+
if save:
|
175 |
+
torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_matching.pth'))
|
176 |
+
|
177 |
+
if save:
|
178 |
+
torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
|
179 |
+
|
180 |
+
net.train()
|
181 |
+
return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger
|
182 |
+
|
183 |
+
|
184 |
+
@torch.no_grad()
|
185 |
+
def evaluation_transformer(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model, eval_wrapper, draw = True, save = True, savegif=False) :
|
186 |
+
|
187 |
+
trans.eval()
|
188 |
+
nb_sample = 0
|
189 |
+
|
190 |
+
draw_org = []
|
191 |
+
draw_pred = []
|
192 |
+
draw_text = []
|
193 |
+
draw_text_pred = []
|
194 |
+
|
195 |
+
motion_annotation_list = []
|
196 |
+
motion_pred_list = []
|
197 |
+
R_precision_real = 0
|
198 |
+
R_precision = 0
|
199 |
+
matching_score_real = 0
|
200 |
+
matching_score_pred = 0
|
201 |
+
|
202 |
+
nb_sample = 0
|
203 |
+
for i in range(1):
|
204 |
+
for batch in val_loader:
|
205 |
+
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
|
206 |
+
|
207 |
+
bs, seq = pose.shape[:2]
|
208 |
+
num_joints = 21 if pose.shape[-1] == 251 else 22
|
209 |
+
|
210 |
+
text = clip.tokenize(clip_text, truncate=True).cuda()
|
211 |
+
|
212 |
+
feat_clip_text = clip_model.encode_text(text).float()
|
213 |
+
pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
|
214 |
+
pred_len = torch.ones(bs).long()
|
215 |
+
|
216 |
+
for k in range(bs):
|
217 |
+
try:
|
218 |
+
index_motion = trans.sample(feat_clip_text[k:k+1], False)
|
219 |
+
except:
|
220 |
+
index_motion = torch.ones(1,1).cuda().long()
|
221 |
+
|
222 |
+
pred_pose = net.forward_decoder(index_motion)
|
223 |
+
cur_len = pred_pose.shape[1]
|
224 |
+
|
225 |
+
pred_len[k] = min(cur_len, seq)
|
226 |
+
pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
|
227 |
+
|
228 |
+
if draw:
|
229 |
+
pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
|
230 |
+
pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
|
231 |
+
|
232 |
+
if i == 0 and k < 4:
|
233 |
+
draw_pred.append(pred_xyz)
|
234 |
+
draw_text_pred.append(clip_text[k])
|
235 |
+
|
236 |
+
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
|
237 |
+
|
238 |
+
if i == 0:
|
239 |
+
pose = pose.cuda().float()
|
240 |
+
|
241 |
+
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
|
242 |
+
motion_annotation_list.append(em)
|
243 |
+
motion_pred_list.append(em_pred)
|
244 |
+
|
245 |
+
if draw:
|
246 |
+
pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
|
247 |
+
pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
|
248 |
+
|
249 |
+
|
250 |
+
for j in range(min(4, bs)):
|
251 |
+
draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
|
252 |
+
draw_text.append(clip_text[j])
|
253 |
+
|
254 |
+
temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
|
255 |
+
R_precision_real += temp_R
|
256 |
+
matching_score_real += temp_match
|
257 |
+
temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
|
258 |
+
R_precision += temp_R
|
259 |
+
matching_score_pred += temp_match
|
260 |
+
|
261 |
+
nb_sample += bs
|
262 |
+
|
263 |
+
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
|
264 |
+
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
|
265 |
+
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
|
266 |
+
mu, cov= calculate_activation_statistics(motion_pred_np)
|
267 |
+
|
268 |
+
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
|
269 |
+
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
|
270 |
+
|
271 |
+
R_precision_real = R_precision_real / nb_sample
|
272 |
+
R_precision = R_precision / nb_sample
|
273 |
+
|
274 |
+
matching_score_real = matching_score_real / nb_sample
|
275 |
+
matching_score_pred = matching_score_pred / nb_sample
|
276 |
+
|
277 |
+
|
278 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
279 |
+
|
280 |
+
msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
|
281 |
+
logger.info(msg)
|
282 |
+
|
283 |
+
|
284 |
+
if draw:
|
285 |
+
writer.add_scalar('./Test/FID', fid, nb_iter)
|
286 |
+
writer.add_scalar('./Test/Diversity', diversity, nb_iter)
|
287 |
+
writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
|
288 |
+
writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
|
289 |
+
writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
|
290 |
+
writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
|
291 |
+
|
292 |
+
|
293 |
+
#if nb_iter % 10000 == 0 :
|
294 |
+
#for ii in range(4):
|
295 |
+
#tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
|
296 |
+
|
297 |
+
#if nb_iter % 10000 == 0 :
|
298 |
+
#for ii in range(4):
|
299 |
+
#tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
|
300 |
+
|
301 |
+
|
302 |
+
if fid < best_fid :
|
303 |
+
msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
|
304 |
+
logger.info(msg)
|
305 |
+
best_fid, best_iter = fid, nb_iter
|
306 |
+
if save:
|
307 |
+
torch.save({'trans' : trans.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
|
308 |
+
|
309 |
+
if matching_score_pred < best_matching :
|
310 |
+
msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
|
311 |
+
logger.info(msg)
|
312 |
+
best_matching = matching_score_pred
|
313 |
+
|
314 |
+
if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
|
315 |
+
msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
|
316 |
+
logger.info(msg)
|
317 |
+
best_div = diversity
|
318 |
+
|
319 |
+
if R_precision[0] > best_top1 :
|
320 |
+
msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
|
321 |
+
logger.info(msg)
|
322 |
+
best_top1 = R_precision[0]
|
323 |
+
|
324 |
+
if R_precision[1] > best_top2 :
|
325 |
+
msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
|
326 |
+
logger.info(msg)
|
327 |
+
best_top2 = R_precision[1]
|
328 |
+
|
329 |
+
if R_precision[2] > best_top3 :
|
330 |
+
msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
|
331 |
+
logger.info(msg)
|
332 |
+
best_top3 = R_precision[2]
|
333 |
+
|
334 |
+
if save:
|
335 |
+
torch.save({'trans' : trans.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
|
336 |
+
|
337 |
+
trans.train()
|
338 |
+
return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger
|
339 |
+
|
340 |
+
|
341 |
+
@torch.no_grad()
|
342 |
+
def evaluation_transformer_test(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, clip_model, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) :
|
343 |
+
|
344 |
+
trans.eval()
|
345 |
+
nb_sample = 0
|
346 |
+
|
347 |
+
draw_org = []
|
348 |
+
draw_pred = []
|
349 |
+
draw_text = []
|
350 |
+
draw_text_pred = []
|
351 |
+
draw_name = []
|
352 |
+
|
353 |
+
motion_annotation_list = []
|
354 |
+
motion_pred_list = []
|
355 |
+
motion_multimodality = []
|
356 |
+
R_precision_real = 0
|
357 |
+
R_precision = 0
|
358 |
+
matching_score_real = 0
|
359 |
+
matching_score_pred = 0
|
360 |
+
|
361 |
+
nb_sample = 0
|
362 |
+
|
363 |
+
for batch in val_loader:
|
364 |
+
|
365 |
+
word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
|
366 |
+
bs, seq = pose.shape[:2]
|
367 |
+
num_joints = 21 if pose.shape[-1] == 251 else 22
|
368 |
+
|
369 |
+
text = clip.tokenize(clip_text, truncate=True).cuda()
|
370 |
+
|
371 |
+
feat_clip_text = clip_model.encode_text(text).float()
|
372 |
+
motion_multimodality_batch = []
|
373 |
+
for i in range(30):
|
374 |
+
pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
|
375 |
+
pred_len = torch.ones(bs).long()
|
376 |
+
|
377 |
+
for k in range(bs):
|
378 |
+
try:
|
379 |
+
index_motion = trans.sample(feat_clip_text[k:k+1], True)
|
380 |
+
except:
|
381 |
+
index_motion = torch.ones(1,1).cuda().long()
|
382 |
+
|
383 |
+
pred_pose = net.forward_decoder(index_motion)
|
384 |
+
cur_len = pred_pose.shape[1]
|
385 |
+
|
386 |
+
pred_len[k] = min(cur_len, seq)
|
387 |
+
pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
|
388 |
+
|
389 |
+
if i == 0 and (draw or savenpy):
|
390 |
+
pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
|
391 |
+
pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
|
392 |
+
|
393 |
+
if savenpy:
|
394 |
+
np.save(os.path.join(out_dir, name[k]+'_pred.npy'), pred_xyz.detach().cpu().numpy())
|
395 |
+
|
396 |
+
if draw:
|
397 |
+
if i == 0:
|
398 |
+
draw_pred.append(pred_xyz)
|
399 |
+
draw_text_pred.append(clip_text[k])
|
400 |
+
draw_name.append(name[k])
|
401 |
+
|
402 |
+
et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
|
403 |
+
|
404 |
+
motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1))
|
405 |
+
|
406 |
+
if i == 0:
|
407 |
+
pose = pose.cuda().float()
|
408 |
+
|
409 |
+
et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
|
410 |
+
motion_annotation_list.append(em)
|
411 |
+
motion_pred_list.append(em_pred)
|
412 |
+
|
413 |
+
if draw or savenpy:
|
414 |
+
pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
|
415 |
+
pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
|
416 |
+
|
417 |
+
if savenpy:
|
418 |
+
for j in range(bs):
|
419 |
+
np.save(os.path.join(out_dir, name[j]+'_gt.npy'), pose_xyz[j][:m_length[j]].unsqueeze(0).cpu().numpy())
|
420 |
+
|
421 |
+
if draw:
|
422 |
+
for j in range(bs):
|
423 |
+
draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
|
424 |
+
draw_text.append(clip_text[j])
|
425 |
+
|
426 |
+
temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
|
427 |
+
R_precision_real += temp_R
|
428 |
+
matching_score_real += temp_match
|
429 |
+
temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
|
430 |
+
R_precision += temp_R
|
431 |
+
matching_score_pred += temp_match
|
432 |
+
|
433 |
+
nb_sample += bs
|
434 |
+
|
435 |
+
motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1))
|
436 |
+
|
437 |
+
motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
|
438 |
+
motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
|
439 |
+
gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
|
440 |
+
mu, cov= calculate_activation_statistics(motion_pred_np)
|
441 |
+
|
442 |
+
diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
|
443 |
+
diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
|
444 |
+
|
445 |
+
R_precision_real = R_precision_real / nb_sample
|
446 |
+
R_precision = R_precision / nb_sample
|
447 |
+
|
448 |
+
matching_score_real = matching_score_real / nb_sample
|
449 |
+
matching_score_pred = matching_score_pred / nb_sample
|
450 |
+
|
451 |
+
multimodality = 0
|
452 |
+
motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
|
453 |
+
multimodality = calculate_multimodality(motion_multimodality, 10)
|
454 |
+
|
455 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
456 |
+
|
457 |
+
msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}, multimodality. {multimodality:.4f}"
|
458 |
+
logger.info(msg)
|
459 |
+
|
460 |
+
|
461 |
+
#if draw:
|
462 |
+
#for ii in range(len(draw_org)):
|
463 |
+
#tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_org', nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_gt.gif')] if savegif else None)
|
464 |
+
|
465 |
+
#tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_pred', nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_pred.gif')] if savegif else None)
|
466 |
+
|
467 |
+
trans.train()
|
468 |
+
return fid, best_iter, diversity, R_precision[0], R_precision[1], R_precision[2], matching_score_pred, multimodality, writer, logger
|
469 |
+
|
470 |
+
# (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
|
471 |
+
def euclidean_distance_matrix(matrix1, matrix2):
|
472 |
+
"""
|
473 |
+
Params:
|
474 |
+
-- matrix1: N1 x D
|
475 |
+
-- matrix2: N2 x D
|
476 |
+
Returns:
|
477 |
+
-- dist: N1 x N2
|
478 |
+
dist[i, j] == distance(matrix1[i], matrix2[j])
|
479 |
+
"""
|
480 |
+
assert matrix1.shape[1] == matrix2.shape[1]
|
481 |
+
d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
|
482 |
+
d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
|
483 |
+
d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
|
484 |
+
dists = np.sqrt(d1 + d2 + d3) # broadcasting
|
485 |
+
return dists
|
486 |
+
|
487 |
+
|
488 |
+
|
489 |
+
def calculate_top_k(mat, top_k):
|
490 |
+
size = mat.shape[0]
|
491 |
+
gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
|
492 |
+
bool_mat = (mat == gt_mat)
|
493 |
+
correct_vec = False
|
494 |
+
top_k_list = []
|
495 |
+
for i in range(top_k):
|
496 |
+
# print(correct_vec, bool_mat[:, i])
|
497 |
+
correct_vec = (correct_vec | bool_mat[:, i])
|
498 |
+
# print(correct_vec)
|
499 |
+
top_k_list.append(correct_vec[:, None])
|
500 |
+
top_k_mat = np.concatenate(top_k_list, axis=1)
|
501 |
+
return top_k_mat
|
502 |
+
|
503 |
+
|
504 |
+
def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
|
505 |
+
dist_mat = euclidean_distance_matrix(embedding1, embedding2)
|
506 |
+
matching_score = dist_mat.trace()
|
507 |
+
argmax = np.argsort(dist_mat, axis=1)
|
508 |
+
top_k_mat = calculate_top_k(argmax, top_k)
|
509 |
+
if sum_all:
|
510 |
+
return top_k_mat.sum(axis=0), matching_score
|
511 |
+
else:
|
512 |
+
return top_k_mat, matching_score
|
513 |
+
|
514 |
+
def calculate_multimodality(activation, multimodality_times):
|
515 |
+
assert len(activation.shape) == 3
|
516 |
+
assert activation.shape[1] > multimodality_times
|
517 |
+
num_per_sent = activation.shape[1]
|
518 |
+
|
519 |
+
first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
|
520 |
+
second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
|
521 |
+
dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
|
522 |
+
return dist.mean()
|
523 |
+
|
524 |
+
|
525 |
+
def calculate_diversity(activation, diversity_times):
|
526 |
+
assert len(activation.shape) == 2
|
527 |
+
assert activation.shape[0] > diversity_times
|
528 |
+
num_samples = activation.shape[0]
|
529 |
+
|
530 |
+
first_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
531 |
+
second_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
532 |
+
dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
|
533 |
+
return dist.mean()
|
534 |
+
|
535 |
+
|
536 |
+
|
537 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
538 |
+
|
539 |
+
mu1 = np.atleast_1d(mu1)
|
540 |
+
mu2 = np.atleast_1d(mu2)
|
541 |
+
|
542 |
+
sigma1 = np.atleast_2d(sigma1)
|
543 |
+
sigma2 = np.atleast_2d(sigma2)
|
544 |
+
|
545 |
+
assert mu1.shape == mu2.shape, \
|
546 |
+
'Training and test mean vectors have different lengths'
|
547 |
+
assert sigma1.shape == sigma2.shape, \
|
548 |
+
'Training and test covariances have different dimensions'
|
549 |
+
|
550 |
+
diff = mu1 - mu2
|
551 |
+
|
552 |
+
# Product might be almost singular
|
553 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
554 |
+
if not np.isfinite(covmean).all():
|
555 |
+
msg = ('fid calculation produces singular product; '
|
556 |
+
'adding %s to diagonal of cov estimates') % eps
|
557 |
+
print(msg)
|
558 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
559 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
560 |
+
|
561 |
+
# Numerical error might give slight imaginary component
|
562 |
+
if np.iscomplexobj(covmean):
|
563 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
564 |
+
m = np.max(np.abs(covmean.imag))
|
565 |
+
raise ValueError('Imaginary component {}'.format(m))
|
566 |
+
covmean = covmean.real
|
567 |
+
|
568 |
+
tr_covmean = np.trace(covmean)
|
569 |
+
|
570 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
571 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
572 |
+
|
573 |
+
|
574 |
+
|
575 |
+
def calculate_activation_statistics(activations):
|
576 |
+
|
577 |
+
mu = np.mean(activations, axis=0)
|
578 |
+
cov = np.cov(activations, rowvar=False)
|
579 |
+
return mu, cov
|
580 |
+
|
581 |
+
|
582 |
+
def calculate_frechet_feature_distance(feature_list1, feature_list2):
|
583 |
+
feature_list1 = np.stack(feature_list1)
|
584 |
+
feature_list2 = np.stack(feature_list2)
|
585 |
+
|
586 |
+
# normalize the scale
|
587 |
+
mean = np.mean(feature_list1, axis=0)
|
588 |
+
std = np.std(feature_list1, axis=0) + 1e-10
|
589 |
+
feature_list1 = (feature_list1 - mean) / std
|
590 |
+
feature_list2 = (feature_list2 - mean) / std
|
591 |
+
|
592 |
+
dist = calculate_frechet_distance(
|
593 |
+
mu1=np.mean(feature_list1, axis=0),
|
594 |
+
sigma1=np.cov(feature_list1, rowvar=False),
|
595 |
+
mu2=np.mean(feature_list2, axis=0),
|
596 |
+
sigma2=np.cov(feature_list2, rowvar=False),
|
597 |
+
)
|
598 |
+
return dist
|
utils/losses.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class ReConsLoss(nn.Module):
|
5 |
+
def __init__(self, recons_loss, nb_joints):
|
6 |
+
super(ReConsLoss, self).__init__()
|
7 |
+
|
8 |
+
if recons_loss == 'l1':
|
9 |
+
self.Loss = torch.nn.L1Loss()
|
10 |
+
elif recons_loss == 'l2' :
|
11 |
+
self.Loss = torch.nn.MSELoss()
|
12 |
+
elif recons_loss == 'l1_smooth' :
|
13 |
+
self.Loss = torch.nn.SmoothL1Loss()
|
14 |
+
|
15 |
+
# 4 global motion associated to root
|
16 |
+
# 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d)
|
17 |
+
# 3 global vel xyz
|
18 |
+
# 4 foot contact
|
19 |
+
self.nb_joints = nb_joints
|
20 |
+
self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
|
21 |
+
|
22 |
+
def forward(self, motion_pred, motion_gt) :
|
23 |
+
loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
|
24 |
+
return loss
|
25 |
+
|
26 |
+
def forward_vel(self, motion_pred, motion_gt) :
|
27 |
+
loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
|
28 |
+
return loss
|
29 |
+
|
30 |
+
|
utils/motion_process.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils.quaternion import quaternion_to_cont6d, qrot, qinv
|
3 |
+
|
4 |
+
def recover_root_rot_pos(data):
|
5 |
+
rot_vel = data[..., 0]
|
6 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
7 |
+
'''Get Y-axis rotation from rotation velocity'''
|
8 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
9 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
10 |
+
|
11 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
12 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
13 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
14 |
+
|
15 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
16 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
17 |
+
'''Add Y-axis rotation to root position'''
|
18 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
19 |
+
|
20 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
21 |
+
|
22 |
+
r_pos[..., 1] = data[..., 3]
|
23 |
+
return r_rot_quat, r_pos
|
24 |
+
|
25 |
+
|
26 |
+
def recover_from_rot(data, joints_num, skeleton):
|
27 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
28 |
+
|
29 |
+
r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
|
30 |
+
|
31 |
+
start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
|
32 |
+
end_indx = start_indx + (joints_num - 1) * 6
|
33 |
+
cont6d_params = data[..., start_indx:end_indx]
|
34 |
+
# print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
|
35 |
+
cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
|
36 |
+
cont6d_params = cont6d_params.view(-1, joints_num, 6)
|
37 |
+
|
38 |
+
positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
|
39 |
+
|
40 |
+
return positions
|
41 |
+
|
42 |
+
|
43 |
+
def recover_from_ric(data, joints_num):
|
44 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
45 |
+
positions = data[..., 4:(joints_num - 1) * 3 + 4]
|
46 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
47 |
+
|
48 |
+
'''Add Y-axis rotation to local joints'''
|
49 |
+
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
|
50 |
+
|
51 |
+
'''Add root XZ to joints'''
|
52 |
+
positions[..., 0] += r_pos[..., 0:1]
|
53 |
+
positions[..., 2] += r_pos[..., 2:3]
|
54 |
+
|
55 |
+
'''Concate root and joints'''
|
56 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
57 |
+
|
58 |
+
return positions
|
59 |
+
|
utils/paramUtil.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# Define a kinematic tree for the skeletal struture
|
4 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
5 |
+
|
6 |
+
kit_raw_offsets = np.array(
|
7 |
+
[
|
8 |
+
[0, 0, 0],
|
9 |
+
[0, 1, 0],
|
10 |
+
[0, 1, 0],
|
11 |
+
[0, 1, 0],
|
12 |
+
[0, 1, 0],
|
13 |
+
[1, 0, 0],
|
14 |
+
[0, -1, 0],
|
15 |
+
[0, -1, 0],
|
16 |
+
[-1, 0, 0],
|
17 |
+
[0, -1, 0],
|
18 |
+
[0, -1, 0],
|
19 |
+
[1, 0, 0],
|
20 |
+
[0, -1, 0],
|
21 |
+
[0, -1, 0],
|
22 |
+
[0, 0, 1],
|
23 |
+
[0, 0, 1],
|
24 |
+
[-1, 0, 0],
|
25 |
+
[0, -1, 0],
|
26 |
+
[0, -1, 0],
|
27 |
+
[0, 0, 1],
|
28 |
+
[0, 0, 1]
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
t2m_raw_offsets = np.array([[0,0,0],
|
33 |
+
[1,0,0],
|
34 |
+
[-1,0,0],
|
35 |
+
[0,1,0],
|
36 |
+
[0,-1,0],
|
37 |
+
[0,-1,0],
|
38 |
+
[0,1,0],
|
39 |
+
[0,-1,0],
|
40 |
+
[0,-1,0],
|
41 |
+
[0,1,0],
|
42 |
+
[0,0,1],
|
43 |
+
[0,0,1],
|
44 |
+
[0,1,0],
|
45 |
+
[1,0,0],
|
46 |
+
[-1,0,0],
|
47 |
+
[0,0,1],
|
48 |
+
[0,-1,0],
|
49 |
+
[0,-1,0],
|
50 |
+
[0,-1,0],
|
51 |
+
[0,-1,0],
|
52 |
+
[0,-1,0],
|
53 |
+
[0,-1,0]])
|
54 |
+
|
55 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
|
56 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
57 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
58 |
+
|
59 |
+
|
60 |
+
kit_tgt_skel_id = '03950'
|
61 |
+
|
62 |
+
t2m_tgt_skel_id = '000021'
|
63 |
+
|
utils/quaternion.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
12 |
+
|
13 |
+
_FLOAT_EPS = np.finfo(float).eps
|
14 |
+
|
15 |
+
# PyTorch-backed implementations
|
16 |
+
def qinv(q):
|
17 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
18 |
+
mask = torch.ones_like(q)
|
19 |
+
mask[..., 1:] = -mask[..., 1:]
|
20 |
+
return q * mask
|
21 |
+
|
22 |
+
|
23 |
+
def qinv_np(q):
|
24 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
25 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
26 |
+
|
27 |
+
|
28 |
+
def qnormalize(q):
|
29 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
30 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
|
33 |
+
def qmul(q, r):
|
34 |
+
"""
|
35 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
36 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
37 |
+
Returns q*r as a tensor of shape (*, 4).
|
38 |
+
"""
|
39 |
+
assert q.shape[-1] == 4
|
40 |
+
assert r.shape[-1] == 4
|
41 |
+
|
42 |
+
original_shape = q.shape
|
43 |
+
|
44 |
+
# Compute outer product
|
45 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
46 |
+
|
47 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
48 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
49 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
50 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
51 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
52 |
+
|
53 |
+
|
54 |
+
def qrot(q, v):
|
55 |
+
"""
|
56 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
57 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
58 |
+
where * denotes any number of dimensions.
|
59 |
+
Returns a tensor of shape (*, 3).
|
60 |
+
"""
|
61 |
+
assert q.shape[-1] == 4
|
62 |
+
assert v.shape[-1] == 3
|
63 |
+
assert q.shape[:-1] == v.shape[:-1]
|
64 |
+
|
65 |
+
original_shape = list(v.shape)
|
66 |
+
# print(q.shape)
|
67 |
+
q = q.contiguous().view(-1, 4)
|
68 |
+
v = v.contiguous().view(-1, 3)
|
69 |
+
|
70 |
+
qvec = q[:, 1:]
|
71 |
+
uv = torch.cross(qvec, v, dim=1)
|
72 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
73 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
74 |
+
|
75 |
+
|
76 |
+
def qeuler(q, order, epsilon=0, deg=True):
|
77 |
+
"""
|
78 |
+
Convert quaternion(s) q to Euler angles.
|
79 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
80 |
+
Returns a tensor of shape (*, 3).
|
81 |
+
"""
|
82 |
+
assert q.shape[-1] == 4
|
83 |
+
|
84 |
+
original_shape = list(q.shape)
|
85 |
+
original_shape[-1] = 3
|
86 |
+
q = q.view(-1, 4)
|
87 |
+
|
88 |
+
q0 = q[:, 0]
|
89 |
+
q1 = q[:, 1]
|
90 |
+
q2 = q[:, 2]
|
91 |
+
q3 = q[:, 3]
|
92 |
+
|
93 |
+
if order == 'xyz':
|
94 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
95 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
96 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
97 |
+
elif order == 'yzx':
|
98 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
99 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
100 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
101 |
+
elif order == 'zxy':
|
102 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
103 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
104 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
105 |
+
elif order == 'xzy':
|
106 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
107 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
108 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
109 |
+
elif order == 'yxz':
|
110 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
111 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
112 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
113 |
+
elif order == 'zyx':
|
114 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
115 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
116 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
117 |
+
else:
|
118 |
+
raise
|
119 |
+
|
120 |
+
if deg:
|
121 |
+
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
|
122 |
+
else:
|
123 |
+
return torch.stack((x, y, z), dim=1).view(original_shape)
|
124 |
+
|
125 |
+
|
126 |
+
# Numpy-backed implementations
|
127 |
+
|
128 |
+
def qmul_np(q, r):
|
129 |
+
q = torch.from_numpy(q).contiguous().float()
|
130 |
+
r = torch.from_numpy(r).contiguous().float()
|
131 |
+
return qmul(q, r).numpy()
|
132 |
+
|
133 |
+
|
134 |
+
def qrot_np(q, v):
|
135 |
+
q = torch.from_numpy(q).contiguous().float()
|
136 |
+
v = torch.from_numpy(v).contiguous().float()
|
137 |
+
return qrot(q, v).numpy()
|
138 |
+
|
139 |
+
|
140 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
141 |
+
if use_gpu:
|
142 |
+
q = torch.from_numpy(q).cuda().float()
|
143 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
144 |
+
else:
|
145 |
+
q = torch.from_numpy(q).contiguous().float()
|
146 |
+
return qeuler(q, order, epsilon).numpy()
|
147 |
+
|
148 |
+
|
149 |
+
def qfix(q):
|
150 |
+
"""
|
151 |
+
Enforce quaternion continuity across the time dimension by selecting
|
152 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
153 |
+
between two consecutive frames.
|
154 |
+
|
155 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
156 |
+
Returns a tensor of the same shape.
|
157 |
+
"""
|
158 |
+
assert len(q.shape) == 3
|
159 |
+
assert q.shape[-1] == 4
|
160 |
+
|
161 |
+
result = q.copy()
|
162 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
163 |
+
mask = dot_products < 0
|
164 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
165 |
+
result[1:][mask] *= -1
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
def euler2quat(e, order, deg=True):
|
170 |
+
"""
|
171 |
+
Convert Euler angles to quaternions.
|
172 |
+
"""
|
173 |
+
assert e.shape[-1] == 3
|
174 |
+
|
175 |
+
original_shape = list(e.shape)
|
176 |
+
original_shape[-1] = 4
|
177 |
+
|
178 |
+
e = e.view(-1, 3)
|
179 |
+
|
180 |
+
## if euler angles in degrees
|
181 |
+
if deg:
|
182 |
+
e = e * np.pi / 180.
|
183 |
+
|
184 |
+
x = e[:, 0]
|
185 |
+
y = e[:, 1]
|
186 |
+
z = e[:, 2]
|
187 |
+
|
188 |
+
rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
|
189 |
+
ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
|
190 |
+
rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
|
191 |
+
|
192 |
+
result = None
|
193 |
+
for coord in order:
|
194 |
+
if coord == 'x':
|
195 |
+
r = rx
|
196 |
+
elif coord == 'y':
|
197 |
+
r = ry
|
198 |
+
elif coord == 'z':
|
199 |
+
r = rz
|
200 |
+
else:
|
201 |
+
raise
|
202 |
+
if result is None:
|
203 |
+
result = r
|
204 |
+
else:
|
205 |
+
result = qmul(result, r)
|
206 |
+
|
207 |
+
# Reverse antipodal representation to have a non-negative "w"
|
208 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
209 |
+
result *= -1
|
210 |
+
|
211 |
+
return result.view(original_shape)
|
212 |
+
|
213 |
+
|
214 |
+
def expmap_to_quaternion(e):
|
215 |
+
"""
|
216 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
217 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
218 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
219 |
+
Returns a tensor of shape (*, 4).
|
220 |
+
"""
|
221 |
+
assert e.shape[-1] == 3
|
222 |
+
|
223 |
+
original_shape = list(e.shape)
|
224 |
+
original_shape[-1] = 4
|
225 |
+
e = e.reshape(-1, 3)
|
226 |
+
|
227 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
228 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
229 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
230 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
231 |
+
|
232 |
+
|
233 |
+
def euler_to_quaternion(e, order):
|
234 |
+
"""
|
235 |
+
Convert Euler angles to quaternions.
|
236 |
+
"""
|
237 |
+
assert e.shape[-1] == 3
|
238 |
+
|
239 |
+
original_shape = list(e.shape)
|
240 |
+
original_shape[-1] = 4
|
241 |
+
|
242 |
+
e = e.reshape(-1, 3)
|
243 |
+
|
244 |
+
x = e[:, 0]
|
245 |
+
y = e[:, 1]
|
246 |
+
z = e[:, 2]
|
247 |
+
|
248 |
+
rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
|
249 |
+
ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
|
250 |
+
rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
|
251 |
+
|
252 |
+
result = None
|
253 |
+
for coord in order:
|
254 |
+
if coord == 'x':
|
255 |
+
r = rx
|
256 |
+
elif coord == 'y':
|
257 |
+
r = ry
|
258 |
+
elif coord == 'z':
|
259 |
+
r = rz
|
260 |
+
else:
|
261 |
+
raise
|
262 |
+
if result is None:
|
263 |
+
result = r
|
264 |
+
else:
|
265 |
+
result = qmul_np(result, r)
|
266 |
+
|
267 |
+
# Reverse antipodal representation to have a non-negative "w"
|
268 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
269 |
+
result *= -1
|
270 |
+
|
271 |
+
return result.reshape(original_shape)
|
272 |
+
|
273 |
+
|
274 |
+
def quaternion_to_matrix(quaternions):
|
275 |
+
"""
|
276 |
+
Convert rotations given as quaternions to rotation matrices.
|
277 |
+
Args:
|
278 |
+
quaternions: quaternions with real part first,
|
279 |
+
as tensor of shape (..., 4).
|
280 |
+
Returns:
|
281 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
282 |
+
"""
|
283 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
284 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
285 |
+
|
286 |
+
o = torch.stack(
|
287 |
+
(
|
288 |
+
1 - two_s * (j * j + k * k),
|
289 |
+
two_s * (i * j - k * r),
|
290 |
+
two_s * (i * k + j * r),
|
291 |
+
two_s * (i * j + k * r),
|
292 |
+
1 - two_s * (i * i + k * k),
|
293 |
+
two_s * (j * k - i * r),
|
294 |
+
two_s * (i * k - j * r),
|
295 |
+
two_s * (j * k + i * r),
|
296 |
+
1 - two_s * (i * i + j * j),
|
297 |
+
),
|
298 |
+
-1,
|
299 |
+
)
|
300 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
301 |
+
|
302 |
+
|
303 |
+
def quaternion_to_matrix_np(quaternions):
|
304 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
305 |
+
return quaternion_to_matrix(q).numpy()
|
306 |
+
|
307 |
+
|
308 |
+
def quaternion_to_cont6d_np(quaternions):
|
309 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
310 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
311 |
+
return cont_6d
|
312 |
+
|
313 |
+
|
314 |
+
def quaternion_to_cont6d(quaternions):
|
315 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
316 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
317 |
+
return cont_6d
|
318 |
+
|
319 |
+
|
320 |
+
def cont6d_to_matrix(cont6d):
|
321 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
322 |
+
x_raw = cont6d[..., 0:3]
|
323 |
+
y_raw = cont6d[..., 3:6]
|
324 |
+
|
325 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
326 |
+
z = torch.cross(x, y_raw, dim=-1)
|
327 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
328 |
+
|
329 |
+
y = torch.cross(z, x, dim=-1)
|
330 |
+
|
331 |
+
x = x[..., None]
|
332 |
+
y = y[..., None]
|
333 |
+
z = z[..., None]
|
334 |
+
|
335 |
+
mat = torch.cat([x, y, z], dim=-1)
|
336 |
+
return mat
|
337 |
+
|
338 |
+
|
339 |
+
def cont6d_to_matrix_np(cont6d):
|
340 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
341 |
+
return cont6d_to_matrix(q).numpy()
|
342 |
+
|
343 |
+
|
344 |
+
def qpow(q0, t, dtype=torch.float):
|
345 |
+
''' q0 : tensor of quaternions
|
346 |
+
t: tensor of powers
|
347 |
+
'''
|
348 |
+
q0 = qnormalize(q0)
|
349 |
+
theta0 = torch.acos(q0[..., 0])
|
350 |
+
|
351 |
+
## if theta0 is close to zero, add epsilon to avoid NaNs
|
352 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
353 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
354 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
355 |
+
|
356 |
+
if isinstance(t, torch.Tensor):
|
357 |
+
q = torch.zeros(t.shape + q0.shape)
|
358 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
359 |
+
else: ## if t is a number
|
360 |
+
q = torch.zeros(q0.shape)
|
361 |
+
theta = t * theta0
|
362 |
+
|
363 |
+
q[..., 0] = torch.cos(theta)
|
364 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
365 |
+
|
366 |
+
return q.to(dtype)
|
367 |
+
|
368 |
+
|
369 |
+
def qslerp(q0, q1, t):
|
370 |
+
'''
|
371 |
+
q0: starting quaternion
|
372 |
+
q1: ending quaternion
|
373 |
+
t: array of points along the way
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tensor of Slerps: t.shape + q0.shape
|
377 |
+
'''
|
378 |
+
|
379 |
+
q0 = qnormalize(q0)
|
380 |
+
q1 = qnormalize(q1)
|
381 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
382 |
+
|
383 |
+
return qmul(q_,
|
384 |
+
q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
|
385 |
+
|
386 |
+
|
387 |
+
def qbetween(v0, v1):
|
388 |
+
'''
|
389 |
+
find the quaternion used to rotate v0 to v1
|
390 |
+
'''
|
391 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
392 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
393 |
+
|
394 |
+
v = torch.cross(v0, v1)
|
395 |
+
w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
|
396 |
+
keepdim=True)
|
397 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
398 |
+
|
399 |
+
|
400 |
+
def qbetween_np(v0, v1):
|
401 |
+
'''
|
402 |
+
find the quaternion used to rotate v0 to v1
|
403 |
+
'''
|
404 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
405 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
406 |
+
|
407 |
+
v0 = torch.from_numpy(v0).float()
|
408 |
+
v1 = torch.from_numpy(v1).float()
|
409 |
+
return qbetween(v0, v1).numpy()
|
410 |
+
|
411 |
+
|
412 |
+
def lerp(p0, p1, t):
|
413 |
+
if not isinstance(t, torch.Tensor):
|
414 |
+
t = torch.Tensor([t])
|
415 |
+
|
416 |
+
new_shape = t.shape + p0.shape
|
417 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
418 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
419 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
420 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
421 |
+
t = t.view(new_view_t).expand(new_shape)
|
422 |
+
|
423 |
+
return p0 + t * (p1 - p0)
|
utils/rotation_conversions.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
2 |
+
# Check PYTORCH3D_LICENCE before use
|
3 |
+
|
4 |
+
import functools
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
The transformation matrices returned from the functions in this file assume
|
13 |
+
the points on which the transformation will be applied are column vectors.
|
14 |
+
i.e. the R matrix is structured as
|
15 |
+
R = [
|
16 |
+
[Rxx, Rxy, Rxz],
|
17 |
+
[Ryx, Ryy, Ryz],
|
18 |
+
[Rzx, Rzy, Rzz],
|
19 |
+
] # (3, 3)
|
20 |
+
This matrix can be applied to column vectors by post multiplication
|
21 |
+
by the points e.g.
|
22 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
23 |
+
transformed_points = R * points
|
24 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
25 |
+
can be transposed and pre multiplied by the points:
|
26 |
+
e.g.
|
27 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
28 |
+
transformed_points = points * R.transpose(1, 0)
|
29 |
+
"""
|
30 |
+
|
31 |
+
|
32 |
+
def quaternion_to_matrix(quaternions):
|
33 |
+
"""
|
34 |
+
Convert rotations given as quaternions to rotation matrices.
|
35 |
+
Args:
|
36 |
+
quaternions: quaternions with real part first,
|
37 |
+
as tensor of shape (..., 4).
|
38 |
+
Returns:
|
39 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
40 |
+
"""
|
41 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
42 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
43 |
+
|
44 |
+
o = torch.stack(
|
45 |
+
(
|
46 |
+
1 - two_s * (j * j + k * k),
|
47 |
+
two_s * (i * j - k * r),
|
48 |
+
two_s * (i * k + j * r),
|
49 |
+
two_s * (i * j + k * r),
|
50 |
+
1 - two_s * (i * i + k * k),
|
51 |
+
two_s * (j * k - i * r),
|
52 |
+
two_s * (i * k - j * r),
|
53 |
+
two_s * (j * k + i * r),
|
54 |
+
1 - two_s * (i * i + j * j),
|
55 |
+
),
|
56 |
+
-1,
|
57 |
+
)
|
58 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
59 |
+
|
60 |
+
|
61 |
+
def _copysign(a, b):
|
62 |
+
"""
|
63 |
+
Return a tensor where each element has the absolute value taken from the,
|
64 |
+
corresponding element of a, with sign taken from the corresponding
|
65 |
+
element of b. This is like the standard copysign floating-point operation,
|
66 |
+
but is not careful about negative 0 and NaN.
|
67 |
+
Args:
|
68 |
+
a: source tensor.
|
69 |
+
b: tensor whose signs will be used, of the same shape as a.
|
70 |
+
Returns:
|
71 |
+
Tensor of the same shape as a with the signs of b.
|
72 |
+
"""
|
73 |
+
signs_differ = (a < 0) != (b < 0)
|
74 |
+
return torch.where(signs_differ, -a, a)
|
75 |
+
|
76 |
+
|
77 |
+
def _sqrt_positive_part(x):
|
78 |
+
"""
|
79 |
+
Returns torch.sqrt(torch.max(0, x))
|
80 |
+
but with a zero subgradient where x is 0.
|
81 |
+
"""
|
82 |
+
ret = torch.zeros_like(x)
|
83 |
+
positive_mask = x > 0
|
84 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
85 |
+
return ret
|
86 |
+
|
87 |
+
|
88 |
+
def matrix_to_quaternion(matrix):
|
89 |
+
"""
|
90 |
+
Convert rotations given as rotation matrices to quaternions.
|
91 |
+
Args:
|
92 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
93 |
+
Returns:
|
94 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
95 |
+
"""
|
96 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
97 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
98 |
+
m00 = matrix[..., 0, 0]
|
99 |
+
m11 = matrix[..., 1, 1]
|
100 |
+
m22 = matrix[..., 2, 2]
|
101 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
102 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
103 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
104 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
105 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
106 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
107 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
108 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
109 |
+
|
110 |
+
|
111 |
+
def _axis_angle_rotation(axis: str, angle):
|
112 |
+
"""
|
113 |
+
Return the rotation matrices for one of the rotations about an axis
|
114 |
+
of which Euler angles describe, for each value of the angle given.
|
115 |
+
Args:
|
116 |
+
axis: Axis label "X" or "Y or "Z".
|
117 |
+
angle: any shape tensor of Euler angles in radians
|
118 |
+
Returns:
|
119 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
120 |
+
"""
|
121 |
+
|
122 |
+
cos = torch.cos(angle)
|
123 |
+
sin = torch.sin(angle)
|
124 |
+
one = torch.ones_like(angle)
|
125 |
+
zero = torch.zeros_like(angle)
|
126 |
+
|
127 |
+
if axis == "X":
|
128 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
129 |
+
if axis == "Y":
|
130 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
131 |
+
if axis == "Z":
|
132 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
133 |
+
|
134 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
135 |
+
|
136 |
+
|
137 |
+
def euler_angles_to_matrix(euler_angles, convention: str):
|
138 |
+
"""
|
139 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
140 |
+
Args:
|
141 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
142 |
+
convention: Convention string of three uppercase letters from
|
143 |
+
{"X", "Y", and "Z"}.
|
144 |
+
Returns:
|
145 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
146 |
+
"""
|
147 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
148 |
+
raise ValueError("Invalid input euler angles.")
|
149 |
+
if len(convention) != 3:
|
150 |
+
raise ValueError("Convention must have 3 letters.")
|
151 |
+
if convention[1] in (convention[0], convention[2]):
|
152 |
+
raise ValueError(f"Invalid convention {convention}.")
|
153 |
+
for letter in convention:
|
154 |
+
if letter not in ("X", "Y", "Z"):
|
155 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
156 |
+
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
157 |
+
return functools.reduce(torch.matmul, matrices)
|
158 |
+
|
159 |
+
|
160 |
+
def _angle_from_tan(
|
161 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
162 |
+
):
|
163 |
+
"""
|
164 |
+
Extract the first or third Euler angle from the two members of
|
165 |
+
the matrix which are positive constant times its sine and cosine.
|
166 |
+
Args:
|
167 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
168 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
169 |
+
convention.
|
170 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
171 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
172 |
+
which means the relevant entries are in the same row of the
|
173 |
+
rotation matrix. If not, they are in the same column.
|
174 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
175 |
+
Returns:
|
176 |
+
Euler Angles in radians for each matrix in data as a tensor
|
177 |
+
of shape (...).
|
178 |
+
"""
|
179 |
+
|
180 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
181 |
+
if horizontal:
|
182 |
+
i2, i1 = i1, i2
|
183 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
184 |
+
if horizontal == even:
|
185 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
186 |
+
if tait_bryan:
|
187 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
188 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
189 |
+
|
190 |
+
|
191 |
+
def _index_from_letter(letter: str):
|
192 |
+
if letter == "X":
|
193 |
+
return 0
|
194 |
+
if letter == "Y":
|
195 |
+
return 1
|
196 |
+
if letter == "Z":
|
197 |
+
return 2
|
198 |
+
|
199 |
+
|
200 |
+
def matrix_to_euler_angles(matrix, convention: str):
|
201 |
+
"""
|
202 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
203 |
+
Args:
|
204 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
205 |
+
convention: Convention string of three uppercase letters.
|
206 |
+
Returns:
|
207 |
+
Euler angles in radians as tensor of shape (..., 3).
|
208 |
+
"""
|
209 |
+
if len(convention) != 3:
|
210 |
+
raise ValueError("Convention must have 3 letters.")
|
211 |
+
if convention[1] in (convention[0], convention[2]):
|
212 |
+
raise ValueError(f"Invalid convention {convention}.")
|
213 |
+
for letter in convention:
|
214 |
+
if letter not in ("X", "Y", "Z"):
|
215 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
216 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
217 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
218 |
+
i0 = _index_from_letter(convention[0])
|
219 |
+
i2 = _index_from_letter(convention[2])
|
220 |
+
tait_bryan = i0 != i2
|
221 |
+
if tait_bryan:
|
222 |
+
central_angle = torch.asin(
|
223 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
224 |
+
)
|
225 |
+
else:
|
226 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
227 |
+
|
228 |
+
o = (
|
229 |
+
_angle_from_tan(
|
230 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
231 |
+
),
|
232 |
+
central_angle,
|
233 |
+
_angle_from_tan(
|
234 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
235 |
+
),
|
236 |
+
)
|
237 |
+
return torch.stack(o, -1)
|
238 |
+
|
239 |
+
|
240 |
+
def random_quaternions(
|
241 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
242 |
+
):
|
243 |
+
"""
|
244 |
+
Generate random quaternions representing rotations,
|
245 |
+
i.e. versors with nonnegative real part.
|
246 |
+
Args:
|
247 |
+
n: Number of quaternions in a batch to return.
|
248 |
+
dtype: Type to return.
|
249 |
+
device: Desired device of returned tensor. Default:
|
250 |
+
uses the current device for the default tensor type.
|
251 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
252 |
+
flag set.
|
253 |
+
Returns:
|
254 |
+
Quaternions as tensor of shape (N, 4).
|
255 |
+
"""
|
256 |
+
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
257 |
+
s = (o * o).sum(1)
|
258 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
259 |
+
return o
|
260 |
+
|
261 |
+
|
262 |
+
def random_rotations(
|
263 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
264 |
+
):
|
265 |
+
"""
|
266 |
+
Generate random rotations as 3x3 rotation matrices.
|
267 |
+
Args:
|
268 |
+
n: Number of rotation matrices in a batch to return.
|
269 |
+
dtype: Type to return.
|
270 |
+
device: Device of returned tensor. Default: if None,
|
271 |
+
uses the current device for the default tensor type.
|
272 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
273 |
+
flag set.
|
274 |
+
Returns:
|
275 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
276 |
+
"""
|
277 |
+
quaternions = random_quaternions(
|
278 |
+
n, dtype=dtype, device=device, requires_grad=requires_grad
|
279 |
+
)
|
280 |
+
return quaternion_to_matrix(quaternions)
|
281 |
+
|
282 |
+
|
283 |
+
def random_rotation(
|
284 |
+
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
285 |
+
):
|
286 |
+
"""
|
287 |
+
Generate a single random 3x3 rotation matrix.
|
288 |
+
Args:
|
289 |
+
dtype: Type to return
|
290 |
+
device: Device of returned tensor. Default: if None,
|
291 |
+
uses the current device for the default tensor type
|
292 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
293 |
+
flag set
|
294 |
+
Returns:
|
295 |
+
Rotation matrix as tensor of shape (3, 3).
|
296 |
+
"""
|
297 |
+
return random_rotations(1, dtype, device, requires_grad)[0]
|
298 |
+
|
299 |
+
|
300 |
+
def standardize_quaternion(quaternions):
|
301 |
+
"""
|
302 |
+
Convert a unit quaternion to a standard form: one in which the real
|
303 |
+
part is non negative.
|
304 |
+
Args:
|
305 |
+
quaternions: Quaternions with real part first,
|
306 |
+
as tensor of shape (..., 4).
|
307 |
+
Returns:
|
308 |
+
Standardized quaternions as tensor of shape (..., 4).
|
309 |
+
"""
|
310 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
311 |
+
|
312 |
+
|
313 |
+
def quaternion_raw_multiply(a, b):
|
314 |
+
"""
|
315 |
+
Multiply two quaternions.
|
316 |
+
Usual torch rules for broadcasting apply.
|
317 |
+
Args:
|
318 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
319 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
320 |
+
Returns:
|
321 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
322 |
+
"""
|
323 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
324 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
325 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
326 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
327 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
328 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
329 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
330 |
+
|
331 |
+
|
332 |
+
def quaternion_multiply(a, b):
|
333 |
+
"""
|
334 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
335 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
336 |
+
Usual torch rules for broadcasting apply.
|
337 |
+
Args:
|
338 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
339 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
340 |
+
Returns:
|
341 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
342 |
+
"""
|
343 |
+
ab = quaternion_raw_multiply(a, b)
|
344 |
+
return standardize_quaternion(ab)
|
345 |
+
|
346 |
+
|
347 |
+
def quaternion_invert(quaternion):
|
348 |
+
"""
|
349 |
+
Given a quaternion representing rotation, get the quaternion representing
|
350 |
+
its inverse.
|
351 |
+
Args:
|
352 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
353 |
+
first, which must be versors (unit quaternions).
|
354 |
+
Returns:
|
355 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
356 |
+
"""
|
357 |
+
|
358 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
359 |
+
|
360 |
+
|
361 |
+
def quaternion_apply(quaternion, point):
|
362 |
+
"""
|
363 |
+
Apply the rotation given by a quaternion to a 3D point.
|
364 |
+
Usual torch rules for broadcasting apply.
|
365 |
+
Args:
|
366 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
367 |
+
point: Tensor of 3D points of shape (..., 3).
|
368 |
+
Returns:
|
369 |
+
Tensor of rotated points of shape (..., 3).
|
370 |
+
"""
|
371 |
+
if point.size(-1) != 3:
|
372 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
373 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
374 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
375 |
+
out = quaternion_raw_multiply(
|
376 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
377 |
+
quaternion_invert(quaternion),
|
378 |
+
)
|
379 |
+
return out[..., 1:]
|
380 |
+
|
381 |
+
|
382 |
+
def axis_angle_to_matrix(axis_angle):
|
383 |
+
"""
|
384 |
+
Convert rotations given as axis/angle to rotation matrices.
|
385 |
+
Args:
|
386 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
387 |
+
as a tensor of shape (..., 3), where the magnitude is
|
388 |
+
the angle turned anticlockwise in radians around the
|
389 |
+
vector's direction.
|
390 |
+
Returns:
|
391 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
392 |
+
"""
|
393 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
394 |
+
|
395 |
+
|
396 |
+
def matrix_to_axis_angle(matrix):
|
397 |
+
"""
|
398 |
+
Convert rotations given as rotation matrices to axis/angle.
|
399 |
+
Args:
|
400 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
401 |
+
Returns:
|
402 |
+
Rotations given as a vector in axis angle form, as a tensor
|
403 |
+
of shape (..., 3), where the magnitude is the angle
|
404 |
+
turned anticlockwise in radians around the vector's
|
405 |
+
direction.
|
406 |
+
"""
|
407 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
408 |
+
|
409 |
+
|
410 |
+
def axis_angle_to_quaternion(axis_angle):
|
411 |
+
"""
|
412 |
+
Convert rotations given as axis/angle to quaternions.
|
413 |
+
Args:
|
414 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
415 |
+
as a tensor of shape (..., 3), where the magnitude is
|
416 |
+
the angle turned anticlockwise in radians around the
|
417 |
+
vector's direction.
|
418 |
+
Returns:
|
419 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
420 |
+
"""
|
421 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
422 |
+
half_angles = 0.5 * angles
|
423 |
+
eps = 1e-6
|
424 |
+
small_angles = angles.abs() < eps
|
425 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
426 |
+
sin_half_angles_over_angles[~small_angles] = (
|
427 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
428 |
+
)
|
429 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
430 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
431 |
+
sin_half_angles_over_angles[small_angles] = (
|
432 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
433 |
+
)
|
434 |
+
quaternions = torch.cat(
|
435 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
436 |
+
)
|
437 |
+
return quaternions
|
438 |
+
|
439 |
+
|
440 |
+
def quaternion_to_axis_angle(quaternions):
|
441 |
+
"""
|
442 |
+
Convert rotations given as quaternions to axis/angle.
|
443 |
+
Args:
|
444 |
+
quaternions: quaternions with real part first,
|
445 |
+
as tensor of shape (..., 4).
|
446 |
+
Returns:
|
447 |
+
Rotations given as a vector in axis angle form, as a tensor
|
448 |
+
of shape (..., 3), where the magnitude is the angle
|
449 |
+
turned anticlockwise in radians around the vector's
|
450 |
+
direction.
|
451 |
+
"""
|
452 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
453 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
454 |
+
angles = 2 * half_angles
|
455 |
+
eps = 1e-6
|
456 |
+
small_angles = angles.abs() < eps
|
457 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
458 |
+
sin_half_angles_over_angles[~small_angles] = (
|
459 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
460 |
+
)
|
461 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
462 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
463 |
+
sin_half_angles_over_angles[small_angles] = (
|
464 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
465 |
+
)
|
466 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
467 |
+
|
468 |
+
|
469 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
470 |
+
"""
|
471 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
472 |
+
using Gram--Schmidt orthogonalisation per Section B of [1].
|
473 |
+
Args:
|
474 |
+
d6: 6D rotation representation, of size (*, 6)
|
475 |
+
Returns:
|
476 |
+
batch of rotation matrices of size (*, 3, 3)
|
477 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
478 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
479 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
480 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
481 |
+
"""
|
482 |
+
|
483 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
484 |
+
b1 = F.normalize(a1, dim=-1)
|
485 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
486 |
+
b2 = F.normalize(b2, dim=-1)
|
487 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
488 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
489 |
+
|
490 |
+
|
491 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
492 |
+
"""
|
493 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
494 |
+
by dropping the last row. Note that 6D representation is not unique.
|
495 |
+
Args:
|
496 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
497 |
+
Returns:
|
498 |
+
6D rotation representation, of size (*, 6)
|
499 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
500 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
501 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
502 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
503 |
+
"""
|
504 |
+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
505 |
+
|
506 |
+
def canonicalize_smplh(poses, trans = None):
|
507 |
+
bs, nframes, njoints = poses.shape[:3]
|
508 |
+
|
509 |
+
global_orient = poses[:, :, 0]
|
510 |
+
|
511 |
+
# first global rotations
|
512 |
+
rot2d = matrix_to_axis_angle(global_orient[:, 0])
|
513 |
+
#rot2d[:, :2] = 0 # Remove the rotation along the vertical axis
|
514 |
+
rot2d = axis_angle_to_matrix(rot2d)
|
515 |
+
|
516 |
+
# Rotate the global rotation to eliminate Z rotations
|
517 |
+
global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient)
|
518 |
+
|
519 |
+
# Construct canonicalized version of x
|
520 |
+
xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2)
|
521 |
+
|
522 |
+
if trans is not None:
|
523 |
+
vel = trans[:, 1:] - trans[:, :-1]
|
524 |
+
# Turn the translation as well
|
525 |
+
vel = torch.einsum("ikj,ilk->ilj", rot2d, vel)
|
526 |
+
trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device),
|
527 |
+
torch.cumsum(vel, 1)), 1)
|
528 |
+
return xc, trans
|
529 |
+
else:
|
530 |
+
return xc
|
531 |
+
|
532 |
+
|
utils/skeleton.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.quaternion import *
|
2 |
+
import scipy.ndimage.filters as filters
|
3 |
+
|
4 |
+
class Skeleton(object):
|
5 |
+
def __init__(self, offset, kinematic_tree, device):
|
6 |
+
self.device = device
|
7 |
+
self._raw_offset_np = offset.numpy()
|
8 |
+
self._raw_offset = offset.clone().detach().to(device).float()
|
9 |
+
self._kinematic_tree = kinematic_tree
|
10 |
+
self._offset = None
|
11 |
+
self._parents = [0] * len(self._raw_offset)
|
12 |
+
self._parents[0] = -1
|
13 |
+
for chain in self._kinematic_tree:
|
14 |
+
for j in range(1, len(chain)):
|
15 |
+
self._parents[chain[j]] = chain[j-1]
|
16 |
+
|
17 |
+
def njoints(self):
|
18 |
+
return len(self._raw_offset)
|
19 |
+
|
20 |
+
def offset(self):
|
21 |
+
return self._offset
|
22 |
+
|
23 |
+
def set_offset(self, offsets):
|
24 |
+
self._offset = offsets.clone().detach().to(self.device).float()
|
25 |
+
|
26 |
+
def kinematic_tree(self):
|
27 |
+
return self._kinematic_tree
|
28 |
+
|
29 |
+
def parents(self):
|
30 |
+
return self._parents
|
31 |
+
|
32 |
+
# joints (batch_size, joints_num, 3)
|
33 |
+
def get_offsets_joints_batch(self, joints):
|
34 |
+
assert len(joints.shape) == 3
|
35 |
+
_offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
|
36 |
+
for i in range(1, self._raw_offset.shape[0]):
|
37 |
+
_offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
|
38 |
+
|
39 |
+
self._offset = _offsets.detach()
|
40 |
+
return _offsets
|
41 |
+
|
42 |
+
# joints (joints_num, 3)
|
43 |
+
def get_offsets_joints(self, joints):
|
44 |
+
assert len(joints.shape) == 2
|
45 |
+
_offsets = self._raw_offset.clone()
|
46 |
+
for i in range(1, self._raw_offset.shape[0]):
|
47 |
+
# print(joints.shape)
|
48 |
+
_offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
|
49 |
+
|
50 |
+
self._offset = _offsets.detach()
|
51 |
+
return _offsets
|
52 |
+
|
53 |
+
# face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
|
54 |
+
# joints (batch_size, joints_num, 3)
|
55 |
+
def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
|
56 |
+
assert len(face_joint_idx) == 4
|
57 |
+
'''Get Forward Direction'''
|
58 |
+
l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
|
59 |
+
across1 = joints[:, r_hip] - joints[:, l_hip]
|
60 |
+
across2 = joints[:, sdr_r] - joints[:, sdr_l]
|
61 |
+
across = across1 + across2
|
62 |
+
across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
|
63 |
+
# print(across1.shape, across2.shape)
|
64 |
+
|
65 |
+
# forward (batch_size, 3)
|
66 |
+
forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
67 |
+
if smooth_forward:
|
68 |
+
forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
|
69 |
+
# forward (batch_size, 3)
|
70 |
+
forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
|
71 |
+
|
72 |
+
'''Get Root Rotation'''
|
73 |
+
target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
|
74 |
+
root_quat = qbetween_np(forward, target)
|
75 |
+
|
76 |
+
'''Inverse Kinematics'''
|
77 |
+
# quat_params (batch_size, joints_num, 4)
|
78 |
+
# print(joints.shape[:-1])
|
79 |
+
quat_params = np.zeros(joints.shape[:-1] + (4,))
|
80 |
+
# print(quat_params.shape)
|
81 |
+
root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
82 |
+
quat_params[:, 0] = root_quat
|
83 |
+
# quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
84 |
+
for chain in self._kinematic_tree:
|
85 |
+
R = root_quat
|
86 |
+
for j in range(len(chain) - 1):
|
87 |
+
# (batch, 3)
|
88 |
+
u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
|
89 |
+
# print(u.shape)
|
90 |
+
# (batch, 3)
|
91 |
+
v = joints[:, chain[j+1]] - joints[:, chain[j]]
|
92 |
+
v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
|
93 |
+
# print(u.shape, v.shape)
|
94 |
+
rot_u_v = qbetween_np(u, v)
|
95 |
+
|
96 |
+
R_loc = qmul_np(qinv_np(R), rot_u_v)
|
97 |
+
|
98 |
+
quat_params[:,chain[j + 1], :] = R_loc
|
99 |
+
R = qmul_np(R, R_loc)
|
100 |
+
|
101 |
+
return quat_params
|
102 |
+
|
103 |
+
# Be sure root joint is at the beginning of kinematic chains
|
104 |
+
def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
105 |
+
# quat_params (batch_size, joints_num, 4)
|
106 |
+
# joints (batch_size, joints_num, 3)
|
107 |
+
# root_pos (batch_size, 3)
|
108 |
+
if skel_joints is not None:
|
109 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
110 |
+
if len(self._offset.shape) == 2:
|
111 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
112 |
+
joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
|
113 |
+
joints[:, 0] = root_pos
|
114 |
+
for chain in self._kinematic_tree:
|
115 |
+
if do_root_R:
|
116 |
+
R = quat_params[:, 0]
|
117 |
+
else:
|
118 |
+
R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
|
119 |
+
for i in range(1, len(chain)):
|
120 |
+
R = qmul(R, quat_params[:, chain[i]])
|
121 |
+
offset_vec = offsets[:, chain[i]]
|
122 |
+
joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
|
123 |
+
return joints
|
124 |
+
|
125 |
+
# Be sure root joint is at the beginning of kinematic chains
|
126 |
+
def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
127 |
+
# quat_params (batch_size, joints_num, 4)
|
128 |
+
# joints (batch_size, joints_num, 3)
|
129 |
+
# root_pos (batch_size, 3)
|
130 |
+
if skel_joints is not None:
|
131 |
+
skel_joints = torch.from_numpy(skel_joints)
|
132 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
133 |
+
if len(self._offset.shape) == 2:
|
134 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
135 |
+
offsets = offsets.numpy()
|
136 |
+
joints = np.zeros(quat_params.shape[:-1] + (3,))
|
137 |
+
joints[:, 0] = root_pos
|
138 |
+
for chain in self._kinematic_tree:
|
139 |
+
if do_root_R:
|
140 |
+
R = quat_params[:, 0]
|
141 |
+
else:
|
142 |
+
R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
|
143 |
+
for i in range(1, len(chain)):
|
144 |
+
R = qmul_np(R, quat_params[:, chain[i]])
|
145 |
+
offset_vec = offsets[:, chain[i]]
|
146 |
+
joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
|
147 |
+
return joints
|
148 |
+
|
149 |
+
def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
150 |
+
# cont6d_params (batch_size, joints_num, 6)
|
151 |
+
# joints (batch_size, joints_num, 3)
|
152 |
+
# root_pos (batch_size, 3)
|
153 |
+
if skel_joints is not None:
|
154 |
+
skel_joints = torch.from_numpy(skel_joints)
|
155 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
156 |
+
if len(self._offset.shape) == 2:
|
157 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
158 |
+
offsets = offsets.numpy()
|
159 |
+
joints = np.zeros(cont6d_params.shape[:-1] + (3,))
|
160 |
+
joints[:, 0] = root_pos
|
161 |
+
for chain in self._kinematic_tree:
|
162 |
+
if do_root_R:
|
163 |
+
matR = cont6d_to_matrix_np(cont6d_params[:, 0])
|
164 |
+
else:
|
165 |
+
matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
|
166 |
+
for i in range(1, len(chain)):
|
167 |
+
matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
|
168 |
+
offset_vec = offsets[:, chain[i]][..., np.newaxis]
|
169 |
+
# print(matR.shape, offset_vec.shape)
|
170 |
+
joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
171 |
+
return joints
|
172 |
+
|
173 |
+
def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
174 |
+
# cont6d_params (batch_size, joints_num, 6)
|
175 |
+
# joints (batch_size, joints_num, 3)
|
176 |
+
# root_pos (batch_size, 3)
|
177 |
+
if skel_joints is not None:
|
178 |
+
# skel_joints = torch.from_numpy(skel_joints)
|
179 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
180 |
+
if len(self._offset.shape) == 2:
|
181 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
182 |
+
joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
|
183 |
+
joints[..., 0, :] = root_pos
|
184 |
+
for chain in self._kinematic_tree:
|
185 |
+
if do_root_R:
|
186 |
+
matR = cont6d_to_matrix(cont6d_params[:, 0])
|
187 |
+
else:
|
188 |
+
matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
|
189 |
+
for i in range(1, len(chain)):
|
190 |
+
matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
|
191 |
+
offset_vec = offsets[:, chain[i]].unsqueeze(-1)
|
192 |
+
# print(matR.shape, offset_vec.shape)
|
193 |
+
joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
194 |
+
return joints
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
utils/utils_model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.optim as optim
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
def getCi(accLog):
|
9 |
+
|
10 |
+
mean = np.mean(accLog)
|
11 |
+
std = np.std(accLog)
|
12 |
+
ci95 = 1.96*std/np.sqrt(len(accLog))
|
13 |
+
|
14 |
+
return mean, ci95
|
15 |
+
|
16 |
+
def get_logger(out_dir):
|
17 |
+
logger = logging.getLogger('Exp')
|
18 |
+
logger.setLevel(logging.INFO)
|
19 |
+
formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
|
20 |
+
|
21 |
+
file_path = os.path.join(out_dir, "run.log")
|
22 |
+
file_hdlr = logging.FileHandler(file_path)
|
23 |
+
file_hdlr.setFormatter(formatter)
|
24 |
+
|
25 |
+
strm_hdlr = logging.StreamHandler(sys.stdout)
|
26 |
+
strm_hdlr.setFormatter(formatter)
|
27 |
+
|
28 |
+
logger.addHandler(file_hdlr)
|
29 |
+
logger.addHandler(strm_hdlr)
|
30 |
+
return logger
|
31 |
+
|
32 |
+
## Optimizer
|
33 |
+
def initial_optim(decay_option, lr, weight_decay, net, optimizer) :
|
34 |
+
|
35 |
+
if optimizer == 'adamw' :
|
36 |
+
optimizer_adam_family = optim.AdamW
|
37 |
+
elif optimizer == 'adam' :
|
38 |
+
optimizer_adam_family = optim.Adam
|
39 |
+
if decay_option == 'all':
|
40 |
+
#optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
|
41 |
+
optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=weight_decay)
|
42 |
+
|
43 |
+
elif decay_option == 'noVQ':
|
44 |
+
all_params = set(net.parameters())
|
45 |
+
no_decay = set([net.vq_layer])
|
46 |
+
|
47 |
+
decay = all_params - no_decay
|
48 |
+
optimizer = optimizer_adam_family([
|
49 |
+
{'params': list(no_decay), 'weight_decay': 0},
|
50 |
+
{'params': list(decay), 'weight_decay' : weight_decay}], lr=lr)
|
51 |
+
|
52 |
+
return optimizer
|
53 |
+
|
54 |
+
|
55 |
+
def get_motion_with_trans(motion, velocity) :
|
56 |
+
'''
|
57 |
+
motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0
|
58 |
+
velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0
|
59 |
+
|
60 |
+
'''
|
61 |
+
trans = torch.cumsum(velocity, dim=1)
|
62 |
+
trans = trans - trans[:, :1] ## the first root is initialized at 0 (just for visualization)
|
63 |
+
trans = trans.repeat((1, 1, 21))
|
64 |
+
motion_with_trans = motion + trans
|
65 |
+
return motion_with_trans
|
66 |
+
|
utils/word_vectorizer.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pickle
|
3 |
+
from os.path import join as pjoin
|
4 |
+
|
5 |
+
POS_enumerator = {
|
6 |
+
'VERB': 0,
|
7 |
+
'NOUN': 1,
|
8 |
+
'DET': 2,
|
9 |
+
'ADP': 3,
|
10 |
+
'NUM': 4,
|
11 |
+
'AUX': 5,
|
12 |
+
'PRON': 6,
|
13 |
+
'ADJ': 7,
|
14 |
+
'ADV': 8,
|
15 |
+
'Loc_VIP': 9,
|
16 |
+
'Body_VIP': 10,
|
17 |
+
'Obj_VIP': 11,
|
18 |
+
'Act_VIP': 12,
|
19 |
+
'Desc_VIP': 13,
|
20 |
+
'OTHER': 14,
|
21 |
+
}
|
22 |
+
|
23 |
+
Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
|
24 |
+
'up', 'down', 'straight', 'curve')
|
25 |
+
|
26 |
+
Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
|
27 |
+
|
28 |
+
Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
|
29 |
+
|
30 |
+
Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
|
31 |
+
'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
|
32 |
+
'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
|
33 |
+
|
34 |
+
Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
|
35 |
+
'angrily', 'sadly')
|
36 |
+
|
37 |
+
VIP_dict = {
|
38 |
+
'Loc_VIP': Loc_list,
|
39 |
+
'Body_VIP': Body_list,
|
40 |
+
'Obj_VIP': Obj_List,
|
41 |
+
'Act_VIP': Act_list,
|
42 |
+
'Desc_VIP': Desc_list,
|
43 |
+
}
|
44 |
+
|
45 |
+
|
46 |
+
class WordVectorizer(object):
|
47 |
+
def __init__(self, meta_root, prefix):
|
48 |
+
vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
|
49 |
+
words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
|
50 |
+
self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
|
51 |
+
self.word2vec = {w: vectors[self.word2idx[w]] for w in words}
|
52 |
+
|
53 |
+
def _get_pos_ohot(self, pos):
|
54 |
+
pos_vec = np.zeros(len(POS_enumerator))
|
55 |
+
if pos in POS_enumerator:
|
56 |
+
pos_vec[POS_enumerator[pos]] = 1
|
57 |
+
else:
|
58 |
+
pos_vec[POS_enumerator['OTHER']] = 1
|
59 |
+
return pos_vec
|
60 |
+
|
61 |
+
def __len__(self):
|
62 |
+
return len(self.word2vec)
|
63 |
+
|
64 |
+
def __getitem__(self, item):
|
65 |
+
word, pos = item.split('/')
|
66 |
+
if word in self.word2vec:
|
67 |
+
word_vec = self.word2vec[word]
|
68 |
+
vip_pos = None
|
69 |
+
for key, values in VIP_dict.items():
|
70 |
+
if word in values:
|
71 |
+
vip_pos = key
|
72 |
+
break
|
73 |
+
if vip_pos is not None:
|
74 |
+
pos_vec = self._get_pos_ohot(vip_pos)
|
75 |
+
else:
|
76 |
+
pos_vec = self._get_pos_ohot(pos)
|
77 |
+
else:
|
78 |
+
word_vec = self.word2vec['unk']
|
79 |
+
pos_vec = self._get_pos_ohot('OTHER')
|
80 |
+
return word_vec, pos_vec
|
81 |
+
|
82 |
+
|
83 |
+
class WordVectorizerV2(WordVectorizer):
|
84 |
+
def __init__(self, meta_root, prefix):
|
85 |
+
super(WordVectorizerV2, self).__init__(meta_root, prefix)
|
86 |
+
self.idx2word = {self.word2idx[w]: w for w in self.word2idx}
|
87 |
+
|
88 |
+
def __getitem__(self, item):
|
89 |
+
word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item)
|
90 |
+
word, pos = item.split('/')
|
91 |
+
if word in self.word2vec:
|
92 |
+
return word_vec, pose_vec, self.word2idx[word]
|
93 |
+
else:
|
94 |
+
return word_vec, pose_vec, self.word2idx['unk']
|
95 |
+
|
96 |
+
def itos(self, idx):
|
97 |
+
if idx == len(self.idx2word):
|
98 |
+
return "pad"
|
99 |
+
return self.idx2word[idx]
|