|
from typing import Optional |
|
from os.path import join as pjoin |
|
|
|
import numpy as np |
|
|
|
from omegaconf import DictConfig |
|
|
|
from .data import DataModule |
|
from .base import BaseDataModule |
|
from .utils import mld_collate, mld_collate_motion_only |
|
from .humanml.utils.word_vectorizer import WordVectorizer |
|
|
|
|
|
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]: |
|
name = "t2m" if dataset_name == "humanml3d" else dataset_name |
|
assert name in ["t2m", "kit"] |
|
if phase in ["val"]: |
|
if name == 't2m': |
|
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta") |
|
elif name == 'kit': |
|
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta") |
|
else: |
|
raise ValueError("Only support t2m and kit") |
|
mean = np.load(pjoin(data_root, "mean.npy")) |
|
std = np.load(pjoin(data_root, "std.npy")) |
|
else: |
|
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") |
|
mean = np.load(pjoin(data_root, "Mean.npy")) |
|
std = np.load(pjoin(data_root, "Std.npy")) |
|
|
|
return mean, std |
|
|
|
|
|
def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]: |
|
if dataset_name.lower() in ["humanml3d", "kit"]: |
|
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab") |
|
else: |
|
raise ValueError("Only support WordVectorizer for HumanML3D and KIT") |
|
|
|
|
|
dataset_module_map = {"humanml3d": DataModule, "kit": DataModule} |
|
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"} |
|
|
|
|
|
def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule: |
|
dataset_name = cfg.DATASET.NAME |
|
if dataset_name.lower() in ["humanml3d", "kit"]: |
|
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") |
|
mean, std = get_mean_std('train', cfg, dataset_name) |
|
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) |
|
wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name) |
|
collate_fn = mld_collate_motion_only if motion_only else mld_collate |
|
dataset = dataset_module_map[dataset_name.lower()]( |
|
name=dataset_name.lower(), |
|
cfg=cfg, |
|
motion_only=motion_only, |
|
collate_fn=collate_fn, |
|
mean=mean, |
|
std=std, |
|
mean_eval=mean_eval, |
|
std_eval=std_eval, |
|
w_vectorizer=wordVectorizer, |
|
text_dir=pjoin(data_root, "texts"), |
|
motion_dir=pjoin(data_root, motion_subdir[dataset_name]), |
|
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, |
|
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, |
|
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, |
|
unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), |
|
fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"), |
|
padding_to_max=cfg.DATASET.PADDING_TO_MAX, |
|
window_size=cfg.DATASET.WINDOW_SIZE, |
|
control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS")) |
|
|
|
cfg.DATASET.NFEATS = dataset.nfeats |
|
cfg.DATASET.NJOINTS = dataset.njoints |
|
return dataset |
|
|
|
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]: |
|
raise NotImplementedError |
|
|