""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import os from typing import Dict, List import numpy as np import torch import torchaudio from data_loaders.data import Social from data_loaders.tensors import social_collate from torch.utils.data import DataLoader from utils.misc import prGreen def get_dataset_loader( args, data_dict: Dict[str, np.ndarray], split: str = "train", chunk: bool = False, add_padding: bool = True, ) -> DataLoader: dataset = Social( args=args, data_dict=data_dict, split=split, chunk=chunk, add_padding=add_padding, ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=not split == "test", num_workers=8, drop_last=True, collate_fn=social_collate, pin_memory=True, ) return loader def _load_pose_data( all_paths: List[str], audio_per_frame: int, flip_person: bool = False ) -> Dict[str, List]: data = [] face = [] audio = [] lengths = [] missing = [] for _, curr_path_name in enumerate(all_paths): if not curr_path_name.endswith("_body_pose.npy"): continue # load face information and deal with missing codes curr_code = np.load( curr_path_name.replace("_body_pose.npy", "_face_expression.npy") ).astype(float) # curr_code = np.array(curr_face["codes"], dtype=float) missing_list = np.load( curr_path_name.replace("_body_pose.npy", "_missing_face_frames.npy") ) if len(missing_list) == len(curr_code): print("skipping", curr_path_name, curr_code.shape) continue curr_missing = np.ones_like(curr_code) curr_missing[missing_list] = 0.0 # load pose information and deal with discontinuities curr_pose = np.load(curr_path_name) if "PXB184" in curr_path_name or "RLW104" in curr_path_name: # Capture 1 or 2 curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi) curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi) # load audio information curr_audio, _ = torchaudio.load( curr_path_name.replace("_body_pose.npy", "_audio.wav") ) curr_audio = curr_audio.T if flip_person: prGreen("[get_data.py] flipping the dataset of left right person") tmp = torch.zeros_like(curr_audio) tmp[:, 1] = curr_audio[:, 0] tmp[:, 0] = curr_audio[:, 1] curr_audio = tmp assert len(curr_pose) * audio_per_frame == len( curr_audio ), f"motion {curr_pose.shape} vs audio {curr_audio.shape}" data.append(curr_pose) face.append(curr_code) missing.append(curr_missing) audio.append(curr_audio) lengths.append(len(curr_pose)) data_dict = { "data": data, "face": face, "audio": audio, "lengths": lengths, "missing": missing, } return data_dict def load_local_data( data_root: str, audio_per_frame: int, flip_person: bool = False ) -> Dict[str, List]: if flip_person: if "PXB184" in data_root: data_root = data_root.replace("PXB184", "RLW104") elif "RLW104" in data_root: data_root = data_root.replace("RLW104", "PXB184") elif "TXB805" in data_root: data_root = data_root.replace("TXB805", "GQS883") elif "GQS883" in data_root: data_root = data_root.replace("GQS883", "TXB805") all_paths = [os.path.join(data_root, x) for x in os.listdir(data_root)] all_paths.sort() return _load_pose_data( all_paths, audio_per_frame, flip_person=flip_person, )