import os import random import logging import codecs as cs from os.path import join as pjoin import numpy as np from rich.progress import track import torch from torch.utils.data import Dataset from .scripts.motion_process import recover_from_ric from .utils.word_vectorizer import WordVectorizer logger = logging.getLogger(__name__) class MotionDataset(Dataset): def __init__(self, mean: np.ndarray, std: np.ndarray, split_file: str, motion_dir: str, window_size: int, tiny: bool = False, progress_bar: bool = True, **kwargs) -> None: self.data = [] self.lengths = [] id_list = [] with cs.open(split_file, "r") as f: for line in f.readlines(): id_list.append(line.strip()) maxdata = 10 if tiny else 1e10 if progress_bar: enumerator = enumerate( track( id_list, f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}", )) else: enumerator = enumerate(id_list) count = 0 for i, name in enumerator: if count > maxdata: break try: motion = np.load(pjoin(motion_dir, name + '.npy')) if motion.shape[0] < window_size: continue self.lengths.append(motion.shape[0] - window_size) self.data.append(motion) except Exception as e: print(e) pass self.cumsum = np.cumsum([0] + self.lengths) if not tiny: logger.info("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1])) self.mean = mean self.std = std self.window_size = window_size def __len__(self) -> int: return self.cumsum[-1] def __getitem__(self, item: int) -> tuple: if item != 0: motion_id = np.searchsorted(self.cumsum, item) - 1 idx = item - self.cumsum[motion_id] - 1 else: motion_id = 0 idx = 0 motion = self.data[motion_id][idx:idx + self.window_size] "Z Normalization" motion = (motion - self.mean) / self.std return motion, self.window_size class Text2MotionDataset(Dataset): def __init__( self, mean: np.ndarray, std: np.ndarray, split_file: str, w_vectorizer: WordVectorizer, max_motion_length: int, min_motion_length: int, max_text_len: int, unit_length: int, motion_dir: str, text_dir: str, fps: int, padding_to_max: bool, njoints: int, tiny: bool = False, progress_bar: bool = True, **kwargs, ) -> None: self.w_vectorizer = w_vectorizer self.max_motion_length = max_motion_length self.min_motion_length = min_motion_length self.max_text_len = max_text_len self.unit_length = unit_length self.padding_to_max = padding_to_max self.njoints = njoints data_dict = {} id_list = [] with cs.open(split_file, "r") as f: for line in f.readlines(): id_list.append(line.strip()) self.id_list = id_list maxdata = 10 if tiny else 1e10 if progress_bar: enumerator = enumerate( track( id_list, f"Loading HumanML3D {split_file.split('/')[-1].split('.')[0]}", )) else: enumerator = enumerate(id_list) count = 0 bad_count = 0 new_name_list = [] length_list = [] for i, name in enumerator: if count > maxdata: break try: motion = np.load(pjoin(motion_dir, name + ".npy")) if len(motion) < self.min_motion_length or len(motion) >= self.max_motion_length: bad_count += 1 continue text_data = [] flag = False with cs.open(pjoin(text_dir, name + ".txt")) as f: for line in f.readlines(): text_dict = {} line_split = line.strip().split("#") caption = line_split[0] tokens = line_split[1].split(" ") f_tag = float(line_split[2]) to_tag = float(line_split[3]) f_tag = 0.0 if np.isnan(f_tag) else f_tag to_tag = 0.0 if np.isnan(to_tag) else to_tag text_dict["caption"] = caption text_dict["tokens"] = tokens if f_tag == 0.0 and to_tag == 0.0: flag = True text_data.append(text_dict) else: try: n_motion = motion[int(f_tag * fps): int(to_tag * fps)] if (len(n_motion)) < self.min_motion_length or \ len(n_motion) >= self.max_motion_length: continue new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name while new_name in data_dict: new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name data_dict[new_name] = { "motion": n_motion, "length": len(n_motion), "text": [text_dict], } new_name_list.append(new_name) length_list.append(len(n_motion)) except ValueError: print(line_split) print(line_split[2], line_split[3], f_tag, to_tag, name) if flag: data_dict[name] = { "motion": motion, "length": len(motion), "text": text_data, } new_name_list.append(name) length_list.append(len(motion)) count += 1 except Exception as e: print(e) pass name_list, length_list = zip( *sorted(zip(new_name_list, length_list), key=lambda x: x[1])) if not tiny: logger.info(f"Reading {len(self.id_list)} motions from {split_file}.") logger.info(f"Total {len(name_list)} motions are used.") logger.info(f"{bad_count} motion sequences not within the length range of " f"[{self.min_motion_length}, {self.max_motion_length}) are filtered out.") self.mean = mean self.std = std control_args = kwargs['control_args'] self.control_mode = None if os.path.exists(control_args.MEAN_STD_PATH): self.raw_mean = np.load(pjoin(control_args.MEAN_STD_PATH, 'Mean_raw.npy')) self.raw_std = np.load(pjoin(control_args.MEAN_STD_PATH, 'Std_raw.npy')) else: self.raw_mean = self.raw_std = None if not tiny and control_args.CONTROL: self.t_ctrl = control_args.TEMPORAL self.training_control_joints = np.array(control_args.TRAIN_JOINTS) self.testing_control_joints = np.array(control_args.TEST_JOINTS) self.training_density = control_args.TRAIN_DENSITY self.testing_density = control_args.TEST_DENSITY self.control_mode = 'val' if ('test' in split_file or 'val' in split_file) else 'train' if self.control_mode == 'train': logger.info(f'Training Control Joints: {self.training_control_joints}') logger.info(f'Training Control Density: {self.training_density}') else: logger.info(f'Testing Control Joints: {self.testing_control_joints}') logger.info(f'Testing Control Density: {self.testing_density}') logger.info(f"Temporal Control: {self.t_ctrl}") self.data_dict = data_dict self.name_list = name_list def __len__(self) -> int: return len(self.name_list) def random_mask(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]: choose_joint = self.testing_control_joints length = joints.shape[0] density = self.testing_density if density in [1, 2, 5]: choose_seq_num = density else: choose_seq_num = int(length * density / 100) if self.t_ctrl: choose_seq = np.arange(0, choose_seq_num) else: choose_seq = np.random.choice(length, choose_seq_num, replace=False) choose_seq.sort() mask_seq = np.zeros((length, self.njoints, 3)) for cj in choose_joint: mask_seq[choose_seq, cj] = 1.0 joints = (joints - self.raw_mean) / self.raw_std joints = joints * mask_seq return joints, mask_seq def random_mask_train(self, joints: np.ndarray) -> tuple[np.ndarray, np.ndarray]: if self.t_ctrl: choose_joint = self.training_control_joints else: num_joints = len(self.training_control_joints) num_joints_control = 1 choose_joint = np.random.choice(num_joints, num_joints_control, replace=False) choose_joint = self.training_control_joints[choose_joint] length = joints.shape[0] if self.training_density == 'random': choose_seq_num = np.random.choice(length - 1, 1) + 1 else: choose_seq_num = int(length * random.uniform(self.training_density[0], self.training_density[1]) / 100) if self.t_ctrl: choose_seq = np.arange(0, choose_seq_num) else: choose_seq = np.random.choice(length, choose_seq_num, replace=False) choose_seq.sort() mask_seq = np.zeros((length, self.njoints, 3)) for cj in choose_joint: mask_seq[choose_seq, cj] = 1 joints = (joints - self.raw_mean) / self.raw_std joints = joints * mask_seq return joints, mask_seq def __getitem__(self, idx: int) -> tuple: data = self.data_dict[self.name_list[idx]] motion, m_length, text_list = data["motion"], data["length"], data["text"] # Randomly select a caption text_data = random.choice(text_list) caption, tokens = text_data["caption"], text_data["tokens"] if len(tokens) < self.max_text_len: # pad with "unk" tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] sent_len = len(tokens) tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len) else: # crop tokens = tokens[:self.max_text_len] tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] sent_len = len(tokens) pos_one_hots = [] word_embeddings = [] for token in tokens: word_emb, pos_oh = self.w_vectorizer[token] pos_one_hots.append(pos_oh[None, :]) word_embeddings.append(word_emb[None, :]) pos_one_hots = np.concatenate(pos_one_hots, axis=0) word_embeddings = np.concatenate(word_embeddings, axis=0) # Crop the motions in to times of 4, and introduce small variations if self.unit_length < 10: coin2 = np.random.choice(["single", "single", "double"]) else: coin2 = "single" if coin2 == "double": m_length = (m_length // self.unit_length - 1) * self.unit_length elif coin2 == "single": m_length = (m_length // self.unit_length) * self.unit_length idx = random.randint(0, len(motion) - m_length) motion = motion[idx:idx + m_length] hint, hint_mask = None, None if self.control_mode is not None: joints = recover_from_ric(torch.from_numpy(motion).float(), self.njoints) joints = joints.numpy() if self.control_mode == 'train': hint, hint_mask = self.random_mask_train(joints) else: hint, hint_mask = self.random_mask(joints) if self.padding_to_max: padding = np.zeros((self.max_motion_length - m_length, *hint.shape[1:])) hint = np.concatenate([hint, padding], axis=0) hint_mask = np.concatenate([hint_mask, padding], axis=0) "Z Normalization" motion = (motion - self.mean) / self.std if self.padding_to_max: padding = np.zeros((self.max_motion_length - m_length, motion.shape[1])) motion = np.concatenate([motion, padding], axis=0) return (word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, "_".join(tokens), (hint, hint_mask))