|
import glob |
|
import json |
|
import os |
|
|
|
import pickle |
|
import random |
|
import re |
|
import subprocess |
|
from functools import partial |
|
|
|
import librosa.core |
|
import numpy as np |
|
import torch |
|
import torch.distributions |
|
import torch.distributed as dist |
|
import torch.optim |
|
import torch.utils.data |
|
|
|
from utils.commons.indexed_datasets import IndexedDataset |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
import torch.nn.functional as F |
|
import pandas as pd |
|
import tqdm |
|
import csv |
|
from utils.commons.hparams import hparams, set_hparams |
|
from utils.commons.meters import Timer |
|
from data_util.face3d_helper import Face3DHelper |
|
from utils.audio import librosa_wav2mfcc |
|
from utils.commons.dataset_utils import collate_xd |
|
from utils.commons.tensor_utils import convert_to_tensor |
|
|
|
face3d_helper = None |
|
|
|
def erosion_1d(arr): |
|
result = arr.copy() |
|
start_index = None |
|
continuous_length = 0 |
|
|
|
for i, num in enumerate(arr): |
|
if num == 1: |
|
if continuous_length == 0: |
|
start_index = i |
|
continuous_length += 1 |
|
else: |
|
if continuous_length > 0: |
|
|
|
for j in range(start_index, start_index + continuous_length): |
|
result[j] = 0 |
|
result[start_index + continuous_length // 2] = 1 |
|
continuous_length = 0 |
|
if continuous_length > 0: |
|
|
|
for j in range(start_index, start_index + continuous_length): |
|
result[j] = 0 |
|
|
|
return result |
|
|
|
def get_mouth_amp(ldm): |
|
""" |
|
ldm: [T, 68/468, 3] |
|
""" |
|
is_mediapipe = ldm.shape[1] != 68 |
|
is_torch = isinstance(ldm, torch.Tensor) |
|
if not is_torch: |
|
ldm = torch.FloatTensor(ldm) |
|
if is_mediapipe: |
|
assert ldm.shape[1] in [468, 478] |
|
mouth_d = (ldm[:, 0] - ldm[:, 17]).abs().sum(-1) |
|
else: |
|
mouth_d = (ldm[:, 51] - ldm[:, 57]).abs().sum(-1) |
|
|
|
mouth_amp = torch.quantile(mouth_d, 0.9, dim=0) |
|
return mouth_amp |
|
|
|
def get_eye_amp(ldm): |
|
""" |
|
ldm: [T, 68/468, 3] |
|
""" |
|
is_mediapipe = ldm.shape[1] != 68 |
|
is_torch = isinstance(ldm, torch.Tensor) |
|
if not is_torch: |
|
ldm = torch.FloatTensor(ldm) |
|
if is_mediapipe: |
|
assert ldm.shape[1] in [468, 478] |
|
eye_d = (ldm[:, 159] - ldm[:, 145]).abs().sum(-1) + (ldm[:, 386] - ldm[:, 374]).abs().sum(-1) |
|
else: |
|
eye_d = (ldm[:, 41] - ldm[:, 37]).abs().sum(-1) + (ldm[:, 40] - ldm[:, 38]).abs().sum(-1) + (ldm[:, 47] - ldm[:, 43]).abs().sum(-1) + (ldm[:, 46] - ldm[:, 44]).abs().sum(-1) |
|
|
|
eye_amp = torch.quantile(eye_d, 0.9, dim=0) |
|
return eye_amp |
|
|
|
def get_blink(ldm): |
|
""" |
|
ldm: [T, 68/468, 3] |
|
""" |
|
is_mediapipe = ldm.shape[1] != 68 |
|
is_torch = isinstance(ldm, torch.Tensor) |
|
if not is_torch: |
|
ldm = torch.FloatTensor(ldm) |
|
if is_mediapipe: |
|
assert ldm.shape[1] in [468, 478] |
|
eye_d = (ldm[:, 159] - ldm[:, 145]).abs().sum(-1) + (ldm[:, 386] - ldm[:, 374]).abs().sum(-1) |
|
else: |
|
eye_d = (ldm[:, 41] - ldm[:, 37]).abs().sum(-1) + (ldm[:, 40] - ldm[:, 38]).abs().sum(-1) + (ldm[:, 47] - ldm[:, 43]).abs().sum(-1) + (ldm[:, 46] - ldm[:, 44]).abs().sum(-1) |
|
|
|
eye_d_qtl = torch.quantile(eye_d, 0.75, dim=0) |
|
blink = eye_d / eye_d_qtl |
|
blink = (blink < 0.85).long().numpy() |
|
blink = erosion_1d(blink) |
|
if is_torch: |
|
blink = torch.LongTensor(blink) |
|
return blink |
|
|
|
|
|
class Audio2Motion_Dataset(Dataset): |
|
def __init__(self, prefix='train', data_dir=None): |
|
self.hparams = hparams |
|
self.db_key = prefix |
|
self.ds_path = self.hparams['binary_data_dir'] if data_dir is None else data_dir |
|
self.ds = None |
|
self.sizes = None |
|
self.x_maxframes = 200 |
|
self.x_multiply = 8 |
|
self.hparams = hparams |
|
|
|
def __len__(self): |
|
ds = self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
|
return len(ds) |
|
|
|
def _get_item(self, index): |
|
""" |
|
This func is necessary to open files in multi-threads! |
|
""" |
|
if self.ds is None: |
|
self.ds = IndexedDataset(f'{self.ds_path}/{self.db_key}') |
|
return self.ds[index] |
|
|
|
def __getitem__(self, idx): |
|
raw_item = self._get_item(idx) |
|
if raw_item is None: |
|
print("loading from binary data failed!") |
|
return None |
|
item = { |
|
'idx': idx, |
|
'item_id': raw_item['img_dir'], |
|
'id': torch.from_numpy(raw_item['id']).float(), |
|
'exp': torch.from_numpy(raw_item['exp']).float(), |
|
} |
|
if item['id'].shape[0] == 1: |
|
item['id'] = item['id'].repeat([item['exp'].shape[0], 1]) |
|
item['hubert'] = torch.from_numpy(raw_item['hubert']).float() |
|
item['f0'] = torch.from_numpy(raw_item['f0']).float() |
|
|
|
global face3d_helper |
|
if face3d_helper is None: |
|
face3d_helper = Face3DHelper(use_gpu=False) |
|
cano_lm3d = face3d_helper.reconstruct_cano_lm3d(item['id'], item['exp']) |
|
item['blink_unit'] = get_blink(cano_lm3d) |
|
item['eye_amp'] = get_eye_amp(cano_lm3d) |
|
item['mouth_amp'] = get_mouth_amp(cano_lm3d) |
|
|
|
x_len = len(item['hubert']) |
|
x_len = x_len // self.x_multiply * self.x_multiply |
|
y_len = x_len // 2 |
|
item['hubert'] = item['hubert'][:x_len] |
|
item['f0'] = item['f0'][:x_len] |
|
|
|
item['id'] = item['id'][:y_len] |
|
item['exp'] = item['exp'][:y_len] |
|
item['euler'] = convert_to_tensor(raw_item['euler'][:y_len]) |
|
item['trans'] = convert_to_tensor(raw_item['trans'][:y_len]) |
|
item['blink_unit'] = item['blink_unit'][:y_len].reshape([-1,1]) |
|
item['eye_amp'] = item['eye_amp'].reshape([1,]) |
|
item['mouth_amp'] = item['mouth_amp'].reshape([1,]) |
|
return item |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
sizes_fname = os.path.join(self.ds_path, f"sizes_{self.db_key}.npy") |
|
if os.path.exists(sizes_fname): |
|
sizes = np.load(sizes_fname, allow_pickle=True) |
|
self.sizes = sizes |
|
if self.sizes is None: |
|
self.sizes = [] |
|
print("Counting the size of each item in dataset...") |
|
ds = IndexedDataset(f"{self.ds_path}/{self.db_key}") |
|
for i_sample in tqdm.trange(len(ds)): |
|
sample = ds[i_sample] |
|
if sample is None: |
|
size = 0 |
|
else: |
|
x = sample['mel'] |
|
size = x.shape[-1] |
|
self.sizes.append(size) |
|
np.save(sizes_fname, self.sizes) |
|
indices = np.arange(len(self)) |
|
indices = indices[np.argsort(np.array(self.sizes)[indices], kind='mergesort')] |
|
return indices |
|
|
|
def batch_by_size(self, indices, max_tokens=None, max_sentences=None, |
|
required_batch_size_multiple=1): |
|
""" |
|
Yield mini-batches of indices bucketed by size. Batches may contain |
|
sequences of different lengths. |
|
|
|
Args: |
|
indices (List[int]): ordered list of dataset indices |
|
num_tokens_fn (callable): function that returns the number of tokens at |
|
a given index |
|
max_tokens (int, optional): max number of tokens in each batch |
|
(default: None). |
|
max_sentences (int, optional): max number of sentences in each |
|
batch (default: None). |
|
required_batch_size_multiple (int, optional): require batch size to |
|
be a multiple of N (default: 1). |
|
""" |
|
def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
if len(batch) == 0: |
|
return 0 |
|
if len(batch) == max_sentences: |
|
return 1 |
|
if num_tokens > max_tokens: |
|
return 1 |
|
return 0 |
|
|
|
num_tokens_fn = lambda x: self.sizes[x] |
|
max_tokens = max_tokens if max_tokens is not None else 60000 |
|
max_sentences = max_sentences if max_sentences is not None else 512 |
|
bsz_mult = required_batch_size_multiple |
|
|
|
sample_len = 0 |
|
sample_lens = [] |
|
batch = [] |
|
batches = [] |
|
for i in range(len(indices)): |
|
idx = indices[i] |
|
num_tokens = num_tokens_fn(idx) |
|
sample_lens.append(num_tokens) |
|
sample_len = max(sample_len, num_tokens) |
|
|
|
assert sample_len <= max_tokens, ( |
|
"sentence at index {} of size {} exceeds max_tokens " |
|
"limit of {}!".format(idx, sample_len, max_tokens) |
|
) |
|
num_tokens = (len(batch) + 1) * sample_len |
|
|
|
if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): |
|
mod_len = max( |
|
bsz_mult * (len(batch) // bsz_mult), |
|
len(batch) % bsz_mult, |
|
) |
|
batches.append(batch[:mod_len]) |
|
batch = batch[mod_len:] |
|
sample_lens = sample_lens[mod_len:] |
|
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 |
|
batch.append(idx) |
|
if len(batch) > 0: |
|
batches.append(batch) |
|
return batches |
|
|
|
|
|
def get_dataloader(self, batch_size=1, num_workers=0): |
|
batches_idx = self.batch_by_size(self.ordered_indices(), max_tokens=hparams['max_tokens_per_batch'], max_sentences=hparams['max_sentences_per_batch']) |
|
batches_idx = batches_idx * 50 |
|
random.shuffle(batches_idx) |
|
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_size=batch_size, num_workers=num_workers) |
|
loader = DataLoader(self, pin_memory=True,collate_fn=self.collater, batch_sampler=batches_idx, num_workers=num_workers) |
|
return loader |
|
|
|
def collater(self, samples): |
|
hparams = self.hparams |
|
if len(samples) == 0: |
|
return {} |
|
|
|
batch = {} |
|
item_names = [s['item_id'] for s in samples] |
|
x_len = max(s['hubert'].size(0) for s in samples) |
|
assert x_len % self.x_multiply == 0 |
|
y_len = x_len // 2 |
|
|
|
batch['hubert'] = collate_xd([s["hubert"] for s in samples], max_len=x_len, pad_idx=0) |
|
batch['x_mask'] = (batch['hubert'].abs().sum(dim=-1) > 0).float() |
|
batch['f0'] = collate_xd([s["f0"].reshape([-1,1]) for s in samples], max_len=x_len, pad_idx=0).squeeze(-1) |
|
|
|
batch.update({ |
|
'item_id': item_names, |
|
}) |
|
|
|
batch['id'] = collate_xd([s["id"] for s in samples], max_len=y_len, pad_idx=0) |
|
batch['exp'] = collate_xd([s["exp"] for s in samples], max_len=y_len, pad_idx=0) |
|
batch['euler'] = collate_xd([s["euler"] for s in samples], max_len=y_len, pad_idx=0) |
|
batch['trans'] = collate_xd([s["trans"] for s in samples], max_len=y_len, pad_idx=0) |
|
batch['blink_unit'] = collate_xd([s["blink_unit"] for s in samples], max_len=y_len, pad_idx=0) |
|
batch['eye_amp'] = collate_xd([s["eye_amp"] for s in samples], max_len=1, pad_idx=0) |
|
batch['mouth_amp'] = collate_xd([s["mouth_amp"] for s in samples], max_len=1, pad_idx=0) |
|
batch['y_mask'] = (batch['id'].abs().sum(dim=-1) > 0).float() |
|
return batch |
|
|
|
|
|
if __name__ == '__main__': |
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
set_hparams('egs/os_avatar/audio2secc_vae.yaml') |
|
ds = Audio2Motion_Dataset("train", 'data/binary/th1kh') |
|
dl = ds.get_dataloader() |
|
for b in tqdm.tqdm(dl): |
|
pass |
|
|