Spaces:
Sleeping
Sleeping
Phil Sobrepena
commited on
Commit
·
2c4e2b0
1
Parent(s):
977df40
clone
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -6
- app.py +24 -103
- batch_eval.py +0 -110
- config/__init__.py +0 -0
- config/base_config.yaml +0 -62
- config/data/base.yaml +0 -70
- config/eval_config.yaml +0 -17
- config/eval_data/base.yaml +0 -22
- config/hydra/job_logging/custom-eval.yaml +0 -32
- config/hydra/job_logging/custom-no-rank.yaml +0 -32
- config/hydra/job_logging/custom-simplest.yaml +0 -26
- config/hydra/job_logging/custom.yaml +0 -33
- config/train_config.yaml +0 -41
- demo.py +1 -7
- docs/EVAL.md +0 -22
- docs/MODELS.md +0 -50
- docs/TRAINING.md +0 -160
- docs/index.html +10 -12
- gitattributes +0 -35
- mmaudio/data/av_utils.py +0 -26
- mmaudio/data/data_setup.py +0 -174
- mmaudio/data/eval/__init__.py +0 -0
- mmaudio/data/eval/audiocaps.py +0 -39
- mmaudio/data/eval/moviegen.py +0 -131
- mmaudio/data/eval/video_dataset.py +0 -197
- mmaudio/data/extracted_audio.py +0 -88
- mmaudio/data/extracted_vgg.py +0 -101
- mmaudio/data/extraction/__init__.py +0 -0
- mmaudio/data/extraction/vgg_sound.py +0 -193
- mmaudio/data/extraction/wav_dataset.py +0 -132
- mmaudio/data/mm_dataset.py +0 -45
- mmaudio/data/utils.py +0 -148
- mmaudio/eval_utils.py +9 -47
- mmaudio/ext/autoencoder/autoencoder.py +1 -1
- mmaudio/ext/autoencoder/vae.py +4 -0
- mmaudio/ext/mel_converter.py +9 -33
- mmaudio/model/embeddings.py +1 -1
- mmaudio/model/flow_matching.py +18 -1
- mmaudio/model/networks.py +1 -1
- mmaudio/model/transformer_layers.py +1 -0
- mmaudio/model/utils/features_utils.py +2 -2
- mmaudio/runner.py +0 -609
- mmaudio/sample.py +0 -90
- mmaudio/utils/email_utils.py +0 -50
- mmaudio/utils/log_integrator.py +0 -112
- mmaudio/utils/logger.py +0 -231
- mmaudio/utils/synthesize_ema.py +0 -19
- mmaudio/utils/tensor_utils.py +0 -14
- mmaudio/utils/time_estimator.py +0 -72
- mmaudio/utils/timezone.py +0 -1
.gitignore
CHANGED
@@ -2,18 +2,16 @@ run_*.sh
|
|
2 |
log/
|
3 |
saves
|
4 |
saves/
|
5 |
-
weights/
|
6 |
-
weights
|
7 |
output/
|
8 |
output
|
9 |
pretrained/
|
10 |
workspace
|
11 |
workspace/
|
12 |
-
ext_weights/
|
13 |
-
ext_weights
|
14 |
.checkpoints/
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
|
18 |
# Byte-compiled / optimized / DLL files
|
19 |
__pycache__/
|
|
|
2 |
log/
|
3 |
saves
|
4 |
saves/
|
|
|
|
|
5 |
output/
|
6 |
output
|
7 |
pretrained/
|
8 |
workspace
|
9 |
workspace/
|
|
|
|
|
10 |
.checkpoints/
|
11 |
+
weights/
|
12 |
+
ext_weights/
|
13 |
+
*.pth
|
14 |
+
*.pt
|
15 |
|
16 |
# Byte-compiled / optimized / DLL files
|
17 |
__pycache__/
|
app.py
CHANGED
@@ -14,12 +14,13 @@ except ImportError:
|
|
14 |
os.system("pip install -e .")
|
15 |
import mmaudio
|
16 |
|
17 |
-
from mmaudio.eval_utils import (ModelConfig,
|
18 |
-
|
19 |
from mmaudio.model.flow_matching import FlowMatching
|
20 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
21 |
from mmaudio.model.sequence_config import SequenceConfig
|
22 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
|
|
23 |
|
24 |
torch.backends.cuda.matmul.allow_tf32 = True
|
25 |
torch.backends.cudnn.allow_tf32 = True
|
@@ -56,6 +57,7 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
|
56 |
|
57 |
net, feature_utils, seq_cfg = get_model()
|
58 |
|
|
|
59 |
@spaces.GPU(duration=120)
|
60 |
@torch.inference_mode()
|
61 |
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
@@ -88,17 +90,18 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
|
|
88 |
audio = audios.float().cpu()[0]
|
89 |
|
90 |
# current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
|
|
91 |
# output_dir.mkdir(exist_ok=True, parents=True)
|
92 |
# video_save_path = output_dir / f'{current_time_string}.mp4'
|
93 |
-
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
94 |
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
95 |
log.info(f'Saved video to {video_save_path}')
|
96 |
return video_save_path
|
97 |
|
|
|
98 |
@spaces.GPU(duration=120)
|
99 |
@torch.inference_mode()
|
100 |
-
def
|
101 |
-
|
102 |
|
103 |
rng = torch.Generator(device=device)
|
104 |
if seed >= 0:
|
@@ -107,11 +110,7 @@ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int
|
|
107 |
rng.seed()
|
108 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
|
110 |
-
|
111 |
-
clip_frames = image_info.clip_frames
|
112 |
-
sync_frames = image_info.sync_frames
|
113 |
-
clip_frames = clip_frames.unsqueeze(0)
|
114 |
-
sync_frames = sync_frames.unsqueeze(0)
|
115 |
seq_cfg.duration = duration
|
116 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
117 |
|
@@ -122,61 +121,24 @@ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int
|
|
122 |
net=net,
|
123 |
fm=fm,
|
124 |
rng=rng,
|
125 |
-
cfg_strength=cfg_strength
|
126 |
-
image_input=True)
|
127 |
audio = audios.float().cpu()[0]
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
134 |
-
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
135 |
-
log.info(f'Saved video to {video_save_path}')
|
136 |
-
return video_save_path
|
137 |
-
|
138 |
-
# @spaces.GPU(duration=120)
|
139 |
-
# @torch.inference_mode()
|
140 |
-
# def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
141 |
-
# duration: float):
|
142 |
-
|
143 |
-
# rng = torch.Generator(device=device)
|
144 |
-
# if seed >= 0:
|
145 |
-
# rng.manual_seed(seed)
|
146 |
-
# else:
|
147 |
-
# rng.seed()
|
148 |
-
# fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
149 |
-
|
150 |
-
# clip_frames = sync_frames = None
|
151 |
-
# seq_cfg.duration = duration
|
152 |
-
# net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
153 |
-
|
154 |
-
# audios = generate(clip_frames,
|
155 |
-
# sync_frames, [prompt],
|
156 |
-
# negative_text=[negative_prompt],
|
157 |
-
# feature_utils=feature_utils,
|
158 |
-
# net=net,
|
159 |
-
# fm=fm,
|
160 |
-
# rng=rng,
|
161 |
-
# cfg_strength=cfg_strength)
|
162 |
-
# audio = audios.float().cpu()[0]
|
163 |
-
|
164 |
-
# current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
165 |
-
# output_dir.mkdir(exist_ok=True, parents=True)
|
166 |
-
# audio_save_path = output_dir / f'{current_time_string}.flac'
|
167 |
-
# torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
168 |
-
# gc.collect()
|
169 |
-
# return audio_save_path
|
170 |
|
171 |
|
172 |
video_to_audio_tab = gr.Interface(
|
173 |
fn=video_to_audio,
|
174 |
-
description="""
|
|
|
|
|
175 |
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
176 |
Doing so does not improve results.
|
177 |
|
178 |
-
The model has been trained on 8-second videos.
|
179 |
-
Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
|
180 |
""",
|
181 |
inputs=[
|
182 |
gr.Video(),
|
@@ -189,52 +151,11 @@ video_to_audio_tab = gr.Interface(
|
|
189 |
],
|
190 |
outputs='playable_video',
|
191 |
cache_examples=False,
|
192 |
-
title='
|
193 |
-
|
194 |
-
|
195 |
-
# text_to_audio_tab = gr.Interface(
|
196 |
-
# fn=text_to_audio,
|
197 |
-
# description=""" Text-to-Audio
|
198 |
-
# """,
|
199 |
-
# inputs=[
|
200 |
-
# gr.Text(label='Prompt'),
|
201 |
-
# gr.Text(label='Negative prompt'),
|
202 |
-
# gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
203 |
-
# gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
204 |
-
# gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
205 |
-
# gr.Number(label='Duration (sec)', value=8, minimum=1),
|
206 |
-
# ],
|
207 |
-
# outputs='audio',
|
208 |
-
# cache_examples=False,
|
209 |
-
# title='Sonisphere - Sonic Branding Tool',
|
210 |
-
# )
|
211 |
-
|
212 |
-
image_to_audio_tab = gr.Interface(
|
213 |
-
fn=image_to_audio,
|
214 |
-
description="""
|
215 |
-
Image-to-Audio
|
216 |
-
NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
|
217 |
-
Doing so does not improve results.
|
218 |
-
""",
|
219 |
-
inputs=[
|
220 |
-
gr.Image(type='filepath'),
|
221 |
-
gr.Text(label='Prompt'),
|
222 |
-
gr.Text(label='Negative prompt'),
|
223 |
-
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
224 |
-
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
225 |
-
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
226 |
-
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
227 |
-
],
|
228 |
-
outputs='playable_video',
|
229 |
-
cache_examples=False,
|
230 |
-
title='Image-to-Audio Synthesis (experimental)',
|
231 |
-
)
|
232 |
|
233 |
-
if __name__ == "__main__":
|
234 |
-
# parser = ArgumentParser()
|
235 |
-
# parser.add_argument('--port', type=int, default=7860)
|
236 |
-
# args = parser.parse_args()
|
237 |
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
14 |
os.system("pip install -e .")
|
15 |
import mmaudio
|
16 |
|
17 |
+
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
18 |
+
setup_eval_logging)
|
19 |
from mmaudio.model.flow_matching import FlowMatching
|
20 |
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
21 |
from mmaudio.model.sequence_config import SequenceConfig
|
22 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
23 |
+
import tempfile
|
24 |
|
25 |
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
torch.backends.cudnn.allow_tf32 = True
|
|
|
57 |
|
58 |
net, feature_utils, seq_cfg = get_model()
|
59 |
|
60 |
+
|
61 |
@spaces.GPU(duration=120)
|
62 |
@torch.inference_mode()
|
63 |
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
|
|
90 |
audio = audios.float().cpu()[0]
|
91 |
|
92 |
# current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
93 |
+
video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
94 |
# output_dir.mkdir(exist_ok=True, parents=True)
|
95 |
# video_save_path = output_dir / f'{current_time_string}.mp4'
|
|
|
96 |
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
97 |
log.info(f'Saved video to {video_save_path}')
|
98 |
return video_save_path
|
99 |
|
100 |
+
|
101 |
@spaces.GPU(duration=120)
|
102 |
@torch.inference_mode()
|
103 |
+
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
104 |
+
duration: float):
|
105 |
|
106 |
rng = torch.Generator(device=device)
|
107 |
if seed >= 0:
|
|
|
110 |
rng.seed()
|
111 |
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
112 |
|
113 |
+
clip_frames = sync_frames = None
|
|
|
|
|
|
|
|
|
114 |
seq_cfg.duration = duration
|
115 |
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
116 |
|
|
|
121 |
net=net,
|
122 |
fm=fm,
|
123 |
rng=rng,
|
124 |
+
cfg_strength=cfg_strength)
|
|
|
125 |
audio = audios.float().cpu()[0]
|
126 |
|
127 |
+
audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
|
128 |
+
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
129 |
+
log.info(f'Saved audio to {audio_save_path}')
|
130 |
+
return audio_save_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
|
133 |
video_to_audio_tab = gr.Interface(
|
134 |
fn=video_to_audio,
|
135 |
+
description="""
|
136 |
+
Sonisphere
|
137 |
+
Video-to-Audio
|
138 |
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
139 |
Doing so does not improve results.
|
140 |
|
141 |
+
The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
|
|
|
142 |
""",
|
143 |
inputs=[
|
144 |
gr.Video(),
|
|
|
151 |
],
|
152 |
outputs='playable_video',
|
153 |
cache_examples=False,
|
154 |
+
title='MMAudio — Video-to-Audio Synthesis',
|
155 |
+
examples=[
|
156 |
+
])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
|
|
|
|
|
|
|
|
|
158 |
|
159 |
+
if __name__ == "__main__":
|
160 |
+
gr.TabbedInterface([video_to_audio_tab],
|
161 |
+
['Video-to-Audio']).launch(allowed_paths=[output_dir])
|
batch_eval.py
DELETED
@@ -1,110 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
|
5 |
-
import hydra
|
6 |
-
import torch
|
7 |
-
import torch.distributed as distributed
|
8 |
-
import torchaudio
|
9 |
-
from hydra.core.hydra_config import HydraConfig
|
10 |
-
from omegaconf import DictConfig
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
from mmaudio.data.data_setup import setup_eval_dataset
|
14 |
-
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
|
15 |
-
from mmaudio.model.flow_matching import FlowMatching
|
16 |
-
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
17 |
-
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
-
|
19 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
-
torch.backends.cudnn.allow_tf32 = True
|
21 |
-
|
22 |
-
local_rank = int(os.environ['LOCAL_RANK'])
|
23 |
-
world_size = int(os.environ['WORLD_SIZE'])
|
24 |
-
log = logging.getLogger()
|
25 |
-
|
26 |
-
|
27 |
-
@torch.inference_mode()
|
28 |
-
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
|
29 |
-
def main(cfg: DictConfig):
|
30 |
-
device = 'cuda'
|
31 |
-
torch.cuda.set_device(local_rank)
|
32 |
-
|
33 |
-
if cfg.model not in all_model_cfg:
|
34 |
-
raise ValueError(f'Unknown model variant: {cfg.model}')
|
35 |
-
model: ModelConfig = all_model_cfg[cfg.model]
|
36 |
-
model.download_if_needed()
|
37 |
-
seq_cfg = model.seq_cfg
|
38 |
-
|
39 |
-
run_dir = Path(HydraConfig.get().run.dir)
|
40 |
-
if cfg.output_name is None:
|
41 |
-
output_dir = run_dir / cfg.dataset
|
42 |
-
else:
|
43 |
-
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
|
44 |
-
output_dir.mkdir(parents=True, exist_ok=True)
|
45 |
-
|
46 |
-
# load a pretrained model
|
47 |
-
seq_cfg.duration = cfg.duration_s
|
48 |
-
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
|
49 |
-
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
50 |
-
log.info(f'Loaded weights from {model.model_path}')
|
51 |
-
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
52 |
-
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
|
53 |
-
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
|
54 |
-
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
|
55 |
-
|
56 |
-
# misc setup
|
57 |
-
rng = torch.Generator(device=device)
|
58 |
-
rng.manual_seed(cfg.seed)
|
59 |
-
fm = FlowMatching(cfg.sampling.min_sigma,
|
60 |
-
inference_mode=cfg.sampling.method,
|
61 |
-
num_steps=cfg.sampling.num_steps)
|
62 |
-
|
63 |
-
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
64 |
-
synchformer_ckpt=model.synchformer_ckpt,
|
65 |
-
enable_conditions=True,
|
66 |
-
mode=model.mode,
|
67 |
-
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
68 |
-
need_vae_encoder=False)
|
69 |
-
feature_utils = feature_utils.to(device).eval()
|
70 |
-
|
71 |
-
if cfg.compile:
|
72 |
-
net.preprocess_conditions = torch.compile(net.preprocess_conditions)
|
73 |
-
net.predict_flow = torch.compile(net.predict_flow)
|
74 |
-
feature_utils.compile()
|
75 |
-
|
76 |
-
dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
|
77 |
-
|
78 |
-
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
|
79 |
-
for batch in tqdm(loader):
|
80 |
-
audios = generate(batch.get('clip_video', None),
|
81 |
-
batch.get('sync_video', None),
|
82 |
-
batch.get('caption', None),
|
83 |
-
feature_utils=feature_utils,
|
84 |
-
net=net,
|
85 |
-
fm=fm,
|
86 |
-
rng=rng,
|
87 |
-
cfg_strength=cfg.cfg_strength,
|
88 |
-
clip_batch_size_multiplier=64,
|
89 |
-
sync_batch_size_multiplier=64)
|
90 |
-
audios = audios.float().cpu()
|
91 |
-
names = batch['name']
|
92 |
-
for audio, name in zip(audios, names):
|
93 |
-
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
|
94 |
-
|
95 |
-
|
96 |
-
def distributed_setup():
|
97 |
-
distributed.init_process_group(backend="nccl")
|
98 |
-
local_rank = distributed.get_rank()
|
99 |
-
world_size = distributed.get_world_size()
|
100 |
-
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
|
101 |
-
return local_rank, world_size
|
102 |
-
|
103 |
-
|
104 |
-
if __name__ == '__main__':
|
105 |
-
distributed_setup()
|
106 |
-
|
107 |
-
main()
|
108 |
-
|
109 |
-
# clean-up
|
110 |
-
distributed.destroy_process_group()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/__init__.py
DELETED
File without changes
|
config/base_config.yaml
DELETED
@@ -1,62 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- data: base
|
3 |
-
- eval_data: base
|
4 |
-
- override hydra/job_logging: custom-simplest
|
5 |
-
- _self_
|
6 |
-
|
7 |
-
hydra:
|
8 |
-
run:
|
9 |
-
dir: ./output/${exp_id}
|
10 |
-
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
-
|
12 |
-
enable_email: False
|
13 |
-
|
14 |
-
model: small_16k
|
15 |
-
|
16 |
-
exp_id: default
|
17 |
-
debug: False
|
18 |
-
cudnn_benchmark: True
|
19 |
-
compile: True
|
20 |
-
amp: True
|
21 |
-
weights: null
|
22 |
-
checkpoint: null
|
23 |
-
seed: 14159265
|
24 |
-
num_workers: 10 # per-GPU
|
25 |
-
pin_memory: False # set to True if your system can handle it, i.e., have enough memory
|
26 |
-
|
27 |
-
# NOTE: This DOSE NOT affect the model during inference in any way
|
28 |
-
# they are just for the dataloader to fill in the missing data in multi-modal loading
|
29 |
-
# to change the sequence length for the model, see networks.py
|
30 |
-
data_dim:
|
31 |
-
text_seq_len: 77
|
32 |
-
clip_dim: 1024
|
33 |
-
sync_dim: 768
|
34 |
-
text_dim: 1024
|
35 |
-
|
36 |
-
# ema configuration
|
37 |
-
ema:
|
38 |
-
enable: True
|
39 |
-
sigma_rels: [0.05, 0.1]
|
40 |
-
update_every: 1
|
41 |
-
checkpoint_every: 5_000
|
42 |
-
checkpoint_folder: ${hydra:run.dir}/ema_ckpts
|
43 |
-
default_output_sigma: 0.05
|
44 |
-
|
45 |
-
|
46 |
-
# sampling
|
47 |
-
sampling:
|
48 |
-
mean: 0.0
|
49 |
-
scale: 1.0
|
50 |
-
min_sigma: 0.0
|
51 |
-
method: euler
|
52 |
-
num_steps: 25
|
53 |
-
|
54 |
-
# classifier-free guidance
|
55 |
-
null_condition_probability: 0.1
|
56 |
-
cfg_strength: 4.5
|
57 |
-
|
58 |
-
# checkpoint paths to external modules
|
59 |
-
vae_16k_ckpt: ./ext_weights/v1-16.pth
|
60 |
-
vae_44k_ckpt: ./ext_weights/v1-44.pth
|
61 |
-
bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
|
62 |
-
synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/data/base.yaml
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
VGGSound:
|
2 |
-
root: ../data/video
|
3 |
-
subset_name: sets/vgg3-train.tsv
|
4 |
-
fps: 8
|
5 |
-
height: 384
|
6 |
-
width: 384
|
7 |
-
sample_duration_sec: 8.0
|
8 |
-
|
9 |
-
VGGSound_test:
|
10 |
-
root: ../data/video
|
11 |
-
subset_name: sets/vgg3-test.tsv
|
12 |
-
fps: 8
|
13 |
-
height: 384
|
14 |
-
width: 384
|
15 |
-
sample_duration_sec: 8.0
|
16 |
-
|
17 |
-
VGGSound_val:
|
18 |
-
root: ../data/video
|
19 |
-
subset_name: sets/vgg3-val.tsv
|
20 |
-
fps: 8
|
21 |
-
height: 384
|
22 |
-
width: 384
|
23 |
-
sample_duration_sec: 8.0
|
24 |
-
|
25 |
-
ExtractedVGG:
|
26 |
-
tsv: ../data/v1-16-memmap/vgg-train.tsv
|
27 |
-
memmap_dir: ../data/v1-16-memmap/vgg-train
|
28 |
-
|
29 |
-
ExtractedVGG_test:
|
30 |
-
tag: test
|
31 |
-
gt_cache: ../data/eval-cache/vggsound-test
|
32 |
-
output_subdir: null
|
33 |
-
tsv: ../data/v1-16-memmap/vgg-test.tsv
|
34 |
-
memmap_dir: ../data/v1-16-memmap/vgg-test
|
35 |
-
|
36 |
-
ExtractedVGG_val:
|
37 |
-
tag: val
|
38 |
-
gt_cache: ../data/eval-cache/vggsound-val
|
39 |
-
output_subdir: val
|
40 |
-
tsv: ../data/v1-16-memmap/vgg-val.tsv
|
41 |
-
memmap_dir: ../data/v1-16-memmap/vgg-val
|
42 |
-
|
43 |
-
AudioCaps:
|
44 |
-
tsv: ../data/v1-16-memmap/audiocaps.tsv
|
45 |
-
memmap_dir: ../data/v1-16-memmap/audiocaps
|
46 |
-
|
47 |
-
AudioSetSL:
|
48 |
-
tsv: ../data/v1-16-memmap/audioset_sl.tsv
|
49 |
-
memmap_dir: ../data/v1-16-memmap/audioset_sl
|
50 |
-
|
51 |
-
BBCSound:
|
52 |
-
tsv: ../data/v1-16-memmap/bbcsound.tsv
|
53 |
-
memmap_dir: ../data/v1-16-memmap/bbcsound
|
54 |
-
|
55 |
-
FreeSound:
|
56 |
-
tsv: ../data/v1-16-memmap/freesound.tsv
|
57 |
-
memmap_dir: ../data/v1-16-memmap/freesound
|
58 |
-
|
59 |
-
Clotho:
|
60 |
-
tsv: ../data/v1-16-memmap/clotho.tsv
|
61 |
-
memmap_dir: ../data/v1-16-memmap/clotho
|
62 |
-
|
63 |
-
Example_video:
|
64 |
-
tsv: ./training/example_output/memmap/vgg-example.tsv
|
65 |
-
memmap_dir: ./training/example_output/memmap/vgg-example
|
66 |
-
|
67 |
-
Example_audio:
|
68 |
-
tsv: ./training/example_output/memmap/audio-example.tsv
|
69 |
-
memmap_dir: ./training/example_output/memmap/audio-example
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/eval_config.yaml
DELETED
@@ -1,17 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_config
|
3 |
-
- override hydra/job_logging: custom-simplest
|
4 |
-
- _self_
|
5 |
-
|
6 |
-
hydra:
|
7 |
-
run:
|
8 |
-
dir: ./output/${exp_id}
|
9 |
-
output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
10 |
-
|
11 |
-
exp_id: ${model}
|
12 |
-
dataset: audiocaps
|
13 |
-
duration_s: 8.0
|
14 |
-
|
15 |
-
# for inference, this is the per-GPU batch size
|
16 |
-
batch_size: 16
|
17 |
-
output_name: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/eval_data/base.yaml
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
AudioCaps:
|
2 |
-
audio_path: ../data/AudioCaps-test-audioldm-ver
|
3 |
-
# a csv file, with a header row of 'name' and 'caption'
|
4 |
-
# name should match the audio file name without extension
|
5 |
-
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
|
6 |
-
csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
|
7 |
-
|
8 |
-
AudioCaps_full:
|
9 |
-
audio_path: ../data/AudioCaps-test-full-ver
|
10 |
-
# a csv file, with a header row of 'name' and 'caption'
|
11 |
-
# name should match the audio file name without extension
|
12 |
-
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
|
13 |
-
csv_path: ../data/AudioCaps-test-full-ver/data.csv
|
14 |
-
|
15 |
-
MovieGen:
|
16 |
-
video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
|
17 |
-
jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
|
18 |
-
|
19 |
-
VGGSound:
|
20 |
-
video_path: ../data/test-videos
|
21 |
-
# from the officially released csv file
|
22 |
-
csv_path: ../data/vggsound.csv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom-eval.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# python logging configuration for tasks
|
2 |
-
version: 1
|
3 |
-
formatters:
|
4 |
-
simple:
|
5 |
-
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
6 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
-
colorlog:
|
8 |
-
'()': 'colorlog.ColoredFormatter'
|
9 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
-
log_colors:
|
12 |
-
DEBUG: purple
|
13 |
-
INFO: green
|
14 |
-
WARNING: yellow
|
15 |
-
ERROR: red
|
16 |
-
CRITICAL: red
|
17 |
-
handlers:
|
18 |
-
console:
|
19 |
-
class: logging.StreamHandler
|
20 |
-
formatter: colorlog
|
21 |
-
stream: ext://sys.stdout
|
22 |
-
file:
|
23 |
-
class: logging.FileHandler
|
24 |
-
formatter: simple
|
25 |
-
# absolute file path
|
26 |
-
filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
27 |
-
mode: w
|
28 |
-
root:
|
29 |
-
level: INFO
|
30 |
-
handlers: [console, file]
|
31 |
-
|
32 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom-no-rank.yaml
DELETED
@@ -1,32 +0,0 @@
|
|
1 |
-
# python logging configuration for tasks
|
2 |
-
version: 1
|
3 |
-
formatters:
|
4 |
-
simple:
|
5 |
-
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
-
colorlog:
|
8 |
-
'()': 'colorlog.ColoredFormatter'
|
9 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
-
log_colors:
|
12 |
-
DEBUG: purple
|
13 |
-
INFO: green
|
14 |
-
WARNING: yellow
|
15 |
-
ERROR: red
|
16 |
-
CRITICAL: red
|
17 |
-
handlers:
|
18 |
-
console:
|
19 |
-
class: logging.StreamHandler
|
20 |
-
formatter: colorlog
|
21 |
-
stream: ext://sys.stdout
|
22 |
-
file:
|
23 |
-
class: logging.FileHandler
|
24 |
-
formatter: simple
|
25 |
-
# absolute file path
|
26 |
-
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
|
27 |
-
mode: w
|
28 |
-
root:
|
29 |
-
level: INFO
|
30 |
-
handlers: [console, file]
|
31 |
-
|
32 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom-simplest.yaml
DELETED
@@ -1,26 +0,0 @@
|
|
1 |
-
# python logging configuration for tasks
|
2 |
-
version: 1
|
3 |
-
formatters:
|
4 |
-
simple:
|
5 |
-
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
-
colorlog:
|
8 |
-
'()': 'colorlog.ColoredFormatter'
|
9 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
-
log_colors:
|
12 |
-
DEBUG: purple
|
13 |
-
INFO: green
|
14 |
-
WARNING: yellow
|
15 |
-
ERROR: red
|
16 |
-
CRITICAL: red
|
17 |
-
handlers:
|
18 |
-
console:
|
19 |
-
class: logging.StreamHandler
|
20 |
-
formatter: colorlog
|
21 |
-
stream: ext://sys.stdout
|
22 |
-
root:
|
23 |
-
level: INFO
|
24 |
-
handlers: [console]
|
25 |
-
|
26 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/hydra/job_logging/custom.yaml
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
# @package hydra.job_logging
|
2 |
-
# python logging configuration for tasks
|
3 |
-
version: 1
|
4 |
-
formatters:
|
5 |
-
simple:
|
6 |
-
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
7 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
8 |
-
colorlog:
|
9 |
-
'()': 'colorlog.ColoredFormatter'
|
10 |
-
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
11 |
-
datefmt: '%Y-%m-%d %H:%M:%S'
|
12 |
-
log_colors:
|
13 |
-
DEBUG: purple
|
14 |
-
INFO: green
|
15 |
-
WARNING: yellow
|
16 |
-
ERROR: red
|
17 |
-
CRITICAL: red
|
18 |
-
handlers:
|
19 |
-
console:
|
20 |
-
class: logging.StreamHandler
|
21 |
-
formatter: colorlog
|
22 |
-
stream: ext://sys.stdout
|
23 |
-
file:
|
24 |
-
class: logging.FileHandler
|
25 |
-
formatter: simple
|
26 |
-
# absolute file path
|
27 |
-
filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
28 |
-
mode: w
|
29 |
-
root:
|
30 |
-
level: INFO
|
31 |
-
handlers: [console, file]
|
32 |
-
|
33 |
-
disable_existing_loggers: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/train_config.yaml
DELETED
@@ -1,41 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- base_config
|
3 |
-
- override data: base
|
4 |
-
- override hydra/job_logging: custom
|
5 |
-
- _self_
|
6 |
-
|
7 |
-
hydra:
|
8 |
-
run:
|
9 |
-
dir: ./output/${exp_id}
|
10 |
-
output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
-
|
12 |
-
ema:
|
13 |
-
start: 0
|
14 |
-
|
15 |
-
mini_train: False
|
16 |
-
example_train: False
|
17 |
-
enable_grad_scaler: False
|
18 |
-
vgg_oversample_rate: 5
|
19 |
-
|
20 |
-
log_text_interval: 200
|
21 |
-
log_extra_interval: 20_000
|
22 |
-
val_interval: 5_000
|
23 |
-
eval_interval: 20_000
|
24 |
-
save_eval_interval: 40_000
|
25 |
-
save_weights_interval: 10_000
|
26 |
-
save_checkpoint_interval: 10_000
|
27 |
-
save_copy_iterations: []
|
28 |
-
|
29 |
-
batch_size: 512
|
30 |
-
eval_batch_size: 256 # per-GPU
|
31 |
-
|
32 |
-
num_iterations: 300_000
|
33 |
-
learning_rate: 1.0e-4
|
34 |
-
linear_warmup_steps: 1_000
|
35 |
-
|
36 |
-
lr_schedule: step
|
37 |
-
lr_schedule_steps: [240_000, 270_000]
|
38 |
-
lr_schedule_gamma: 0.1
|
39 |
-
|
40 |
-
clip_grad_norm: 1.0
|
41 |
-
weight_decay: 1.0e-6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.py
CHANGED
@@ -62,13 +62,7 @@ def main():
|
|
62 |
skip_video_composite: bool = args.skip_video_composite
|
63 |
mask_away_clip: bool = args.mask_away_clip
|
64 |
|
65 |
-
device = '
|
66 |
-
if torch.cuda.is_available():
|
67 |
-
device = 'cuda'
|
68 |
-
elif torch.backends.mps.is_available():
|
69 |
-
device = 'mps'
|
70 |
-
else:
|
71 |
-
log.warning('CUDA/MPS are not available, running on CPU')
|
72 |
dtype = torch.float32 if args.full_precision else torch.bfloat16
|
73 |
|
74 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
62 |
skip_video_composite: bool = args.skip_video_composite
|
63 |
mask_away_clip: bool = args.mask_away_clip
|
64 |
|
65 |
+
device = 'cuda'
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
dtype = torch.float32 if args.full_precision else torch.bfloat16
|
67 |
|
68 |
output_dir.mkdir(parents=True, exist_ok=True)
|
docs/EVAL.md
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# Evaluation
|
2 |
-
|
3 |
-
## Batch Evaluation
|
4 |
-
|
5 |
-
To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
|
6 |
-
|
7 |
-
An example of running this script with four GPUs is as follows:
|
8 |
-
|
9 |
-
```bash
|
10 |
-
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
|
11 |
-
```
|
12 |
-
|
13 |
-
You may need to update the data paths in `config/eval_data/base.yaml`.
|
14 |
-
More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
|
15 |
-
|
16 |
-
## Precomputed Results
|
17 |
-
|
18 |
-
Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
|
19 |
-
|
20 |
-
## Obtaining Quantitative Metrics
|
21 |
-
|
22 |
-
Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/MODELS.md
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
# Pretrained models
|
2 |
-
|
3 |
-
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
|
4 |
-
The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
|
5 |
-
|
6 |
-
| Model | Download link | File size |
|
7 |
-
| -------- | ------- | ------- |
|
8 |
-
| Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
|
9 |
-
| Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
|
10 |
-
| Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
|
11 |
-
| Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
|
12 |
-
| Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
|
13 |
-
| 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
|
14 |
-
| 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
|
15 |
-
| 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
|
16 |
-
| Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
|
17 |
-
|
18 |
-
To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
|
19 |
-
The 44.1kHz vocoder will be downloaded automatically.
|
20 |
-
The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
|
21 |
-
|
22 |
-
The expected directory structure (full):
|
23 |
-
|
24 |
-
```bash
|
25 |
-
MMAudio
|
26 |
-
├── ext_weights
|
27 |
-
│ ├── best_netG.pt
|
28 |
-
│ ├── synchformer_state_dict.pth
|
29 |
-
│ ├── v1-16.pth
|
30 |
-
│ └── v1-44.pth
|
31 |
-
├── weights
|
32 |
-
│ ├── mmaudio_small_16k.pth
|
33 |
-
│ ├── mmaudio_small_44k.pth
|
34 |
-
│ ├── mmaudio_medium_44k.pth
|
35 |
-
│ ├── mmaudio_large_44k.pth
|
36 |
-
│ └── mmaudio_large_44k_v2.pth
|
37 |
-
└── ...
|
38 |
-
```
|
39 |
-
|
40 |
-
The expected directory structure (minimal, for the recommended model only):
|
41 |
-
|
42 |
-
```bash
|
43 |
-
MMAudio
|
44 |
-
├── ext_weights
|
45 |
-
│ ├── synchformer_state_dict.pth
|
46 |
-
│ └── v1-44.pth
|
47 |
-
├── weights
|
48 |
-
│ └── mmaudio_large_44k_v2.pth
|
49 |
-
└── ...
|
50 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/TRAINING.md
DELETED
@@ -1,160 +0,0 @@
|
|
1 |
-
# Training
|
2 |
-
|
3 |
-
## Overview
|
4 |
-
|
5 |
-
We have put a large emphasis on making training as fast as possible.
|
6 |
-
Consequently, some pre-processing steps are required.
|
7 |
-
|
8 |
-
Namely, before starting any training, we
|
9 |
-
|
10 |
-
1. Obtain training data as videos, audios, and captions.
|
11 |
-
2. Encode training audios into spectrograms and then with VAE into mean/std
|
12 |
-
3. Extract CLIP and synchronization features from videos
|
13 |
-
4. Extract CLIP features from text (captions)
|
14 |
-
5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
|
15 |
-
|
16 |
-
**NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
|
17 |
-
|
18 |
-
The current training script does not support `_v2` training.
|
19 |
-
|
20 |
-
## Recommended Hardware Configuration
|
21 |
-
|
22 |
-
These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
|
23 |
-
|
24 |
-
- Single-node machine. We did not implement multi-node training
|
25 |
-
- GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
|
26 |
-
- System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
|
27 |
-
- Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
|
28 |
-
|
29 |
-
## Prerequisites
|
30 |
-
|
31 |
-
1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
|
32 |
-
2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
|
33 |
-
|
34 |
-
3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
|
35 |
-
|
36 |
-
```bash
|
37 |
-
conda install -c conda-forge 'ffmpeg<7'
|
38 |
-
```
|
39 |
-
|
40 |
-
4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://arxiv.org/abs/2303.17395). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
|
41 |
-
|
42 |
-
5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
|
43 |
-
|
44 |
-
## Preparing Audio-Video-Text Features
|
45 |
-
|
46 |
-
We have prepared some example data in `training/example_videos`.
|
47 |
-
`training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
48 |
-
|
49 |
-
To run this script, use the `torchrun` utility:
|
50 |
-
|
51 |
-
```bash
|
52 |
-
torchrun --standalone training/extract_video_training_latents.py
|
53 |
-
```
|
54 |
-
|
55 |
-
You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
|
56 |
-
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
57 |
-
Change the data path definitions in `data_cfg` if necessary.
|
58 |
-
|
59 |
-
Arguments:
|
60 |
-
|
61 |
-
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
62 |
-
- `output_dir` -- where TensorDict and the metadata file are saved.
|
63 |
-
|
64 |
-
Outputs produced in `output_dir`:
|
65 |
-
|
66 |
-
1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
|
67 |
-
a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
68 |
-
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
69 |
-
c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
|
70 |
-
d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
|
71 |
-
e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
|
72 |
-
f. `meta.json` that contains the metadata for the above memory mappings
|
73 |
-
2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
74 |
-
|
75 |
-
## Preparing Audio-Text Features
|
76 |
-
|
77 |
-
We have prepared some example data in `training/example_audios`.
|
78 |
-
|
79 |
-
1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
|
80 |
-
2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
81 |
-
|
82 |
-
### Partitioning the audio files
|
83 |
-
|
84 |
-
Run
|
85 |
-
|
86 |
-
```bash
|
87 |
-
python training/partition_clips.py
|
88 |
-
```
|
89 |
-
|
90 |
-
Arguments:
|
91 |
-
|
92 |
-
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
|
93 |
-
- `output_dir` -- path to the output `.csv` file
|
94 |
-
- `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
|
95 |
-
- `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
|
96 |
-
|
97 |
-
### Extracting audio and text features
|
98 |
-
|
99 |
-
Run
|
100 |
-
|
101 |
-
```bash
|
102 |
-
torchrun --standalone training/extract_audio_training_latents.py
|
103 |
-
```
|
104 |
-
|
105 |
-
You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
|
106 |
-
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
107 |
-
|
108 |
-
Arguments:
|
109 |
-
|
110 |
-
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
|
111 |
-
- `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
|
112 |
-
- `clips_tsv` -- path to the clips file, generated in the last step
|
113 |
-
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
114 |
-
- `output_dir` -- where TensorDict and the metadata file are saved.
|
115 |
-
|
116 |
-
**Reference tsv files (with overlaps removed as mentioned in the paper) can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).**
|
117 |
-
Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
|
118 |
-
Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
|
119 |
-
The Clotho data contains both the development set and the validation set.
|
120 |
-
|
121 |
-
Outputs produced in `output_dir`:
|
122 |
-
|
123 |
-
1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
|
124 |
-
a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
125 |
-
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
126 |
-
c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
|
127 |
-
f. `meta.json` that contains the metadata for the above memory mappings
|
128 |
-
2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
129 |
-
|
130 |
-
## Training on Extracted Features
|
131 |
-
|
132 |
-
We use Distributed Data Parallel (DDP) for training.
|
133 |
-
First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
|
134 |
-
|
135 |
-
To run training on the example data, use the following command:
|
136 |
-
|
137 |
-
```bash
|
138 |
-
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
|
139 |
-
```
|
140 |
-
|
141 |
-
This will not train a useful model, but it will check if everything is set up correctly.
|
142 |
-
|
143 |
-
For full training on the base model with two GPUs, use the following command:
|
144 |
-
|
145 |
-
```bash
|
146 |
-
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
|
147 |
-
```
|
148 |
-
|
149 |
-
Any outputs from training will be stored in `output/<exp_id>`.
|
150 |
-
|
151 |
-
More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
|
152 |
-
For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
|
153 |
-
|
154 |
-
## Checkpoints
|
155 |
-
|
156 |
-
Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
|
157 |
-
|
158 |
-
---
|
159 |
-
|
160 |
-
Godspeed!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docs/index.html
CHANGED
@@ -40,7 +40,7 @@
|
|
40 |
<br>
|
41 |
<div class="row text-center" style="font-size:28px">
|
42 |
<div class="col">
|
43 |
-
|
44 |
</div>
|
45 |
</div>
|
46 |
<br>
|
@@ -83,21 +83,19 @@
|
|
83 |
<br>
|
84 |
|
85 |
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
|
86 |
-
<div class="col-sm-2">
|
87 |
-
<a href="https://arxiv.org/abs/
|
88 |
-
</div>
|
89 |
-
<div class="col-sm-2">
|
90 |
-
<a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
|
91 |
-
</div>
|
92 |
<div class="col-sm-3">
|
93 |
-
<a href="
|
94 |
-
</div>
|
95 |
-
<div class="col-sm-2">
|
96 |
-
<a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
|
97 |
</div>
|
98 |
<div class="col-sm-3">
|
99 |
-
<a href="https://
|
100 |
</div>
|
|
|
|
|
|
|
|
|
101 |
</div>
|
102 |
|
103 |
<br>
|
|
|
40 |
<br>
|
41 |
<div class="row text-center" style="font-size:28px">
|
42 |
<div class="col">
|
43 |
+
arXiv 2024
|
44 |
</div>
|
45 |
</div>
|
46 |
<br>
|
|
|
83 |
<br>
|
84 |
|
85 |
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
|
86 |
+
<!-- <div class="col-sm-2">
|
87 |
+
<a href="https://arxiv.org/abs/2310.12982">[arXiv]</a>
|
88 |
+
</div> -->
|
|
|
|
|
|
|
89 |
<div class="col-sm-3">
|
90 |
+
<a href="">[Paper (being prepared)]</a>
|
|
|
|
|
|
|
91 |
</div>
|
92 |
<div class="col-sm-3">
|
93 |
+
<a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
|
94 |
</div>
|
95 |
+
<!-- <div class="col-sm-2">
|
96 |
+
<a
|
97 |
+
href="https://colab.research.google.com/drive/1yo43XTbjxuWA7XgCUO9qxAi7wBI6HzvP?usp=sharing">[Colab]</a>
|
98 |
+
</div> -->
|
99 |
</div>
|
100 |
|
101 |
<br>
|
gitattributes
DELETED
@@ -1,35 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/av_utils.py
CHANGED
@@ -25,32 +25,6 @@ class VideoInfo:
|
|
25 |
def width(self):
|
26 |
return self.all_frames[0].shape[1]
|
27 |
|
28 |
-
@classmethod
|
29 |
-
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
30 |
-
fps: Fraction) -> 'VideoInfo':
|
31 |
-
num_frames = int(duration_sec * fps)
|
32 |
-
all_frames = [image_info.original_frame] * num_frames
|
33 |
-
return cls(duration_sec=duration_sec,
|
34 |
-
fps=fps,
|
35 |
-
clip_frames=image_info.clip_frames,
|
36 |
-
sync_frames=image_info.sync_frames,
|
37 |
-
all_frames=all_frames)
|
38 |
-
|
39 |
-
|
40 |
-
@dataclass
|
41 |
-
class ImageInfo:
|
42 |
-
clip_frames: torch.Tensor
|
43 |
-
sync_frames: torch.Tensor
|
44 |
-
original_frame: Optional[np.ndarray]
|
45 |
-
|
46 |
-
@property
|
47 |
-
def height(self):
|
48 |
-
return self.original_frame.shape[0]
|
49 |
-
|
50 |
-
@property
|
51 |
-
def width(self):
|
52 |
-
return self.original_frame.shape[1]
|
53 |
-
|
54 |
|
55 |
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
56 |
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
|
|
25 |
def width(self):
|
26 |
return self.all_frames[0].shape[1]
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
30 |
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
mmaudio/data/data_setup.py
DELETED
@@ -1,174 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import random
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
from omegaconf import DictConfig
|
7 |
-
from torch.utils.data import DataLoader, Dataset
|
8 |
-
from torch.utils.data.dataloader import default_collate
|
9 |
-
from torch.utils.data.distributed import DistributedSampler
|
10 |
-
|
11 |
-
from mmaudio.data.eval.audiocaps import AudioCapsData
|
12 |
-
from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
|
13 |
-
from mmaudio.data.extracted_audio import ExtractedAudio
|
14 |
-
from mmaudio.data.extracted_vgg import ExtractedVGG
|
15 |
-
from mmaudio.data.mm_dataset import MultiModalDataset
|
16 |
-
from mmaudio.utils.dist_utils import local_rank
|
17 |
-
|
18 |
-
log = logging.getLogger()
|
19 |
-
|
20 |
-
|
21 |
-
# Re-seed randomness every time we start a worker
|
22 |
-
def worker_init_fn(worker_id: int):
|
23 |
-
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
24 |
-
np.random.seed(worker_seed)
|
25 |
-
random.seed(worker_seed)
|
26 |
-
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
27 |
-
|
28 |
-
|
29 |
-
def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
30 |
-
dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
|
31 |
-
data_dim=cfg.data_dim,
|
32 |
-
premade_mmap_dir=data_cfg.memmap_dir)
|
33 |
-
|
34 |
-
return dataset
|
35 |
-
|
36 |
-
|
37 |
-
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
38 |
-
dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
|
39 |
-
data_dim=cfg.data_dim,
|
40 |
-
premade_mmap_dir=data_cfg.memmap_dir)
|
41 |
-
|
42 |
-
return dataset
|
43 |
-
|
44 |
-
|
45 |
-
def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
46 |
-
if cfg.mini_train:
|
47 |
-
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
48 |
-
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
49 |
-
dataset = MultiModalDataset([vgg], [audiocaps])
|
50 |
-
if cfg.example_train:
|
51 |
-
video = load_vgg_data(cfg, cfg.data.Example_video)
|
52 |
-
audio = load_audio_data(cfg, cfg.data.Example_audio)
|
53 |
-
dataset = MultiModalDataset([video], [audio])
|
54 |
-
else:
|
55 |
-
# load the largest one first
|
56 |
-
freesound = load_audio_data(cfg, cfg.data.FreeSound)
|
57 |
-
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
|
58 |
-
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
59 |
-
audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
|
60 |
-
bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
|
61 |
-
clotho = load_audio_data(cfg, cfg.data.Clotho)
|
62 |
-
dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
|
63 |
-
[audiocaps, audioset_sl, bbcsound, freesound, clotho])
|
64 |
-
|
65 |
-
batch_size = cfg.batch_size
|
66 |
-
num_workers = cfg.num_workers
|
67 |
-
pin_memory = cfg.pin_memory
|
68 |
-
sampler, loader = construct_loader(dataset,
|
69 |
-
batch_size,
|
70 |
-
num_workers,
|
71 |
-
shuffle=True,
|
72 |
-
drop_last=True,
|
73 |
-
pin_memory=pin_memory)
|
74 |
-
|
75 |
-
return dataset, sampler, loader
|
76 |
-
|
77 |
-
|
78 |
-
def setup_test_datasets(cfg):
|
79 |
-
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
|
80 |
-
|
81 |
-
batch_size = cfg.batch_size
|
82 |
-
num_workers = cfg.num_workers
|
83 |
-
pin_memory = cfg.pin_memory
|
84 |
-
sampler, loader = construct_loader(dataset,
|
85 |
-
batch_size,
|
86 |
-
num_workers,
|
87 |
-
shuffle=False,
|
88 |
-
drop_last=False,
|
89 |
-
pin_memory=pin_memory)
|
90 |
-
|
91 |
-
return dataset, sampler, loader
|
92 |
-
|
93 |
-
|
94 |
-
def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
|
95 |
-
if cfg.example_train:
|
96 |
-
dataset = load_vgg_data(cfg, cfg.data.Example_video)
|
97 |
-
else:
|
98 |
-
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
99 |
-
|
100 |
-
val_batch_size = cfg.batch_size
|
101 |
-
val_eval_batch_size = cfg.eval_batch_size
|
102 |
-
num_workers = cfg.num_workers
|
103 |
-
pin_memory = cfg.pin_memory
|
104 |
-
_, val_loader = construct_loader(dataset,
|
105 |
-
val_batch_size,
|
106 |
-
num_workers,
|
107 |
-
shuffle=False,
|
108 |
-
drop_last=False,
|
109 |
-
pin_memory=pin_memory)
|
110 |
-
_, eval_loader = construct_loader(dataset,
|
111 |
-
val_eval_batch_size,
|
112 |
-
num_workers,
|
113 |
-
shuffle=False,
|
114 |
-
drop_last=False,
|
115 |
-
pin_memory=pin_memory)
|
116 |
-
|
117 |
-
return dataset, val_loader, eval_loader
|
118 |
-
|
119 |
-
|
120 |
-
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
121 |
-
if dataset_name.startswith('audiocaps_full'):
|
122 |
-
dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
|
123 |
-
cfg.eval_data.AudioCaps_full.csv_path)
|
124 |
-
elif dataset_name.startswith('audiocaps'):
|
125 |
-
dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
|
126 |
-
cfg.eval_data.AudioCaps.csv_path)
|
127 |
-
elif dataset_name.startswith('moviegen'):
|
128 |
-
dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
|
129 |
-
cfg.eval_data.MovieGen.jsonl_path,
|
130 |
-
duration_sec=cfg.duration_s)
|
131 |
-
elif dataset_name.startswith('vggsound'):
|
132 |
-
dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
|
133 |
-
cfg.eval_data.VGGSound.csv_path,
|
134 |
-
duration_sec=cfg.duration_s)
|
135 |
-
else:
|
136 |
-
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
137 |
-
|
138 |
-
batch_size = cfg.batch_size
|
139 |
-
num_workers = cfg.num_workers
|
140 |
-
pin_memory = cfg.pin_memory
|
141 |
-
_, loader = construct_loader(dataset,
|
142 |
-
batch_size,
|
143 |
-
num_workers,
|
144 |
-
shuffle=False,
|
145 |
-
drop_last=False,
|
146 |
-
pin_memory=pin_memory,
|
147 |
-
error_avoidance=True)
|
148 |
-
return dataset, loader
|
149 |
-
|
150 |
-
|
151 |
-
def error_avoidance_collate(batch):
|
152 |
-
batch = list(filter(lambda x: x is not None, batch))
|
153 |
-
return default_collate(batch)
|
154 |
-
|
155 |
-
|
156 |
-
def construct_loader(dataset: Dataset,
|
157 |
-
batch_size: int,
|
158 |
-
num_workers: int,
|
159 |
-
*,
|
160 |
-
shuffle: bool = True,
|
161 |
-
drop_last: bool = True,
|
162 |
-
pin_memory: bool = False,
|
163 |
-
error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
|
164 |
-
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
165 |
-
train_loader = DataLoader(dataset,
|
166 |
-
batch_size,
|
167 |
-
sampler=train_sampler,
|
168 |
-
num_workers=num_workers,
|
169 |
-
worker_init_fn=worker_init_fn,
|
170 |
-
drop_last=drop_last,
|
171 |
-
persistent_workers=num_workers > 0,
|
172 |
-
pin_memory=pin_memory,
|
173 |
-
collate_fn=error_avoidance_collate if error_avoidance else None)
|
174 |
-
return train_sampler, train_loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/eval/__init__.py
DELETED
File without changes
|
mmaudio/data/eval/audiocaps.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from collections import defaultdict
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Union
|
6 |
-
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
from torch.utils.data.dataset import Dataset
|
10 |
-
|
11 |
-
log = logging.getLogger()
|
12 |
-
|
13 |
-
|
14 |
-
class AudioCapsData(Dataset):
|
15 |
-
|
16 |
-
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
17 |
-
df = pd.read_csv(csv_path).to_dict(orient='records')
|
18 |
-
|
19 |
-
audio_files = sorted(os.listdir(audio_path))
|
20 |
-
audio_files = set(
|
21 |
-
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
22 |
-
|
23 |
-
self.data = []
|
24 |
-
for row in df:
|
25 |
-
self.data.append({
|
26 |
-
'name': row['name'],
|
27 |
-
'caption': row['caption'],
|
28 |
-
})
|
29 |
-
|
30 |
-
self.audio_path = Path(audio_path)
|
31 |
-
self.csv_path = Path(csv_path)
|
32 |
-
|
33 |
-
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
34 |
-
|
35 |
-
def __getitem__(self, idx: int) -> torch.Tensor:
|
36 |
-
return self.data[idx]
|
37 |
-
|
38 |
-
def __len__(self):
|
39 |
-
return len(self.data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/eval/moviegen.py
DELETED
@@ -1,131 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Union
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.utils.data.dataset import Dataset
|
9 |
-
from torchvision.transforms import v2
|
10 |
-
from torio.io import StreamingMediaDecoder
|
11 |
-
|
12 |
-
from mmaudio.utils.dist_utils import local_rank
|
13 |
-
|
14 |
-
log = logging.getLogger()
|
15 |
-
|
16 |
-
_CLIP_SIZE = 384
|
17 |
-
_CLIP_FPS = 8.0
|
18 |
-
|
19 |
-
_SYNC_SIZE = 224
|
20 |
-
_SYNC_FPS = 25.0
|
21 |
-
|
22 |
-
|
23 |
-
class MovieGenData(Dataset):
|
24 |
-
|
25 |
-
def __init__(
|
26 |
-
self,
|
27 |
-
video_root: Union[str, Path],
|
28 |
-
sync_root: Union[str, Path],
|
29 |
-
jsonl_root: Union[str, Path],
|
30 |
-
*,
|
31 |
-
duration_sec: float = 10.0,
|
32 |
-
read_clip: bool = True,
|
33 |
-
):
|
34 |
-
self.video_root = Path(video_root)
|
35 |
-
self.sync_root = Path(sync_root)
|
36 |
-
self.jsonl_root = Path(jsonl_root)
|
37 |
-
self.read_clip = read_clip
|
38 |
-
|
39 |
-
videos = sorted(os.listdir(self.video_root))
|
40 |
-
videos = [v[:-4] for v in videos] # remove extensions
|
41 |
-
self.captions = {}
|
42 |
-
|
43 |
-
for v in videos:
|
44 |
-
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
45 |
-
data = json.load(f)
|
46 |
-
self.captions[v] = data['audio_prompt']
|
47 |
-
|
48 |
-
if local_rank == 0:
|
49 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
50 |
-
|
51 |
-
self.duration_sec = duration_sec
|
52 |
-
|
53 |
-
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
54 |
-
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
55 |
-
|
56 |
-
self.clip_augment = v2.Compose([
|
57 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
58 |
-
v2.ToImage(),
|
59 |
-
v2.ToDtype(torch.float32, scale=True),
|
60 |
-
])
|
61 |
-
|
62 |
-
self.sync_augment = v2.Compose([
|
63 |
-
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
64 |
-
v2.CenterCrop(_SYNC_SIZE),
|
65 |
-
v2.ToImage(),
|
66 |
-
v2.ToDtype(torch.float32, scale=True),
|
67 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
68 |
-
])
|
69 |
-
|
70 |
-
self.videos = videos
|
71 |
-
|
72 |
-
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
73 |
-
video_id = self.videos[idx]
|
74 |
-
caption = self.captions[video_id]
|
75 |
-
|
76 |
-
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
77 |
-
reader.add_basic_video_stream(
|
78 |
-
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
79 |
-
frame_rate=_CLIP_FPS,
|
80 |
-
format='rgb24',
|
81 |
-
)
|
82 |
-
reader.add_basic_video_stream(
|
83 |
-
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
84 |
-
frame_rate=_SYNC_FPS,
|
85 |
-
format='rgb24',
|
86 |
-
)
|
87 |
-
|
88 |
-
reader.fill_buffer()
|
89 |
-
data_chunk = reader.pop_chunks()
|
90 |
-
|
91 |
-
clip_chunk = data_chunk[0]
|
92 |
-
sync_chunk = data_chunk[1]
|
93 |
-
if clip_chunk is None:
|
94 |
-
raise RuntimeError(f'CLIP video returned None {video_id}')
|
95 |
-
if clip_chunk.shape[0] < self.clip_expected_length:
|
96 |
-
raise RuntimeError(f'CLIP video too short {video_id}')
|
97 |
-
|
98 |
-
if sync_chunk is None:
|
99 |
-
raise RuntimeError(f'Sync video returned None {video_id}')
|
100 |
-
if sync_chunk.shape[0] < self.sync_expected_length:
|
101 |
-
raise RuntimeError(f'Sync video too short {video_id}')
|
102 |
-
|
103 |
-
# truncate the video
|
104 |
-
clip_chunk = clip_chunk[:self.clip_expected_length]
|
105 |
-
if clip_chunk.shape[0] != self.clip_expected_length:
|
106 |
-
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
107 |
-
f'expected {self.clip_expected_length}, '
|
108 |
-
f'got {clip_chunk.shape[0]}')
|
109 |
-
clip_chunk = self.clip_augment(clip_chunk)
|
110 |
-
|
111 |
-
sync_chunk = sync_chunk[:self.sync_expected_length]
|
112 |
-
if sync_chunk.shape[0] != self.sync_expected_length:
|
113 |
-
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
114 |
-
f'expected {self.sync_expected_length}, '
|
115 |
-
f'got {sync_chunk.shape[0]}')
|
116 |
-
sync_chunk = self.sync_augment(sync_chunk)
|
117 |
-
|
118 |
-
data = {
|
119 |
-
'name': video_id,
|
120 |
-
'caption': caption,
|
121 |
-
'clip_video': clip_chunk,
|
122 |
-
'sync_video': sync_chunk,
|
123 |
-
}
|
124 |
-
|
125 |
-
return data
|
126 |
-
|
127 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
128 |
-
return self.sample(idx)
|
129 |
-
|
130 |
-
def __len__(self):
|
131 |
-
return len(self.captions)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/eval/video_dataset.py
DELETED
@@ -1,197 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
-
from typing import Union
|
6 |
-
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
from torch.utils.data.dataset import Dataset
|
10 |
-
from torchvision.transforms import v2
|
11 |
-
from torio.io import StreamingMediaDecoder
|
12 |
-
|
13 |
-
from mmaudio.utils.dist_utils import local_rank
|
14 |
-
|
15 |
-
log = logging.getLogger()
|
16 |
-
|
17 |
-
_CLIP_SIZE = 384
|
18 |
-
_CLIP_FPS = 8.0
|
19 |
-
|
20 |
-
_SYNC_SIZE = 224
|
21 |
-
_SYNC_FPS = 25.0
|
22 |
-
|
23 |
-
|
24 |
-
class VideoDataset(Dataset):
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
video_root: Union[str, Path],
|
29 |
-
*,
|
30 |
-
duration_sec: float = 8.0,
|
31 |
-
):
|
32 |
-
self.video_root = Path(video_root)
|
33 |
-
|
34 |
-
self.duration_sec = duration_sec
|
35 |
-
|
36 |
-
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
37 |
-
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
38 |
-
|
39 |
-
self.clip_transform = v2.Compose([
|
40 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
41 |
-
v2.ToImage(),
|
42 |
-
v2.ToDtype(torch.float32, scale=True),
|
43 |
-
])
|
44 |
-
|
45 |
-
self.sync_transform = v2.Compose([
|
46 |
-
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
47 |
-
v2.CenterCrop(_SYNC_SIZE),
|
48 |
-
v2.ToImage(),
|
49 |
-
v2.ToDtype(torch.float32, scale=True),
|
50 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
51 |
-
])
|
52 |
-
|
53 |
-
# to be implemented by subclasses
|
54 |
-
self.captions = {}
|
55 |
-
self.videos = sorted(list(self.captions.keys()))
|
56 |
-
|
57 |
-
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
58 |
-
video_id = self.videos[idx]
|
59 |
-
caption = self.captions[video_id]
|
60 |
-
|
61 |
-
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
62 |
-
reader.add_basic_video_stream(
|
63 |
-
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
64 |
-
frame_rate=_CLIP_FPS,
|
65 |
-
format='rgb24',
|
66 |
-
)
|
67 |
-
reader.add_basic_video_stream(
|
68 |
-
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
69 |
-
frame_rate=_SYNC_FPS,
|
70 |
-
format='rgb24',
|
71 |
-
)
|
72 |
-
|
73 |
-
reader.fill_buffer()
|
74 |
-
data_chunk = reader.pop_chunks()
|
75 |
-
|
76 |
-
clip_chunk = data_chunk[0]
|
77 |
-
sync_chunk = data_chunk[1]
|
78 |
-
if clip_chunk is None:
|
79 |
-
raise RuntimeError(f'CLIP video returned None {video_id}')
|
80 |
-
if clip_chunk.shape[0] < self.clip_expected_length:
|
81 |
-
raise RuntimeError(
|
82 |
-
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
83 |
-
)
|
84 |
-
|
85 |
-
if sync_chunk is None:
|
86 |
-
raise RuntimeError(f'Sync video returned None {video_id}')
|
87 |
-
if sync_chunk.shape[0] < self.sync_expected_length:
|
88 |
-
raise RuntimeError(
|
89 |
-
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
90 |
-
)
|
91 |
-
|
92 |
-
# truncate the video
|
93 |
-
clip_chunk = clip_chunk[:self.clip_expected_length]
|
94 |
-
if clip_chunk.shape[0] != self.clip_expected_length:
|
95 |
-
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
96 |
-
f'expected {self.clip_expected_length}, '
|
97 |
-
f'got {clip_chunk.shape[0]}')
|
98 |
-
clip_chunk = self.clip_transform(clip_chunk)
|
99 |
-
|
100 |
-
sync_chunk = sync_chunk[:self.sync_expected_length]
|
101 |
-
if sync_chunk.shape[0] != self.sync_expected_length:
|
102 |
-
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
103 |
-
f'expected {self.sync_expected_length}, '
|
104 |
-
f'got {sync_chunk.shape[0]}')
|
105 |
-
sync_chunk = self.sync_transform(sync_chunk)
|
106 |
-
|
107 |
-
data = {
|
108 |
-
'name': video_id,
|
109 |
-
'caption': caption,
|
110 |
-
'clip_video': clip_chunk,
|
111 |
-
'sync_video': sync_chunk,
|
112 |
-
}
|
113 |
-
|
114 |
-
return data
|
115 |
-
|
116 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
117 |
-
try:
|
118 |
-
return self.sample(idx)
|
119 |
-
except Exception as e:
|
120 |
-
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
121 |
-
return None
|
122 |
-
|
123 |
-
def __len__(self):
|
124 |
-
return len(self.captions)
|
125 |
-
|
126 |
-
|
127 |
-
class VGGSound(VideoDataset):
|
128 |
-
|
129 |
-
def __init__(
|
130 |
-
self,
|
131 |
-
video_root: Union[str, Path],
|
132 |
-
csv_path: Union[str, Path],
|
133 |
-
*,
|
134 |
-
duration_sec: float = 8.0,
|
135 |
-
):
|
136 |
-
super().__init__(video_root, duration_sec=duration_sec)
|
137 |
-
self.video_root = Path(video_root)
|
138 |
-
self.csv_path = Path(csv_path)
|
139 |
-
|
140 |
-
videos = sorted(os.listdir(self.video_root))
|
141 |
-
if local_rank == 0:
|
142 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
143 |
-
self.captions = {}
|
144 |
-
|
145 |
-
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
146 |
-
'split']).to_dict(orient='records')
|
147 |
-
|
148 |
-
videos_no_found = []
|
149 |
-
for row in df:
|
150 |
-
if row['split'] == 'test':
|
151 |
-
start_sec = int(row['sec'])
|
152 |
-
video_id = str(row['id'])
|
153 |
-
# this is how our videos are named
|
154 |
-
video_name = f'{video_id}_{start_sec:06d}'
|
155 |
-
if video_name + '.mp4' not in videos:
|
156 |
-
videos_no_found.append(video_name)
|
157 |
-
continue
|
158 |
-
|
159 |
-
self.captions[video_name] = row['caption']
|
160 |
-
|
161 |
-
if local_rank == 0:
|
162 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
163 |
-
log.info(f'{len(self.captions)} useable videos found')
|
164 |
-
if videos_no_found:
|
165 |
-
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
166 |
-
log.info(
|
167 |
-
'A small amount is expected, as not all videos are still available on YouTube')
|
168 |
-
|
169 |
-
self.videos = sorted(list(self.captions.keys()))
|
170 |
-
|
171 |
-
|
172 |
-
class MovieGen(VideoDataset):
|
173 |
-
|
174 |
-
def __init__(
|
175 |
-
self,
|
176 |
-
video_root: Union[str, Path],
|
177 |
-
jsonl_root: Union[str, Path],
|
178 |
-
*,
|
179 |
-
duration_sec: float = 10.0,
|
180 |
-
):
|
181 |
-
super().__init__(video_root, duration_sec=duration_sec)
|
182 |
-
self.video_root = Path(video_root)
|
183 |
-
self.jsonl_root = Path(jsonl_root)
|
184 |
-
|
185 |
-
videos = sorted(os.listdir(self.video_root))
|
186 |
-
videos = [v[:-4] for v in videos] # remove extensions
|
187 |
-
self.captions = {}
|
188 |
-
|
189 |
-
for v in videos:
|
190 |
-
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
191 |
-
data = json.load(f)
|
192 |
-
self.captions[v] = data['audio_prompt']
|
193 |
-
|
194 |
-
if local_rank == 0:
|
195 |
-
log.info(f'{len(videos)} videos found in {video_root}')
|
196 |
-
|
197 |
-
self.videos = videos
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extracted_audio.py
DELETED
@@ -1,88 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import torch
|
7 |
-
from tensordict import TensorDict
|
8 |
-
from torch.utils.data.dataset import Dataset
|
9 |
-
|
10 |
-
from mmaudio.utils.dist_utils import local_rank
|
11 |
-
|
12 |
-
log = logging.getLogger()
|
13 |
-
|
14 |
-
|
15 |
-
class ExtractedAudio(Dataset):
|
16 |
-
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
tsv_path: Union[str, Path],
|
20 |
-
*,
|
21 |
-
premade_mmap_dir: Union[str, Path],
|
22 |
-
data_dim: dict[str, int],
|
23 |
-
):
|
24 |
-
super().__init__()
|
25 |
-
|
26 |
-
self.data_dim = data_dim
|
27 |
-
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
-
self.ids = [str(d['id']) for d in self.df_list]
|
29 |
-
|
30 |
-
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
-
# load precomputed memory mapped tensors
|
32 |
-
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
-
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
-
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
-
self.mean = td['mean']
|
36 |
-
self.std = td['std']
|
37 |
-
self.text_features = td['text_features']
|
38 |
-
|
39 |
-
log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
|
40 |
-
log.info(f'Loaded mean: {self.mean.shape}.')
|
41 |
-
log.info(f'Loaded std: {self.std.shape}.')
|
42 |
-
log.info(f'Loaded text features: {self.text_features.shape}.')
|
43 |
-
|
44 |
-
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
45 |
-
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
46 |
-
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
47 |
-
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
48 |
-
|
49 |
-
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
50 |
-
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
51 |
-
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
52 |
-
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
53 |
-
|
54 |
-
self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
|
55 |
-
self.data_dim['clip_dim'])
|
56 |
-
self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
|
57 |
-
self.data_dim['sync_dim'])
|
58 |
-
self.video_exist = torch.tensor(0, dtype=torch.bool)
|
59 |
-
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
60 |
-
|
61 |
-
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
62 |
-
latents = self.mean
|
63 |
-
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
64 |
-
|
65 |
-
def get_memory_mapped_tensor(self) -> TensorDict:
|
66 |
-
td = TensorDict({
|
67 |
-
'mean': self.mean,
|
68 |
-
'std': self.std,
|
69 |
-
'text_features': self.text_features,
|
70 |
-
})
|
71 |
-
return td
|
72 |
-
|
73 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
74 |
-
data = {
|
75 |
-
'id': str(self.df_list[idx]['id']),
|
76 |
-
'a_mean': self.mean[idx],
|
77 |
-
'a_std': self.std[idx],
|
78 |
-
'clip_features': self.fake_clip_features,
|
79 |
-
'sync_features': self.fake_sync_features,
|
80 |
-
'text_features': self.text_features[idx],
|
81 |
-
'caption': self.df_list[idx]['caption'],
|
82 |
-
'video_exist': self.video_exist,
|
83 |
-
'text_exist': self.text_exist,
|
84 |
-
}
|
85 |
-
return data
|
86 |
-
|
87 |
-
def __len__(self):
|
88 |
-
return len(self.ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extracted_vgg.py
DELETED
@@ -1,101 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
from pathlib import Path
|
3 |
-
from typing import Union
|
4 |
-
|
5 |
-
import pandas as pd
|
6 |
-
import torch
|
7 |
-
from tensordict import TensorDict
|
8 |
-
from torch.utils.data.dataset import Dataset
|
9 |
-
|
10 |
-
from mmaudio.utils.dist_utils import local_rank
|
11 |
-
|
12 |
-
log = logging.getLogger()
|
13 |
-
|
14 |
-
|
15 |
-
class ExtractedVGG(Dataset):
|
16 |
-
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
tsv_path: Union[str, Path],
|
20 |
-
*,
|
21 |
-
premade_mmap_dir: Union[str, Path],
|
22 |
-
data_dim: dict[str, int],
|
23 |
-
):
|
24 |
-
super().__init__()
|
25 |
-
|
26 |
-
self.data_dim = data_dim
|
27 |
-
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
-
self.ids = [d['id'] for d in self.df_list]
|
29 |
-
|
30 |
-
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
-
# load precomputed memory mapped tensors
|
32 |
-
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
-
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
-
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
-
self.mean = td['mean']
|
36 |
-
self.std = td['std']
|
37 |
-
self.clip_features = td['clip_features']
|
38 |
-
self.sync_features = td['sync_features']
|
39 |
-
self.text_features = td['text_features']
|
40 |
-
|
41 |
-
if local_rank == 0:
|
42 |
-
log.info(f'Loaded {len(self)} samples.')
|
43 |
-
log.info(f'Loaded mean: {self.mean.shape}.')
|
44 |
-
log.info(f'Loaded std: {self.std.shape}.')
|
45 |
-
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
46 |
-
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
47 |
-
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
48 |
-
|
49 |
-
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
50 |
-
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
51 |
-
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
52 |
-
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
53 |
-
|
54 |
-
assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
|
55 |
-
f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
|
56 |
-
assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
|
57 |
-
f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
|
58 |
-
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
59 |
-
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
60 |
-
|
61 |
-
assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
|
62 |
-
f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
|
63 |
-
assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
|
64 |
-
f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
|
65 |
-
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
66 |
-
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
67 |
-
|
68 |
-
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
69 |
-
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
70 |
-
|
71 |
-
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
72 |
-
latents = self.mean
|
73 |
-
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
74 |
-
|
75 |
-
def get_memory_mapped_tensor(self) -> TensorDict:
|
76 |
-
td = TensorDict({
|
77 |
-
'mean': self.mean,
|
78 |
-
'std': self.std,
|
79 |
-
'clip_features': self.clip_features,
|
80 |
-
'sync_features': self.sync_features,
|
81 |
-
'text_features': self.text_features,
|
82 |
-
})
|
83 |
-
return td
|
84 |
-
|
85 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
86 |
-
data = {
|
87 |
-
'id': self.df_list[idx]['id'],
|
88 |
-
'a_mean': self.mean[idx],
|
89 |
-
'a_std': self.std[idx],
|
90 |
-
'clip_features': self.clip_features[idx],
|
91 |
-
'sync_features': self.sync_features[idx],
|
92 |
-
'text_features': self.text_features[idx],
|
93 |
-
'caption': self.df_list[idx]['label'],
|
94 |
-
'video_exist': self.video_exist,
|
95 |
-
'text_exist': self.text_exist,
|
96 |
-
}
|
97 |
-
|
98 |
-
return data
|
99 |
-
|
100 |
-
def __len__(self):
|
101 |
-
return len(self.ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extraction/__init__.py
DELETED
File without changes
|
mmaudio/data/extraction/vgg_sound.py
DELETED
@@ -1,193 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Optional, Union
|
5 |
-
|
6 |
-
import pandas as pd
|
7 |
-
import torch
|
8 |
-
import torchaudio
|
9 |
-
from torch.utils.data.dataset import Dataset
|
10 |
-
from torchvision.transforms import v2
|
11 |
-
from torio.io import StreamingMediaDecoder
|
12 |
-
|
13 |
-
from mmaudio.utils.dist_utils import local_rank
|
14 |
-
|
15 |
-
log = logging.getLogger()
|
16 |
-
|
17 |
-
_CLIP_SIZE = 384
|
18 |
-
_CLIP_FPS = 8.0
|
19 |
-
|
20 |
-
_SYNC_SIZE = 224
|
21 |
-
_SYNC_FPS = 25.0
|
22 |
-
|
23 |
-
|
24 |
-
class VGGSound(Dataset):
|
25 |
-
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
root: Union[str, Path],
|
29 |
-
*,
|
30 |
-
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
31 |
-
sample_rate: int = 16_000,
|
32 |
-
duration_sec: float = 8.0,
|
33 |
-
audio_samples: Optional[int] = None,
|
34 |
-
normalize_audio: bool = False,
|
35 |
-
):
|
36 |
-
self.root = Path(root)
|
37 |
-
self.normalize_audio = normalize_audio
|
38 |
-
if audio_samples is None:
|
39 |
-
self.audio_samples = int(sample_rate * duration_sec)
|
40 |
-
else:
|
41 |
-
self.audio_samples = audio_samples
|
42 |
-
effective_duration = audio_samples / sample_rate
|
43 |
-
# make sure the duration is close enough, within 15ms
|
44 |
-
assert abs(effective_duration - duration_sec) < 0.015, \
|
45 |
-
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
46 |
-
|
47 |
-
videos = sorted(os.listdir(self.root))
|
48 |
-
videos = set([Path(v).stem for v in videos]) # remove extensions
|
49 |
-
self.labels = {}
|
50 |
-
self.videos = []
|
51 |
-
missing_videos = []
|
52 |
-
|
53 |
-
# read the tsv for subset information
|
54 |
-
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
55 |
-
for record in df_list:
|
56 |
-
id = record['id']
|
57 |
-
label = record['label']
|
58 |
-
if id in videos:
|
59 |
-
self.labels[id] = label
|
60 |
-
self.videos.append(id)
|
61 |
-
else:
|
62 |
-
missing_videos.append(id)
|
63 |
-
|
64 |
-
if local_rank == 0:
|
65 |
-
log.info(f'{len(videos)} videos found in {root}')
|
66 |
-
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
67 |
-
log.info(f'{len(missing_videos)} videos missing in {root}')
|
68 |
-
|
69 |
-
self.sample_rate = sample_rate
|
70 |
-
self.duration_sec = duration_sec
|
71 |
-
|
72 |
-
self.expected_audio_length = audio_samples
|
73 |
-
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
74 |
-
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
75 |
-
|
76 |
-
self.clip_transform = v2.Compose([
|
77 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
78 |
-
v2.ToImage(),
|
79 |
-
v2.ToDtype(torch.float32, scale=True),
|
80 |
-
])
|
81 |
-
|
82 |
-
self.sync_transform = v2.Compose([
|
83 |
-
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
84 |
-
v2.CenterCrop(_SYNC_SIZE),
|
85 |
-
v2.ToImage(),
|
86 |
-
v2.ToDtype(torch.float32, scale=True),
|
87 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
88 |
-
])
|
89 |
-
|
90 |
-
self.resampler = {}
|
91 |
-
|
92 |
-
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
93 |
-
video_id = self.videos[idx]
|
94 |
-
label = self.labels[video_id]
|
95 |
-
|
96 |
-
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
97 |
-
reader.add_basic_video_stream(
|
98 |
-
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
99 |
-
frame_rate=_CLIP_FPS,
|
100 |
-
format='rgb24',
|
101 |
-
)
|
102 |
-
reader.add_basic_video_stream(
|
103 |
-
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
104 |
-
frame_rate=_SYNC_FPS,
|
105 |
-
format='rgb24',
|
106 |
-
)
|
107 |
-
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
108 |
-
|
109 |
-
reader.fill_buffer()
|
110 |
-
data_chunk = reader.pop_chunks()
|
111 |
-
|
112 |
-
clip_chunk = data_chunk[0]
|
113 |
-
sync_chunk = data_chunk[1]
|
114 |
-
audio_chunk = data_chunk[2]
|
115 |
-
|
116 |
-
if clip_chunk is None:
|
117 |
-
raise RuntimeError(f'CLIP video returned None {video_id}')
|
118 |
-
if clip_chunk.shape[0] < self.clip_expected_length:
|
119 |
-
raise RuntimeError(
|
120 |
-
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
121 |
-
)
|
122 |
-
|
123 |
-
if sync_chunk is None:
|
124 |
-
raise RuntimeError(f'Sync video returned None {video_id}')
|
125 |
-
if sync_chunk.shape[0] < self.sync_expected_length:
|
126 |
-
raise RuntimeError(
|
127 |
-
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
128 |
-
)
|
129 |
-
|
130 |
-
# process audio
|
131 |
-
sample_rate = int(reader.get_out_stream_info(2).sample_rate)
|
132 |
-
audio_chunk = audio_chunk.transpose(0, 1)
|
133 |
-
audio_chunk = audio_chunk.mean(dim=0) # mono
|
134 |
-
if self.normalize_audio:
|
135 |
-
abs_max = audio_chunk.abs().max()
|
136 |
-
audio_chunk = audio_chunk / abs_max * 0.95
|
137 |
-
if abs_max <= 1e-6:
|
138 |
-
raise RuntimeError(f'Audio is silent {video_id}')
|
139 |
-
|
140 |
-
# resample
|
141 |
-
if sample_rate == self.sample_rate:
|
142 |
-
audio_chunk = audio_chunk
|
143 |
-
else:
|
144 |
-
if sample_rate not in self.resampler:
|
145 |
-
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
146 |
-
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
147 |
-
sample_rate,
|
148 |
-
self.sample_rate,
|
149 |
-
lowpass_filter_width=64,
|
150 |
-
rolloff=0.9475937167399596,
|
151 |
-
resampling_method='sinc_interp_kaiser',
|
152 |
-
beta=14.769656459379492,
|
153 |
-
)
|
154 |
-
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
155 |
-
|
156 |
-
if audio_chunk.shape[0] < self.expected_audio_length:
|
157 |
-
raise RuntimeError(f'Audio too short {video_id}')
|
158 |
-
audio_chunk = audio_chunk[:self.expected_audio_length]
|
159 |
-
|
160 |
-
# truncate the video
|
161 |
-
clip_chunk = clip_chunk[:self.clip_expected_length]
|
162 |
-
if clip_chunk.shape[0] != self.clip_expected_length:
|
163 |
-
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
164 |
-
f'expected {self.clip_expected_length}, '
|
165 |
-
f'got {clip_chunk.shape[0]}')
|
166 |
-
clip_chunk = self.clip_transform(clip_chunk)
|
167 |
-
|
168 |
-
sync_chunk = sync_chunk[:self.sync_expected_length]
|
169 |
-
if sync_chunk.shape[0] != self.sync_expected_length:
|
170 |
-
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
171 |
-
f'expected {self.sync_expected_length}, '
|
172 |
-
f'got {sync_chunk.shape[0]}')
|
173 |
-
sync_chunk = self.sync_transform(sync_chunk)
|
174 |
-
|
175 |
-
data = {
|
176 |
-
'id': video_id,
|
177 |
-
'caption': label,
|
178 |
-
'audio': audio_chunk,
|
179 |
-
'clip_video': clip_chunk,
|
180 |
-
'sync_video': sync_chunk,
|
181 |
-
}
|
182 |
-
|
183 |
-
return data
|
184 |
-
|
185 |
-
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
186 |
-
try:
|
187 |
-
return self.sample(idx)
|
188 |
-
except Exception as e:
|
189 |
-
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
190 |
-
return None
|
191 |
-
|
192 |
-
def __len__(self):
|
193 |
-
return len(self.labels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/extraction/wav_dataset.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from pathlib import Path
|
4 |
-
from typing import Union
|
5 |
-
|
6 |
-
import open_clip
|
7 |
-
import pandas as pd
|
8 |
-
import torch
|
9 |
-
import torchaudio
|
10 |
-
from torch.utils.data.dataset import Dataset
|
11 |
-
|
12 |
-
log = logging.getLogger()
|
13 |
-
|
14 |
-
|
15 |
-
class WavTextClipsDataset(Dataset):
|
16 |
-
|
17 |
-
def __init__(
|
18 |
-
self,
|
19 |
-
root: Union[str, Path],
|
20 |
-
*,
|
21 |
-
captions_tsv: Union[str, Path],
|
22 |
-
clips_tsv: Union[str, Path],
|
23 |
-
sample_rate: int,
|
24 |
-
num_samples: int,
|
25 |
-
normalize_audio: bool = False,
|
26 |
-
reject_silent: bool = False,
|
27 |
-
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
28 |
-
):
|
29 |
-
self.root = Path(root)
|
30 |
-
self.sample_rate = sample_rate
|
31 |
-
self.num_samples = num_samples
|
32 |
-
self.normalize_audio = normalize_audio
|
33 |
-
self.reject_silent = reject_silent
|
34 |
-
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
35 |
-
|
36 |
-
audios = sorted(os.listdir(self.root))
|
37 |
-
audios = set([
|
38 |
-
Path(audio).stem for audio in audios
|
39 |
-
if audio.endswith('.wav') or audio.endswith('.flac')
|
40 |
-
])
|
41 |
-
self.captions = {}
|
42 |
-
|
43 |
-
# read the caption tsv
|
44 |
-
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
45 |
-
for record in df_list:
|
46 |
-
id = record['id']
|
47 |
-
caption = record['caption']
|
48 |
-
self.captions[id] = caption
|
49 |
-
|
50 |
-
# read the clip tsv
|
51 |
-
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
52 |
-
'id': str,
|
53 |
-
'name': str
|
54 |
-
}).to_dict('records')
|
55 |
-
self.clips = []
|
56 |
-
for record in df_list:
|
57 |
-
record['id'] = record['id']
|
58 |
-
record['name'] = record['name']
|
59 |
-
id = record['id']
|
60 |
-
name = record['name']
|
61 |
-
if name not in self.captions:
|
62 |
-
log.warning(f'Audio {name} not found in {captions_tsv}')
|
63 |
-
continue
|
64 |
-
record['caption'] = self.captions[name]
|
65 |
-
self.clips.append(record)
|
66 |
-
|
67 |
-
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
68 |
-
|
69 |
-
self.resampler = {}
|
70 |
-
|
71 |
-
def __getitem__(self, idx: int) -> torch.Tensor:
|
72 |
-
try:
|
73 |
-
clip = self.clips[idx]
|
74 |
-
audio_name = clip['name']
|
75 |
-
audio_id = clip['id']
|
76 |
-
caption = clip['caption']
|
77 |
-
start_sample = clip['start_sample']
|
78 |
-
end_sample = clip['end_sample']
|
79 |
-
|
80 |
-
audio_path = self.root / f'{audio_name}.flac'
|
81 |
-
if not audio_path.exists():
|
82 |
-
audio_path = self.root / f'{audio_name}.wav'
|
83 |
-
assert audio_path.exists()
|
84 |
-
|
85 |
-
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
86 |
-
audio_chunk = audio_chunk.mean(dim=0) # mono
|
87 |
-
abs_max = audio_chunk.abs().max()
|
88 |
-
if self.normalize_audio:
|
89 |
-
audio_chunk = audio_chunk / abs_max * 0.95
|
90 |
-
|
91 |
-
if self.reject_silent and abs_max < 1e-6:
|
92 |
-
log.warning(f'Rejecting silent audio')
|
93 |
-
return None
|
94 |
-
|
95 |
-
audio_chunk = audio_chunk[start_sample:end_sample]
|
96 |
-
|
97 |
-
# resample
|
98 |
-
if sample_rate == self.sample_rate:
|
99 |
-
audio_chunk = audio_chunk
|
100 |
-
else:
|
101 |
-
if sample_rate not in self.resampler:
|
102 |
-
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
103 |
-
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
104 |
-
sample_rate,
|
105 |
-
self.sample_rate,
|
106 |
-
lowpass_filter_width=64,
|
107 |
-
rolloff=0.9475937167399596,
|
108 |
-
resampling_method='sinc_interp_kaiser',
|
109 |
-
beta=14.769656459379492,
|
110 |
-
)
|
111 |
-
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
112 |
-
|
113 |
-
if audio_chunk.shape[0] < self.num_samples:
|
114 |
-
raise ValueError('Audio is too short')
|
115 |
-
audio_chunk = audio_chunk[:self.num_samples]
|
116 |
-
|
117 |
-
tokens = self.tokenizer([caption])[0]
|
118 |
-
|
119 |
-
output = {
|
120 |
-
'waveform': audio_chunk,
|
121 |
-
'id': audio_id,
|
122 |
-
'caption': caption,
|
123 |
-
'tokens': tokens,
|
124 |
-
}
|
125 |
-
|
126 |
-
return output
|
127 |
-
except Exception as e:
|
128 |
-
log.error(f'Error reading {audio_path}: {e}')
|
129 |
-
return None
|
130 |
-
|
131 |
-
def __len__(self):
|
132 |
-
return len(self.clips)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/mm_dataset.py
DELETED
@@ -1,45 +0,0 @@
|
|
1 |
-
import bisect
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch.utils.data.dataset import Dataset
|
5 |
-
|
6 |
-
|
7 |
-
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
8 |
-
class MultiModalDataset(Dataset):
|
9 |
-
datasets: list[Dataset]
|
10 |
-
cumulative_sizes: list[int]
|
11 |
-
|
12 |
-
@staticmethod
|
13 |
-
def cumsum(sequence):
|
14 |
-
r, s = [], 0
|
15 |
-
for e in sequence:
|
16 |
-
l = len(e)
|
17 |
-
r.append(l + s)
|
18 |
-
s += l
|
19 |
-
return r
|
20 |
-
|
21 |
-
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
22 |
-
super().__init__()
|
23 |
-
self.video_datasets = list(video_datasets)
|
24 |
-
self.audio_datasets = list(audio_datasets)
|
25 |
-
self.datasets = self.video_datasets + self.audio_datasets
|
26 |
-
|
27 |
-
self.cumulative_sizes = self.cumsum(self.datasets)
|
28 |
-
|
29 |
-
def __len__(self):
|
30 |
-
return self.cumulative_sizes[-1]
|
31 |
-
|
32 |
-
def __getitem__(self, idx):
|
33 |
-
if idx < 0:
|
34 |
-
if -idx > len(self):
|
35 |
-
raise ValueError("absolute value of index should not exceed dataset length")
|
36 |
-
idx = len(self) + idx
|
37 |
-
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
38 |
-
if dataset_idx == 0:
|
39 |
-
sample_idx = idx
|
40 |
-
else:
|
41 |
-
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
42 |
-
return self.datasets[dataset_idx][sample_idx]
|
43 |
-
|
44 |
-
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
45 |
-
return self.video_datasets[0].compute_latent_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/data/utils.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
import random
|
4 |
-
import tempfile
|
5 |
-
from pathlib import Path
|
6 |
-
from typing import Any, Optional, Union
|
7 |
-
|
8 |
-
import torch
|
9 |
-
import torch.distributed as dist
|
10 |
-
from tensordict import MemoryMappedTensor
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
-
from torch.utils.data.dataset import Dataset
|
13 |
-
from tqdm import tqdm
|
14 |
-
|
15 |
-
from mmaudio.utils.dist_utils import local_rank, world_size
|
16 |
-
|
17 |
-
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
18 |
-
shm_path = Path('/dev/shm')
|
19 |
-
|
20 |
-
log = logging.getLogger()
|
21 |
-
|
22 |
-
|
23 |
-
def reseed(seed):
|
24 |
-
random.seed(seed)
|
25 |
-
torch.manual_seed(seed)
|
26 |
-
|
27 |
-
|
28 |
-
def local_scatter_torch(obj: Optional[Any]):
|
29 |
-
if world_size == 1:
|
30 |
-
# Just one worker. Do nothing.
|
31 |
-
return obj
|
32 |
-
|
33 |
-
array = [obj] * world_size
|
34 |
-
target_array = [None]
|
35 |
-
if local_rank == 0:
|
36 |
-
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
37 |
-
else:
|
38 |
-
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
39 |
-
return target_array[0]
|
40 |
-
|
41 |
-
|
42 |
-
class ShardDataset(Dataset):
|
43 |
-
|
44 |
-
def __init__(self, root):
|
45 |
-
self.root = root
|
46 |
-
self.shards = sorted(os.listdir(root))
|
47 |
-
|
48 |
-
def __len__(self):
|
49 |
-
return len(self.shards)
|
50 |
-
|
51 |
-
def __getitem__(self, idx):
|
52 |
-
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
53 |
-
|
54 |
-
|
55 |
-
def get_tmp_dir(in_memory: bool) -> Path:
|
56 |
-
return shm_path if in_memory else scratch_path
|
57 |
-
|
58 |
-
|
59 |
-
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
60 |
-
in_memory: bool) -> MemoryMappedTensor:
|
61 |
-
if local_rank == 0:
|
62 |
-
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
63 |
-
log.info(f'Loading shards from {data_path} into {f.name}...')
|
64 |
-
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
65 |
-
data = share_tensor_to_all(data)
|
66 |
-
torch.distributed.barrier()
|
67 |
-
f.close() # why does the context manager not close the file for me?
|
68 |
-
else:
|
69 |
-
log.info('Waiting for the data to be shared with me...')
|
70 |
-
data = share_tensor_to_all(None)
|
71 |
-
torch.distributed.barrier()
|
72 |
-
|
73 |
-
return data
|
74 |
-
|
75 |
-
|
76 |
-
def load_shards(
|
77 |
-
data_path: Union[str, Path],
|
78 |
-
ids: list[int],
|
79 |
-
*,
|
80 |
-
tmp_file_path: str,
|
81 |
-
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
82 |
-
|
83 |
-
id_set = set(ids)
|
84 |
-
shards = sorted(os.listdir(data_path))
|
85 |
-
log.info(f'Found {len(shards)} shards in {data_path}.')
|
86 |
-
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
87 |
-
|
88 |
-
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
89 |
-
first_item = next(iter(first_shard.values()))
|
90 |
-
log.info(f'First item shape: {first_item.shape}')
|
91 |
-
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
92 |
-
dtype=torch.float32,
|
93 |
-
filename=tmp_file_path,
|
94 |
-
existsok=True)
|
95 |
-
total_count = 0
|
96 |
-
used_index = set()
|
97 |
-
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
98 |
-
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
99 |
-
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
100 |
-
for data in tqdm(loader, desc='Loading shards'):
|
101 |
-
for i, v in data.items():
|
102 |
-
if i not in id_set:
|
103 |
-
continue
|
104 |
-
|
105 |
-
# tensor_index = ids.index(i)
|
106 |
-
tensor_index = id_indexing[i]
|
107 |
-
if tensor_index in used_index:
|
108 |
-
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
109 |
-
used_index.add(tensor_index)
|
110 |
-
mm_tensor[tensor_index] = v
|
111 |
-
total_count += 1
|
112 |
-
|
113 |
-
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
114 |
-
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
115 |
-
|
116 |
-
return mm_tensor
|
117 |
-
|
118 |
-
|
119 |
-
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
120 |
-
"""
|
121 |
-
x: the tensor to be shared; None if local_rank != 0
|
122 |
-
return: the shared tensor
|
123 |
-
"""
|
124 |
-
|
125 |
-
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
126 |
-
if world_size == 1:
|
127 |
-
return x
|
128 |
-
|
129 |
-
if local_rank == 0:
|
130 |
-
assert x is not None, 'x must not be None if local_rank == 0'
|
131 |
-
else:
|
132 |
-
assert x is None, 'x must be None if local_rank != 0'
|
133 |
-
|
134 |
-
if local_rank == 0:
|
135 |
-
filename = x.filename
|
136 |
-
meta_information = (filename, x.shape, x.dtype)
|
137 |
-
else:
|
138 |
-
meta_information = None
|
139 |
-
|
140 |
-
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
141 |
-
if local_rank == 0:
|
142 |
-
data = x
|
143 |
-
else:
|
144 |
-
data = MemoryMappedTensor.from_filename(filename=filename,
|
145 |
-
dtype=data_type,
|
146 |
-
shape=data_shape)
|
147 |
-
|
148 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/eval_utils.py
CHANGED
@@ -3,16 +3,14 @@ import logging
|
|
3 |
from pathlib import Path
|
4 |
from typing import Optional
|
5 |
|
6 |
-
import numpy as np
|
7 |
import torch
|
8 |
from colorlog import ColoredFormatter
|
9 |
-
from PIL import Image
|
10 |
from torchvision.transforms import v2
|
11 |
|
12 |
-
from mmaudio.data.av_utils import
|
13 |
from mmaudio.model.flow_matching import FlowMatching
|
14 |
from mmaudio.model.networks import MMAudio
|
15 |
-
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
16 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
17 |
from mmaudio.utils.download_utils import download_model_if_needed
|
18 |
|
@@ -90,7 +88,6 @@ def generate(
|
|
90 |
cfg_strength: float,
|
91 |
clip_batch_size_multiplier: int = 40,
|
92 |
sync_batch_size_multiplier: int = 40,
|
93 |
-
image_input: bool = False,
|
94 |
) -> torch.Tensor:
|
95 |
device = feature_utils.device
|
96 |
dtype = feature_utils.dtype
|
@@ -101,12 +98,10 @@ def generate(
|
|
101 |
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
102 |
batch_size=bs *
|
103 |
clip_batch_size_multiplier)
|
104 |
-
if image_input:
|
105 |
-
clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
|
106 |
else:
|
107 |
clip_features = net.get_empty_clip_sequence(bs)
|
108 |
|
109 |
-
if sync_video is not None
|
110 |
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
111 |
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
112 |
batch_size=bs *
|
@@ -144,7 +139,7 @@ def generate(
|
|
144 |
return audio
|
145 |
|
146 |
|
147 |
-
LOGFORMAT = "
|
148 |
|
149 |
|
150 |
def setup_eval_logging(log_level: int = logging.INFO):
|
@@ -158,14 +153,12 @@ def setup_eval_logging(log_level: int = logging.INFO):
|
|
158 |
log.addHandler(stream)
|
159 |
|
160 |
|
161 |
-
_CLIP_SIZE = 384
|
162 |
-
_CLIP_FPS = 8.0
|
163 |
-
|
164 |
-
_SYNC_SIZE = 224
|
165 |
-
_SYNC_FPS = 25.0
|
166 |
-
|
167 |
-
|
168 |
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
clip_transform = v2.Compose([
|
171 |
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
@@ -220,36 +213,5 @@ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = Tr
|
|
220 |
return video_info
|
221 |
|
222 |
|
223 |
-
def load_image(image_path: Path) -> VideoInfo:
|
224 |
-
clip_transform = v2.Compose([
|
225 |
-
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
226 |
-
v2.ToImage(),
|
227 |
-
v2.ToDtype(torch.float32, scale=True),
|
228 |
-
])
|
229 |
-
|
230 |
-
sync_transform = v2.Compose([
|
231 |
-
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
232 |
-
v2.CenterCrop(_SYNC_SIZE),
|
233 |
-
v2.ToImage(),
|
234 |
-
v2.ToDtype(torch.float32, scale=True),
|
235 |
-
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
236 |
-
])
|
237 |
-
|
238 |
-
frame = np.array(Image.open(image_path))
|
239 |
-
|
240 |
-
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
241 |
-
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
242 |
-
|
243 |
-
clip_frames = clip_transform(clip_chunk)
|
244 |
-
sync_frames = sync_transform(sync_chunk)
|
245 |
-
|
246 |
-
video_info = ImageInfo(
|
247 |
-
clip_frames=clip_frames,
|
248 |
-
sync_frames=sync_frames,
|
249 |
-
original_frame=frame,
|
250 |
-
)
|
251 |
-
return video_info
|
252 |
-
|
253 |
-
|
254 |
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
255 |
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
|
|
3 |
from pathlib import Path
|
4 |
from typing import Optional
|
5 |
|
|
|
6 |
import torch
|
7 |
from colorlog import ColoredFormatter
|
|
|
8 |
from torchvision.transforms import v2
|
9 |
|
10 |
+
from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
|
11 |
from mmaudio.model.flow_matching import FlowMatching
|
12 |
from mmaudio.model.networks import MMAudio
|
13 |
+
from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
|
14 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
15 |
from mmaudio.utils.download_utils import download_model_if_needed
|
16 |
|
|
|
88 |
cfg_strength: float,
|
89 |
clip_batch_size_multiplier: int = 40,
|
90 |
sync_batch_size_multiplier: int = 40,
|
|
|
91 |
) -> torch.Tensor:
|
92 |
device = feature_utils.device
|
93 |
dtype = feature_utils.dtype
|
|
|
98 |
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
99 |
batch_size=bs *
|
100 |
clip_batch_size_multiplier)
|
|
|
|
|
101 |
else:
|
102 |
clip_features = net.get_empty_clip_sequence(bs)
|
103 |
|
104 |
+
if sync_video is not None:
|
105 |
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
106 |
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
107 |
batch_size=bs *
|
|
|
139 |
return audio
|
140 |
|
141 |
|
142 |
+
LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
|
143 |
|
144 |
|
145 |
def setup_eval_logging(log_level: int = logging.INFO):
|
|
|
153 |
log.addHandler(stream)
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
157 |
+
_CLIP_SIZE = 384
|
158 |
+
_CLIP_FPS = 8.0
|
159 |
+
|
160 |
+
_SYNC_SIZE = 224
|
161 |
+
_SYNC_FPS = 25.0
|
162 |
|
163 |
clip_transform = v2.Compose([
|
164 |
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
|
|
213 |
return video_info
|
214 |
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
217 |
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
mmaudio/ext/autoencoder/autoencoder.py
CHANGED
@@ -20,7 +20,7 @@ class AutoEncoderModule(nn.Module):
|
|
20 |
super().__init__()
|
21 |
self.vae: VAE = get_my_vae(mode).eval()
|
22 |
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
23 |
-
self.vae.load_state_dict(vae_state_dict)
|
24 |
self.vae.remove_weight_norm()
|
25 |
|
26 |
if mode == '16k':
|
|
|
20 |
super().__init__()
|
21 |
self.vae: VAE = get_my_vae(mode).eval()
|
22 |
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
23 |
+
self.vae.load_state_dict(vae_state_dict, strict=False)
|
24 |
self.vae.remove_weight_norm()
|
25 |
|
26 |
if mode == '16k':
|
mmaudio/ext/autoencoder/vae.py
CHANGED
@@ -75,9 +75,13 @@ class VAE(nn.Module):
|
|
75 |
super().__init__()
|
76 |
|
77 |
if data_dim == 80:
|
|
|
|
|
78 |
self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
79 |
self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
80 |
elif data_dim == 128:
|
|
|
|
|
81 |
self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
82 |
self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
83 |
|
|
|
75 |
super().__init__()
|
76 |
|
77 |
if data_dim == 80:
|
78 |
+
# self.data_mean = torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda()
|
79 |
+
# self.data_std = torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda()
|
80 |
self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
81 |
self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
82 |
elif data_dim == 128:
|
83 |
+
# torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda()
|
84 |
+
# self.data_std = torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda()
|
85 |
self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
86 |
self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
87 |
|
mmaudio/ext/mel_converter.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
# Reference: # https://github.com/bytedance/Make-An-Audio-2
|
2 |
-
from typing import Literal
|
3 |
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
from librosa.filters import mel as librosa_mel_fn
|
7 |
|
8 |
|
9 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5,
|
10 |
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
11 |
|
12 |
|
@@ -20,14 +19,14 @@ class MelConverter(nn.Module):
|
|
20 |
def __init__(
|
21 |
self,
|
22 |
*,
|
23 |
-
sampling_rate: float,
|
24 |
-
n_fft: int,
|
25 |
-
num_mels: int,
|
26 |
-
hop_size: int,
|
27 |
-
win_size: int,
|
28 |
-
fmin: float,
|
29 |
-
fmax: float,
|
30 |
-
norm_fn,
|
31 |
):
|
32 |
super().__init__()
|
33 |
self.sampling_rate = sampling_rate
|
@@ -81,26 +80,3 @@ class MelConverter(nn.Module):
|
|
81 |
spec = spectral_normalize_torch(spec, self.norm_fn)
|
82 |
|
83 |
return spec
|
84 |
-
|
85 |
-
|
86 |
-
def get_mel_converter(mode: Literal['16k', '44k']) -> MelConverter:
|
87 |
-
if mode == '16k':
|
88 |
-
return MelConverter(sampling_rate=16_000,
|
89 |
-
n_fft=1024,
|
90 |
-
num_mels=80,
|
91 |
-
hop_size=256,
|
92 |
-
win_size=1024,
|
93 |
-
fmin=0,
|
94 |
-
fmax=8_000,
|
95 |
-
norm_fn=torch.log10)
|
96 |
-
elif mode == '44k':
|
97 |
-
return MelConverter(sampling_rate=44_100,
|
98 |
-
n_fft=2048,
|
99 |
-
num_mels=128,
|
100 |
-
hop_size=512,
|
101 |
-
win_size=2048,
|
102 |
-
fmin=0,
|
103 |
-
fmax=44100 / 2,
|
104 |
-
norm_fn=torch.log)
|
105 |
-
else:
|
106 |
-
raise ValueError(f'Unknown mode: {mode}')
|
|
|
1 |
# Reference: # https://github.com/bytedance/Make-An-Audio-2
|
|
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
from librosa.filters import mel as librosa_mel_fn
|
6 |
|
7 |
|
8 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10):
|
9 |
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
10 |
|
11 |
|
|
|
19 |
def __init__(
|
20 |
self,
|
21 |
*,
|
22 |
+
sampling_rate: float = 16_000,
|
23 |
+
n_fft: int = 1024,
|
24 |
+
num_mels: int = 80,
|
25 |
+
hop_size: int = 256,
|
26 |
+
win_size: int = 1024,
|
27 |
+
fmin: float = 0,
|
28 |
+
fmax: float = 8_000,
|
29 |
+
norm_fn=torch.log10,
|
30 |
):
|
31 |
super().__init__()
|
32 |
self.sampling_rate = sampling_rate
|
|
|
80 |
spec = spectral_normalize_torch(spec, self.norm_fn)
|
81 |
|
82 |
return spec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/model/embeddings.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
-
import math
|
4 |
|
5 |
# https://github.com/facebookresearch/DiT
|
6 |
|
|
|
7 |
class TimestepEmbedder(nn.Module):
|
8 |
"""
|
9 |
Embeds scalar timesteps into vector representations.
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
|
|
3 |
|
4 |
# https://github.com/facebookresearch/DiT
|
5 |
|
6 |
+
|
7 |
class TimestepEmbedder(nn.Module):
|
8 |
"""
|
9 |
Embeds scalar timesteps into vector representations.
|
mmaudio/model/flow_matching.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import logging
|
2 |
-
from typing import Callable, Optional
|
3 |
|
4 |
import torch
|
5 |
from torchdiffeq import odeint
|
6 |
|
|
|
|
|
7 |
log = logging.getLogger()
|
8 |
|
9 |
|
@@ -43,8 +45,12 @@ class FlowMatching:
|
|
43 |
Cs: list[torch.Tensor],
|
44 |
generator: Optional[torch.Generator] = None
|
45 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
46 |
x0 = torch.empty_like(x1).normal_(generator=generator)
|
47 |
|
|
|
|
|
|
|
48 |
xt = self.get_conditional_flow(x0, x1, t)
|
49 |
return x0, x1, xt, Cs
|
50 |
|
@@ -68,4 +74,15 @@ class FlowMatching:
|
|
68 |
dt = next_t - t
|
69 |
x = x + dt * flow
|
70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
return x
|
|
|
1 |
import logging
|
2 |
+
from typing import Callable, Iterable, Optional
|
3 |
|
4 |
import torch
|
5 |
from torchdiffeq import odeint
|
6 |
|
7 |
+
# from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
|
8 |
+
|
9 |
log = logging.getLogger()
|
10 |
|
11 |
|
|
|
45 |
Cs: list[torch.Tensor],
|
46 |
generator: Optional[torch.Generator] = None
|
47 |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
48 |
+
# x0 = torch.randn_like(x1, generator=generator)
|
49 |
x0 = torch.empty_like(x1).normal_(generator=generator)
|
50 |
|
51 |
+
# find mini-batch optimal transport
|
52 |
+
# x0, x1, _, Cs = self.fm.ot_sampler.sample_plan_with_labels(x0, x1, None, Cs, replace=True)
|
53 |
+
|
54 |
xt = self.get_conditional_flow(x0, x1, t)
|
55 |
return x0, x1, xt, Cs
|
56 |
|
|
|
74 |
dt = next_t - t
|
75 |
x = x + dt * flow
|
76 |
|
77 |
+
# return odeint(fn,
|
78 |
+
# x0,
|
79 |
+
# torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype),
|
80 |
+
# method='rk4',
|
81 |
+
# options=dict(step_size=(t1 - t0) / self.num_steps))[-1]
|
82 |
+
# return odeint(fn,
|
83 |
+
# x0,
|
84 |
+
# torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype),
|
85 |
+
# method='euler',
|
86 |
+
# options=dict(step_size=(t1 - t0) / self.num_steps))[-1]
|
87 |
+
|
88 |
return x
|
mmaudio/model/networks.py
CHANGED
@@ -468,4 +468,4 @@ if __name__ == '__main__':
|
|
468 |
|
469 |
# print the number of parameters in terms of millions
|
470 |
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
471 |
-
print(f'Number of parameters: {num_params:.2f}M')
|
|
|
468 |
|
469 |
# print the number of parameters in terms of millions
|
470 |
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
471 |
+
print(f'Number of parameters: {num_params:.2f}M')
|
mmaudio/model/transformer_layers.py
CHANGED
@@ -5,6 +5,7 @@ import torch.nn as nn
|
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from einops.layers.torch import Rearrange
|
|
|
8 |
|
9 |
from mmaudio.ext.rotary_embeddings import apply_rope
|
10 |
from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
|
|
5 |
import torch.nn.functional as F
|
6 |
from einops import rearrange
|
7 |
from einops.layers.torch import Rearrange
|
8 |
+
from torch.nn.attention import SDPBackend, sdpa_kernel
|
9 |
|
10 |
from mmaudio.ext.rotary_embeddings import apply_rope
|
11 |
from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP
|
mmaudio/model/utils/features_utils.py
CHANGED
@@ -9,7 +9,7 @@ from open_clip import create_model_from_pretrained
|
|
9 |
from torchvision.transforms import Normalize
|
10 |
|
11 |
from mmaudio.ext.autoencoder import AutoEncoderModule
|
12 |
-
from mmaudio.ext.mel_converter import
|
13 |
from mmaudio.ext.synchformer import Synchformer
|
14 |
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
|
15 |
|
@@ -63,13 +63,13 @@ class FeaturesUtils(nn.Module):
|
|
63 |
self.tokenizer = None
|
64 |
|
65 |
if tod_vae_ckpt is not None:
|
66 |
-
self.mel_converter = get_mel_converter(mode)
|
67 |
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
68 |
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
69 |
mode=mode,
|
70 |
need_vae_encoder=need_vae_encoder)
|
71 |
else:
|
72 |
self.tod = None
|
|
|
73 |
|
74 |
def compile(self):
|
75 |
if self.clip_model is not None:
|
|
|
9 |
from torchvision.transforms import Normalize
|
10 |
|
11 |
from mmaudio.ext.autoencoder import AutoEncoderModule
|
12 |
+
from mmaudio.ext.mel_converter import MelConverter
|
13 |
from mmaudio.ext.synchformer import Synchformer
|
14 |
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
|
15 |
|
|
|
63 |
self.tokenizer = None
|
64 |
|
65 |
if tod_vae_ckpt is not None:
|
|
|
66 |
self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
|
67 |
vocoder_ckpt_path=bigvgan_vocoder_ckpt,
|
68 |
mode=mode,
|
69 |
need_vae_encoder=need_vae_encoder)
|
70 |
else:
|
71 |
self.tod = None
|
72 |
+
self.mel_converter = MelConverter()
|
73 |
|
74 |
def compile(self):
|
75 |
if self.clip_model is not None:
|
mmaudio/runner.py
DELETED
@@ -1,609 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
trainer.py - wrapper and utility functions for network training
|
3 |
-
Compute loss, back-prop, update parameters, logging, etc.
|
4 |
-
"""
|
5 |
-
import os
|
6 |
-
from pathlib import Path
|
7 |
-
from typing import Optional, Union
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.distributed
|
11 |
-
import torch.optim as optim
|
12 |
-
from av_bench.evaluate import evaluate
|
13 |
-
from av_bench.extract import extract
|
14 |
-
from nitrous_ema import PostHocEMA
|
15 |
-
from omegaconf import DictConfig
|
16 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
17 |
-
|
18 |
-
from mmaudio.model.flow_matching import FlowMatching
|
19 |
-
from mmaudio.model.networks import get_my_mmaudio
|
20 |
-
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K
|
21 |
-
from mmaudio.model.utils.features_utils import FeaturesUtils
|
22 |
-
from mmaudio.model.utils.parameter_groups import get_parameter_groups
|
23 |
-
from mmaudio.model.utils.sample_utils import log_normal_sample
|
24 |
-
from mmaudio.utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero)
|
25 |
-
from mmaudio.utils.log_integrator import Integrator
|
26 |
-
from mmaudio.utils.logger import TensorboardLogger
|
27 |
-
from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator
|
28 |
-
from mmaudio.utils.video_joiner import VideoJoiner
|
29 |
-
|
30 |
-
|
31 |
-
class Runner:
|
32 |
-
|
33 |
-
def __init__(self,
|
34 |
-
cfg: DictConfig,
|
35 |
-
log: TensorboardLogger,
|
36 |
-
run_path: Union[str, Path],
|
37 |
-
for_training: bool = True,
|
38 |
-
latent_mean: Optional[torch.Tensor] = None,
|
39 |
-
latent_std: Optional[torch.Tensor] = None):
|
40 |
-
self.exp_id = cfg.exp_id
|
41 |
-
self.use_amp = cfg.amp
|
42 |
-
self.enable_grad_scaler = cfg.enable_grad_scaler
|
43 |
-
self.for_training = for_training
|
44 |
-
self.cfg = cfg
|
45 |
-
|
46 |
-
if cfg.model.endswith('16k'):
|
47 |
-
self.seq_cfg = CONFIG_16K
|
48 |
-
mode = '16k'
|
49 |
-
elif cfg.model.endswith('44k'):
|
50 |
-
self.seq_cfg = CONFIG_44K
|
51 |
-
mode = '44k'
|
52 |
-
else:
|
53 |
-
raise ValueError(f'Unknown model: {cfg.model}')
|
54 |
-
|
55 |
-
self.sample_rate = self.seq_cfg.sampling_rate
|
56 |
-
self.duration_sec = self.seq_cfg.duration
|
57 |
-
|
58 |
-
# setting up the model
|
59 |
-
empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0]
|
60 |
-
self.network = DDP(get_my_mmaudio(cfg.model,
|
61 |
-
latent_mean=latent_mean,
|
62 |
-
latent_std=latent_std,
|
63 |
-
empty_string_feat=empty_string_feat).cuda(),
|
64 |
-
device_ids=[local_rank],
|
65 |
-
broadcast_buffers=False)
|
66 |
-
if cfg.compile:
|
67 |
-
# NOTE: though train_fn and val_fn are very similar
|
68 |
-
# (early on they are implemented as a single function)
|
69 |
-
# keeping them separate and compiling them separately are CRUCIAL for high performance
|
70 |
-
self.train_fn = torch.compile(self.train_fn)
|
71 |
-
self.val_fn = torch.compile(self.val_fn)
|
72 |
-
|
73 |
-
self.fm = FlowMatching(cfg.sampling.min_sigma,
|
74 |
-
inference_mode=cfg.sampling.method,
|
75 |
-
num_steps=cfg.sampling.num_steps)
|
76 |
-
|
77 |
-
# ema profile
|
78 |
-
if for_training and cfg.ema.enable and local_rank == 0:
|
79 |
-
self.ema = PostHocEMA(self.network.module,
|
80 |
-
sigma_rels=cfg.ema.sigma_rels,
|
81 |
-
update_every=cfg.ema.update_every,
|
82 |
-
checkpoint_every_num_steps=cfg.ema.checkpoint_every,
|
83 |
-
checkpoint_folder=cfg.ema.checkpoint_folder,
|
84 |
-
step_size_correction=True).cuda()
|
85 |
-
self.ema_start = cfg.ema.start
|
86 |
-
else:
|
87 |
-
self.ema = None
|
88 |
-
|
89 |
-
self.rng = torch.Generator(device='cuda')
|
90 |
-
self.rng.manual_seed(cfg['seed'] + local_rank)
|
91 |
-
|
92 |
-
# setting up feature extractors and VAEs
|
93 |
-
if mode == '16k':
|
94 |
-
self.features = FeaturesUtils(
|
95 |
-
tod_vae_ckpt=cfg['vae_16k_ckpt'],
|
96 |
-
bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'],
|
97 |
-
synchformer_ckpt=cfg['synchformer_ckpt'],
|
98 |
-
enable_conditions=True,
|
99 |
-
mode=mode,
|
100 |
-
need_vae_encoder=False,
|
101 |
-
)
|
102 |
-
elif mode == '44k':
|
103 |
-
self.features = FeaturesUtils(
|
104 |
-
tod_vae_ckpt=cfg['vae_44k_ckpt'],
|
105 |
-
synchformer_ckpt=cfg['synchformer_ckpt'],
|
106 |
-
enable_conditions=True,
|
107 |
-
mode=mode,
|
108 |
-
need_vae_encoder=False,
|
109 |
-
)
|
110 |
-
self.features = self.features.cuda().eval()
|
111 |
-
|
112 |
-
if cfg.compile:
|
113 |
-
self.features.compile()
|
114 |
-
|
115 |
-
# hyperparameters
|
116 |
-
self.log_normal_sampling_mean = cfg.sampling.mean
|
117 |
-
self.log_normal_sampling_scale = cfg.sampling.scale
|
118 |
-
self.null_condition_probability = cfg.null_condition_probability
|
119 |
-
self.cfg_strength = cfg.cfg_strength
|
120 |
-
|
121 |
-
# setting up logging
|
122 |
-
self.log = log
|
123 |
-
self.run_path = Path(run_path)
|
124 |
-
vgg_cfg = cfg.data.VGGSound
|
125 |
-
if for_training:
|
126 |
-
self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos',
|
127 |
-
self.sample_rate, self.duration_sec)
|
128 |
-
else:
|
129 |
-
self.test_video_joiner = VideoJoiner(vgg_cfg.root,
|
130 |
-
self.run_path / 'test-sampled-videos',
|
131 |
-
self.sample_rate, self.duration_sec)
|
132 |
-
string_if_rank_zero(self.log, 'model_size',
|
133 |
-
f'{sum([param.nelement() for param in self.network.parameters()])}')
|
134 |
-
string_if_rank_zero(
|
135 |
-
self.log, 'number_of_parameters_that_require_gradient: ',
|
136 |
-
str(
|
137 |
-
sum([
|
138 |
-
param.nelement()
|
139 |
-
for param in filter(lambda p: p.requires_grad, self.network.parameters())
|
140 |
-
])))
|
141 |
-
info_if_rank_zero(self.log, 'torch version: ' + torch.__version__)
|
142 |
-
self.train_integrator = Integrator(self.log, distributed=True)
|
143 |
-
self.val_integrator = Integrator(self.log, distributed=True)
|
144 |
-
|
145 |
-
# setting up optimizer and loss
|
146 |
-
if for_training:
|
147 |
-
self.enter_train()
|
148 |
-
parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0))
|
149 |
-
self.optimizer = optim.AdamW(parameter_groups,
|
150 |
-
lr=cfg['learning_rate'],
|
151 |
-
weight_decay=cfg['weight_decay'],
|
152 |
-
betas=[0.9, 0.95],
|
153 |
-
eps=1e-6 if self.use_amp else 1e-8,
|
154 |
-
fused=True)
|
155 |
-
if self.enable_grad_scaler:
|
156 |
-
self.scaler = torch.amp.GradScaler(init_scale=2048)
|
157 |
-
self.clip_grad_norm = cfg['clip_grad_norm']
|
158 |
-
|
159 |
-
# linearly warmup learning rate
|
160 |
-
linear_warmup_steps = cfg['linear_warmup_steps']
|
161 |
-
|
162 |
-
def warmup(currrent_step: int):
|
163 |
-
return (currrent_step + 1) / (linear_warmup_steps + 1)
|
164 |
-
|
165 |
-
warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup)
|
166 |
-
|
167 |
-
# setting up learning rate scheduler
|
168 |
-
if cfg['lr_schedule'] == 'constant':
|
169 |
-
next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1)
|
170 |
-
elif cfg['lr_schedule'] == 'poly':
|
171 |
-
total_num_iter = cfg['iterations']
|
172 |
-
next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
|
173 |
-
lr_lambda=lambda x:
|
174 |
-
(1 - (x / total_num_iter))**0.9)
|
175 |
-
elif cfg['lr_schedule'] == 'step':
|
176 |
-
next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
|
177 |
-
cfg['lr_schedule_steps'],
|
178 |
-
cfg['lr_schedule_gamma'])
|
179 |
-
else:
|
180 |
-
raise NotImplementedError
|
181 |
-
|
182 |
-
self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer,
|
183 |
-
[warmup_scheduler, next_scheduler],
|
184 |
-
[linear_warmup_steps])
|
185 |
-
|
186 |
-
# Logging info
|
187 |
-
self.log_text_interval = cfg['log_text_interval']
|
188 |
-
self.log_extra_interval = cfg['log_extra_interval']
|
189 |
-
self.save_weights_interval = cfg['save_weights_interval']
|
190 |
-
self.save_checkpoint_interval = cfg['save_checkpoint_interval']
|
191 |
-
self.save_copy_iterations = cfg['save_copy_iterations']
|
192 |
-
self.num_iterations = cfg['num_iterations']
|
193 |
-
if cfg['debug']:
|
194 |
-
self.log_text_interval = self.log_extra_interval = 1
|
195 |
-
|
196 |
-
# update() is called when we log metrics, within the logger
|
197 |
-
self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval)
|
198 |
-
# update() is called every iteration, in this script
|
199 |
-
self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9)
|
200 |
-
else:
|
201 |
-
self.enter_val()
|
202 |
-
|
203 |
-
def train_fn(
|
204 |
-
self,
|
205 |
-
clip_f: torch.Tensor,
|
206 |
-
sync_f: torch.Tensor,
|
207 |
-
text_f: torch.Tensor,
|
208 |
-
a_mean: torch.Tensor,
|
209 |
-
a_std: torch.Tensor,
|
210 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
211 |
-
# sample
|
212 |
-
a_randn = torch.empty_like(a_mean).normal_(generator=self.rng)
|
213 |
-
x1 = a_mean + a_std * a_randn
|
214 |
-
bs = x1.shape[0] # batch_size * seq_len * num_channels
|
215 |
-
|
216 |
-
# normalize the latents
|
217 |
-
x1 = self.network.module.normalize(x1)
|
218 |
-
|
219 |
-
t = log_normal_sample(x1,
|
220 |
-
generator=self.rng,
|
221 |
-
m=self.log_normal_sampling_mean,
|
222 |
-
s=self.log_normal_sampling_scale)
|
223 |
-
x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1,
|
224 |
-
t,
|
225 |
-
Cs=[clip_f, sync_f, text_f],
|
226 |
-
generator=self.rng)
|
227 |
-
|
228 |
-
# classifier-free training
|
229 |
-
samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
230 |
-
null_video = (samples < self.null_condition_probability)
|
231 |
-
clip_f[null_video] = self.network.module.empty_clip_feat
|
232 |
-
sync_f[null_video] = self.network.module.empty_sync_feat
|
233 |
-
|
234 |
-
samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
235 |
-
null_text = (samples < self.null_condition_probability)
|
236 |
-
text_f[null_text] = self.network.module.empty_string_feat
|
237 |
-
|
238 |
-
pred_v = self.network(xt, clip_f, sync_f, text_f, t)
|
239 |
-
loss = self.fm.loss(pred_v, x0, x1)
|
240 |
-
mean_loss = loss.mean()
|
241 |
-
return x1, loss, mean_loss, t
|
242 |
-
|
243 |
-
def val_fn(
|
244 |
-
self,
|
245 |
-
clip_f: torch.Tensor,
|
246 |
-
sync_f: torch.Tensor,
|
247 |
-
text_f: torch.Tensor,
|
248 |
-
x1: torch.Tensor,
|
249 |
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
250 |
-
bs = x1.shape[0] # batch_size * seq_len * num_channels
|
251 |
-
# normalize the latents
|
252 |
-
x1 = self.network.module.normalize(x1)
|
253 |
-
t = log_normal_sample(x1,
|
254 |
-
generator=self.rng,
|
255 |
-
m=self.log_normal_sampling_mean,
|
256 |
-
s=self.log_normal_sampling_scale)
|
257 |
-
x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1,
|
258 |
-
t,
|
259 |
-
Cs=[clip_f, sync_f, text_f],
|
260 |
-
generator=self.rng)
|
261 |
-
|
262 |
-
# classifier-free training
|
263 |
-
samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
264 |
-
# null mask is for when a video is provided but we decided to ignore it
|
265 |
-
null_video = (samples < self.null_condition_probability)
|
266 |
-
# complete mask is for when a video is not provided or we decided to ignore it
|
267 |
-
clip_f[null_video] = self.network.module.empty_clip_feat
|
268 |
-
sync_f[null_video] = self.network.module.empty_sync_feat
|
269 |
-
|
270 |
-
samples = torch.rand(bs, device=x1.device, generator=self.rng)
|
271 |
-
null_text = (samples < self.null_condition_probability)
|
272 |
-
text_f[null_text] = self.network.module.empty_string_feat
|
273 |
-
|
274 |
-
pred_v = self.network(xt, clip_f, sync_f, text_f, t)
|
275 |
-
|
276 |
-
loss = self.fm.loss(pred_v, x0, x1)
|
277 |
-
mean_loss = loss.mean()
|
278 |
-
return loss, mean_loss, t
|
279 |
-
|
280 |
-
def train_pass(self, data, it: int = 0):
|
281 |
-
|
282 |
-
if not self.for_training:
|
283 |
-
raise ValueError('train_pass() should not be called when not training.')
|
284 |
-
|
285 |
-
self.enter_train()
|
286 |
-
with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
|
287 |
-
clip_f = data['clip_features'].cuda(non_blocking=True)
|
288 |
-
sync_f = data['sync_features'].cuda(non_blocking=True)
|
289 |
-
text_f = data['text_features'].cuda(non_blocking=True)
|
290 |
-
video_exist = data['video_exist'].cuda(non_blocking=True)
|
291 |
-
text_exist = data['text_exist'].cuda(non_blocking=True)
|
292 |
-
a_mean = data['a_mean'].cuda(non_blocking=True)
|
293 |
-
a_std = data['a_std'].cuda(non_blocking=True)
|
294 |
-
|
295 |
-
# these masks are for non-existent data; masking for CFG training is in train_fn
|
296 |
-
clip_f[~video_exist] = self.network.module.empty_clip_feat
|
297 |
-
sync_f[~video_exist] = self.network.module.empty_sync_feat
|
298 |
-
text_f[~text_exist] = self.network.module.empty_string_feat
|
299 |
-
|
300 |
-
self.log.data_timer.end()
|
301 |
-
if it % self.log_extra_interval == 0:
|
302 |
-
unmasked_clip_f = clip_f.clone()
|
303 |
-
unmasked_sync_f = sync_f.clone()
|
304 |
-
unmasked_text_f = text_f.clone()
|
305 |
-
x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std)
|
306 |
-
|
307 |
-
self.train_integrator.add_dict({'loss': mean_loss})
|
308 |
-
|
309 |
-
if it % self.log_text_interval == 0 and it != 0:
|
310 |
-
self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0])
|
311 |
-
self.train_integrator.add_binned_tensor('binned_loss', loss, t)
|
312 |
-
self.train_integrator.finalize('train', it)
|
313 |
-
self.train_integrator.reset_except_hooks()
|
314 |
-
|
315 |
-
# Backward pass
|
316 |
-
self.optimizer.zero_grad(set_to_none=True)
|
317 |
-
if self.enable_grad_scaler:
|
318 |
-
self.scaler.scale(mean_loss).backward()
|
319 |
-
self.scaler.unscale_(self.optimizer)
|
320 |
-
grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(),
|
321 |
-
self.clip_grad_norm)
|
322 |
-
self.scaler.step(self.optimizer)
|
323 |
-
self.scaler.update()
|
324 |
-
else:
|
325 |
-
mean_loss.backward()
|
326 |
-
grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(),
|
327 |
-
self.clip_grad_norm)
|
328 |
-
self.optimizer.step()
|
329 |
-
|
330 |
-
if self.ema is not None and it >= self.ema_start:
|
331 |
-
self.ema.update()
|
332 |
-
self.scheduler.step()
|
333 |
-
self.integrator.add_scalar('grad_norm', grad_norm)
|
334 |
-
|
335 |
-
self.enter_val()
|
336 |
-
with torch.amp.autocast('cuda', enabled=self.use_amp,
|
337 |
-
dtype=torch.bfloat16), torch.inference_mode():
|
338 |
-
try:
|
339 |
-
if it % self.log_extra_interval == 0:
|
340 |
-
# save GT audio
|
341 |
-
# unnormalize the latents
|
342 |
-
x1 = self.network.module.unnormalize(x1[0:1])
|
343 |
-
mel = self.features.decode(x1)
|
344 |
-
audio = self.features.vocode(mel).cpu()[0] # 1 * num_samples
|
345 |
-
self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it)
|
346 |
-
self.log.log_audio('train',
|
347 |
-
f'audio-gt-r{local_rank}',
|
348 |
-
audio,
|
349 |
-
it,
|
350 |
-
sample_rate=self.sample_rate)
|
351 |
-
|
352 |
-
# save audio from sampling
|
353 |
-
x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng)
|
354 |
-
clip_f = unmasked_clip_f[0:1]
|
355 |
-
sync_f = unmasked_sync_f[0:1]
|
356 |
-
text_f = unmasked_text_f[0:1]
|
357 |
-
conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f)
|
358 |
-
empty_conditions = self.network.module.get_empty_conditions(x0.shape[0])
|
359 |
-
cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper(
|
360 |
-
t, x, conditions, empty_conditions, self.cfg_strength)
|
361 |
-
x1_hat = self.fm.to_data(cfg_ode_wrapper, x0)
|
362 |
-
x1_hat = self.network.module.unnormalize(x1_hat)
|
363 |
-
mel = self.features.decode(x1_hat)
|
364 |
-
audio = self.features.vocode(mel).cpu()[0]
|
365 |
-
self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it)
|
366 |
-
self.log.log_audio('train',
|
367 |
-
f'audio-r{local_rank}',
|
368 |
-
audio,
|
369 |
-
it,
|
370 |
-
sample_rate=self.sample_rate)
|
371 |
-
except Exception as e:
|
372 |
-
self.log.warning(f'Error in extra logging: {e}')
|
373 |
-
if self.cfg.debug:
|
374 |
-
raise
|
375 |
-
|
376 |
-
# Save network weights and checkpoint if needed
|
377 |
-
save_copy = it in self.save_copy_iterations
|
378 |
-
|
379 |
-
if (it % self.save_weights_interval == 0 and it != 0) or save_copy:
|
380 |
-
self.save_weights(it)
|
381 |
-
|
382 |
-
if it % self.save_checkpoint_interval == 0 and it != 0:
|
383 |
-
self.save_checkpoint(it, save_copy=save_copy)
|
384 |
-
|
385 |
-
self.log.data_timer.start()
|
386 |
-
|
387 |
-
@torch.inference_mode()
|
388 |
-
def validation_pass(self, data, it: int = 0):
|
389 |
-
self.enter_val()
|
390 |
-
with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
|
391 |
-
clip_f = data['clip_features'].cuda(non_blocking=True)
|
392 |
-
sync_f = data['sync_features'].cuda(non_blocking=True)
|
393 |
-
text_f = data['text_features'].cuda(non_blocking=True)
|
394 |
-
video_exist = data['video_exist'].cuda(non_blocking=True)
|
395 |
-
text_exist = data['text_exist'].cuda(non_blocking=True)
|
396 |
-
a_mean = data['a_mean'].cuda(non_blocking=True)
|
397 |
-
a_std = data['a_std'].cuda(non_blocking=True)
|
398 |
-
|
399 |
-
clip_f[~video_exist] = self.network.module.empty_clip_feat
|
400 |
-
sync_f[~video_exist] = self.network.module.empty_sync_feat
|
401 |
-
text_f[~text_exist] = self.network.module.empty_string_feat
|
402 |
-
a_randn = torch.empty_like(a_mean).normal_(generator=self.rng)
|
403 |
-
x1 = a_mean + a_std * a_randn
|
404 |
-
|
405 |
-
self.log.data_timer.end()
|
406 |
-
loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1)
|
407 |
-
|
408 |
-
self.val_integrator.add_binned_tensor('binned_loss', loss, t)
|
409 |
-
self.val_integrator.add_dict({'loss': mean_loss})
|
410 |
-
|
411 |
-
self.log.data_timer.start()
|
412 |
-
|
413 |
-
@torch.inference_mode()
|
414 |
-
def inference_pass(self,
|
415 |
-
data,
|
416 |
-
it: int,
|
417 |
-
data_cfg: DictConfig,
|
418 |
-
*,
|
419 |
-
save_eval: bool = True) -> Path:
|
420 |
-
self.enter_val()
|
421 |
-
with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
|
422 |
-
clip_f = data['clip_features'].cuda(non_blocking=True)
|
423 |
-
sync_f = data['sync_features'].cuda(non_blocking=True)
|
424 |
-
text_f = data['text_features'].cuda(non_blocking=True)
|
425 |
-
video_exist = data['video_exist'].cuda(non_blocking=True)
|
426 |
-
text_exist = data['text_exist'].cuda(non_blocking=True)
|
427 |
-
a_mean = data['a_mean'].cuda(non_blocking=True) # for the shape only
|
428 |
-
|
429 |
-
clip_f[~video_exist] = self.network.module.empty_clip_feat
|
430 |
-
sync_f[~video_exist] = self.network.module.empty_sync_feat
|
431 |
-
text_f[~text_exist] = self.network.module.empty_string_feat
|
432 |
-
|
433 |
-
# sample
|
434 |
-
x0 = torch.empty_like(a_mean).normal_(generator=self.rng)
|
435 |
-
conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f)
|
436 |
-
empty_conditions = self.network.module.get_empty_conditions(x0.shape[0])
|
437 |
-
cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper(
|
438 |
-
t, x, conditions, empty_conditions, self.cfg_strength)
|
439 |
-
x1_hat = self.fm.to_data(cfg_ode_wrapper, x0)
|
440 |
-
x1_hat = self.network.module.unnormalize(x1_hat)
|
441 |
-
mel = self.features.decode(x1_hat)
|
442 |
-
audio = self.features.vocode(mel).cpu()
|
443 |
-
for i in range(audio.shape[0]):
|
444 |
-
video_id = data['id'][i]
|
445 |
-
if (not self.for_training) and i == 0:
|
446 |
-
# save very few videos
|
447 |
-
self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1))
|
448 |
-
|
449 |
-
if data_cfg.output_subdir is not None:
|
450 |
-
# validation
|
451 |
-
if save_eval:
|
452 |
-
iter_naming = f'{it:09d}'
|
453 |
-
else:
|
454 |
-
iter_naming = 'val-cache'
|
455 |
-
audio_dir = self.log.log_audio(iter_naming,
|
456 |
-
f'{video_id}',
|
457 |
-
audio[i],
|
458 |
-
it=None,
|
459 |
-
sample_rate=self.sample_rate,
|
460 |
-
subdir=Path(data_cfg.output_subdir))
|
461 |
-
if save_eval and i == 0:
|
462 |
-
self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}',
|
463 |
-
audio[i].transpose(0, 1))
|
464 |
-
else:
|
465 |
-
# full test set, usually
|
466 |
-
audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled',
|
467 |
-
f'{video_id}',
|
468 |
-
audio[i],
|
469 |
-
it=None,
|
470 |
-
sample_rate=self.sample_rate)
|
471 |
-
|
472 |
-
return Path(audio_dir)
|
473 |
-
|
474 |
-
@torch.inference_mode()
|
475 |
-
def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]:
|
476 |
-
with torch.amp.autocast('cuda', enabled=False):
|
477 |
-
if local_rank == 0:
|
478 |
-
extract(audio_path=audio_dir,
|
479 |
-
output_path=audio_dir / 'cache',
|
480 |
-
device='cuda',
|
481 |
-
batch_size=32,
|
482 |
-
audio_length=8)
|
483 |
-
output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache),
|
484 |
-
pred_audio_cache=audio_dir / 'cache')
|
485 |
-
for k, v in output_metrics.items():
|
486 |
-
# pad k to 10 characters
|
487 |
-
# pad v to 10 decimal places
|
488 |
-
self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it)
|
489 |
-
self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}')
|
490 |
-
else:
|
491 |
-
output_metrics = None
|
492 |
-
|
493 |
-
return output_metrics
|
494 |
-
|
495 |
-
def save_weights(self, it, save_copy=False):
|
496 |
-
if local_rank != 0:
|
497 |
-
return
|
498 |
-
|
499 |
-
os.makedirs(self.run_path, exist_ok=True)
|
500 |
-
if save_copy:
|
501 |
-
model_path = self.run_path / f'{self.exp_id}_{it}.pth'
|
502 |
-
torch.save(self.network.module.state_dict(), model_path)
|
503 |
-
self.log.info(f'Network weights saved to {model_path}.')
|
504 |
-
|
505 |
-
# if last exists, move it to a shadow copy
|
506 |
-
model_path = self.run_path / f'{self.exp_id}_last.pth'
|
507 |
-
if model_path.exists():
|
508 |
-
shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow'))
|
509 |
-
model_path.replace(shadow_path)
|
510 |
-
self.log.info(f'Network weights shadowed to {shadow_path}.')
|
511 |
-
|
512 |
-
torch.save(self.network.module.state_dict(), model_path)
|
513 |
-
self.log.info(f'Network weights saved to {model_path}.')
|
514 |
-
|
515 |
-
def save_checkpoint(self, it, save_copy=False):
|
516 |
-
if local_rank != 0:
|
517 |
-
return
|
518 |
-
|
519 |
-
checkpoint = {
|
520 |
-
'it': it,
|
521 |
-
'weights': self.network.module.state_dict(),
|
522 |
-
'optimizer': self.optimizer.state_dict(),
|
523 |
-
'scheduler': self.scheduler.state_dict(),
|
524 |
-
'ema': self.ema.state_dict() if self.ema is not None else None,
|
525 |
-
}
|
526 |
-
|
527 |
-
os.makedirs(self.run_path, exist_ok=True)
|
528 |
-
if save_copy:
|
529 |
-
model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth'
|
530 |
-
torch.save(checkpoint, model_path)
|
531 |
-
self.log.info(f'Checkpoint saved to {model_path}.')
|
532 |
-
|
533 |
-
# if ckpt_last exists, move it to a shadow copy
|
534 |
-
model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth'
|
535 |
-
if model_path.exists():
|
536 |
-
shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow'))
|
537 |
-
model_path.replace(shadow_path) # moves the file
|
538 |
-
self.log.info(f'Checkpoint shadowed to {shadow_path}.')
|
539 |
-
|
540 |
-
torch.save(checkpoint, model_path)
|
541 |
-
self.log.info(f'Checkpoint saved to {model_path}.')
|
542 |
-
|
543 |
-
def get_latest_checkpoint_path(self):
|
544 |
-
ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth'
|
545 |
-
if not ckpt_path.exists():
|
546 |
-
info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.')
|
547 |
-
return None
|
548 |
-
return ckpt_path
|
549 |
-
|
550 |
-
def get_latest_weight_path(self):
|
551 |
-
weight_path = self.run_path / f'{self.exp_id}_last.pth'
|
552 |
-
if not weight_path.exists():
|
553 |
-
self.log.info(f'No weight found at {weight_path}.')
|
554 |
-
return None
|
555 |
-
return weight_path
|
556 |
-
|
557 |
-
def get_final_ema_weight_path(self):
|
558 |
-
weight_path = self.run_path / f'{self.exp_id}_ema_final.pth'
|
559 |
-
if not weight_path.exists():
|
560 |
-
self.log.info(f'No weight found at {weight_path}.')
|
561 |
-
return None
|
562 |
-
return weight_path
|
563 |
-
|
564 |
-
def load_checkpoint(self, path):
|
565 |
-
# This method loads everything and should be used to resume training
|
566 |
-
map_location = 'cuda:%d' % local_rank
|
567 |
-
checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True)
|
568 |
-
|
569 |
-
it = checkpoint['it']
|
570 |
-
weights = checkpoint['weights']
|
571 |
-
optimizer = checkpoint['optimizer']
|
572 |
-
scheduler = checkpoint['scheduler']
|
573 |
-
if self.ema is not None:
|
574 |
-
self.ema.load_state_dict(checkpoint['ema'])
|
575 |
-
self.log.info(f'EMA states loaded from step {self.ema.step}')
|
576 |
-
|
577 |
-
map_location = 'cuda:%d' % local_rank
|
578 |
-
self.network.module.load_state_dict(weights)
|
579 |
-
self.optimizer.load_state_dict(optimizer)
|
580 |
-
self.scheduler.load_state_dict(scheduler)
|
581 |
-
|
582 |
-
self.log.info(f'Global iteration {it} loaded.')
|
583 |
-
self.log.info('Network weights, optimizer states, and scheduler states loaded.')
|
584 |
-
|
585 |
-
return it
|
586 |
-
|
587 |
-
def load_weights_in_memory(self, src_dict):
|
588 |
-
self.network.module.load_weights(src_dict)
|
589 |
-
self.log.info('Network weights loaded from memory.')
|
590 |
-
|
591 |
-
def load_weights(self, path):
|
592 |
-
# This method loads only the network weight and should be used to load a pretrained model
|
593 |
-
map_location = 'cuda:%d' % local_rank
|
594 |
-
src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True)
|
595 |
-
|
596 |
-
self.log.info(f'Importing network weights from {path}...')
|
597 |
-
self.load_weights_in_memory(src_dict)
|
598 |
-
|
599 |
-
def weights(self):
|
600 |
-
return self.network.module.state_dict()
|
601 |
-
|
602 |
-
def enter_train(self):
|
603 |
-
self.integrator = self.train_integrator
|
604 |
-
self.network.train()
|
605 |
-
return self
|
606 |
-
|
607 |
-
def enter_val(self):
|
608 |
-
self.network.eval()
|
609 |
-
return self
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/sample.py
DELETED
@@ -1,90 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import logging
|
3 |
-
import os
|
4 |
-
import random
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
from hydra.core.hydra_config import HydraConfig
|
9 |
-
from omegaconf import DictConfig, open_dict
|
10 |
-
from tqdm import tqdm
|
11 |
-
|
12 |
-
from mmaudio.data.data_setup import setup_test_datasets
|
13 |
-
from mmaudio.runner import Runner
|
14 |
-
from mmaudio.utils.dist_utils import info_if_rank_zero
|
15 |
-
from mmaudio.utils.logger import TensorboardLogger
|
16 |
-
|
17 |
-
local_rank = int(os.environ['LOCAL_RANK'])
|
18 |
-
world_size = int(os.environ['WORLD_SIZE'])
|
19 |
-
|
20 |
-
|
21 |
-
def sample(cfg: DictConfig):
|
22 |
-
# initial setup
|
23 |
-
num_gpus = world_size
|
24 |
-
run_dir = HydraConfig.get().run.dir
|
25 |
-
|
26 |
-
# wrap python logger with a tensorboard logger
|
27 |
-
log = TensorboardLogger(cfg.exp_id,
|
28 |
-
run_dir,
|
29 |
-
logging.getLogger(),
|
30 |
-
is_rank0=(local_rank == 0),
|
31 |
-
enable_email=cfg.enable_email and not cfg.debug)
|
32 |
-
|
33 |
-
info_if_rank_zero(log, f'All configuration: {cfg}')
|
34 |
-
info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}')
|
35 |
-
|
36 |
-
# cuda setup
|
37 |
-
torch.cuda.set_device(local_rank)
|
38 |
-
torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
|
39 |
-
|
40 |
-
# number of dataloader workers
|
41 |
-
info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}')
|
42 |
-
|
43 |
-
# Set seeds to ensure the same initialization
|
44 |
-
torch.manual_seed(cfg.seed)
|
45 |
-
np.random.seed(cfg.seed)
|
46 |
-
random.seed(cfg.seed)
|
47 |
-
|
48 |
-
# setting up configurations
|
49 |
-
info_if_rank_zero(log, f'Configuration: {cfg}')
|
50 |
-
info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}')
|
51 |
-
|
52 |
-
# construct the trainer
|
53 |
-
runner = Runner(cfg, log=log, run_path=run_dir, for_training=False).enter_val()
|
54 |
-
|
55 |
-
# load the last weights if needed
|
56 |
-
if cfg['weights'] is not None:
|
57 |
-
info_if_rank_zero(log, f'Loading weights from the disk: {cfg["weights"]}')
|
58 |
-
runner.load_weights(cfg['weights'])
|
59 |
-
cfg['weights'] = None
|
60 |
-
else:
|
61 |
-
weights = runner.get_final_ema_weight_path()
|
62 |
-
if weights is not None:
|
63 |
-
info_if_rank_zero(log, f'Automatically finding weight: {weights}')
|
64 |
-
runner.load_weights(weights)
|
65 |
-
|
66 |
-
# setup datasets
|
67 |
-
dataset, sampler, loader = setup_test_datasets(cfg)
|
68 |
-
data_cfg = cfg.data.ExtractedVGG_test
|
69 |
-
with open_dict(data_cfg):
|
70 |
-
if cfg.output_name is not None:
|
71 |
-
# append to the tag
|
72 |
-
data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}'
|
73 |
-
|
74 |
-
# loop
|
75 |
-
audio_path = None
|
76 |
-
for curr_iter, data in enumerate(tqdm(loader)):
|
77 |
-
new_audio_path = runner.inference_pass(data, curr_iter, data_cfg)
|
78 |
-
if audio_path is None:
|
79 |
-
audio_path = new_audio_path
|
80 |
-
else:
|
81 |
-
assert audio_path == new_audio_path, 'Different audio path detected'
|
82 |
-
|
83 |
-
info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}')
|
84 |
-
output_metrics = runner.eval(audio_path, curr_iter, data_cfg)
|
85 |
-
|
86 |
-
if local_rank == 0:
|
87 |
-
# write the output metrics to run_dir
|
88 |
-
output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json')
|
89 |
-
with open(output_metrics_path, 'w') as f:
|
90 |
-
json.dump(output_metrics, f, indent=4)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/email_utils.py
DELETED
@@ -1,50 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
import os
|
3 |
-
from datetime import datetime
|
4 |
-
|
5 |
-
import requests
|
6 |
-
from dotenv import load_dotenv
|
7 |
-
from pytz import timezone
|
8 |
-
|
9 |
-
from mmaudio.utils.timezone import my_timezone
|
10 |
-
|
11 |
-
_source = 'USE YOURS'
|
12 |
-
_target = 'USE YOURS'
|
13 |
-
|
14 |
-
log = logging.getLogger()
|
15 |
-
|
16 |
-
_fmt = "%Y-%m-%d %H:%M:%S %Z%z"
|
17 |
-
|
18 |
-
|
19 |
-
class EmailSender:
|
20 |
-
|
21 |
-
def __init__(self, exp_id: str, enable: bool):
|
22 |
-
self.exp_id = exp_id
|
23 |
-
self.enable = enable
|
24 |
-
if enable:
|
25 |
-
load_dotenv()
|
26 |
-
self.MAILGUN_API_KEY = os.getenv('MAILGUN_API_KEY')
|
27 |
-
if self.MAILGUN_API_KEY is None:
|
28 |
-
log.warning('MAILGUN_API_KEY is not set')
|
29 |
-
self.enable = False
|
30 |
-
|
31 |
-
def send(self, subject, content):
|
32 |
-
if self.enable:
|
33 |
-
subject = str(subject)
|
34 |
-
content = str(content)
|
35 |
-
try:
|
36 |
-
return requests.post(f'https://api.mailgun.net/v3/{_source}/messages',
|
37 |
-
auth=('api', self.MAILGUN_API_KEY),
|
38 |
-
data={
|
39 |
-
'from':
|
40 |
-
f'<agent name>🤖 <mailgun@{_source}>',
|
41 |
-
'to': [f'{_target}'],
|
42 |
-
'subject':
|
43 |
-
f'[{self.exp_id}] {subject}',
|
44 |
-
'text':
|
45 |
-
('\n\n' + content + '\n\n<your sign off>\n' +
|
46 |
-
datetime.now(timezone(my_timezone)).strftime(_fmt)),
|
47 |
-
},
|
48 |
-
timeout=20)
|
49 |
-
except Exception as e:
|
50 |
-
log.error(f'Failed to send email: {e}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/log_integrator.py
DELETED
@@ -1,112 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Integrate numerical values for some iterations
|
3 |
-
Typically used for loss computation / logging to tensorboard
|
4 |
-
Call finalize and create a new Integrator when you want to display/log
|
5 |
-
"""
|
6 |
-
from typing import Callable, Union
|
7 |
-
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from mmaudio.utils.logger import TensorboardLogger
|
11 |
-
from mmaudio.utils.tensor_utils import distribute_into_histogram
|
12 |
-
|
13 |
-
|
14 |
-
class Integrator:
|
15 |
-
|
16 |
-
def __init__(self, logger: TensorboardLogger, distributed: bool = True):
|
17 |
-
self.values = {}
|
18 |
-
self.counts = {}
|
19 |
-
self.hooks = [] # List is used here to maintain insertion order
|
20 |
-
|
21 |
-
# for binned tensors
|
22 |
-
self.binned_tensors = {}
|
23 |
-
self.binned_tensor_indices = {}
|
24 |
-
|
25 |
-
self.logger = logger
|
26 |
-
|
27 |
-
self.distributed = distributed
|
28 |
-
self.local_rank = torch.distributed.get_rank()
|
29 |
-
self.world_size = torch.distributed.get_world_size()
|
30 |
-
|
31 |
-
def add_scalar(self, key: str, x: Union[torch.Tensor, int, float]):
|
32 |
-
if isinstance(x, torch.Tensor):
|
33 |
-
x = x.detach()
|
34 |
-
if x.dtype in [torch.long, torch.int, torch.bool]:
|
35 |
-
x = x.float()
|
36 |
-
|
37 |
-
if key not in self.values:
|
38 |
-
self.counts[key] = 1
|
39 |
-
self.values[key] = x
|
40 |
-
else:
|
41 |
-
self.counts[key] += 1
|
42 |
-
self.values[key] += x
|
43 |
-
|
44 |
-
def add_dict(self, tensor_dict: dict[str, torch.Tensor]):
|
45 |
-
for k, v in tensor_dict.items():
|
46 |
-
self.add_scalar(k, v)
|
47 |
-
|
48 |
-
def add_binned_tensor(self, key: str, x: torch.Tensor, indices: torch.Tensor):
|
49 |
-
if key not in self.binned_tensors:
|
50 |
-
self.binned_tensors[key] = [x.detach().flatten()]
|
51 |
-
self.binned_tensor_indices[key] = [indices.detach().flatten()]
|
52 |
-
else:
|
53 |
-
self.binned_tensors[key].append(x.detach().flatten())
|
54 |
-
self.binned_tensor_indices[key].append(indices.detach().flatten())
|
55 |
-
|
56 |
-
def add_hook(self, hook: Callable[[torch.Tensor], tuple[str, torch.Tensor]]):
|
57 |
-
"""
|
58 |
-
Adds a custom hook, i.e. compute new metrics using values in the dict
|
59 |
-
The hook takes the dict as argument, and returns a (k, v) tuple
|
60 |
-
e.g. for computing IoU
|
61 |
-
"""
|
62 |
-
self.hooks.append(hook)
|
63 |
-
|
64 |
-
def reset_except_hooks(self):
|
65 |
-
self.values = {}
|
66 |
-
self.counts = {}
|
67 |
-
|
68 |
-
# Average and output the metrics
|
69 |
-
def finalize(self, prefix: str, it: int, ignore_timer: bool = False) -> None:
|
70 |
-
|
71 |
-
for hook in self.hooks:
|
72 |
-
k, v = hook(self.values)
|
73 |
-
self.add_scalar(k, v)
|
74 |
-
|
75 |
-
# for the metrics
|
76 |
-
outputs = {}
|
77 |
-
for k, v in self.values.items():
|
78 |
-
avg = v / self.counts[k]
|
79 |
-
if self.distributed:
|
80 |
-
# Inplace operation
|
81 |
-
if isinstance(avg, torch.Tensor):
|
82 |
-
avg = avg.cuda()
|
83 |
-
else:
|
84 |
-
avg = torch.tensor(avg).cuda()
|
85 |
-
torch.distributed.reduce(avg, dst=0)
|
86 |
-
|
87 |
-
if self.local_rank == 0:
|
88 |
-
avg = (avg / self.world_size).cpu().item()
|
89 |
-
outputs[k] = avg
|
90 |
-
else:
|
91 |
-
# Simple does it
|
92 |
-
outputs[k] = avg
|
93 |
-
|
94 |
-
if (not self.distributed) or (self.local_rank == 0):
|
95 |
-
self.logger.log_metrics(prefix, outputs, it, ignore_timer=ignore_timer)
|
96 |
-
|
97 |
-
# for the binned tensors
|
98 |
-
for k, v in self.binned_tensors.items():
|
99 |
-
x = torch.cat(v, dim=0)
|
100 |
-
indices = torch.cat(self.binned_tensor_indices[k], dim=0)
|
101 |
-
hist, count = distribute_into_histogram(x, indices)
|
102 |
-
|
103 |
-
if self.distributed:
|
104 |
-
torch.distributed.reduce(hist, dst=0)
|
105 |
-
torch.distributed.reduce(count, dst=0)
|
106 |
-
if self.local_rank == 0:
|
107 |
-
hist = hist / count
|
108 |
-
else:
|
109 |
-
hist = hist / count
|
110 |
-
|
111 |
-
if (not self.distributed) or (self.local_rank == 0):
|
112 |
-
self.logger.log_histogram(f'{prefix}/{k}', hist, it)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/logger.py
DELETED
@@ -1,231 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
Dumps things to tensorboard and console
|
3 |
-
"""
|
4 |
-
|
5 |
-
import datetime
|
6 |
-
import logging
|
7 |
-
import math
|
8 |
-
import os
|
9 |
-
from collections import defaultdict
|
10 |
-
from pathlib import Path
|
11 |
-
from typing import Optional, Union
|
12 |
-
|
13 |
-
import matplotlib.pyplot as plt
|
14 |
-
import numpy as np
|
15 |
-
import torch
|
16 |
-
import torchaudio
|
17 |
-
from PIL import Image
|
18 |
-
from pytz import timezone
|
19 |
-
from torch.utils.tensorboard import SummaryWriter
|
20 |
-
|
21 |
-
from mmaudio.utils.email_utils import EmailSender
|
22 |
-
from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator
|
23 |
-
from mmaudio.utils.timezone import my_timezone
|
24 |
-
|
25 |
-
|
26 |
-
def tensor_to_numpy(image: torch.Tensor):
|
27 |
-
image_np = (image.numpy() * 255).astype('uint8')
|
28 |
-
return image_np
|
29 |
-
|
30 |
-
|
31 |
-
def detach_to_cpu(x: torch.Tensor):
|
32 |
-
return x.detach().cpu()
|
33 |
-
|
34 |
-
|
35 |
-
def fix_width_trunc(x: float):
|
36 |
-
return ('{:.9s}'.format('{:0.9f}'.format(x)))
|
37 |
-
|
38 |
-
|
39 |
-
def plot_spectrogram(spectrogram: np.ndarray, title=None, ylabel="freq_bin", ax=None):
|
40 |
-
if ax is None:
|
41 |
-
_, ax = plt.subplots(1, 1)
|
42 |
-
if title is not None:
|
43 |
-
ax.set_title(title)
|
44 |
-
ax.set_ylabel(ylabel)
|
45 |
-
ax.imshow(spectrogram, origin="lower", aspect="auto", interpolation="nearest")
|
46 |
-
|
47 |
-
|
48 |
-
class TensorboardLogger:
|
49 |
-
|
50 |
-
def __init__(self,
|
51 |
-
exp_id: str,
|
52 |
-
run_dir: Union[Path, str],
|
53 |
-
py_logger: logging.Logger,
|
54 |
-
*,
|
55 |
-
is_rank0: bool = False,
|
56 |
-
enable_email: bool = False):
|
57 |
-
self.exp_id = exp_id
|
58 |
-
self.run_dir = Path(run_dir)
|
59 |
-
self.py_log = py_logger
|
60 |
-
self.email_sender = EmailSender(exp_id, enable=(is_rank0 and enable_email))
|
61 |
-
if is_rank0:
|
62 |
-
self.tb_log = SummaryWriter(run_dir)
|
63 |
-
else:
|
64 |
-
self.tb_log = None
|
65 |
-
|
66 |
-
# Get current git info for logging
|
67 |
-
try:
|
68 |
-
import git
|
69 |
-
repo = git.Repo(".")
|
70 |
-
git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
|
71 |
-
except (ImportError, RuntimeError, TypeError):
|
72 |
-
print('Failed to fetch git info. Defaulting to None')
|
73 |
-
git_info = 'None'
|
74 |
-
|
75 |
-
self.log_string('git', git_info)
|
76 |
-
|
77 |
-
# log the SLURM job id if available
|
78 |
-
job_id = os.environ.get('SLURM_JOB_ID', None)
|
79 |
-
if job_id is not None:
|
80 |
-
self.log_string('slurm_job_id', job_id)
|
81 |
-
self.email_sender.send(f'Job {job_id} started', f'Job started {run_dir}')
|
82 |
-
|
83 |
-
# used when logging metrics
|
84 |
-
self.batch_timer: TimeEstimator = None
|
85 |
-
self.data_timer: PartialTimeEstimator = None
|
86 |
-
|
87 |
-
self.nan_count = defaultdict(int)
|
88 |
-
|
89 |
-
def log_scalar(self, tag: str, x: float, it: int):
|
90 |
-
if self.tb_log is None:
|
91 |
-
return
|
92 |
-
if math.isnan(x) and 'grad_norm' not in tag:
|
93 |
-
self.nan_count[tag] += 1
|
94 |
-
if self.nan_count[tag] == 10:
|
95 |
-
self.email_sender.send(
|
96 |
-
f'Nan detected in {tag} @ {self.run_dir}',
|
97 |
-
f'Nan detected in {tag} at iteration {it}; run_dir: {self.run_dir}')
|
98 |
-
else:
|
99 |
-
self.nan_count[tag] = 0
|
100 |
-
self.tb_log.add_scalar(tag, x, it)
|
101 |
-
|
102 |
-
def log_metrics(self,
|
103 |
-
prefix: str,
|
104 |
-
metrics: dict[str, float],
|
105 |
-
it: int,
|
106 |
-
ignore_timer: bool = False):
|
107 |
-
msg = f'{self.exp_id}-{prefix} - it {it:6d}: '
|
108 |
-
metrics_msg = ''
|
109 |
-
for k, v in sorted(metrics.items()):
|
110 |
-
self.log_scalar(f'{prefix}/{k}', v, it)
|
111 |
-
metrics_msg += f'{k: >10}:{v:.7f},\t'
|
112 |
-
|
113 |
-
if self.batch_timer is not None and not ignore_timer:
|
114 |
-
self.batch_timer.update()
|
115 |
-
avg_time = self.batch_timer.get_and_reset_avg_time()
|
116 |
-
data_time = self.data_timer.get_and_reset_avg_time()
|
117 |
-
|
118 |
-
# add time to tensorboard
|
119 |
-
self.log_scalar(f'{prefix}/avg_time', avg_time, it)
|
120 |
-
self.log_scalar(f'{prefix}/data_time', data_time, it)
|
121 |
-
|
122 |
-
est = self.batch_timer.get_est_remaining(it)
|
123 |
-
est = datetime.timedelta(seconds=est)
|
124 |
-
if est.days > 0:
|
125 |
-
remaining_str = f'{est.days}d {est.seconds // 3600}h'
|
126 |
-
else:
|
127 |
-
remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
|
128 |
-
eta = datetime.datetime.now(timezone(my_timezone)) + est
|
129 |
-
eta_str = eta.strftime('%Y-%m-%d %H:%M:%S %Z%z')
|
130 |
-
time_msg = f'avg_time:{avg_time:.3f},data:{data_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
|
131 |
-
msg = f'{msg} {time_msg}'
|
132 |
-
|
133 |
-
msg = f'{msg} {metrics_msg}'
|
134 |
-
self.py_log.info(msg)
|
135 |
-
|
136 |
-
def log_histogram(self, tag: str, hist: torch.Tensor, it: int):
|
137 |
-
if self.tb_log is None:
|
138 |
-
return
|
139 |
-
# hist should be a 1D tensor
|
140 |
-
hist = hist.cpu().numpy()
|
141 |
-
fig, ax = plt.subplots()
|
142 |
-
x_range = np.linspace(0, 1, len(hist))
|
143 |
-
ax.bar(x_range, hist, width=1 / (len(hist) - 1))
|
144 |
-
ax.set_xticks(x_range)
|
145 |
-
ax.set_xticklabels(x_range)
|
146 |
-
plt.tight_layout()
|
147 |
-
self.tb_log.add_figure(tag, fig, it)
|
148 |
-
plt.close()
|
149 |
-
|
150 |
-
def log_image(self, prefix: str, tag: str, image: np.ndarray, it: int):
|
151 |
-
image_dir = self.run_dir / f'{prefix}_images'
|
152 |
-
image_dir.mkdir(exist_ok=True, parents=True)
|
153 |
-
|
154 |
-
image = Image.fromarray(image)
|
155 |
-
image.save(image_dir / f'{it:09d}_{tag}.png')
|
156 |
-
|
157 |
-
def log_audio(self,
|
158 |
-
prefix: str,
|
159 |
-
tag: str,
|
160 |
-
waveform: torch.Tensor,
|
161 |
-
it: Optional[int] = None,
|
162 |
-
*,
|
163 |
-
subdir: Optional[Path] = None,
|
164 |
-
sample_rate: int = 16000) -> Path:
|
165 |
-
if subdir is None:
|
166 |
-
audio_dir = self.run_dir / prefix
|
167 |
-
else:
|
168 |
-
audio_dir = self.run_dir / subdir / prefix
|
169 |
-
audio_dir.mkdir(exist_ok=True, parents=True)
|
170 |
-
|
171 |
-
if it is None:
|
172 |
-
name = f'{tag}.flac'
|
173 |
-
else:
|
174 |
-
name = f'{it:09d}_{tag}.flac'
|
175 |
-
|
176 |
-
torchaudio.save(audio_dir / name,
|
177 |
-
waveform.cpu().float(),
|
178 |
-
sample_rate=sample_rate,
|
179 |
-
channels_first=True)
|
180 |
-
return Path(audio_dir)
|
181 |
-
|
182 |
-
def log_spectrogram(
|
183 |
-
self,
|
184 |
-
prefix: str,
|
185 |
-
tag: str,
|
186 |
-
spec: torch.Tensor,
|
187 |
-
it: Optional[int],
|
188 |
-
*,
|
189 |
-
subdir: Optional[Path] = None,
|
190 |
-
):
|
191 |
-
if subdir is None:
|
192 |
-
spec_dir = self.run_dir / prefix
|
193 |
-
else:
|
194 |
-
spec_dir = self.run_dir / subdir / prefix
|
195 |
-
spec_dir.mkdir(exist_ok=True, parents=True)
|
196 |
-
|
197 |
-
if it is None:
|
198 |
-
name = f'{tag}.png'
|
199 |
-
else:
|
200 |
-
name = f'{it:09d}_{tag}.png'
|
201 |
-
|
202 |
-
plot_spectrogram(spec.cpu().float())
|
203 |
-
plt.tight_layout()
|
204 |
-
plt.savefig(spec_dir / name)
|
205 |
-
plt.close()
|
206 |
-
|
207 |
-
def log_string(self, tag: str, x: str):
|
208 |
-
self.py_log.info(f'{tag} - {x}')
|
209 |
-
if self.tb_log is None:
|
210 |
-
return
|
211 |
-
self.tb_log.add_text(tag, x)
|
212 |
-
|
213 |
-
def debug(self, x):
|
214 |
-
self.py_log.debug(x)
|
215 |
-
|
216 |
-
def info(self, x):
|
217 |
-
self.py_log.info(x)
|
218 |
-
|
219 |
-
def warning(self, x):
|
220 |
-
self.py_log.warning(x)
|
221 |
-
|
222 |
-
def error(self, x):
|
223 |
-
self.py_log.error(x)
|
224 |
-
|
225 |
-
def critical(self, x):
|
226 |
-
self.py_log.critical(x)
|
227 |
-
|
228 |
-
self.email_sender.send(f'Error occurred in {self.run_dir}', x)
|
229 |
-
|
230 |
-
def complete(self):
|
231 |
-
self.email_sender.send(f'Job completed in {self.run_dir}', 'Job completed')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/synthesize_ema.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
from typing import Optional
|
2 |
-
|
3 |
-
from nitrous_ema import PostHocEMA
|
4 |
-
from omegaconf import DictConfig
|
5 |
-
|
6 |
-
from mmaudio.model.networks import get_my_mmaudio
|
7 |
-
|
8 |
-
|
9 |
-
def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]):
|
10 |
-
vae = get_my_mmaudio(cfg.model)
|
11 |
-
emas = PostHocEMA(vae,
|
12 |
-
sigma_rels=cfg.ema.sigma_rels,
|
13 |
-
update_every=cfg.ema.update_every,
|
14 |
-
checkpoint_every_num_steps=cfg.ema.checkpoint_every,
|
15 |
-
checkpoint_folder=cfg.ema.checkpoint_folder)
|
16 |
-
|
17 |
-
synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu')
|
18 |
-
state_dict = synthesized_ema.ema_model.state_dict()
|
19 |
-
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/tensor_utils.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
|
3 |
-
|
4 |
-
def distribute_into_histogram(loss: torch.Tensor,
|
5 |
-
t: torch.Tensor,
|
6 |
-
num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
|
7 |
-
loss = loss.detach().flatten()
|
8 |
-
t = t.detach().flatten()
|
9 |
-
t = (t * num_bins).long()
|
10 |
-
hist = torch.zeros(num_bins, device=loss.device)
|
11 |
-
count = torch.zeros(num_bins, device=loss.device)
|
12 |
-
hist.scatter_add_(0, t, loss)
|
13 |
-
count.scatter_add_(0, t, torch.ones_like(loss))
|
14 |
-
return hist, count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/time_estimator.py
DELETED
@@ -1,72 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
|
3 |
-
|
4 |
-
class TimeEstimator:
|
5 |
-
|
6 |
-
def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7):
|
7 |
-
self.avg_time_window = [] # window-based average
|
8 |
-
self.exp_avg_time = None # exponential moving average
|
9 |
-
self.alpha = ema_alpha # for exponential moving average
|
10 |
-
|
11 |
-
self.last_time = time.time() # would not be accurate for the first iteration but well
|
12 |
-
self.total_iter = total_iter
|
13 |
-
self.step_size = step_size
|
14 |
-
|
15 |
-
self._buffering_exp = True
|
16 |
-
|
17 |
-
# call this at a fixed interval
|
18 |
-
# does not have to be every step
|
19 |
-
def update(self):
|
20 |
-
curr_time = time.time()
|
21 |
-
time_per_iter = curr_time - self.last_time
|
22 |
-
self.last_time = curr_time
|
23 |
-
|
24 |
-
self.avg_time_window.append(time_per_iter)
|
25 |
-
|
26 |
-
if self._buffering_exp:
|
27 |
-
if self.exp_avg_time is not None:
|
28 |
-
# discard the first iteration call to not pollute the ema
|
29 |
-
self._buffering_exp = False
|
30 |
-
self.exp_avg_time = time_per_iter
|
31 |
-
else:
|
32 |
-
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
|
33 |
-
|
34 |
-
def get_est_remaining(self, it: int):
|
35 |
-
if self.exp_avg_time is None:
|
36 |
-
return 0
|
37 |
-
|
38 |
-
remaining_iter = self.total_iter - it
|
39 |
-
return remaining_iter * self.exp_avg_time / self.step_size
|
40 |
-
|
41 |
-
def get_and_reset_avg_time(self):
|
42 |
-
avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
|
43 |
-
self.avg_time_window = []
|
44 |
-
return avg
|
45 |
-
|
46 |
-
|
47 |
-
class PartialTimeEstimator(TimeEstimator):
|
48 |
-
"""
|
49 |
-
Used where the start_time and the end_time do not align
|
50 |
-
"""
|
51 |
-
|
52 |
-
def update(self):
|
53 |
-
raise RuntimeError('Please use start() and end() for PartialTimeEstimator')
|
54 |
-
|
55 |
-
def start(self):
|
56 |
-
self.last_time = time.time()
|
57 |
-
|
58 |
-
def end(self):
|
59 |
-
assert self.last_time is not None, 'Please call start() before calling end()'
|
60 |
-
curr_time = time.time()
|
61 |
-
time_per_iter = curr_time - self.last_time
|
62 |
-
self.last_time = None
|
63 |
-
|
64 |
-
self.avg_time_window.append(time_per_iter)
|
65 |
-
|
66 |
-
if self._buffering_exp:
|
67 |
-
if self.exp_avg_time is not None:
|
68 |
-
# discard the first iteration call to not pollute the ema
|
69 |
-
self._buffering_exp = False
|
70 |
-
self.exp_avg_time = time_per_iter
|
71 |
-
else:
|
72 |
-
self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mmaudio/utils/timezone.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
my_timezone = 'US/Central'
|
|
|
|