|
import copy |
|
from typing import Callable, Optional |
|
|
|
import numpy as np |
|
from omegaconf import DictConfig |
|
|
|
import torch |
|
|
|
from .base import BaseDataModule |
|
from .humanml.dataset import Text2MotionDataset, MotionDataset |
|
from .humanml.scripts.motion_process import recover_from_ric |
|
|
|
|
|
|
|
dataset_map = {'humanml3d': (263, 22), 'kit': (251, 21)} |
|
|
|
|
|
class DataModule(BaseDataModule): |
|
|
|
def __init__(self, |
|
name: str, |
|
cfg: DictConfig, |
|
motion_only: bool, |
|
collate_fn: Optional[Callable] = None, |
|
**kwargs) -> None: |
|
super().__init__(collate_fn=collate_fn) |
|
self.cfg = cfg |
|
self.name = name |
|
self.nfeats, self.njoints = dataset_map[name] |
|
self.hparams = copy.deepcopy({**kwargs, 'njoints': self.njoints}) |
|
self.Dataset = MotionDataset if motion_only else Text2MotionDataset |
|
sample_overrides = {"tiny": True, "progress_bar": False} |
|
self._sample_set = self.get_sample_set(overrides=sample_overrides) |
|
|
|
def denorm_spatial(self, hint: torch.Tensor) -> torch.Tensor: |
|
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) |
|
raw_std = torch.tensor(self._sample_set.raw_std).to(hint) |
|
hint = hint * raw_std + raw_mean |
|
return hint |
|
|
|
def norm_spatial(self, hint: torch.Tensor) -> torch.Tensor: |
|
raw_mean = torch.tensor(self._sample_set.raw_mean).to(hint) |
|
raw_std = torch.tensor(self._sample_set.raw_std).to(hint) |
|
hint = (hint - raw_mean) / raw_std |
|
return hint |
|
|
|
def feats2joints(self, features: torch.Tensor) -> torch.Tensor: |
|
mean = torch.tensor(self.hparams['mean']).to(features) |
|
std = torch.tensor(self.hparams['std']).to(features) |
|
features = features * std + mean |
|
return recover_from_ric(features, self.njoints) |
|
|
|
def renorm4t2m(self, features: torch.Tensor) -> torch.Tensor: |
|
|
|
ori_mean = torch.tensor(self.hparams['mean']).to(features) |
|
ori_std = torch.tensor(self.hparams['std']).to(features) |
|
eval_mean = torch.tensor(self.hparams['mean_eval']).to(features) |
|
eval_std = torch.tensor(self.hparams['std_eval']).to(features) |
|
features = features * ori_std + ori_mean |
|
features = (features - eval_mean) / eval_std |
|
return features |
|
|
|
def mm_mode(self, mm_on: bool = True) -> None: |
|
if mm_on: |
|
self.is_mm = True |
|
self.name_list = self.test_dataset.name_list |
|
self.mm_list = np.random.choice(self.name_list, |
|
self.cfg.TEST.MM_NUM_SAMPLES, |
|
replace=False) |
|
self.test_dataset.name_list = self.mm_list |
|
else: |
|
self.is_mm = False |
|
self.test_dataset.name_list = self.name_list |
|
|