Spaces:
Runtime error
Runtime error
""" | |
Copyright (c) Meta Platforms, Inc. and affiliates. | |
All rights reserved. | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import os | |
from typing import Dict, Iterable, List, Union | |
import numpy as np | |
import torch | |
from torch.utils import data | |
from utils.misc import prGreen | |
class Social(data.Dataset): | |
def __init__( | |
self, | |
args, | |
data_dict: Dict[str, Iterable], | |
split: str = "train", | |
chunk: bool = False, | |
add_padding: bool = True, | |
) -> None: | |
if args.data_format == "face": | |
prGreen("[dataset.py] training face only model") | |
data_dict["data"] = data_dict["face"] | |
elif args.data_format == "pose": | |
prGreen("[dataset.py] training pose only model") | |
missing = [] | |
for d in data_dict["data"]: | |
missing.append(np.ones_like(d)) | |
data_dict["missing"] = missing | |
# set up variables for dataloader | |
self.data_format = args.data_format | |
self.add_frame_cond = args.add_frame_cond | |
self._register_keyframe_step() | |
self.data_root = args.data_root | |
self.max_seq_length = args.max_seq_length | |
if hasattr(args, "curr_seq_length") and args.curr_seq_length is not None: | |
self.max_seq_length = args.curr_seq_length | |
prGreen([f"[dataset.py] sequences of {self.max_seq_length}"]) | |
self.add_padding = add_padding | |
self.audio_per_frame = 1600 | |
self.max_audio_length = self.max_seq_length * self.audio_per_frame | |
self.min_seq_length = 400 | |
# set up training/validation splits | |
train_idx = list(range(0, len(data_dict["data"]) - 6)) | |
val_idx = list(range(len(data_dict["data"]) - 6, len(data_dict["data"]) - 4)) | |
test_idx = list(range(len(data_dict["data"]) - 4, len(data_dict["data"]))) | |
self.split = split | |
if split == "train": | |
self._pick_sequences(data_dict, train_idx) | |
elif split == "val": | |
self._pick_sequences(data_dict, val_idx) | |
else: | |
self._pick_sequences(data_dict, test_idx) | |
self.chunk = chunk | |
if split == "test": | |
print("[dataset.py] chunking data...") | |
self._chunk_data() | |
self._load_std() | |
prGreen( | |
f"[dataset.py] {split} | {len(self.data)} sequences ({self.data[0].shape}) | total len {self.total_len}" | |
) | |
def inv_transform( | |
self, data: Union[np.ndarray, torch.Tensor], data_type: str | |
) -> Union[np.ndarray, torch.Tensor]: | |
if data_type == "pose": | |
std = self.std | |
mean = self.mean | |
elif data_type == "face": | |
std = self.face_std | |
mean = self.face_mean | |
elif data_type == "audio": | |
std = self.audio_std | |
mean = self.audio_mean | |
else: | |
assert False, f"datatype not defined: {data_type}" | |
if torch.is_tensor(data): | |
return data * torch.tensor( | |
std, device=data.device, requires_grad=False | |
) + torch.tensor(mean, device=data.device, requires_grad=False) | |
else: | |
return data * std + mean | |
def _pick_sequences(self, data_dict: Dict[str, Iterable], idx: List[int]) -> None: | |
self.data = np.take(data_dict["data"], idx, axis=0) | |
self.missing = np.take(data_dict["missing"], idx, axis=0) | |
self.audio = np.take(data_dict["audio"], idx, axis=0) | |
self.lengths = np.take(data_dict["lengths"], idx, axis=0) | |
self.total_len = sum([len(d) for d in self.data]) | |
def _load_std(self) -> None: | |
stats = torch.load(os.path.join(self.data_root, "data_stats.pth")) | |
print( | |
f'[dataset.py] loading from... {os.path.join(self.data_root, "data_stats.pth")}' | |
) | |
self.mean = stats["pose_mean"].reshape(-1) | |
self.std = stats["pose_std"].reshape(-1) | |
self.face_mean = stats["code_mean"] | |
self.face_std = stats["code_std"] | |
self.audio_mean = stats["audio_mean"] | |
self.audio_std = stats["audio_std_flat"] | |
def _chunk_data(self) -> None: | |
chunk_data = [] | |
chunk_missing = [] | |
chunk_lengths = [] | |
chunk_audio = [] | |
# create sequences of set lengths | |
for d_idx in range(len(self.data)): | |
curr_data = self.data[d_idx] | |
curr_missing = self.missing[d_idx] | |
curr_audio = self.audio[d_idx] | |
end_range = len(self.data[d_idx]) - self.max_seq_length | |
for chunk_idx in range(0, end_range, self.max_seq_length): | |
chunk_end = chunk_idx + self.max_seq_length | |
curr_data_chunk = curr_data[chunk_idx:chunk_end, :] | |
curr_missing_chunk = curr_missing[chunk_idx:chunk_end, :] | |
curr_audio_chunk = curr_audio[ | |
chunk_idx * self.audio_per_frame : chunk_end * self.audio_per_frame, | |
:, | |
] | |
if curr_data_chunk.shape[0] < self.max_seq_length: | |
# do not add a short chunk to the list | |
continue | |
chunk_lengths.append(curr_data_chunk.shape[0]) | |
chunk_data.append(curr_data_chunk) | |
chunk_missing.append(curr_missing_chunk) | |
chunk_audio.append(curr_audio_chunk) | |
idx = np.random.permutation(len(chunk_data)) | |
print("==> shuffle", idx) | |
self.data = np.take(chunk_data, idx, axis=0) | |
self.missing = np.take(chunk_missing, idx, axis=0) | |
self.lengths = np.take(chunk_lengths, idx, axis=0) | |
self.audio = np.take(chunk_audio, idx, axis=0) | |
self.total_len = len(self.data) | |
def _register_keyframe_step(self) -> None: | |
if self.add_frame_cond == 1: | |
self.step = 30 | |
if self.add_frame_cond is None: | |
self.step = 1 | |
def _pad_sequence( | |
self, sequence: np.ndarray, actual_length: int, max_length: int | |
) -> np.ndarray: | |
sequence = np.concatenate( | |
( | |
sequence, | |
np.zeros((max_length - actual_length, sequence.shape[-1])), | |
), | |
axis=0, | |
) | |
return sequence | |
def _get_idx(self, item: int) -> int: | |
cumulative_len = 0 | |
seq_idx = 0 | |
while item > cumulative_len: | |
cumulative_len += len(self.data[seq_idx]) | |
seq_idx += 1 | |
item = seq_idx - 1 | |
return item | |
def _get_random_subsection( | |
self, data_dict: Dict[str, Iterable] | |
) -> Dict[str, np.ndarray]: | |
isnonzero = False | |
while not isnonzero: | |
start = np.random.randint(0, data_dict["m_length"] - self.max_seq_length) | |
if self.add_padding: | |
length = ( | |
np.random.randint(self.min_seq_length, self.max_seq_length) | |
if not self.split == "test" | |
else self.max_seq_length | |
) | |
else: | |
length = self.max_seq_length | |
curr_missing = data_dict["missing"][start : start + length] | |
isnonzero = np.any(curr_missing) | |
missing = curr_missing | |
motion = data_dict["motion"][start : start + length, :] | |
keyframes = motion[:: self.step] | |
audio = data_dict["audio"][ | |
start * self.audio_per_frame : (start + length) * self.audio_per_frame, | |
:, | |
] | |
data_dict["m_length"] = len(motion) | |
data_dict["k_length"] = len(keyframes) | |
data_dict["a_length"] = len(audio) | |
if data_dict["m_length"] < self.max_seq_length: | |
motion = self._pad_sequence( | |
motion, data_dict["m_length"], self.max_seq_length | |
) | |
missing = self._pad_sequence( | |
missing, data_dict["m_length"], self.max_seq_length | |
) | |
audio = self._pad_sequence( | |
audio, data_dict["a_length"], self.max_audio_length | |
) | |
max_step_length = len(np.zeros(self.max_seq_length)[:: self.step]) | |
keyframes = self._pad_sequence( | |
keyframes, data_dict["k_length"], max_step_length | |
) | |
data_dict["motion"] = motion | |
data_dict["keyframes"] = keyframes | |
data_dict["audio"] = audio | |
data_dict["missing"] = missing | |
return data_dict | |
def __len__(self) -> int: | |
return self.total_len | |
def __getitem__(self, item: int) -> Dict[str, np.ndarray]: | |
# figure out which sequence to randomly sample from | |
if not self.split == "test": | |
item = self._get_idx(item) | |
motion = self.data[item] | |
audio = self.audio[item] | |
m_length = self.lengths[item] | |
missing = self.missing[item] | |
a_length = len(audio) | |
# Z Normalization | |
if self.data_format == "pose": | |
motion = (motion - self.mean) / self.std | |
elif self.data_format == "face": | |
motion = (motion - self.face_mean) / self.face_std | |
audio = (audio - self.audio_mean) / self.audio_std | |
keyframes = motion[:: self.step] | |
k_length = len(keyframes) | |
data_dict = { | |
"motion": motion, | |
"m_length": m_length, | |
"audio": audio, | |
"a_length": a_length, | |
"keyframes": keyframes, | |
"k_length": k_length, | |
"missing": missing, | |
} | |
if not self.split == "test" and not self.chunk: | |
data_dict = self._get_random_subsection(data_dict) | |
if self.data_format == "face": | |
data_dict["motion"] *= data_dict["missing"] | |
return data_dict | |