sonisphere / training /extract_video_training_latents.py
Phil Sobrepena
initial commit
73ed896
raw
history blame
7.06 kB
import logging
import os
from argparse import ArgumentParser
from datetime import timedelta
from pathlib import Path
import pandas as pd
import tensordict as td
import torch
import torch.distributed as distributed
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from mmaudio.data.data_setup import error_avoidance_collate
from mmaudio.data.extraction.vgg_sound import VGGSound
from mmaudio.model.utils.features_utils import FeaturesUtils
from mmaudio.utils.dist_utils import local_rank, world_size
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# for the 16kHz model
SAMPLING_RATE = 16000
DURATION_SEC = 8.0
NUM_SAMPLES = 128000
vae_path = './ext_weights/v1-16.pth'
bigvgan_path = './ext_weights/best_netG.pt'
mode = '16k'
# for the 44.1kHz model
"""
NOTE: 352800 (8*44100) is not divisible by (STFT hop size * VAE downsampling ratio) which is 1024.
353280 is the next integer divisible by 1024.
"""
# SAMPLING_RATE = 44100
# DURATION_SEC = 8.0
# NUM_SAMPLES = 353280
# vae_path = './ext_weights/v1-44.pth'
# bigvgan_path = None
# mode = '44k'
synchformer_ckpt = './ext_weights/synchformer_state_dict.pth'
# per-GPU
BATCH_SIZE = 16
NUM_WORKERS = 16
log = logging.getLogger()
log.setLevel(logging.INFO)
# uncomment the train/test/val sets to extract latents for them
data_cfg = {
'example': {
'root': './training/example_videos',
'subset_name': './training/example_video.tsv',
'normalize_audio': True,
},
# 'train': {
# 'root': '../data/video',
# 'subset_name': './sets/vgg3-train.tsv',
# 'normalize_audio': True,
# },
# 'test': {
# 'root': '../data/video',
# 'subset_name': './sets/vgg3-test.tsv',
# 'normalize_audio': False,
# },
# 'val': {
# 'root': '../data/video',
# 'subset_name': './sets/vgg3-val.tsv',
# 'normalize_audio': False,
# },
}
def distributed_setup():
distributed.init_process_group(backend="nccl", timeout=timedelta(hours=1))
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
return local_rank, world_size
def setup_dataset(split: str):
dataset = VGGSound(
data_cfg[split]['root'],
tsv_path=data_cfg[split]['subset_name'],
sample_rate=SAMPLING_RATE,
duration_sec=DURATION_SEC,
audio_samples=NUM_SAMPLES,
normalize_audio=data_cfg[split]['normalize_audio'],
)
sampler = DistributedSampler(dataset, rank=local_rank, shuffle=False)
loader = DataLoader(dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
sampler=sampler,
drop_last=False,
collate_fn=error_avoidance_collate)
return dataset, loader
@torch.inference_mode()
def extract():
# initial setup
distributed_setup()
parser = ArgumentParser()
parser.add_argument('--latent_dir',
type=Path,
default='./training/example_output/video-latents')
parser.add_argument('--output_dir', type=Path, default='./training/example_output/memmap')
args = parser.parse_args()
latent_dir = args.latent_dir
output_dir = args.output_dir
# cuda setup
torch.cuda.set_device(local_rank)
feature_extractor = FeaturesUtils(tod_vae_ckpt=vae_path,
enable_conditions=True,
bigvgan_vocoder_ckpt=bigvgan_path,
synchformer_ckpt=synchformer_ckpt,
mode=mode).eval().cuda()
for split in data_cfg.keys():
print(f'Extracting latents for the {split} split')
this_latent_dir = latent_dir / split
this_latent_dir.mkdir(parents=True, exist_ok=True)
# setup datasets
dataset, loader = setup_dataset(split)
log.info(f'Number of samples: {len(dataset)}')
log.info(f'Number of batches: {len(loader)}')
for curr_iter, data in enumerate(tqdm(loader)):
output = {
'id': data['id'],
'caption': data['caption'],
}
audio = data['audio'].cuda()
dist = feature_extractor.encode_audio(audio)
output['mean'] = dist.mean.detach().cpu().transpose(1, 2)
output['std'] = dist.std.detach().cpu().transpose(1, 2)
clip_video = data['clip_video'].cuda()
clip_features = feature_extractor.encode_video_with_clip(clip_video)
output['clip_features'] = clip_features.detach().cpu()
sync_video = data['sync_video'].cuda()
sync_features = feature_extractor.encode_video_with_sync(sync_video)
output['sync_features'] = sync_features.detach().cpu()
caption = data['caption']
text_features = feature_extractor.encode_text(caption)
output['text_features'] = text_features.detach().cpu()
torch.save(output, this_latent_dir / f'r{local_rank}_{curr_iter}.pth')
distributed.barrier()
# combine the results
if local_rank == 0:
print('Extraction done. Combining the results.')
used_id = set()
list_of_ids_and_labels = []
output_data = {
'mean': [],
'std': [],
'clip_features': [],
'sync_features': [],
'text_features': [],
}
for t in tqdm(sorted(os.listdir(this_latent_dir))):
data = torch.load(this_latent_dir / t, weights_only=True)
bs = len(data['id'])
for bi in range(bs):
this_id = data['id'][bi]
this_caption = data['caption'][bi]
if this_id in used_id:
print('Duplicate id:', this_id)
continue
list_of_ids_and_labels.append({'id': this_id, 'label': this_caption})
used_id.add(this_id)
output_data['mean'].append(data['mean'][bi])
output_data['std'].append(data['std'][bi])
output_data['clip_features'].append(data['clip_features'][bi])
output_data['sync_features'].append(data['sync_features'][bi])
output_data['text_features'].append(data['text_features'][bi])
output_dir.mkdir(parents=True, exist_ok=True)
output_df = pd.DataFrame(list_of_ids_and_labels)
output_df.to_csv(output_dir / f'vgg-{split}.tsv', sep='\t', index=False)
print(f'Output: {len(output_df)}')
output_data = {k: torch.stack(v) for k, v in output_data.items()}
td.TensorDict(output_data).memmap_(output_dir / f'vgg-{split}')
if __name__ == '__main__':
extract()
distributed.destroy_process_group()