Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from argparse import ArgumentParser | |
from datetime import timedelta | |
from pathlib import Path | |
import pandas as pd | |
import tensordict as td | |
import torch | |
import torch.distributed as distributed | |
from torch.utils.data import DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from tqdm import tqdm | |
from mmaudio.data.data_setup import error_avoidance_collate | |
from mmaudio.data.extraction.vgg_sound import VGGSound | |
from mmaudio.model.utils.features_utils import FeaturesUtils | |
from mmaudio.utils.dist_utils import local_rank, world_size | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# for the 16kHz model | |
SAMPLING_RATE = 16000 | |
DURATION_SEC = 8.0 | |
NUM_SAMPLES = 128000 | |
vae_path = './ext_weights/v1-16.pth' | |
bigvgan_path = './ext_weights/best_netG.pt' | |
mode = '16k' | |
# for the 44.1kHz model | |
""" | |
NOTE: 352800 (8*44100) is not divisible by (STFT hop size * VAE downsampling ratio) which is 1024. | |
353280 is the next integer divisible by 1024. | |
""" | |
# SAMPLING_RATE = 44100 | |
# DURATION_SEC = 8.0 | |
# NUM_SAMPLES = 353280 | |
# vae_path = './ext_weights/v1-44.pth' | |
# bigvgan_path = None | |
# mode = '44k' | |
synchformer_ckpt = './ext_weights/synchformer_state_dict.pth' | |
# per-GPU | |
BATCH_SIZE = 16 | |
NUM_WORKERS = 16 | |
log = logging.getLogger() | |
log.setLevel(logging.INFO) | |
# uncomment the train/test/val sets to extract latents for them | |
data_cfg = { | |
'example': { | |
'root': './training/example_videos', | |
'subset_name': './training/example_video.tsv', | |
'normalize_audio': True, | |
}, | |
# 'train': { | |
# 'root': '../data/video', | |
# 'subset_name': './sets/vgg3-train.tsv', | |
# 'normalize_audio': True, | |
# }, | |
# 'test': { | |
# 'root': '../data/video', | |
# 'subset_name': './sets/vgg3-test.tsv', | |
# 'normalize_audio': False, | |
# }, | |
# 'val': { | |
# 'root': '../data/video', | |
# 'subset_name': './sets/vgg3-val.tsv', | |
# 'normalize_audio': False, | |
# }, | |
} | |
def distributed_setup(): | |
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=1)) | |
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | |
return local_rank, world_size | |
def setup_dataset(split: str): | |
dataset = VGGSound( | |
data_cfg[split]['root'], | |
tsv_path=data_cfg[split]['subset_name'], | |
sample_rate=SAMPLING_RATE, | |
duration_sec=DURATION_SEC, | |
audio_samples=NUM_SAMPLES, | |
normalize_audio=data_cfg[split]['normalize_audio'], | |
) | |
sampler = DistributedSampler(dataset, rank=local_rank, shuffle=False) | |
loader = DataLoader(dataset, | |
batch_size=BATCH_SIZE, | |
num_workers=NUM_WORKERS, | |
sampler=sampler, | |
drop_last=False, | |
collate_fn=error_avoidance_collate) | |
return dataset, loader | |
def extract(): | |
# initial setup | |
distributed_setup() | |
parser = ArgumentParser() | |
parser.add_argument('--latent_dir', | |
type=Path, | |
default='./training/example_output/video-latents') | |
parser.add_argument('--output_dir', type=Path, default='./training/example_output/memmap') | |
args = parser.parse_args() | |
latent_dir = args.latent_dir | |
output_dir = args.output_dir | |
# cuda setup | |
torch.cuda.set_device(local_rank) | |
feature_extractor = FeaturesUtils(tod_vae_ckpt=vae_path, | |
enable_conditions=True, | |
bigvgan_vocoder_ckpt=bigvgan_path, | |
synchformer_ckpt=synchformer_ckpt, | |
mode=mode).eval().cuda() | |
for split in data_cfg.keys(): | |
print(f'Extracting latents for the {split} split') | |
this_latent_dir = latent_dir / split | |
this_latent_dir.mkdir(parents=True, exist_ok=True) | |
# setup datasets | |
dataset, loader = setup_dataset(split) | |
log.info(f'Number of samples: {len(dataset)}') | |
log.info(f'Number of batches: {len(loader)}') | |
for curr_iter, data in enumerate(tqdm(loader)): | |
output = { | |
'id': data['id'], | |
'caption': data['caption'], | |
} | |
audio = data['audio'].cuda() | |
dist = feature_extractor.encode_audio(audio) | |
output['mean'] = dist.mean.detach().cpu().transpose(1, 2) | |
output['std'] = dist.std.detach().cpu().transpose(1, 2) | |
clip_video = data['clip_video'].cuda() | |
clip_features = feature_extractor.encode_video_with_clip(clip_video) | |
output['clip_features'] = clip_features.detach().cpu() | |
sync_video = data['sync_video'].cuda() | |
sync_features = feature_extractor.encode_video_with_sync(sync_video) | |
output['sync_features'] = sync_features.detach().cpu() | |
caption = data['caption'] | |
text_features = feature_extractor.encode_text(caption) | |
output['text_features'] = text_features.detach().cpu() | |
torch.save(output, this_latent_dir / f'r{local_rank}_{curr_iter}.pth') | |
distributed.barrier() | |
# combine the results | |
if local_rank == 0: | |
print('Extraction done. Combining the results.') | |
used_id = set() | |
list_of_ids_and_labels = [] | |
output_data = { | |
'mean': [], | |
'std': [], | |
'clip_features': [], | |
'sync_features': [], | |
'text_features': [], | |
} | |
for t in tqdm(sorted(os.listdir(this_latent_dir))): | |
data = torch.load(this_latent_dir / t, weights_only=True) | |
bs = len(data['id']) | |
for bi in range(bs): | |
this_id = data['id'][bi] | |
this_caption = data['caption'][bi] | |
if this_id in used_id: | |
print('Duplicate id:', this_id) | |
continue | |
list_of_ids_and_labels.append({'id': this_id, 'label': this_caption}) | |
used_id.add(this_id) | |
output_data['mean'].append(data['mean'][bi]) | |
output_data['std'].append(data['std'][bi]) | |
output_data['clip_features'].append(data['clip_features'][bi]) | |
output_data['sync_features'].append(data['sync_features'][bi]) | |
output_data['text_features'].append(data['text_features'][bi]) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
output_df = pd.DataFrame(list_of_ids_and_labels) | |
output_df.to_csv(output_dir / f'vgg-{split}.tsv', sep='\t', index=False) | |
print(f'Output: {len(output_df)}') | |
output_data = {k: torch.stack(v) for k, v in output_data.items()} | |
td.TensorDict(output_data).memmap_(output_dir / f'vgg-{split}') | |
if __name__ == '__main__': | |
extract() | |
distributed.destroy_process_group() | |