Spaces:
Running
Running
File size: 3,451 Bytes
5f028d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import argparse
import os
from pathlib import Path
import imageio
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from cmib.data.lafan1_dataset import LAFAN1Dataset
from cmib.data.utils import write_json
from cmib.lafan1.utils import quat_ik
from cmib.model.network import TransformerModel
from cmib.model.preprocess import (lerp_input_repr, replace_constant,
slerp_input_repr, vectorize_representation)
from cmib.model.skeleton import (Skeleton, sk_joints_to_remove, sk_offsets, joint_names,
sk_parents)
from cmib.vis.pose import plot_pose_with_stop
def test(opt, device):
save_dir = Path(os.path.join('runs', 'train', opt.exp_name))
wdir = save_dir / 'weights'
weights = os.listdir(wdir)
weights_paths = [wdir / weight for weight in weights]
latest_weight = max(weights_paths , key = os.path.getctime)
ckpt = torch.load(latest_weight, map_location=device)
print(f"Loaded weight: {latest_weight}")
# Load Skeleton
skeleton_mocap = Skeleton(offsets=sk_offsets, parents=sk_parents, device=device)
skeleton_mocap.remove_joints(sk_joints_to_remove)
# Load LAFAN Dataset
Path(opt.processed_data_dir).mkdir(parents=True, exist_ok=True)
lafan_dataset = LAFAN1Dataset(lafan_path=opt.data_path, processed_data_dir=opt.processed_data_dir, train=False, device=device)
total_data = lafan_dataset.data['global_pos'].shape[0]
# Replace with noise to In-betweening Frames
from_idx, target_idx = ckpt['from_idx'], ckpt['target_idx'] # default: 9-40, max: 48
horizon = ckpt['horizon']
print(f"HORIZON: {horizon}")
test_idx = []
for i in range(total_data):
test_idx.append(i)
# Compare Input data, Prediction, GT
save_path = os.path.join(opt.save_path, 'sampler')
for i in range(len(test_idx)):
Path(save_path).mkdir(parents=True, exist_ok=True)
start_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
target_pose = lafan_dataset.data['global_pos'][test_idx[i], target_idx]
gt_stopover_pose = lafan_dataset.data['global_pos'][test_idx[i], from_idx]
gt_img_path = os.path.join(save_path)
plot_pose_with_stop(start_pose, target_pose, target_pose, gt_stopover_pose, i, skeleton_mocap, save_dir=gt_img_path, prefix='gt')
print(f"ID {test_idx[i]}: completed.")
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--project', default='runs/train', help='project/name')
parser.add_argument('--exp_name', default='slerp_40', help='experiment name')
parser.add_argument('--data_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH', help='BVH dataset path')
parser.add_argument('--skeleton_path', type=str, default='ubisoft-laforge-animation-dataset/output/BVH/walk1_subject1.bvh', help='path to reference skeleton')
parser.add_argument('--processed_data_dir', type=str, default='processed_data_original/', help='path to save pickled processed data')
parser.add_argument('--save_path', type=str, default='runs/test', help='path to save model')
parser.add_argument('--motion_type', type=str, default='jumps', help='motion type')
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_opt()
device = torch.device("cpu")
test(opt, device)
|