MotionLCM / mld /data /get_data.py
wxDai's picture
[Init]
eb339cb
from typing import Optional
from os.path import join as pjoin
import numpy as np
from omegaconf import DictConfig
from .data import DataModule
from .base import BaseDataModule
from .utils import mld_collate, mld_collate_motion_only
from .humanml.utils.word_vectorizer import WordVectorizer
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
name = "t2m" if dataset_name == "humanml3d" else dataset_name
assert name in ["t2m", "kit"]
if phase in ["val"]:
if name == 't2m':
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
elif name == 'kit':
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
else:
raise ValueError("Only support t2m and kit")
mean = np.load(pjoin(data_root, "mean.npy"))
std = np.load(pjoin(data_root, "std.npy"))
else:
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
mean = np.load(pjoin(data_root, "Mean.npy"))
std = np.load(pjoin(data_root, "Std.npy"))
return mean, std
def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]:
if dataset_name.lower() in ["humanml3d", "kit"]:
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
else:
raise ValueError("Only support WordVectorizer for HumanML3D and KIT")
dataset_module_map = {"humanml3d": DataModule, "kit": DataModule}
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule:
dataset_name = cfg.DATASET.NAME
if dataset_name.lower() in ["humanml3d", "kit"]:
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
mean, std = get_mean_std('train', cfg, dataset_name)
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name)
collate_fn = mld_collate_motion_only if motion_only else mld_collate
dataset = dataset_module_map[dataset_name.lower()](
name=dataset_name.lower(),
cfg=cfg,
motion_only=motion_only,
collate_fn=collate_fn,
mean=mean,
std=std,
mean_eval=mean_eval,
std_eval=std_eval,
w_vectorizer=wordVectorizer,
text_dir=pjoin(data_root, "texts"),
motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"),
padding_to_max=cfg.DATASET.PADDING_TO_MAX,
window_size=cfg.DATASET.WINDOW_SIZE,
control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS"))
cfg.DATASET.NFEATS = dataset.nfeats
cfg.DATASET.NJOINTS = dataset.njoints
return dataset
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
raise NotImplementedError