File size: 4,145 Bytes
73ed896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
import os
from pathlib import Path

import hydra
import torch
import torch.distributed as distributed
import torchaudio
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from tqdm import tqdm

from mmaudio.data.data_setup import setup_eval_dataset
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
log = logging.getLogger()


@torch.inference_mode()
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
def main(cfg: DictConfig):
    device = 'cuda'
    torch.cuda.set_device(local_rank)

    if cfg.model not in all_model_cfg:
        raise ValueError(f'Unknown model variant: {cfg.model}')
    model: ModelConfig = all_model_cfg[cfg.model]
    model.download_if_needed()
    seq_cfg = model.seq_cfg

    run_dir = Path(HydraConfig.get().run.dir)
    if cfg.output_name is None:
        output_dir = run_dir / cfg.dataset
    else:
        output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
    output_dir.mkdir(parents=True, exist_ok=True)

    # load a pretrained model
    seq_cfg.duration = cfg.duration_s
    net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
    net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
    log.info(f'Loaded weights from {model.model_path}')
    net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
    log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
    log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
    log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')

    # misc setup
    rng = torch.Generator(device=device)
    rng.manual_seed(cfg.seed)
    fm = FlowMatching(cfg.sampling.min_sigma,
                      inference_mode=cfg.sampling.method,
                      num_steps=cfg.sampling.num_steps)

    feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
                                  synchformer_ckpt=model.synchformer_ckpt,
                                  enable_conditions=True,
                                  mode=model.mode,
                                  bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
                                  need_vae_encoder=False)
    feature_utils = feature_utils.to(device).eval()

    if cfg.compile:
        net.preprocess_conditions = torch.compile(net.preprocess_conditions)
        net.predict_flow = torch.compile(net.predict_flow)
        feature_utils.compile()

    dataset, loader = setup_eval_dataset(cfg.dataset, cfg)

    with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
        for batch in tqdm(loader):
            audios = generate(batch.get('clip_video', None),
                              batch.get('sync_video', None),
                              batch.get('caption', None),
                              feature_utils=feature_utils,
                              net=net,
                              fm=fm,
                              rng=rng,
                              cfg_strength=cfg.cfg_strength,
                              clip_batch_size_multiplier=64,
                              sync_batch_size_multiplier=64)
            audios = audios.float().cpu()
            names = batch['name']
            for audio, name in zip(audios, names):
                torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)


def distributed_setup():
    distributed.init_process_group(backend="nccl")
    local_rank = distributed.get_rank()
    world_size = distributed.get_world_size()
    log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
    return local_rank, world_size


if __name__ == '__main__':
    distributed_setup()

    main()

    # clean-up
    distributed.destroy_process_group()