Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from utils.word_vectorizer import WordVectorizer | |
| from torch.utils.data import Dataset, DataLoader | |
| from os.path import join as pjoin | |
| from tqdm import tqdm | |
| import numpy as np | |
| from eval.evaluator_modules import * | |
| from torch.utils.data._utils.collate import default_collate | |
| class GeneratedDataset(Dataset): | |
| """ | |
| opt.dataset_name | |
| opt.max_motion_length | |
| opt.unit_length | |
| """ | |
| def __init__( | |
| self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats | |
| ): | |
| assert mm_num_samples < len(dataset) | |
| self.dataset = dataset | |
| dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) | |
| generated_motion = [] | |
| min_mov_length = 10 if opt.dataset_name == "t2m" else 6 | |
| # Pre-process all target captions | |
| mm_generated_motions = [] | |
| if mm_num_samples > 0: | |
| mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) | |
| mm_idxs = np.sort(mm_idxs) | |
| all_caption = [] | |
| all_m_lens = [] | |
| all_data = [] | |
| with torch.no_grad(): | |
| for i, data in tqdm(enumerate(dataloader)): | |
| word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | |
| all_data.append(data) | |
| tokens = tokens[0].split("_") | |
| mm_num_now = len(mm_generated_motions) | |
| is_mm = ( | |
| True | |
| if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) | |
| else False | |
| ) | |
| repeat_times = mm_num_repeats if is_mm else 1 | |
| m_lens = max( | |
| torch.div(m_lens, opt.unit_length, rounding_mode="trunc") | |
| * opt.unit_length, | |
| min_mov_length * opt.unit_length, | |
| ) | |
| m_lens = min(m_lens, opt.max_motion_length) | |
| if isinstance(m_lens, int): | |
| m_lens = torch.LongTensor([m_lens]).to(opt.device) | |
| else: | |
| m_lens = m_lens.to(opt.device) | |
| for t in range(repeat_times): | |
| all_m_lens.append(m_lens) | |
| all_caption.extend(caption) | |
| if is_mm: | |
| mm_generated_motions.append(0) | |
| all_m_lens = torch.stack(all_m_lens) | |
| # Generate all sequences | |
| with torch.no_grad(): | |
| all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens) | |
| self.eval_generate_time = t_eval | |
| cur_idx = 0 | |
| mm_generated_motions = [] | |
| with torch.no_grad(): | |
| for i, data_dummy in tqdm(enumerate(dataloader)): | |
| data = all_data[i] | |
| word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | |
| tokens = tokens[0].split("_") | |
| mm_num_now = len(mm_generated_motions) | |
| is_mm = ( | |
| True | |
| if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) | |
| else False | |
| ) | |
| repeat_times = mm_num_repeats if is_mm else 1 | |
| mm_motions = [] | |
| for t in range(repeat_times): | |
| pred_motions = all_pred_motions[cur_idx] | |
| cur_idx += 1 | |
| if t == 0: | |
| sub_dict = { | |
| "motion": pred_motions.cpu().numpy(), | |
| "length": pred_motions.shape[0], # m_lens[0].item(), # | |
| "caption": caption[0], | |
| "cap_len": cap_lens[0].item(), | |
| "tokens": tokens, | |
| } | |
| generated_motion.append(sub_dict) | |
| if is_mm: | |
| mm_motions.append( | |
| { | |
| "motion": pred_motions.cpu().numpy(), | |
| "length": pred_motions.shape[ | |
| 0 | |
| ], # m_lens[0].item(), #m_lens[0].item() | |
| } | |
| ) | |
| if is_mm: | |
| mm_generated_motions.append( | |
| { | |
| "caption": caption[0], | |
| "tokens": tokens, | |
| "cap_len": cap_lens[0].item(), | |
| "mm_motions": mm_motions, | |
| } | |
| ) | |
| self.generated_motion = generated_motion | |
| self.mm_generated_motion = mm_generated_motions | |
| self.opt = opt | |
| self.w_vectorizer = w_vectorizer | |
| def __len__(self): | |
| return len(self.generated_motion) | |
| def __getitem__(self, item): | |
| data = self.generated_motion[item] | |
| motion, m_length, caption, tokens = ( | |
| data["motion"], | |
| data["length"], | |
| data["caption"], | |
| data["tokens"], | |
| ) | |
| sent_len = data["cap_len"] | |
| # This step is needed because T2M evaluators expect their norm convention | |
| normed_motion = motion | |
| denormed_motion = self.dataset.inv_transform(normed_motion) | |
| renormed_motion = ( | |
| denormed_motion - self.dataset.mean_for_eval | |
| ) / self.dataset.std_for_eval # according to T2M norms | |
| motion = renormed_motion | |
| pos_one_hots = [] | |
| word_embeddings = [] | |
| for token in tokens: | |
| word_emb, pos_oh = self.w_vectorizer[token] | |
| pos_one_hots.append(pos_oh[None, :]) | |
| word_embeddings.append(word_emb[None, :]) | |
| pos_one_hots = np.concatenate(pos_one_hots, axis=0) | |
| word_embeddings = np.concatenate(word_embeddings, axis=0) | |
| length = len(motion) | |
| if length < self.opt.max_motion_length: | |
| motion = np.concatenate( | |
| [ | |
| motion, | |
| np.zeros((self.opt.max_motion_length - length, motion.shape[1])), | |
| ], | |
| axis=0, | |
| ) | |
| return ( | |
| word_embeddings, | |
| pos_one_hots, | |
| caption, | |
| sent_len, | |
| motion, | |
| m_length, | |
| "_".join(tokens), | |
| ) | |
| def collate_fn(batch): | |
| batch.sort(key=lambda x: x[3], reverse=True) | |
| return default_collate(batch) | |
| class MMGeneratedDataset(Dataset): | |
| def __init__(self, opt, motion_dataset, w_vectorizer): | |
| self.opt = opt | |
| self.dataset = motion_dataset.mm_generated_motion | |
| self.w_vectorizer = w_vectorizer | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, item): | |
| data = self.dataset[item] | |
| mm_motions = data["mm_motions"] | |
| m_lens = [] | |
| motions = [] | |
| for mm_motion in mm_motions: | |
| m_lens.append(mm_motion["length"]) | |
| motion = mm_motion["motion"] | |
| if len(motion) < self.opt.max_motion_length: | |
| motion = np.concatenate( | |
| [ | |
| motion, | |
| np.zeros( | |
| (self.opt.max_motion_length - len(motion), motion.shape[1]) | |
| ), | |
| ], | |
| axis=0, | |
| ) | |
| motion = motion[None, :] | |
| motions.append(motion) | |
| m_lens = np.array(m_lens, dtype=np.int32) | |
| motions = np.concatenate(motions, axis=0) | |
| sort_indx = np.argsort(m_lens)[::-1].copy() | |
| m_lens = m_lens[sort_indx] | |
| motions = motions[sort_indx] | |
| return motions, m_lens | |
| def get_motion_loader( | |
| opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats | |
| ): | |
| # Currently the configurations of two datasets are almost the same | |
| if opt.dataset_name == "t2m" or opt.dataset_name == "kit": | |
| w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab") | |
| else: | |
| raise KeyError("Dataset not recognized!!") | |
| dataset = GeneratedDataset( | |
| opt, | |
| pipeline, | |
| ground_truth_dataset, | |
| w_vectorizer, | |
| mm_num_samples, | |
| mm_num_repeats, | |
| ) | |
| mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) | |
| motion_loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| collate_fn=collate_fn, | |
| drop_last=True, | |
| num_workers=4, | |
| ) | |
| mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) | |
| return motion_loader, mm_motion_loader, dataset.eval_generate_time | |