lybxin's picture
Upload folder using huggingface_hub
66b7c56 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
from typing import Dict, List
import numpy as np
import torch
import torchaudio
from data_loaders.data import Social
from data_loaders.tensors import social_collate
from torch.utils.data import DataLoader
from utils.misc import prGreen
def get_dataset_loader(
args,
data_dict: Dict[str, np.ndarray],
split: str = "train",
chunk: bool = False,
add_padding: bool = True,
) -> DataLoader:
dataset = Social(
args=args,
data_dict=data_dict,
split=split,
chunk=chunk,
add_padding=add_padding,
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=not split == "test",
num_workers=8,
drop_last=True,
collate_fn=social_collate,
pin_memory=True,
)
return loader
def _load_pose_data(
all_paths: List[str], audio_per_frame: int, flip_person: bool = False
) -> Dict[str, List]:
data = []
face = []
audio = []
lengths = []
missing = []
for _, curr_path_name in enumerate(all_paths):
if not curr_path_name.endswith("_body_pose.npy"):
continue
# load face information and deal with missing codes
curr_code = np.load(
curr_path_name.replace("_body_pose.npy", "_face_expression.npy")
).astype(float)
# curr_code = np.array(curr_face["codes"], dtype=float)
missing_list = np.load(
curr_path_name.replace("_body_pose.npy", "_missing_face_frames.npy")
)
if len(missing_list) == len(curr_code):
print("skipping", curr_path_name, curr_code.shape)
continue
curr_missing = np.ones_like(curr_code)
curr_missing[missing_list] = 0.0
# load pose information and deal with discontinuities
curr_pose = np.load(curr_path_name)
if "PXB184" in curr_path_name or "RLW104" in curr_path_name: # Capture 1 or 2
curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi)
curr_pose[:, 3] = (curr_pose[:, 3] + np.pi) % (2 * np.pi)
# load audio information
curr_audio, _ = torchaudio.load(
curr_path_name.replace("_body_pose.npy", "_audio.wav")
)
curr_audio = curr_audio.T
if flip_person:
prGreen("[get_data.py] flipping the dataset of left right person")
tmp = torch.zeros_like(curr_audio)
tmp[:, 1] = curr_audio[:, 0]
tmp[:, 0] = curr_audio[:, 1]
curr_audio = tmp
assert len(curr_pose) * audio_per_frame == len(
curr_audio
), f"motion {curr_pose.shape} vs audio {curr_audio.shape}"
data.append(curr_pose)
face.append(curr_code)
missing.append(curr_missing)
audio.append(curr_audio)
lengths.append(len(curr_pose))
data_dict = {
"data": data,
"face": face,
"audio": audio,
"lengths": lengths,
"missing": missing,
}
return data_dict
def load_local_data(
data_root: str, audio_per_frame: int, flip_person: bool = False
) -> Dict[str, List]:
if flip_person:
if "PXB184" in data_root:
data_root = data_root.replace("PXB184", "RLW104")
elif "RLW104" in data_root:
data_root = data_root.replace("RLW104", "PXB184")
elif "TXB805" in data_root:
data_root = data_root.replace("TXB805", "GQS883")
elif "GQS883" in data_root:
data_root = data_root.replace("GQS883", "TXB805")
all_paths = [os.path.join(data_root, x) for x in os.listdir(data_root)]
all_paths.sort()
return _load_pose_data(
all_paths,
audio_per_frame,
flip_person=flip_person,
)