File size: 3,286 Bytes
eb339cb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
from typing import Optional
from os.path import join as pjoin
import numpy as np
from omegaconf import DictConfig
from .data import DataModule
from .base import BaseDataModule
from .utils import mld_collate, mld_collate_motion_only
from .humanml.utils.word_vectorizer import WordVectorizer
def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]:
name = "t2m" if dataset_name == "humanml3d" else dataset_name
assert name in ["t2m", "kit"]
if phase in ["val"]:
if name == 't2m':
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta")
elif name == 'kit':
data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta")
else:
raise ValueError("Only support t2m and kit")
mean = np.load(pjoin(data_root, "mean.npy"))
std = np.load(pjoin(data_root, "std.npy"))
else:
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
mean = np.load(pjoin(data_root, "Mean.npy"))
std = np.load(pjoin(data_root, "Std.npy"))
return mean, std
def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]:
if dataset_name.lower() in ["humanml3d", "kit"]:
return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab")
else:
raise ValueError("Only support WordVectorizer for HumanML3D and KIT")
dataset_module_map = {"humanml3d": DataModule, "kit": DataModule}
motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"}
def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule:
dataset_name = cfg.DATASET.NAME
if dataset_name.lower() in ["humanml3d", "kit"]:
data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT")
mean, std = get_mean_std('train', cfg, dataset_name)
mean_eval, std_eval = get_mean_std("val", cfg, dataset_name)
wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name)
collate_fn = mld_collate_motion_only if motion_only else mld_collate
dataset = dataset_module_map[dataset_name.lower()](
name=dataset_name.lower(),
cfg=cfg,
motion_only=motion_only,
collate_fn=collate_fn,
mean=mean,
std=std,
mean_eval=mean_eval,
std_eval=std_eval,
w_vectorizer=wordVectorizer,
text_dir=pjoin(data_root, "texts"),
motion_dir=pjoin(data_root, motion_subdir[dataset_name]),
max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN,
min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN,
max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN,
unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"),
fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"),
padding_to_max=cfg.DATASET.PADDING_TO_MAX,
window_size=cfg.DATASET.WINDOW_SIZE,
control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS"))
cfg.DATASET.NFEATS = dataset.nfeats
cfg.DATASET.NJOINTS = dataset.njoints
return dataset
elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]:
raise NotImplementedError
|