import glob import json import os import pickle import random import re import subprocess from functools import partial import librosa.core import numpy as np import torch import torch.distributions import torch.distributed as dist import torch.optim import torch.utils.data from utils.commons.indexed_datasets import IndexedDataset from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F import pandas as pd import tqdm import csv from utils.commons.hparams import hparams, set_hparams from utils.commons.meters import Timer from data_util.face3d_helper import Face3DHelper from utils.audio import librosa_wav2mfcc from utils.commons.dataset_utils import collate_xd from utils.commons.tensor_utils import convert_to_tensor face3d_helper = None def erosion_1d(arr): result = arr.copy() start_index = None continuous_length = 0 for i, num in enumerate(arr): if num == 1: if continuous_length == 0: start_index = i continuous_length += 1 else: if continuous_length > 0: # Replace middle 1s with 0s, keep first and last 1 for j in range(start_index, start_index + continuous_length): result[j] = 0 result[start_index + continuous_length // 2] = 1 continuous_length = 0 if continuous_length > 0: # Replace middle 1s with 0s, keep first and last 1 for j in range(start_index, start_index + continuous_length): result[j] = 0 # result[start_index + continuous_length // 2] = 1 return result def get_mouth_amp(ldm): """ ldm: [T, 68/468, 3] """ is_mediapipe = ldm.shape[1] != 68 is_torch = isinstance(ldm, torch.Tensor) if not is_torch: ldm = torch.FloatTensor(ldm) if is_mediapipe: assert ldm.shape[1] in [468, 478] mouth_d = (ldm[:, 0] - ldm[:, 17]).abs().sum(-1) else: mouth_d = (ldm[:, 51] - ldm[:, 57]).abs().sum(-1) mouth_amp = torch.quantile(mouth_d, 0.9, dim=0) return mouth_amp def get_eye_amp(ldm): """ ldm: [T, 68/468, 3] """ is_mediapipe = ldm.shape[1] != 68 is_torch = isinstance(ldm, torch.Tensor) if not is_torch: ldm = torch.FloatTensor(ldm) if is_mediapipe: assert ldm.shape[1] in [468, 478] eye_d = (ldm[:, 159] - ldm[:, 145]).abs().sum(-1) + (ldm[:, 386] - ldm[:, 374]).abs().sum(-1) else: eye_d = (ldm[:, 41] - ldm[:, 37]).abs().sum(-1) + (ldm[:, 40] - ldm[:, 38]).abs().sum(-1) + (ldm[:, 47] - ldm[:, 43]).abs().sum(-1) + (ldm[:, 46] - ldm[:, 44]).abs().sum(-1) eye_amp = torch.quantile(eye_d, 0.9, dim=0) return eye_amp def get_blink(ldm): """ ldm: [T, 68/468, 3] """ is_mediapipe = ldm.shape[1] != 68 is_torch = isinstance(ldm, torch.Tensor) if not is_torch: ldm = torch.FloatTensor(ldm) if is_mediapipe: assert ldm.shape[1] in [468, 478] eye_d = (ldm[:, 159] - ldm[:, 145]).abs().sum(-1) + (ldm[:, 386] - ldm[:, 374]).abs().sum(-1) else: eye_d = (ldm[:, 41] - ldm[:, 37]).abs().sum(-1) + (ldm[:, 40] - ldm[:, 38]).abs().sum(-1) + (ldm[:, 47] - ldm[:, 43]).abs().sum(-1) + (ldm[:, 46] - ldm[:, 44]).abs().sum(-1) eye_d_qtl = torch.quantile(eye_d, 0.75, dim=0) blink = eye_d / eye_d_qtl blink = (blink < 0.85).long().numpy() blink = erosion_1d(blink) if is_torch: blink = torch.LongTensor(blink) return blink class Audio2Motion_Dataset(Dataset): def __init__(self, prefix='train', data_dir=None): self.hparams = hparams self.db_key = prefix self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir self.ds = None self.sizes = None self.x_maxframes = 200 # 50 video frames self.x_multiply = 8 self.hparams = hparams def __len__(self): ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') return len(ds) def _get_item(self, index): """ This func is necessary to open files in multi-threads! """ if self.ds is None: self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') return self.ds[index] def __getitem__(self, idx): raw_item = self._get_item(idx) if raw_item is None: print("loading from binary data failed!") return None item = { 'idx': idx, 'item_id': raw_item['img_dir'], 'id': torch.from_numpy(raw_item['id']).float(), # [T_x, c=80] 'exp': torch.from_numpy(raw_item['exp']).float(), # [T_x, c=80] } if item['id'].shape[0] == 1: # global_id item['id'] = item['id'].repeat([item['exp'].shape[0], 1]) item['hubert'] = torch.from_numpy(raw_item['hubert']).float() # [T_x, 1024] item['f0'] = torch.from_numpy(raw_item['f0']).float() # [T_x,] global face3d_helper if face3d_helper is None: face3d_helper = Face3DHelper(use_gpu=False) cano_lm3d = face3d_helper.reconstruct_cano_lm3d(item['id'], item['exp']) item['blink_unit'] = get_blink(cano_lm3d) item['eye_amp'] = get_eye_amp(cano_lm3d) item['mouth_amp'] = get_mouth_amp(cano_lm3d) x_len = len(item['hubert']) x_len = x_len // self.x_multiply * self.x_multiply # make it divisible by our CNN y_len = x_len // 2 # video is 25fps item['hubert'] = item['hubert'][:x_len] # [T_x, c=80] item['f0'] = item['f0'][:x_len] item['id'] = item['id'][:y_len] item['exp'] = item['exp'][:y_len] item['euler'] = convert_to_tensor(raw_item['euler'][:y_len]) item['trans'] = convert_to_tensor(raw_item['trans'][:y_len]) item['blink_unit'] = item['blink_unit'][:y_len].reshape([-1,1]) item['eye_amp'] = item['eye_amp'].reshape([1,]) item['mouth_amp'] = item['mouth_amp'].reshape([1,]) return item def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" sizes_fname = os.path.join(self.ds_path, f"sizes_{self.db_key}.npy") if os.path.exists(sizes_fname): sizes = np.load(sizes_fname, allow_pickle=True) self.sizes = sizes if self.sizes is None: self.sizes = [] print("Counting the size of each item in dataset...") ds = IndexedDataset(f"{self.ds_path}/{self.db_key}") for i_sample in tqdm.trange(len(ds)): sample = ds[i_sample] if sample is None: size = 0 else: x = sample['mel'] size = x.shape[-1] # time step in audio self.sizes.append(size) np.save(sizes_fname, self.sizes) indices = np.arange(len(self)) indices = indices[np.argsort(np.array(self.sizes)[indices], kind='mergesort')] return indices def batch_by_size(self, indices, max_tokens=None, max_sentences=None, required_batch_size_multiple=1): """ Yield mini-batches of indices bucketed by size. Batches may contain sequences of different lengths. Args: indices (List[int]): ordered list of dataset indices num_tokens_fn (callable): function that returns the number of tokens at a given index max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). """ def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): if len(batch) == 0: return 0 if len(batch) == max_sentences: return 1 if num_tokens > max_tokens: return 1 return 0 num_tokens_fn = lambda x: self.sizes[x] max_tokens = max_tokens if max_tokens is not None else 60000 max_sentences = max_sentences if max_sentences is not None else 512 bsz_mult = required_batch_size_multiple sample_len = 0 sample_lens = [] batch = [] batches = [] for i in range(len(indices)): idx = indices[i] num_tokens = num_tokens_fn(idx) sample_lens.append(num_tokens) sample_len = max(sample_len, num_tokens) assert sample_len <= max_tokens, ( "sentence at index {} of size {} exceeds max_tokens " "limit of {}!".format(idx, sample_len, max_tokens) ) num_tokens = (len(batch) + 1) * sample_len if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): mod_len = max( bsz_mult * (len(batch) // bsz_mult), len(batch) % bsz_mult, ) batches.append(batch[:mod_len]) batch = batch[mod_len:] sample_lens = sample_lens[mod_len:] sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 batch.append(idx) if len(batch) > 0: batches.append(batch) return batches def get_dataloader(self, batch_size=1, num_workers=0): batches_idx = self.batch_by_size(self.ordered_indices(), max_tokens=hparams['max_tokens_per_batch'], max_sentences=hparams['max_sentences_per_batch']) batches_idx = batches_idx * 50 random.shuffle(batches_idx) loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_sampler=batches_idx, num_workers=num_workers) return loader def collater(self, samples): hparams = self.hparams if len(samples) == 0: return {} batch = {} item_names = [s['item_id'] for s in samples] x_len = max(s['hubert'].size(0) for s in samples) assert x_len % self.x_multiply == 0 y_len = x_len // 2 batch['hubert'] = collate_xd([s["hubert"] for s in samples], max_len=x_len, pad_idx=0) # [b, t_max_y, 64] batch['x_mask'] = (batch['hubert'].abs().sum(dim=-1) > 0).float() # [b, t_max_x] batch['f0'] = collate_xd([s["f0"].reshape([-1,1]) for s in samples], max_len=x_len, pad_idx=0).squeeze(-1) # [b, t_max_y] batch.update({ 'item_id': item_names, }) batch['id'] = collate_xd([s["id"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['exp'] = collate_xd([s["exp"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['euler'] = collate_xd([s["euler"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['trans'] = collate_xd([s["trans"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['blink_unit'] = collate_xd([s["blink_unit"] for s in samples], max_len=y_len, pad_idx=0) # [b, t_max, 1] batch['eye_amp'] = collate_xd([s["eye_amp"] for s in samples], max_len=1, pad_idx=0) # [b, t_max, 1] batch['mouth_amp'] = collate_xd([s["mouth_amp"] for s in samples], max_len=1, pad_idx=0) # [b, t_max, 1] batch['y_mask'] = (batch['id'].abs().sum(dim=-1) > 0).float() # [b, t_max_y] return batch if __name__ == '__main__': os.environ["OMP_NUM_THREADS"] = "1" set_hparams('egs/os_avatar/audio2secc_vae.yaml') ds = Audio2Motion_Dataset("train", 'data/binary/th1kh') dl = ds.get_dataloader() for b in tqdm.tqdm(dl): pass