File size: 4,883 Bytes
02428a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import shutil
import argparse
import re
import json
import numpy as np
import cv2
import torch
from tqdm import tqdm

try:
    import mmpose  # noqa: F401
except Exception as e:
    print(e)
    print("mmpose error, installing transformer_utils")
    os.system("pip install ./main/transformer_utils")


def extract_frame_number(file_name):
    match = re.search(r"(\d{5})", file_name)
    if match:
        return int(match.group(1))
    return None


def merge_npz_files(npz_files, output_file):
    npz_files = sorted(npz_files, key=lambda x: extract_frame_number(os.path.basename(x)))
    merged_data = {}
    for file in npz_files:
        data = np.load(file)
        for key in data.files:
            if key not in merged_data:
                merged_data[key] = []
            merged_data[key].append(data[key])
    for key in merged_data:
        merged_data[key] = np.stack(merged_data[key], axis=0)
    np.savez(output_file, **merged_data)


def npz_to_npz(pkl_path, npz_path):
    # Load the pickle file
    pkl_example = np.load(pkl_path, allow_pickle=True)
    n = pkl_example["expression"].shape[0]  # Assuming this is the batch size
    full_pose = np.concatenate(
        [
            pkl_example["global_orient"],
            pkl_example["body_pose"],
            pkl_example["jaw_pose"],
            pkl_example["leye_pose"],
            pkl_example["reye_pose"],
            pkl_example["left_hand_pose"],
            pkl_example["right_hand_pose"],
        ],
        axis=1,
    )
    # print(full_pose.shape)
    np.savez(
        npz_path,
        betas=np.zeros(300),
        poses=full_pose.reshape(n, -1),
        expressions=np.zeros((n, 100)),
        trans=pkl_example["transl"].reshape(n, -1),
        model="smplx2020",
        gender="neutral",
        mocap_frame_rate=30,
    )


def get_json(root_dir, output_dir):
    clips = []
    dirs = os.listdir(root_dir)
    all_length = 0
    for dir in dirs:
        if not dir.endswith(".mp4"):
            continue
        video_id = dir[:-4]
        root = root_dir
        try:
            length = np.load(os.path.join(root, video_id + ".npz"), allow_pickle=True)["poses"].shape[0]
            all_length += length
        except Exception as e:
            print("cant open ", dir, e)
            continue
        clip = {
            "video_id": video_id,
            "video_path": root,
            # "audio_path": root,
            "motion_path": root,
            "mode": "test",
            "start_idx": 0,
            "end_idx": length,
        }
        clips.append(clip)
    if all_length < 1:
        print(f"skip due to total frames is less than 1500 for {root_dir}")
        return 0
    else:
        with open(output_dir, "w") as f:
            json.dump(clips, f, indent=4)
        return all_length


def infer(video_input, in_threshold, num_people, render_mesh, inferer, OUT_FOLDER):
    shutil.rmtree(f"{OUT_FOLDER}/smplx", ignore_errors=True)
    os.makedirs(f"{OUT_FOLDER}/smplx", exist_ok=True)
    multi_person = num_people
    cap = cv2.VideoCapture(video_input)
    video_name = os.path.basename(video_input)
    success = 1
    frame = 0
    while success:
        success, original_img = cap.read()
        if not success:
            break
        frame += 1
        _, _, _ = inferer.infer(original_img, in_threshold, frame, multi_person, not (render_mesh))
    cap.release()
    npz_files = [os.path.join(OUT_FOLDER, "smplx", x) for x in os.listdir(os.path.join(OUT_FOLDER, "smplx"))]

    merge_npz_files(npz_files, os.path.join(OUT_FOLDER, video_name.replace(".mp4", ".npz")))
    shutil.rmtree(f"{OUT_FOLDER}/smplx", ignore_errors=True)
    npz_to_npz(os.path.join(OUT_FOLDER, video_name.replace(".mp4", ".npz")), os.path.join(OUT_FOLDER, video_name.replace(".mp4", ".npz")))
    source = video_input
    destination = os.path.join(OUT_FOLDER, video_name.replace(".mp4", ".npz")).replace(".npz", ".mp4")
    shutil.copy(source, destination)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--video_folder_path", type=str, default="")
    parser.add_argument("--data_save_path", type=str, default="")
    parser.add_argument("--json_save_path", type=str, default="")
    args = parser.parse_args()
    video_folder = args.video_folder_path

    DEFAULT_MODEL = "smpler_x_s32"
    OUT_FOLDER = args.data_save_path
    os.makedirs(OUT_FOLDER, exist_ok=True)
    num_gpus = 1 if torch.cuda.is_available() else -1
    index = torch.cuda.current_device()
    from main.inference import Inferer

    inferer = Inferer(DEFAULT_MODEL, num_gpus, OUT_FOLDER)

    for video_input in tqdm(os.listdir(video_folder)):
        if not video_input.endswith(".mp4"):
            continue
        infer(os.path.join(video_folder, video_input), 0.5, False, False, inferer, OUT_FOLDER)
    get_json(OUT_FOLDER, args.json_save_path)