rsax commited on
Commit
822ae8b
·
verified ·
1 Parent(s): aca2c5a

Upload 10 files

Browse files
utils/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ SMPL_DATA_PATH = "./body_models/smpl"
4
+
5
+ SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl")
6
+ SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl")
7
+ JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')
8
+
9
+ ROT_CONVENTION_TO_ROT_NUMBER = {
10
+ 'legacy': 23,
11
+ 'no_hands': 21,
12
+ 'full_hands': 51,
13
+ 'mitten_hands': 33,
14
+ }
15
+
16
+ GENDERS = ['neutral', 'male', 'female']
17
+ NUM_BETAS = 10
utils/eval_trans.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import clip
4
+ import numpy as np
5
+ import torch
6
+ from scipy import linalg
7
+
8
+ import visualization.plot_3d_global as plot_3d
9
+ from utils.motion_process import recover_from_ric
10
+
11
+
12
+ import torch
13
+ import numpy as np
14
+
15
+ def tensorborad_add_video_xyz(writer, xyz, nb_iter, tag, nb_vis=4, title_batch=None, outname=None, fps=25):
16
+ # Validate that xyz has the expected dimensions
17
+ if xyz.ndimension() != 4 or xyz.shape[-1] != 3:
18
+ raise ValueError(f"Expected a 4D tensor with shape (batch_size, seq_length, n_joints, 3), but got {xyz.shape}")
19
+
20
+ # Convert tensor to NumPy array for drawing the batch
21
+ xyz_numpy = xyz.cpu().numpy()
22
+
23
+ # Generate animations using draw_to_batch
24
+ plot_xyz = plot_3d.draw_to_batch(xyz_numpy, title_batch, outname, fps=fps)
25
+
26
+ # Ensure the correct TensorBoard shape (batch, seq, channels, height, width)
27
+ plot_xyz = np.transpose(plot_xyz.numpy(), (0, 1, 4, 2, 3)) # Ensure TensorBoard compatibility
28
+
29
+ if plot_xyz.ndimension() != 5:
30
+ raise ValueError(f"Expected a 5D tensor with (batch_size, seq_length, channels, height, width), but got {plot_xyz.shape}")
31
+
32
+ # Add the video to TensorBoard
33
+ writer.add_video(tag, plot_xyz, nb_iter, fps=fps)
34
+
35
+
36
+
37
+
38
+ @torch.no_grad()
39
+ def evaluation_vqvae(out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) :
40
+ net.eval()
41
+ nb_sample = 0
42
+
43
+ draw_org = []
44
+ draw_pred = []
45
+ draw_text = []
46
+
47
+
48
+ motion_annotation_list = []
49
+ motion_pred_list = []
50
+
51
+ R_precision_real = 0
52
+ R_precision = 0
53
+
54
+ nb_sample = 0
55
+ matching_score_real = 0
56
+ matching_score_pred = 0
57
+ for batch in val_loader:
58
+ word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, token, name = batch
59
+
60
+ motion = motion.cuda()
61
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, motion, m_length)
62
+ bs, seq = motion.shape[0], motion.shape[1]
63
+
64
+ num_joints = 21 if motion.shape[-1] == 251 else 22
65
+
66
+ pred_pose_eval = torch.zeros((bs, seq, motion.shape[-1])).cuda()
67
+
68
+ for i in range(bs):
69
+ pose = val_loader.dataset.inv_transform(motion[i:i+1, :m_length[i], :].detach().cpu().numpy())
70
+ pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
71
+
72
+
73
+ pred_pose, loss_commit, perplexity = net(motion[i:i+1, :m_length[i]])
74
+ pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
75
+ pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
76
+
77
+ if savenpy:
78
+ np.save(os.path.join(out_dir, name[i]+'_gt.npy'), pose_xyz[:, :m_length[i]].cpu().numpy())
79
+ np.save(os.path.join(out_dir, name[i]+'_pred.npy'), pred_xyz.detach().cpu().numpy())
80
+
81
+ pred_pose_eval[i:i+1,:m_length[i],:] = pred_pose
82
+
83
+ if i < min(4, bs):
84
+ draw_org.append(pose_xyz)
85
+ draw_pred.append(pred_xyz)
86
+ draw_text.append(caption[i])
87
+
88
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, m_length)
89
+
90
+ motion_pred_list.append(em_pred)
91
+ motion_annotation_list.append(em)
92
+
93
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
94
+ R_precision_real += temp_R
95
+ matching_score_real += temp_match
96
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
97
+ R_precision += temp_R
98
+ matching_score_pred += temp_match
99
+
100
+ nb_sample += bs
101
+
102
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
103
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
104
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
105
+ mu, cov= calculate_activation_statistics(motion_pred_np)
106
+
107
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
108
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
109
+
110
+ R_precision_real = R_precision_real / nb_sample
111
+ R_precision = R_precision / nb_sample
112
+
113
+ matching_score_real = matching_score_real / nb_sample
114
+ matching_score_pred = matching_score_pred / nb_sample
115
+
116
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
117
+
118
+ msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
119
+ logger.info(msg)
120
+
121
+ if draw:
122
+ writer.add_scalar('./Test/FID', fid, nb_iter)
123
+ writer.add_scalar('./Test/Diversity', diversity, nb_iter)
124
+ writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
125
+ writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
126
+ writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
127
+ writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
128
+
129
+
130
+ #if nb_iter % 5000 == 0 :
131
+ #for ii in range(4):
132
+ #tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
133
+
134
+ #if nb_iter % 5000 == 0 :
135
+ #for ii in range(4):
136
+ #tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
137
+
138
+
139
+ if fid < best_fid :
140
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
141
+ logger.info(msg)
142
+ best_fid, best_iter = fid, nb_iter
143
+ if save:
144
+ torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
145
+
146
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
147
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
148
+ logger.info(msg)
149
+ best_div = diversity
150
+ if save:
151
+ torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_div.pth'))
152
+
153
+ if R_precision[0] > best_top1 :
154
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
155
+ logger.info(msg)
156
+ best_top1 = R_precision[0]
157
+ if save:
158
+ torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_top1.pth'))
159
+
160
+ if R_precision[1] > best_top2 :
161
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
162
+ logger.info(msg)
163
+ best_top2 = R_precision[1]
164
+
165
+ if R_precision[2] > best_top3 :
166
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
167
+ logger.info(msg)
168
+ best_top3 = R_precision[2]
169
+
170
+ if matching_score_pred < best_matching :
171
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
172
+ logger.info(msg)
173
+ best_matching = matching_score_pred
174
+ if save:
175
+ torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_best_matching.pth'))
176
+
177
+ if save:
178
+ torch.save({'net' : net.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
179
+
180
+ net.train()
181
+ return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger
182
+
183
+
184
+ @torch.no_grad()
185
+ def evaluation_transformer(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, clip_model, eval_wrapper, draw = True, save = True, savegif=False) :
186
+
187
+ trans.eval()
188
+ nb_sample = 0
189
+
190
+ draw_org = []
191
+ draw_pred = []
192
+ draw_text = []
193
+ draw_text_pred = []
194
+
195
+ motion_annotation_list = []
196
+ motion_pred_list = []
197
+ R_precision_real = 0
198
+ R_precision = 0
199
+ matching_score_real = 0
200
+ matching_score_pred = 0
201
+
202
+ nb_sample = 0
203
+ for i in range(1):
204
+ for batch in val_loader:
205
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
206
+
207
+ bs, seq = pose.shape[:2]
208
+ num_joints = 21 if pose.shape[-1] == 251 else 22
209
+
210
+ text = clip.tokenize(clip_text, truncate=True).cuda()
211
+
212
+ feat_clip_text = clip_model.encode_text(text).float()
213
+ pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
214
+ pred_len = torch.ones(bs).long()
215
+
216
+ for k in range(bs):
217
+ try:
218
+ index_motion = trans.sample(feat_clip_text[k:k+1], False)
219
+ except:
220
+ index_motion = torch.ones(1,1).cuda().long()
221
+
222
+ pred_pose = net.forward_decoder(index_motion)
223
+ cur_len = pred_pose.shape[1]
224
+
225
+ pred_len[k] = min(cur_len, seq)
226
+ pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
227
+
228
+ if draw:
229
+ pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
230
+ pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
231
+
232
+ if i == 0 and k < 4:
233
+ draw_pred.append(pred_xyz)
234
+ draw_text_pred.append(clip_text[k])
235
+
236
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
237
+
238
+ if i == 0:
239
+ pose = pose.cuda().float()
240
+
241
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
242
+ motion_annotation_list.append(em)
243
+ motion_pred_list.append(em_pred)
244
+
245
+ if draw:
246
+ pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
247
+ pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
248
+
249
+
250
+ for j in range(min(4, bs)):
251
+ draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
252
+ draw_text.append(clip_text[j])
253
+
254
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
255
+ R_precision_real += temp_R
256
+ matching_score_real += temp_match
257
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
258
+ R_precision += temp_R
259
+ matching_score_pred += temp_match
260
+
261
+ nb_sample += bs
262
+
263
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
264
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
265
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
266
+ mu, cov= calculate_activation_statistics(motion_pred_np)
267
+
268
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
269
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
270
+
271
+ R_precision_real = R_precision_real / nb_sample
272
+ R_precision = R_precision / nb_sample
273
+
274
+ matching_score_real = matching_score_real / nb_sample
275
+ matching_score_pred = matching_score_pred / nb_sample
276
+
277
+
278
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
279
+
280
+ msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}"
281
+ logger.info(msg)
282
+
283
+
284
+ if draw:
285
+ writer.add_scalar('./Test/FID', fid, nb_iter)
286
+ writer.add_scalar('./Test/Diversity', diversity, nb_iter)
287
+ writer.add_scalar('./Test/top1', R_precision[0], nb_iter)
288
+ writer.add_scalar('./Test/top2', R_precision[1], nb_iter)
289
+ writer.add_scalar('./Test/top3', R_precision[2], nb_iter)
290
+ writer.add_scalar('./Test/matching_score', matching_score_pred, nb_iter)
291
+
292
+
293
+ #if nb_iter % 10000 == 0 :
294
+ #for ii in range(4):
295
+ #tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/org_eval'+str(ii), nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, 'gt'+str(ii)+'.gif')] if savegif else None)
296
+
297
+ #if nb_iter % 10000 == 0 :
298
+ #for ii in range(4):
299
+ #tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/pred_eval'+str(ii), nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, 'pred'+str(ii)+'.gif')] if savegif else None)
300
+
301
+
302
+ if fid < best_fid :
303
+ msg = f"--> --> \t FID Improved from {best_fid:.5f} to {fid:.5f} !!!"
304
+ logger.info(msg)
305
+ best_fid, best_iter = fid, nb_iter
306
+ if save:
307
+ torch.save({'trans' : trans.state_dict()}, os.path.join(out_dir, 'net_best_fid.pth'))
308
+
309
+ if matching_score_pred < best_matching :
310
+ msg = f"--> --> \t matching_score Improved from {best_matching:.5f} to {matching_score_pred:.5f} !!!"
311
+ logger.info(msg)
312
+ best_matching = matching_score_pred
313
+
314
+ if abs(diversity_real - diversity) < abs(diversity_real - best_div) :
315
+ msg = f"--> --> \t Diversity Improved from {best_div:.5f} to {diversity:.5f} !!!"
316
+ logger.info(msg)
317
+ best_div = diversity
318
+
319
+ if R_precision[0] > best_top1 :
320
+ msg = f"--> --> \t Top1 Improved from {best_top1:.4f} to {R_precision[0]:.4f} !!!"
321
+ logger.info(msg)
322
+ best_top1 = R_precision[0]
323
+
324
+ if R_precision[1] > best_top2 :
325
+ msg = f"--> --> \t Top2 Improved from {best_top2:.4f} to {R_precision[1]:.4f} !!!"
326
+ logger.info(msg)
327
+ best_top2 = R_precision[1]
328
+
329
+ if R_precision[2] > best_top3 :
330
+ msg = f"--> --> \t Top3 Improved from {best_top3:.4f} to {R_precision[2]:.4f} !!!"
331
+ logger.info(msg)
332
+ best_top3 = R_precision[2]
333
+
334
+ if save:
335
+ torch.save({'trans' : trans.state_dict()}, os.path.join(out_dir, 'net_last.pth'))
336
+
337
+ trans.train()
338
+ return best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger
339
+
340
+
341
+ @torch.no_grad()
342
+ def evaluation_transformer_test(out_dir, val_loader, net, trans, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, best_multi, clip_model, eval_wrapper, draw = True, save = True, savegif=False, savenpy=False) :
343
+
344
+ trans.eval()
345
+ nb_sample = 0
346
+
347
+ draw_org = []
348
+ draw_pred = []
349
+ draw_text = []
350
+ draw_text_pred = []
351
+ draw_name = []
352
+
353
+ motion_annotation_list = []
354
+ motion_pred_list = []
355
+ motion_multimodality = []
356
+ R_precision_real = 0
357
+ R_precision = 0
358
+ matching_score_real = 0
359
+ matching_score_pred = 0
360
+
361
+ nb_sample = 0
362
+
363
+ for batch in val_loader:
364
+
365
+ word_embeddings, pos_one_hots, clip_text, sent_len, pose, m_length, token, name = batch
366
+ bs, seq = pose.shape[:2]
367
+ num_joints = 21 if pose.shape[-1] == 251 else 22
368
+
369
+ text = clip.tokenize(clip_text, truncate=True).cuda()
370
+
371
+ feat_clip_text = clip_model.encode_text(text).float()
372
+ motion_multimodality_batch = []
373
+ for i in range(30):
374
+ pred_pose_eval = torch.zeros((bs, seq, pose.shape[-1])).cuda()
375
+ pred_len = torch.ones(bs).long()
376
+
377
+ for k in range(bs):
378
+ try:
379
+ index_motion = trans.sample(feat_clip_text[k:k+1], True)
380
+ except:
381
+ index_motion = torch.ones(1,1).cuda().long()
382
+
383
+ pred_pose = net.forward_decoder(index_motion)
384
+ cur_len = pred_pose.shape[1]
385
+
386
+ pred_len[k] = min(cur_len, seq)
387
+ pred_pose_eval[k:k+1, :cur_len] = pred_pose[:, :seq]
388
+
389
+ if i == 0 and (draw or savenpy):
390
+ pred_denorm = val_loader.dataset.inv_transform(pred_pose.detach().cpu().numpy())
391
+ pred_xyz = recover_from_ric(torch.from_numpy(pred_denorm).float().cuda(), num_joints)
392
+
393
+ if savenpy:
394
+ np.save(os.path.join(out_dir, name[k]+'_pred.npy'), pred_xyz.detach().cpu().numpy())
395
+
396
+ if draw:
397
+ if i == 0:
398
+ draw_pred.append(pred_xyz)
399
+ draw_text_pred.append(clip_text[k])
400
+ draw_name.append(name[k])
401
+
402
+ et_pred, em_pred = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pred_pose_eval, pred_len)
403
+
404
+ motion_multimodality_batch.append(em_pred.reshape(bs, 1, -1))
405
+
406
+ if i == 0:
407
+ pose = pose.cuda().float()
408
+
409
+ et, em = eval_wrapper.get_co_embeddings(word_embeddings, pos_one_hots, sent_len, pose, m_length)
410
+ motion_annotation_list.append(em)
411
+ motion_pred_list.append(em_pred)
412
+
413
+ if draw or savenpy:
414
+ pose = val_loader.dataset.inv_transform(pose.detach().cpu().numpy())
415
+ pose_xyz = recover_from_ric(torch.from_numpy(pose).float().cuda(), num_joints)
416
+
417
+ if savenpy:
418
+ for j in range(bs):
419
+ np.save(os.path.join(out_dir, name[j]+'_gt.npy'), pose_xyz[j][:m_length[j]].unsqueeze(0).cpu().numpy())
420
+
421
+ if draw:
422
+ for j in range(bs):
423
+ draw_org.append(pose_xyz[j][:m_length[j]].unsqueeze(0))
424
+ draw_text.append(clip_text[j])
425
+
426
+ temp_R, temp_match = calculate_R_precision(et.cpu().numpy(), em.cpu().numpy(), top_k=3, sum_all=True)
427
+ R_precision_real += temp_R
428
+ matching_score_real += temp_match
429
+ temp_R, temp_match = calculate_R_precision(et_pred.cpu().numpy(), em_pred.cpu().numpy(), top_k=3, sum_all=True)
430
+ R_precision += temp_R
431
+ matching_score_pred += temp_match
432
+
433
+ nb_sample += bs
434
+
435
+ motion_multimodality.append(torch.cat(motion_multimodality_batch, dim=1))
436
+
437
+ motion_annotation_np = torch.cat(motion_annotation_list, dim=0).cpu().numpy()
438
+ motion_pred_np = torch.cat(motion_pred_list, dim=0).cpu().numpy()
439
+ gt_mu, gt_cov = calculate_activation_statistics(motion_annotation_np)
440
+ mu, cov= calculate_activation_statistics(motion_pred_np)
441
+
442
+ diversity_real = calculate_diversity(motion_annotation_np, 300 if nb_sample > 300 else 100)
443
+ diversity = calculate_diversity(motion_pred_np, 300 if nb_sample > 300 else 100)
444
+
445
+ R_precision_real = R_precision_real / nb_sample
446
+ R_precision = R_precision / nb_sample
447
+
448
+ matching_score_real = matching_score_real / nb_sample
449
+ matching_score_pred = matching_score_pred / nb_sample
450
+
451
+ multimodality = 0
452
+ motion_multimodality = torch.cat(motion_multimodality, dim=0).cpu().numpy()
453
+ multimodality = calculate_multimodality(motion_multimodality, 10)
454
+
455
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
456
+
457
+ msg = f"--> \t Eva. Iter {nb_iter} :, FID. {fid:.4f}, Diversity Real. {diversity_real:.4f}, Diversity. {diversity:.4f}, R_precision_real. {R_precision_real}, R_precision. {R_precision}, matching_score_real. {matching_score_real}, matching_score_pred. {matching_score_pred}, multimodality. {multimodality:.4f}"
458
+ logger.info(msg)
459
+
460
+
461
+ #if draw:
462
+ #for ii in range(len(draw_org)):
463
+ #tensorborad_add_video_xyz(writer, draw_org[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_org', nb_vis=1, title_batch=[draw_text[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_gt.gif')] if savegif else None)
464
+
465
+ #tensorborad_add_video_xyz(writer, draw_pred[ii], nb_iter, tag='./Vis/'+draw_name[ii]+'_pred', nb_vis=1, title_batch=[draw_text_pred[ii]], outname=[os.path.join(out_dir, draw_name[ii]+'_skel_pred.gif')] if savegif else None)
466
+
467
+ trans.train()
468
+ return fid, best_iter, diversity, R_precision[0], R_precision[1], R_precision[2], matching_score_pred, multimodality, writer, logger
469
+
470
+ # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
471
+ def euclidean_distance_matrix(matrix1, matrix2):
472
+ """
473
+ Params:
474
+ -- matrix1: N1 x D
475
+ -- matrix2: N2 x D
476
+ Returns:
477
+ -- dist: N1 x N2
478
+ dist[i, j] == distance(matrix1[i], matrix2[j])
479
+ """
480
+ assert matrix1.shape[1] == matrix2.shape[1]
481
+ d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
482
+ d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
483
+ d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
484
+ dists = np.sqrt(d1 + d2 + d3) # broadcasting
485
+ return dists
486
+
487
+
488
+
489
+ def calculate_top_k(mat, top_k):
490
+ size = mat.shape[0]
491
+ gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
492
+ bool_mat = (mat == gt_mat)
493
+ correct_vec = False
494
+ top_k_list = []
495
+ for i in range(top_k):
496
+ # print(correct_vec, bool_mat[:, i])
497
+ correct_vec = (correct_vec | bool_mat[:, i])
498
+ # print(correct_vec)
499
+ top_k_list.append(correct_vec[:, None])
500
+ top_k_mat = np.concatenate(top_k_list, axis=1)
501
+ return top_k_mat
502
+
503
+
504
+ def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
505
+ dist_mat = euclidean_distance_matrix(embedding1, embedding2)
506
+ matching_score = dist_mat.trace()
507
+ argmax = np.argsort(dist_mat, axis=1)
508
+ top_k_mat = calculate_top_k(argmax, top_k)
509
+ if sum_all:
510
+ return top_k_mat.sum(axis=0), matching_score
511
+ else:
512
+ return top_k_mat, matching_score
513
+
514
+ def calculate_multimodality(activation, multimodality_times):
515
+ assert len(activation.shape) == 3
516
+ assert activation.shape[1] > multimodality_times
517
+ num_per_sent = activation.shape[1]
518
+
519
+ first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
520
+ second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
521
+ dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
522
+ return dist.mean()
523
+
524
+
525
+ def calculate_diversity(activation, diversity_times):
526
+ assert len(activation.shape) == 2
527
+ assert activation.shape[0] > diversity_times
528
+ num_samples = activation.shape[0]
529
+
530
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
531
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
532
+ dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
533
+ return dist.mean()
534
+
535
+
536
+
537
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
538
+
539
+ mu1 = np.atleast_1d(mu1)
540
+ mu2 = np.atleast_1d(mu2)
541
+
542
+ sigma1 = np.atleast_2d(sigma1)
543
+ sigma2 = np.atleast_2d(sigma2)
544
+
545
+ assert mu1.shape == mu2.shape, \
546
+ 'Training and test mean vectors have different lengths'
547
+ assert sigma1.shape == sigma2.shape, \
548
+ 'Training and test covariances have different dimensions'
549
+
550
+ diff = mu1 - mu2
551
+
552
+ # Product might be almost singular
553
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
554
+ if not np.isfinite(covmean).all():
555
+ msg = ('fid calculation produces singular product; '
556
+ 'adding %s to diagonal of cov estimates') % eps
557
+ print(msg)
558
+ offset = np.eye(sigma1.shape[0]) * eps
559
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
560
+
561
+ # Numerical error might give slight imaginary component
562
+ if np.iscomplexobj(covmean):
563
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
564
+ m = np.max(np.abs(covmean.imag))
565
+ raise ValueError('Imaginary component {}'.format(m))
566
+ covmean = covmean.real
567
+
568
+ tr_covmean = np.trace(covmean)
569
+
570
+ return (diff.dot(diff) + np.trace(sigma1)
571
+ + np.trace(sigma2) - 2 * tr_covmean)
572
+
573
+
574
+
575
+ def calculate_activation_statistics(activations):
576
+
577
+ mu = np.mean(activations, axis=0)
578
+ cov = np.cov(activations, rowvar=False)
579
+ return mu, cov
580
+
581
+
582
+ def calculate_frechet_feature_distance(feature_list1, feature_list2):
583
+ feature_list1 = np.stack(feature_list1)
584
+ feature_list2 = np.stack(feature_list2)
585
+
586
+ # normalize the scale
587
+ mean = np.mean(feature_list1, axis=0)
588
+ std = np.std(feature_list1, axis=0) + 1e-10
589
+ feature_list1 = (feature_list1 - mean) / std
590
+ feature_list2 = (feature_list2 - mean) / std
591
+
592
+ dist = calculate_frechet_distance(
593
+ mu1=np.mean(feature_list1, axis=0),
594
+ sigma1=np.cov(feature_list1, rowvar=False),
595
+ mu2=np.mean(feature_list2, axis=0),
596
+ sigma2=np.cov(feature_list2, rowvar=False),
597
+ )
598
+ return dist
utils/losses.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class ReConsLoss(nn.Module):
5
+ def __init__(self, recons_loss, nb_joints):
6
+ super(ReConsLoss, self).__init__()
7
+
8
+ if recons_loss == 'l1':
9
+ self.Loss = torch.nn.L1Loss()
10
+ elif recons_loss == 'l2' :
11
+ self.Loss = torch.nn.MSELoss()
12
+ elif recons_loss == 'l1_smooth' :
13
+ self.Loss = torch.nn.SmoothL1Loss()
14
+
15
+ # 4 global motion associated to root
16
+ # 12 local motion (3 local xyz, 3 vel xyz, 6 rot6d)
17
+ # 3 global vel xyz
18
+ # 4 foot contact
19
+ self.nb_joints = nb_joints
20
+ self.motion_dim = (nb_joints - 1) * 12 + 4 + 3 + 4
21
+
22
+ def forward(self, motion_pred, motion_gt) :
23
+ loss = self.Loss(motion_pred[..., : self.motion_dim], motion_gt[..., :self.motion_dim])
24
+ return loss
25
+
26
+ def forward_vel(self, motion_pred, motion_gt) :
27
+ loss = self.Loss(motion_pred[..., 4 : (self.nb_joints - 1) * 3 + 4], motion_gt[..., 4 : (self.nb_joints - 1) * 3 + 4])
28
+ return loss
29
+
30
+
utils/motion_process.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.quaternion import quaternion_to_cont6d, qrot, qinv
3
+
4
+ def recover_root_rot_pos(data):
5
+ rot_vel = data[..., 0]
6
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
7
+ '''Get Y-axis rotation from rotation velocity'''
8
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
9
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
10
+
11
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
12
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
13
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
14
+
15
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
16
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
17
+ '''Add Y-axis rotation to root position'''
18
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
19
+
20
+ r_pos = torch.cumsum(r_pos, dim=-2)
21
+
22
+ r_pos[..., 1] = data[..., 3]
23
+ return r_rot_quat, r_pos
24
+
25
+
26
+ def recover_from_rot(data, joints_num, skeleton):
27
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
28
+
29
+ r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
30
+
31
+ start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
32
+ end_indx = start_indx + (joints_num - 1) * 6
33
+ cont6d_params = data[..., start_indx:end_indx]
34
+ # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
35
+ cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
36
+ cont6d_params = cont6d_params.view(-1, joints_num, 6)
37
+
38
+ positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
39
+
40
+ return positions
41
+
42
+
43
+ def recover_from_ric(data, joints_num):
44
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
45
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
46
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
47
+
48
+ '''Add Y-axis rotation to local joints'''
49
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
50
+
51
+ '''Add root XZ to joints'''
52
+ positions[..., 0] += r_pos[..., 0:1]
53
+ positions[..., 2] += r_pos[..., 2:3]
54
+
55
+ '''Concate root and joints'''
56
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
57
+
58
+ return positions
59
+
utils/paramUtil.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal struture
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0,0,0],
33
+ [1,0,0],
34
+ [-1,0,0],
35
+ [0,1,0],
36
+ [0,-1,0],
37
+ [0,-1,0],
38
+ [0,1,0],
39
+ [0,-1,0],
40
+ [0,-1,0],
41
+ [0,1,0],
42
+ [0,0,1],
43
+ [0,0,1],
44
+ [0,1,0],
45
+ [1,0,0],
46
+ [-1,0,0],
47
+ [0,0,1],
48
+ [0,-1,0],
49
+ [0,-1,0],
50
+ [0,-1,0],
51
+ [0,-1,0],
52
+ [0,-1,0],
53
+ [0,-1,0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58
+
59
+
60
+ kit_tgt_skel_id = '03950'
61
+
62
+ t2m_tgt_skel_id = '000021'
63
+
utils/quaternion.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(float).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+
120
+ if deg:
121
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122
+ else:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape)
124
+
125
+
126
+ # Numpy-backed implementations
127
+
128
+ def qmul_np(q, r):
129
+ q = torch.from_numpy(q).contiguous().float()
130
+ r = torch.from_numpy(r).contiguous().float()
131
+ return qmul(q, r).numpy()
132
+
133
+
134
+ def qrot_np(q, v):
135
+ q = torch.from_numpy(q).contiguous().float()
136
+ v = torch.from_numpy(v).contiguous().float()
137
+ return qrot(q, v).numpy()
138
+
139
+
140
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
141
+ if use_gpu:
142
+ q = torch.from_numpy(q).cuda().float()
143
+ return qeuler(q, order, epsilon).cpu().numpy()
144
+ else:
145
+ q = torch.from_numpy(q).contiguous().float()
146
+ return qeuler(q, order, epsilon).numpy()
147
+
148
+
149
+ def qfix(q):
150
+ """
151
+ Enforce quaternion continuity across the time dimension by selecting
152
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153
+ between two consecutive frames.
154
+
155
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156
+ Returns a tensor of the same shape.
157
+ """
158
+ assert len(q.shape) == 3
159
+ assert q.shape[-1] == 4
160
+
161
+ result = q.copy()
162
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
163
+ mask = dot_products < 0
164
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165
+ result[1:][mask] *= -1
166
+ return result
167
+
168
+
169
+ def euler2quat(e, order, deg=True):
170
+ """
171
+ Convert Euler angles to quaternions.
172
+ """
173
+ assert e.shape[-1] == 3
174
+
175
+ original_shape = list(e.shape)
176
+ original_shape[-1] = 4
177
+
178
+ e = e.view(-1, 3)
179
+
180
+ ## if euler angles in degrees
181
+ if deg:
182
+ e = e * np.pi / 180.
183
+
184
+ x = e[:, 0]
185
+ y = e[:, 1]
186
+ z = e[:, 2]
187
+
188
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191
+
192
+ result = None
193
+ for coord in order:
194
+ if coord == 'x':
195
+ r = rx
196
+ elif coord == 'y':
197
+ r = ry
198
+ elif coord == 'z':
199
+ r = rz
200
+ else:
201
+ raise
202
+ if result is None:
203
+ result = r
204
+ else:
205
+ result = qmul(result, r)
206
+
207
+ # Reverse antipodal representation to have a non-negative "w"
208
+ if order in ['xyz', 'yzx', 'zxy']:
209
+ result *= -1
210
+
211
+ return result.view(original_shape)
212
+
213
+
214
+ def expmap_to_quaternion(e):
215
+ """
216
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
217
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219
+ Returns a tensor of shape (*, 4).
220
+ """
221
+ assert e.shape[-1] == 3
222
+
223
+ original_shape = list(e.shape)
224
+ original_shape[-1] = 4
225
+ e = e.reshape(-1, 3)
226
+
227
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228
+ w = np.cos(0.5 * theta).reshape(-1, 1)
229
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231
+
232
+
233
+ def euler_to_quaternion(e, order):
234
+ """
235
+ Convert Euler angles to quaternions.
236
+ """
237
+ assert e.shape[-1] == 3
238
+
239
+ original_shape = list(e.shape)
240
+ original_shape[-1] = 4
241
+
242
+ e = e.reshape(-1, 3)
243
+
244
+ x = e[:, 0]
245
+ y = e[:, 1]
246
+ z = e[:, 2]
247
+
248
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251
+
252
+ result = None
253
+ for coord in order:
254
+ if coord == 'x':
255
+ r = rx
256
+ elif coord == 'y':
257
+ r = ry
258
+ elif coord == 'z':
259
+ r = rz
260
+ else:
261
+ raise
262
+ if result is None:
263
+ result = r
264
+ else:
265
+ result = qmul_np(result, r)
266
+
267
+ # Reverse antipodal representation to have a non-negative "w"
268
+ if order in ['xyz', 'yzx', 'zxy']:
269
+ result *= -1
270
+
271
+ return result.reshape(original_shape)
272
+
273
+
274
+ def quaternion_to_matrix(quaternions):
275
+ """
276
+ Convert rotations given as quaternions to rotation matrices.
277
+ Args:
278
+ quaternions: quaternions with real part first,
279
+ as tensor of shape (..., 4).
280
+ Returns:
281
+ Rotation matrices as tensor of shape (..., 3, 3).
282
+ """
283
+ r, i, j, k = torch.unbind(quaternions, -1)
284
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
285
+
286
+ o = torch.stack(
287
+ (
288
+ 1 - two_s * (j * j + k * k),
289
+ two_s * (i * j - k * r),
290
+ two_s * (i * k + j * r),
291
+ two_s * (i * j + k * r),
292
+ 1 - two_s * (i * i + k * k),
293
+ two_s * (j * k - i * r),
294
+ two_s * (i * k - j * r),
295
+ two_s * (j * k + i * r),
296
+ 1 - two_s * (i * i + j * j),
297
+ ),
298
+ -1,
299
+ )
300
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
301
+
302
+
303
+ def quaternion_to_matrix_np(quaternions):
304
+ q = torch.from_numpy(quaternions).contiguous().float()
305
+ return quaternion_to_matrix(q).numpy()
306
+
307
+
308
+ def quaternion_to_cont6d_np(quaternions):
309
+ rotation_mat = quaternion_to_matrix_np(quaternions)
310
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311
+ return cont_6d
312
+
313
+
314
+ def quaternion_to_cont6d(quaternions):
315
+ rotation_mat = quaternion_to_matrix(quaternions)
316
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317
+ return cont_6d
318
+
319
+
320
+ def cont6d_to_matrix(cont6d):
321
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322
+ x_raw = cont6d[..., 0:3]
323
+ y_raw = cont6d[..., 3:6]
324
+
325
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326
+ z = torch.cross(x, y_raw, dim=-1)
327
+ z = z / torch.norm(z, dim=-1, keepdim=True)
328
+
329
+ y = torch.cross(z, x, dim=-1)
330
+
331
+ x = x[..., None]
332
+ y = y[..., None]
333
+ z = z[..., None]
334
+
335
+ mat = torch.cat([x, y, z], dim=-1)
336
+ return mat
337
+
338
+
339
+ def cont6d_to_matrix_np(cont6d):
340
+ q = torch.from_numpy(cont6d).contiguous().float()
341
+ return cont6d_to_matrix(q).numpy()
342
+
343
+
344
+ def qpow(q0, t, dtype=torch.float):
345
+ ''' q0 : tensor of quaternions
346
+ t: tensor of powers
347
+ '''
348
+ q0 = qnormalize(q0)
349
+ theta0 = torch.acos(q0[..., 0])
350
+
351
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
352
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
354
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355
+
356
+ if isinstance(t, torch.Tensor):
357
+ q = torch.zeros(t.shape + q0.shape)
358
+ theta = t.view(-1, 1) * theta0.view(1, -1)
359
+ else: ## if t is a number
360
+ q = torch.zeros(q0.shape)
361
+ theta = t * theta0
362
+
363
+ q[..., 0] = torch.cos(theta)
364
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365
+
366
+ return q.to(dtype)
367
+
368
+
369
+ def qslerp(q0, q1, t):
370
+ '''
371
+ q0: starting quaternion
372
+ q1: ending quaternion
373
+ t: array of points along the way
374
+
375
+ Returns:
376
+ Tensor of Slerps: t.shape + q0.shape
377
+ '''
378
+
379
+ q0 = qnormalize(q0)
380
+ q1 = qnormalize(q1)
381
+ q_ = qpow(qmul(q1, qinv(q0)), t)
382
+
383
+ return qmul(q_,
384
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385
+
386
+
387
+ def qbetween(v0, v1):
388
+ '''
389
+ find the quaternion used to rotate v0 to v1
390
+ '''
391
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393
+
394
+ v = torch.cross(v0, v1)
395
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396
+ keepdim=True)
397
+ return qnormalize(torch.cat([w, v], dim=-1))
398
+
399
+
400
+ def qbetween_np(v0, v1):
401
+ '''
402
+ find the quaternion used to rotate v0 to v1
403
+ '''
404
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406
+
407
+ v0 = torch.from_numpy(v0).float()
408
+ v1 = torch.from_numpy(v1).float()
409
+ return qbetween(v0, v1).numpy()
410
+
411
+
412
+ def lerp(p0, p1, t):
413
+ if not isinstance(t, torch.Tensor):
414
+ t = torch.Tensor([t])
415
+
416
+ new_shape = t.shape + p0.shape
417
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419
+ p0 = p0.view(new_view_p).expand(new_shape)
420
+ p1 = p1.view(new_view_p).expand(new_shape)
421
+ t = t.view(new_view_t).expand(new_shape)
422
+
423
+ return p0 + t * (p1 - p0)
utils/rotation_conversions.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
+ # Check PYTORCH3D_LICENCE before use
3
+
4
+ import functools
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+
11
+ """
12
+ The transformation matrices returned from the functions in this file assume
13
+ the points on which the transformation will be applied are column vectors.
14
+ i.e. the R matrix is structured as
15
+ R = [
16
+ [Rxx, Rxy, Rxz],
17
+ [Ryx, Ryy, Ryz],
18
+ [Rzx, Rzy, Rzz],
19
+ ] # (3, 3)
20
+ This matrix can be applied to column vectors by post multiplication
21
+ by the points e.g.
22
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
23
+ transformed_points = R * points
24
+ To apply the same matrix to points which are row vectors, the R matrix
25
+ can be transposed and pre multiplied by the points:
26
+ e.g.
27
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
28
+ transformed_points = points * R.transpose(1, 0)
29
+ """
30
+
31
+
32
+ def quaternion_to_matrix(quaternions):
33
+ """
34
+ Convert rotations given as quaternions to rotation matrices.
35
+ Args:
36
+ quaternions: quaternions with real part first,
37
+ as tensor of shape (..., 4).
38
+ Returns:
39
+ Rotation matrices as tensor of shape (..., 3, 3).
40
+ """
41
+ r, i, j, k = torch.unbind(quaternions, -1)
42
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
43
+
44
+ o = torch.stack(
45
+ (
46
+ 1 - two_s * (j * j + k * k),
47
+ two_s * (i * j - k * r),
48
+ two_s * (i * k + j * r),
49
+ two_s * (i * j + k * r),
50
+ 1 - two_s * (i * i + k * k),
51
+ two_s * (j * k - i * r),
52
+ two_s * (i * k - j * r),
53
+ two_s * (j * k + i * r),
54
+ 1 - two_s * (i * i + j * j),
55
+ ),
56
+ -1,
57
+ )
58
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
59
+
60
+
61
+ def _copysign(a, b):
62
+ """
63
+ Return a tensor where each element has the absolute value taken from the,
64
+ corresponding element of a, with sign taken from the corresponding
65
+ element of b. This is like the standard copysign floating-point operation,
66
+ but is not careful about negative 0 and NaN.
67
+ Args:
68
+ a: source tensor.
69
+ b: tensor whose signs will be used, of the same shape as a.
70
+ Returns:
71
+ Tensor of the same shape as a with the signs of b.
72
+ """
73
+ signs_differ = (a < 0) != (b < 0)
74
+ return torch.where(signs_differ, -a, a)
75
+
76
+
77
+ def _sqrt_positive_part(x):
78
+ """
79
+ Returns torch.sqrt(torch.max(0, x))
80
+ but with a zero subgradient where x is 0.
81
+ """
82
+ ret = torch.zeros_like(x)
83
+ positive_mask = x > 0
84
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
85
+ return ret
86
+
87
+
88
+ def matrix_to_quaternion(matrix):
89
+ """
90
+ Convert rotations given as rotation matrices to quaternions.
91
+ Args:
92
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
93
+ Returns:
94
+ quaternions with real part first, as tensor of shape (..., 4).
95
+ """
96
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
97
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
98
+ m00 = matrix[..., 0, 0]
99
+ m11 = matrix[..., 1, 1]
100
+ m22 = matrix[..., 2, 2]
101
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
102
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
103
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
104
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
105
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
106
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
107
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
108
+ return torch.stack((o0, o1, o2, o3), -1)
109
+
110
+
111
+ def _axis_angle_rotation(axis: str, angle):
112
+ """
113
+ Return the rotation matrices for one of the rotations about an axis
114
+ of which Euler angles describe, for each value of the angle given.
115
+ Args:
116
+ axis: Axis label "X" or "Y or "Z".
117
+ angle: any shape tensor of Euler angles in radians
118
+ Returns:
119
+ Rotation matrices as tensor of shape (..., 3, 3).
120
+ """
121
+
122
+ cos = torch.cos(angle)
123
+ sin = torch.sin(angle)
124
+ one = torch.ones_like(angle)
125
+ zero = torch.zeros_like(angle)
126
+
127
+ if axis == "X":
128
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
129
+ if axis == "Y":
130
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
131
+ if axis == "Z":
132
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
133
+
134
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
135
+
136
+
137
+ def euler_angles_to_matrix(euler_angles, convention: str):
138
+ """
139
+ Convert rotations given as Euler angles in radians to rotation matrices.
140
+ Args:
141
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
142
+ convention: Convention string of three uppercase letters from
143
+ {"X", "Y", and "Z"}.
144
+ Returns:
145
+ Rotation matrices as tensor of shape (..., 3, 3).
146
+ """
147
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
148
+ raise ValueError("Invalid input euler angles.")
149
+ if len(convention) != 3:
150
+ raise ValueError("Convention must have 3 letters.")
151
+ if convention[1] in (convention[0], convention[2]):
152
+ raise ValueError(f"Invalid convention {convention}.")
153
+ for letter in convention:
154
+ if letter not in ("X", "Y", "Z"):
155
+ raise ValueError(f"Invalid letter {letter} in convention string.")
156
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
157
+ return functools.reduce(torch.matmul, matrices)
158
+
159
+
160
+ def _angle_from_tan(
161
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
162
+ ):
163
+ """
164
+ Extract the first or third Euler angle from the two members of
165
+ the matrix which are positive constant times its sine and cosine.
166
+ Args:
167
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
168
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
169
+ convention.
170
+ data: Rotation matrices as tensor of shape (..., 3, 3).
171
+ horizontal: Whether we are looking for the angle for the third axis,
172
+ which means the relevant entries are in the same row of the
173
+ rotation matrix. If not, they are in the same column.
174
+ tait_bryan: Whether the first and third axes in the convention differ.
175
+ Returns:
176
+ Euler Angles in radians for each matrix in data as a tensor
177
+ of shape (...).
178
+ """
179
+
180
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
181
+ if horizontal:
182
+ i2, i1 = i1, i2
183
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
184
+ if horizontal == even:
185
+ return torch.atan2(data[..., i1], data[..., i2])
186
+ if tait_bryan:
187
+ return torch.atan2(-data[..., i2], data[..., i1])
188
+ return torch.atan2(data[..., i2], -data[..., i1])
189
+
190
+
191
+ def _index_from_letter(letter: str):
192
+ if letter == "X":
193
+ return 0
194
+ if letter == "Y":
195
+ return 1
196
+ if letter == "Z":
197
+ return 2
198
+
199
+
200
+ def matrix_to_euler_angles(matrix, convention: str):
201
+ """
202
+ Convert rotations given as rotation matrices to Euler angles in radians.
203
+ Args:
204
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
205
+ convention: Convention string of three uppercase letters.
206
+ Returns:
207
+ Euler angles in radians as tensor of shape (..., 3).
208
+ """
209
+ if len(convention) != 3:
210
+ raise ValueError("Convention must have 3 letters.")
211
+ if convention[1] in (convention[0], convention[2]):
212
+ raise ValueError(f"Invalid convention {convention}.")
213
+ for letter in convention:
214
+ if letter not in ("X", "Y", "Z"):
215
+ raise ValueError(f"Invalid letter {letter} in convention string.")
216
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
217
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
218
+ i0 = _index_from_letter(convention[0])
219
+ i2 = _index_from_letter(convention[2])
220
+ tait_bryan = i0 != i2
221
+ if tait_bryan:
222
+ central_angle = torch.asin(
223
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
224
+ )
225
+ else:
226
+ central_angle = torch.acos(matrix[..., i0, i0])
227
+
228
+ o = (
229
+ _angle_from_tan(
230
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
231
+ ),
232
+ central_angle,
233
+ _angle_from_tan(
234
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
235
+ ),
236
+ )
237
+ return torch.stack(o, -1)
238
+
239
+
240
+ def random_quaternions(
241
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
242
+ ):
243
+ """
244
+ Generate random quaternions representing rotations,
245
+ i.e. versors with nonnegative real part.
246
+ Args:
247
+ n: Number of quaternions in a batch to return.
248
+ dtype: Type to return.
249
+ device: Desired device of returned tensor. Default:
250
+ uses the current device for the default tensor type.
251
+ requires_grad: Whether the resulting tensor should have the gradient
252
+ flag set.
253
+ Returns:
254
+ Quaternions as tensor of shape (N, 4).
255
+ """
256
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
257
+ s = (o * o).sum(1)
258
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
259
+ return o
260
+
261
+
262
+ def random_rotations(
263
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
264
+ ):
265
+ """
266
+ Generate random rotations as 3x3 rotation matrices.
267
+ Args:
268
+ n: Number of rotation matrices in a batch to return.
269
+ dtype: Type to return.
270
+ device: Device of returned tensor. Default: if None,
271
+ uses the current device for the default tensor type.
272
+ requires_grad: Whether the resulting tensor should have the gradient
273
+ flag set.
274
+ Returns:
275
+ Rotation matrices as tensor of shape (n, 3, 3).
276
+ """
277
+ quaternions = random_quaternions(
278
+ n, dtype=dtype, device=device, requires_grad=requires_grad
279
+ )
280
+ return quaternion_to_matrix(quaternions)
281
+
282
+
283
+ def random_rotation(
284
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
285
+ ):
286
+ """
287
+ Generate a single random 3x3 rotation matrix.
288
+ Args:
289
+ dtype: Type to return
290
+ device: Device of returned tensor. Default: if None,
291
+ uses the current device for the default tensor type
292
+ requires_grad: Whether the resulting tensor should have the gradient
293
+ flag set
294
+ Returns:
295
+ Rotation matrix as tensor of shape (3, 3).
296
+ """
297
+ return random_rotations(1, dtype, device, requires_grad)[0]
298
+
299
+
300
+ def standardize_quaternion(quaternions):
301
+ """
302
+ Convert a unit quaternion to a standard form: one in which the real
303
+ part is non negative.
304
+ Args:
305
+ quaternions: Quaternions with real part first,
306
+ as tensor of shape (..., 4).
307
+ Returns:
308
+ Standardized quaternions as tensor of shape (..., 4).
309
+ """
310
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
311
+
312
+
313
+ def quaternion_raw_multiply(a, b):
314
+ """
315
+ Multiply two quaternions.
316
+ Usual torch rules for broadcasting apply.
317
+ Args:
318
+ a: Quaternions as tensor of shape (..., 4), real part first.
319
+ b: Quaternions as tensor of shape (..., 4), real part first.
320
+ Returns:
321
+ The product of a and b, a tensor of quaternions shape (..., 4).
322
+ """
323
+ aw, ax, ay, az = torch.unbind(a, -1)
324
+ bw, bx, by, bz = torch.unbind(b, -1)
325
+ ow = aw * bw - ax * bx - ay * by - az * bz
326
+ ox = aw * bx + ax * bw + ay * bz - az * by
327
+ oy = aw * by - ax * bz + ay * bw + az * bx
328
+ oz = aw * bz + ax * by - ay * bx + az * bw
329
+ return torch.stack((ow, ox, oy, oz), -1)
330
+
331
+
332
+ def quaternion_multiply(a, b):
333
+ """
334
+ Multiply two quaternions representing rotations, returning the quaternion
335
+ representing their composition, i.e. the versor with nonnegative real part.
336
+ Usual torch rules for broadcasting apply.
337
+ Args:
338
+ a: Quaternions as tensor of shape (..., 4), real part first.
339
+ b: Quaternions as tensor of shape (..., 4), real part first.
340
+ Returns:
341
+ The product of a and b, a tensor of quaternions of shape (..., 4).
342
+ """
343
+ ab = quaternion_raw_multiply(a, b)
344
+ return standardize_quaternion(ab)
345
+
346
+
347
+ def quaternion_invert(quaternion):
348
+ """
349
+ Given a quaternion representing rotation, get the quaternion representing
350
+ its inverse.
351
+ Args:
352
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
353
+ first, which must be versors (unit quaternions).
354
+ Returns:
355
+ The inverse, a tensor of quaternions of shape (..., 4).
356
+ """
357
+
358
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
359
+
360
+
361
+ def quaternion_apply(quaternion, point):
362
+ """
363
+ Apply the rotation given by a quaternion to a 3D point.
364
+ Usual torch rules for broadcasting apply.
365
+ Args:
366
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
367
+ point: Tensor of 3D points of shape (..., 3).
368
+ Returns:
369
+ Tensor of rotated points of shape (..., 3).
370
+ """
371
+ if point.size(-1) != 3:
372
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
373
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
374
+ point_as_quaternion = torch.cat((real_parts, point), -1)
375
+ out = quaternion_raw_multiply(
376
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
377
+ quaternion_invert(quaternion),
378
+ )
379
+ return out[..., 1:]
380
+
381
+
382
+ def axis_angle_to_matrix(axis_angle):
383
+ """
384
+ Convert rotations given as axis/angle to rotation matrices.
385
+ Args:
386
+ axis_angle: Rotations given as a vector in axis angle form,
387
+ as a tensor of shape (..., 3), where the magnitude is
388
+ the angle turned anticlockwise in radians around the
389
+ vector's direction.
390
+ Returns:
391
+ Rotation matrices as tensor of shape (..., 3, 3).
392
+ """
393
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
394
+
395
+
396
+ def matrix_to_axis_angle(matrix):
397
+ """
398
+ Convert rotations given as rotation matrices to axis/angle.
399
+ Args:
400
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
401
+ Returns:
402
+ Rotations given as a vector in axis angle form, as a tensor
403
+ of shape (..., 3), where the magnitude is the angle
404
+ turned anticlockwise in radians around the vector's
405
+ direction.
406
+ """
407
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
408
+
409
+
410
+ def axis_angle_to_quaternion(axis_angle):
411
+ """
412
+ Convert rotations given as axis/angle to quaternions.
413
+ Args:
414
+ axis_angle: Rotations given as a vector in axis angle form,
415
+ as a tensor of shape (..., 3), where the magnitude is
416
+ the angle turned anticlockwise in radians around the
417
+ vector's direction.
418
+ Returns:
419
+ quaternions with real part first, as tensor of shape (..., 4).
420
+ """
421
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
422
+ half_angles = 0.5 * angles
423
+ eps = 1e-6
424
+ small_angles = angles.abs() < eps
425
+ sin_half_angles_over_angles = torch.empty_like(angles)
426
+ sin_half_angles_over_angles[~small_angles] = (
427
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
428
+ )
429
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
430
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
431
+ sin_half_angles_over_angles[small_angles] = (
432
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
433
+ )
434
+ quaternions = torch.cat(
435
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
436
+ )
437
+ return quaternions
438
+
439
+
440
+ def quaternion_to_axis_angle(quaternions):
441
+ """
442
+ Convert rotations given as quaternions to axis/angle.
443
+ Args:
444
+ quaternions: quaternions with real part first,
445
+ as tensor of shape (..., 4).
446
+ Returns:
447
+ Rotations given as a vector in axis angle form, as a tensor
448
+ of shape (..., 3), where the magnitude is the angle
449
+ turned anticlockwise in radians around the vector's
450
+ direction.
451
+ """
452
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
453
+ half_angles = torch.atan2(norms, quaternions[..., :1])
454
+ angles = 2 * half_angles
455
+ eps = 1e-6
456
+ small_angles = angles.abs() < eps
457
+ sin_half_angles_over_angles = torch.empty_like(angles)
458
+ sin_half_angles_over_angles[~small_angles] = (
459
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
460
+ )
461
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
462
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
463
+ sin_half_angles_over_angles[small_angles] = (
464
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
465
+ )
466
+ return quaternions[..., 1:] / sin_half_angles_over_angles
467
+
468
+
469
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
470
+ """
471
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
472
+ using Gram--Schmidt orthogonalisation per Section B of [1].
473
+ Args:
474
+ d6: 6D rotation representation, of size (*, 6)
475
+ Returns:
476
+ batch of rotation matrices of size (*, 3, 3)
477
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
478
+ On the Continuity of Rotation Representations in Neural Networks.
479
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
480
+ Retrieved from http://arxiv.org/abs/1812.07035
481
+ """
482
+
483
+ a1, a2 = d6[..., :3], d6[..., 3:]
484
+ b1 = F.normalize(a1, dim=-1)
485
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
486
+ b2 = F.normalize(b2, dim=-1)
487
+ b3 = torch.cross(b1, b2, dim=-1)
488
+ return torch.stack((b1, b2, b3), dim=-2)
489
+
490
+
491
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
492
+ """
493
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
494
+ by dropping the last row. Note that 6D representation is not unique.
495
+ Args:
496
+ matrix: batch of rotation matrices of size (*, 3, 3)
497
+ Returns:
498
+ 6D rotation representation, of size (*, 6)
499
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
500
+ On the Continuity of Rotation Representations in Neural Networks.
501
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
502
+ Retrieved from http://arxiv.org/abs/1812.07035
503
+ """
504
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
505
+
506
+ def canonicalize_smplh(poses, trans = None):
507
+ bs, nframes, njoints = poses.shape[:3]
508
+
509
+ global_orient = poses[:, :, 0]
510
+
511
+ # first global rotations
512
+ rot2d = matrix_to_axis_angle(global_orient[:, 0])
513
+ #rot2d[:, :2] = 0 # Remove the rotation along the vertical axis
514
+ rot2d = axis_angle_to_matrix(rot2d)
515
+
516
+ # Rotate the global rotation to eliminate Z rotations
517
+ global_orient = torch.einsum("ikj,imkl->imjl", rot2d, global_orient)
518
+
519
+ # Construct canonicalized version of x
520
+ xc = torch.cat((global_orient[:, :, None], poses[:, :, 1:]), dim=2)
521
+
522
+ if trans is not None:
523
+ vel = trans[:, 1:] - trans[:, :-1]
524
+ # Turn the translation as well
525
+ vel = torch.einsum("ikj,ilk->ilj", rot2d, vel)
526
+ trans = torch.cat((torch.zeros(bs, 1, 3, device=vel.device),
527
+ torch.cumsum(vel, 1)), 1)
528
+ return xc, trans
529
+ else:
530
+ return xc
531
+
532
+
utils/skeleton.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
59
+ across1 = joints[:, r_hip] - joints[:, l_hip]
60
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
61
+ across = across1 + across2
62
+ across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
63
+ # print(across1.shape, across2.shape)
64
+
65
+ # forward (batch_size, 3)
66
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
67
+ if smooth_forward:
68
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
69
+ # forward (batch_size, 3)
70
+ forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
71
+
72
+ '''Get Root Rotation'''
73
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
74
+ root_quat = qbetween_np(forward, target)
75
+
76
+ '''Inverse Kinematics'''
77
+ # quat_params (batch_size, joints_num, 4)
78
+ # print(joints.shape[:-1])
79
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
80
+ # print(quat_params.shape)
81
+ root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
82
+ quat_params[:, 0] = root_quat
83
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
84
+ for chain in self._kinematic_tree:
85
+ R = root_quat
86
+ for j in range(len(chain) - 1):
87
+ # (batch, 3)
88
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
89
+ # print(u.shape)
90
+ # (batch, 3)
91
+ v = joints[:, chain[j+1]] - joints[:, chain[j]]
92
+ v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
93
+ # print(u.shape, v.shape)
94
+ rot_u_v = qbetween_np(u, v)
95
+
96
+ R_loc = qmul_np(qinv_np(R), rot_u_v)
97
+
98
+ quat_params[:,chain[j + 1], :] = R_loc
99
+ R = qmul_np(R, R_loc)
100
+
101
+ return quat_params
102
+
103
+ # Be sure root joint is at the beginning of kinematic chains
104
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
105
+ # quat_params (batch_size, joints_num, 4)
106
+ # joints (batch_size, joints_num, 3)
107
+ # root_pos (batch_size, 3)
108
+ if skel_joints is not None:
109
+ offsets = self.get_offsets_joints_batch(skel_joints)
110
+ if len(self._offset.shape) == 2:
111
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
112
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
113
+ joints[:, 0] = root_pos
114
+ for chain in self._kinematic_tree:
115
+ if do_root_R:
116
+ R = quat_params[:, 0]
117
+ else:
118
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
119
+ for i in range(1, len(chain)):
120
+ R = qmul(R, quat_params[:, chain[i]])
121
+ offset_vec = offsets[:, chain[i]]
122
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
123
+ return joints
124
+
125
+ # Be sure root joint is at the beginning of kinematic chains
126
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
127
+ # quat_params (batch_size, joints_num, 4)
128
+ # joints (batch_size, joints_num, 3)
129
+ # root_pos (batch_size, 3)
130
+ if skel_joints is not None:
131
+ skel_joints = torch.from_numpy(skel_joints)
132
+ offsets = self.get_offsets_joints_batch(skel_joints)
133
+ if len(self._offset.shape) == 2:
134
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
135
+ offsets = offsets.numpy()
136
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
137
+ joints[:, 0] = root_pos
138
+ for chain in self._kinematic_tree:
139
+ if do_root_R:
140
+ R = quat_params[:, 0]
141
+ else:
142
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
143
+ for i in range(1, len(chain)):
144
+ R = qmul_np(R, quat_params[:, chain[i]])
145
+ offset_vec = offsets[:, chain[i]]
146
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
147
+ return joints
148
+
149
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
150
+ # cont6d_params (batch_size, joints_num, 6)
151
+ # joints (batch_size, joints_num, 3)
152
+ # root_pos (batch_size, 3)
153
+ if skel_joints is not None:
154
+ skel_joints = torch.from_numpy(skel_joints)
155
+ offsets = self.get_offsets_joints_batch(skel_joints)
156
+ if len(self._offset.shape) == 2:
157
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
158
+ offsets = offsets.numpy()
159
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
160
+ joints[:, 0] = root_pos
161
+ for chain in self._kinematic_tree:
162
+ if do_root_R:
163
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
164
+ else:
165
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
166
+ for i in range(1, len(chain)):
167
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
168
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
169
+ # print(matR.shape, offset_vec.shape)
170
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
171
+ return joints
172
+
173
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
174
+ # cont6d_params (batch_size, joints_num, 6)
175
+ # joints (batch_size, joints_num, 3)
176
+ # root_pos (batch_size, 3)
177
+ if skel_joints is not None:
178
+ # skel_joints = torch.from_numpy(skel_joints)
179
+ offsets = self.get_offsets_joints_batch(skel_joints)
180
+ if len(self._offset.shape) == 2:
181
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
182
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
183
+ joints[..., 0, :] = root_pos
184
+ for chain in self._kinematic_tree:
185
+ if do_root_R:
186
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
187
+ else:
188
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
189
+ for i in range(1, len(chain)):
190
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
191
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
192
+ # print(matR.shape, offset_vec.shape)
193
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
194
+ return joints
195
+
196
+
197
+
198
+
199
+
utils/utils_model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.optim as optim
4
+ import logging
5
+ import os
6
+ import sys
7
+
8
+ def getCi(accLog):
9
+
10
+ mean = np.mean(accLog)
11
+ std = np.std(accLog)
12
+ ci95 = 1.96*std/np.sqrt(len(accLog))
13
+
14
+ return mean, ci95
15
+
16
+ def get_logger(out_dir):
17
+ logger = logging.getLogger('Exp')
18
+ logger.setLevel(logging.INFO)
19
+ formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
20
+
21
+ file_path = os.path.join(out_dir, "run.log")
22
+ file_hdlr = logging.FileHandler(file_path)
23
+ file_hdlr.setFormatter(formatter)
24
+
25
+ strm_hdlr = logging.StreamHandler(sys.stdout)
26
+ strm_hdlr.setFormatter(formatter)
27
+
28
+ logger.addHandler(file_hdlr)
29
+ logger.addHandler(strm_hdlr)
30
+ return logger
31
+
32
+ ## Optimizer
33
+ def initial_optim(decay_option, lr, weight_decay, net, optimizer) :
34
+
35
+ if optimizer == 'adamw' :
36
+ optimizer_adam_family = optim.AdamW
37
+ elif optimizer == 'adam' :
38
+ optimizer_adam_family = optim.Adam
39
+ if decay_option == 'all':
40
+ #optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
41
+ optimizer = optimizer_adam_family(net.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=weight_decay)
42
+
43
+ elif decay_option == 'noVQ':
44
+ all_params = set(net.parameters())
45
+ no_decay = set([net.vq_layer])
46
+
47
+ decay = all_params - no_decay
48
+ optimizer = optimizer_adam_family([
49
+ {'params': list(no_decay), 'weight_decay': 0},
50
+ {'params': list(decay), 'weight_decay' : weight_decay}], lr=lr)
51
+
52
+ return optimizer
53
+
54
+
55
+ def get_motion_with_trans(motion, velocity) :
56
+ '''
57
+ motion : torch.tensor, shape (batch_size, T, 72), with the global translation = 0
58
+ velocity : torch.tensor, shape (batch_size, T, 3), contain the information of velocity = 0
59
+
60
+ '''
61
+ trans = torch.cumsum(velocity, dim=1)
62
+ trans = trans - trans[:, :1] ## the first root is initialized at 0 (just for visualization)
63
+ trans = trans.repeat((1, 1, 21))
64
+ motion_with_trans = motion + trans
65
+ return motion_with_trans
66
+
utils/word_vectorizer.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from os.path import join as pjoin
4
+
5
+ POS_enumerator = {
6
+ 'VERB': 0,
7
+ 'NOUN': 1,
8
+ 'DET': 2,
9
+ 'ADP': 3,
10
+ 'NUM': 4,
11
+ 'AUX': 5,
12
+ 'PRON': 6,
13
+ 'ADJ': 7,
14
+ 'ADV': 8,
15
+ 'Loc_VIP': 9,
16
+ 'Body_VIP': 10,
17
+ 'Obj_VIP': 11,
18
+ 'Act_VIP': 12,
19
+ 'Desc_VIP': 13,
20
+ 'OTHER': 14,
21
+ }
22
+
23
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
24
+ 'up', 'down', 'straight', 'curve')
25
+
26
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
27
+
28
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
29
+
30
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
31
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
32
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
33
+
34
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
35
+ 'angrily', 'sadly')
36
+
37
+ VIP_dict = {
38
+ 'Loc_VIP': Loc_list,
39
+ 'Body_VIP': Body_list,
40
+ 'Obj_VIP': Obj_List,
41
+ 'Act_VIP': Act_list,
42
+ 'Desc_VIP': Desc_list,
43
+ }
44
+
45
+
46
+ class WordVectorizer(object):
47
+ def __init__(self, meta_root, prefix):
48
+ vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
49
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
50
+ self.word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
51
+ self.word2vec = {w: vectors[self.word2idx[w]] for w in words}
52
+
53
+ def _get_pos_ohot(self, pos):
54
+ pos_vec = np.zeros(len(POS_enumerator))
55
+ if pos in POS_enumerator:
56
+ pos_vec[POS_enumerator[pos]] = 1
57
+ else:
58
+ pos_vec[POS_enumerator['OTHER']] = 1
59
+ return pos_vec
60
+
61
+ def __len__(self):
62
+ return len(self.word2vec)
63
+
64
+ def __getitem__(self, item):
65
+ word, pos = item.split('/')
66
+ if word in self.word2vec:
67
+ word_vec = self.word2vec[word]
68
+ vip_pos = None
69
+ for key, values in VIP_dict.items():
70
+ if word in values:
71
+ vip_pos = key
72
+ break
73
+ if vip_pos is not None:
74
+ pos_vec = self._get_pos_ohot(vip_pos)
75
+ else:
76
+ pos_vec = self._get_pos_ohot(pos)
77
+ else:
78
+ word_vec = self.word2vec['unk']
79
+ pos_vec = self._get_pos_ohot('OTHER')
80
+ return word_vec, pos_vec
81
+
82
+
83
+ class WordVectorizerV2(WordVectorizer):
84
+ def __init__(self, meta_root, prefix):
85
+ super(WordVectorizerV2, self).__init__(meta_root, prefix)
86
+ self.idx2word = {self.word2idx[w]: w for w in self.word2idx}
87
+
88
+ def __getitem__(self, item):
89
+ word_vec, pose_vec = super(WordVectorizerV2, self).__getitem__(item)
90
+ word, pos = item.split('/')
91
+ if word in self.word2vec:
92
+ return word_vec, pose_vec, self.word2idx[word]
93
+ else:
94
+ return word_vec, pose_vec, self.word2idx['unk']
95
+
96
+ def itos(self, idx):
97
+ if idx == len(self.idx2word):
98
+ return "pad"
99
+ return self.idx2word[idx]