Phil Sobrepena commited on
Commit
2c4e2b0
·
1 Parent(s): 977df40
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +4 -6
  2. app.py +24 -103
  3. batch_eval.py +0 -110
  4. config/__init__.py +0 -0
  5. config/base_config.yaml +0 -62
  6. config/data/base.yaml +0 -70
  7. config/eval_config.yaml +0 -17
  8. config/eval_data/base.yaml +0 -22
  9. config/hydra/job_logging/custom-eval.yaml +0 -32
  10. config/hydra/job_logging/custom-no-rank.yaml +0 -32
  11. config/hydra/job_logging/custom-simplest.yaml +0 -26
  12. config/hydra/job_logging/custom.yaml +0 -33
  13. config/train_config.yaml +0 -41
  14. demo.py +1 -7
  15. docs/EVAL.md +0 -22
  16. docs/MODELS.md +0 -50
  17. docs/TRAINING.md +0 -160
  18. docs/index.html +10 -12
  19. gitattributes +0 -35
  20. mmaudio/data/av_utils.py +0 -26
  21. mmaudio/data/data_setup.py +0 -174
  22. mmaudio/data/eval/__init__.py +0 -0
  23. mmaudio/data/eval/audiocaps.py +0 -39
  24. mmaudio/data/eval/moviegen.py +0 -131
  25. mmaudio/data/eval/video_dataset.py +0 -197
  26. mmaudio/data/extracted_audio.py +0 -88
  27. mmaudio/data/extracted_vgg.py +0 -101
  28. mmaudio/data/extraction/__init__.py +0 -0
  29. mmaudio/data/extraction/vgg_sound.py +0 -193
  30. mmaudio/data/extraction/wav_dataset.py +0 -132
  31. mmaudio/data/mm_dataset.py +0 -45
  32. mmaudio/data/utils.py +0 -148
  33. mmaudio/eval_utils.py +9 -47
  34. mmaudio/ext/autoencoder/autoencoder.py +1 -1
  35. mmaudio/ext/autoencoder/vae.py +4 -0
  36. mmaudio/ext/mel_converter.py +9 -33
  37. mmaudio/model/embeddings.py +1 -1
  38. mmaudio/model/flow_matching.py +18 -1
  39. mmaudio/model/networks.py +1 -1
  40. mmaudio/model/transformer_layers.py +1 -0
  41. mmaudio/model/utils/features_utils.py +2 -2
  42. mmaudio/runner.py +0 -609
  43. mmaudio/sample.py +0 -90
  44. mmaudio/utils/email_utils.py +0 -50
  45. mmaudio/utils/log_integrator.py +0 -112
  46. mmaudio/utils/logger.py +0 -231
  47. mmaudio/utils/synthesize_ema.py +0 -19
  48. mmaudio/utils/tensor_utils.py +0 -14
  49. mmaudio/utils/time_estimator.py +0 -72
  50. mmaudio/utils/timezone.py +0 -1
.gitignore CHANGED
@@ -2,18 +2,16 @@ run_*.sh
2
  log/
3
  saves
4
  saves/
5
- weights/
6
- weights
7
  output/
8
  output
9
  pretrained/
10
  workspace
11
  workspace/
12
- ext_weights/
13
- ext_weights
14
  .checkpoints/
15
- .vscode/
16
- training/example_output/
 
 
17
 
18
  # Byte-compiled / optimized / DLL files
19
  __pycache__/
 
2
  log/
3
  saves
4
  saves/
 
 
5
  output/
6
  output
7
  pretrained/
8
  workspace
9
  workspace/
 
 
10
  .checkpoints/
11
+ weights/
12
+ ext_weights/
13
+ *.pth
14
+ *.pt
15
 
16
  # Byte-compiled / optimized / DLL files
17
  __pycache__/
app.py CHANGED
@@ -14,12 +14,13 @@ except ImportError:
14
  os.system("pip install -e .")
15
  import mmaudio
16
 
17
- from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
18
- load_video, make_video, setup_eval_logging)
19
  from mmaudio.model.flow_matching import FlowMatching
20
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
21
  from mmaudio.model.sequence_config import SequenceConfig
22
  from mmaudio.model.utils.features_utils import FeaturesUtils
 
23
 
24
  torch.backends.cuda.matmul.allow_tf32 = True
25
  torch.backends.cudnn.allow_tf32 = True
@@ -56,6 +57,7 @@ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
56
 
57
  net, feature_utils, seq_cfg = get_model()
58
 
 
59
  @spaces.GPU(duration=120)
60
  @torch.inference_mode()
61
  def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
@@ -88,17 +90,18 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int
88
  audio = audios.float().cpu()[0]
89
 
90
  # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
 
91
  # output_dir.mkdir(exist_ok=True, parents=True)
92
  # video_save_path = output_dir / f'{current_time_string}.mp4'
93
- video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
94
  make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
  log.info(f'Saved video to {video_save_path}')
96
  return video_save_path
97
 
 
98
  @spaces.GPU(duration=120)
99
  @torch.inference_mode()
100
- def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
101
- cfg_strength: float, duration: float):
102
 
103
  rng = torch.Generator(device=device)
104
  if seed >= 0:
@@ -107,11 +110,7 @@ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int
107
  rng.seed()
108
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
 
110
- image_info = load_image(image)
111
- clip_frames = image_info.clip_frames
112
- sync_frames = image_info.sync_frames
113
- clip_frames = clip_frames.unsqueeze(0)
114
- sync_frames = sync_frames.unsqueeze(0)
115
  seq_cfg.duration = duration
116
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
117
 
@@ -122,61 +121,24 @@ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int
122
  net=net,
123
  fm=fm,
124
  rng=rng,
125
- cfg_strength=cfg_strength,
126
- image_input=True)
127
  audio = audios.float().cpu()[0]
128
 
129
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
130
- # output_dir.mkdir(exist_ok=True, parents=True)
131
- # video_save_path = output_dir / f'{current_time_string}.mp4'
132
- video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
133
- video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
134
- make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
135
- log.info(f'Saved video to {video_save_path}')
136
- return video_save_path
137
-
138
- # @spaces.GPU(duration=120)
139
- # @torch.inference_mode()
140
- # def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
141
- # duration: float):
142
-
143
- # rng = torch.Generator(device=device)
144
- # if seed >= 0:
145
- # rng.manual_seed(seed)
146
- # else:
147
- # rng.seed()
148
- # fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
149
-
150
- # clip_frames = sync_frames = None
151
- # seq_cfg.duration = duration
152
- # net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
153
-
154
- # audios = generate(clip_frames,
155
- # sync_frames, [prompt],
156
- # negative_text=[negative_prompt],
157
- # feature_utils=feature_utils,
158
- # net=net,
159
- # fm=fm,
160
- # rng=rng,
161
- # cfg_strength=cfg_strength)
162
- # audio = audios.float().cpu()[0]
163
-
164
- # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
165
- # output_dir.mkdir(exist_ok=True, parents=True)
166
- # audio_save_path = output_dir / f'{current_time_string}.flac'
167
- # torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
168
- # gc.collect()
169
- # return audio_save_path
170
 
171
 
172
  video_to_audio_tab = gr.Interface(
173
  fn=video_to_audio,
174
- description=""" Video-to-Audio
 
 
175
  NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
176
  Doing so does not improve results.
177
 
178
- The model has been trained on 8-second videos.
179
- Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
180
  """,
181
  inputs=[
182
  gr.Video(),
@@ -189,52 +151,11 @@ video_to_audio_tab = gr.Interface(
189
  ],
190
  outputs='playable_video',
191
  cache_examples=False,
192
- title='Sonisphere - Sonic Branding Tool',
193
- )
194
-
195
- # text_to_audio_tab = gr.Interface(
196
- # fn=text_to_audio,
197
- # description=""" Text-to-Audio
198
- # """,
199
- # inputs=[
200
- # gr.Text(label='Prompt'),
201
- # gr.Text(label='Negative prompt'),
202
- # gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
203
- # gr.Number(label='Num steps', value=25, precision=0, minimum=1),
204
- # gr.Number(label='Guidance Strength', value=4.5, minimum=1),
205
- # gr.Number(label='Duration (sec)', value=8, minimum=1),
206
- # ],
207
- # outputs='audio',
208
- # cache_examples=False,
209
- # title='Sonisphere - Sonic Branding Tool',
210
- # )
211
-
212
- image_to_audio_tab = gr.Interface(
213
- fn=image_to_audio,
214
- description="""
215
- Image-to-Audio
216
- NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
217
- Doing so does not improve results.
218
- """,
219
- inputs=[
220
- gr.Image(type='filepath'),
221
- gr.Text(label='Prompt'),
222
- gr.Text(label='Negative prompt'),
223
- gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
224
- gr.Number(label='Num steps', value=25, precision=0, minimum=1),
225
- gr.Number(label='Guidance Strength', value=4.5, minimum=1),
226
- gr.Number(label='Duration (sec)', value=8, minimum=1),
227
- ],
228
- outputs='playable_video',
229
- cache_examples=False,
230
- title='Image-to-Audio Synthesis (experimental)',
231
- )
232
 
233
- if __name__ == "__main__":
234
- # parser = ArgumentParser()
235
- # parser.add_argument('--port', type=int, default=7860)
236
- # args = parser.parse_args()
237
 
238
- gr.TabbedInterface([video_to_audio_tab, image_to_audio_tab],
239
- ['Video-to-Audio', 'Image-to-Audio']).launch(
240
- allowed_paths=[output_dir])
 
14
  os.system("pip install -e .")
15
  import mmaudio
16
 
17
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
18
+ setup_eval_logging)
19
  from mmaudio.model.flow_matching import FlowMatching
20
  from mmaudio.model.networks import MMAudio, get_my_mmaudio
21
  from mmaudio.model.sequence_config import SequenceConfig
22
  from mmaudio.model.utils.features_utils import FeaturesUtils
23
+ import tempfile
24
 
25
  torch.backends.cuda.matmul.allow_tf32 = True
26
  torch.backends.cudnn.allow_tf32 = True
 
57
 
58
  net, feature_utils, seq_cfg = get_model()
59
 
60
+
61
  @spaces.GPU(duration=120)
62
  @torch.inference_mode()
63
  def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
 
90
  audio = audios.float().cpu()[0]
91
 
92
  # current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
93
+ video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
94
  # output_dir.mkdir(exist_ok=True, parents=True)
95
  # video_save_path = output_dir / f'{current_time_string}.mp4'
 
96
  make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
97
  log.info(f'Saved video to {video_save_path}')
98
  return video_save_path
99
 
100
+
101
  @spaces.GPU(duration=120)
102
  @torch.inference_mode()
103
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
104
+ duration: float):
105
 
106
  rng = torch.Generator(device=device)
107
  if seed >= 0:
 
110
  rng.seed()
111
  fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
112
 
113
+ clip_frames = sync_frames = None
 
 
 
 
114
  seq_cfg.duration = duration
115
  net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
116
 
 
121
  net=net,
122
  fm=fm,
123
  rng=rng,
124
+ cfg_strength=cfg_strength)
 
125
  audio = audios.float().cpu()[0]
126
 
127
+ audio_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.flac').name
128
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
129
+ log.info(f'Saved audio to {audio_save_path}')
130
+ return audio_save_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  video_to_audio_tab = gr.Interface(
134
  fn=video_to_audio,
135
+ description="""
136
+ Sonisphere
137
+ Video-to-Audio
138
  NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
139
  Doing so does not improve results.
140
 
141
+ The model has been trained on 8-second videos. Using much longer or shorter videos will degrade performance. Around 5s~12s should be fine.
 
142
  """,
143
  inputs=[
144
  gr.Video(),
 
151
  ],
152
  outputs='playable_video',
153
  cache_examples=False,
154
+ title='MMAudio — Video-to-Audio Synthesis',
155
+ examples=[
156
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
 
 
 
 
158
 
159
+ if __name__ == "__main__":
160
+ gr.TabbedInterface([video_to_audio_tab],
161
+ ['Video-to-Audio']).launch(allowed_paths=[output_dir])
batch_eval.py DELETED
@@ -1,110 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
-
5
- import hydra
6
- import torch
7
- import torch.distributed as distributed
8
- import torchaudio
9
- from hydra.core.hydra_config import HydraConfig
10
- from omegaconf import DictConfig
11
- from tqdm import tqdm
12
-
13
- from mmaudio.data.data_setup import setup_eval_dataset
14
- from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
15
- from mmaudio.model.flow_matching import FlowMatching
16
- from mmaudio.model.networks import MMAudio, get_my_mmaudio
17
- from mmaudio.model.utils.features_utils import FeaturesUtils
18
-
19
- torch.backends.cuda.matmul.allow_tf32 = True
20
- torch.backends.cudnn.allow_tf32 = True
21
-
22
- local_rank = int(os.environ['LOCAL_RANK'])
23
- world_size = int(os.environ['WORLD_SIZE'])
24
- log = logging.getLogger()
25
-
26
-
27
- @torch.inference_mode()
28
- @hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
29
- def main(cfg: DictConfig):
30
- device = 'cuda'
31
- torch.cuda.set_device(local_rank)
32
-
33
- if cfg.model not in all_model_cfg:
34
- raise ValueError(f'Unknown model variant: {cfg.model}')
35
- model: ModelConfig = all_model_cfg[cfg.model]
36
- model.download_if_needed()
37
- seq_cfg = model.seq_cfg
38
-
39
- run_dir = Path(HydraConfig.get().run.dir)
40
- if cfg.output_name is None:
41
- output_dir = run_dir / cfg.dataset
42
- else:
43
- output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
44
- output_dir.mkdir(parents=True, exist_ok=True)
45
-
46
- # load a pretrained model
47
- seq_cfg.duration = cfg.duration_s
48
- net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
49
- net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
50
- log.info(f'Loaded weights from {model.model_path}')
51
- net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
52
- log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
53
- log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
54
- log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
55
-
56
- # misc setup
57
- rng = torch.Generator(device=device)
58
- rng.manual_seed(cfg.seed)
59
- fm = FlowMatching(cfg.sampling.min_sigma,
60
- inference_mode=cfg.sampling.method,
61
- num_steps=cfg.sampling.num_steps)
62
-
63
- feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
64
- synchformer_ckpt=model.synchformer_ckpt,
65
- enable_conditions=True,
66
- mode=model.mode,
67
- bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
68
- need_vae_encoder=False)
69
- feature_utils = feature_utils.to(device).eval()
70
-
71
- if cfg.compile:
72
- net.preprocess_conditions = torch.compile(net.preprocess_conditions)
73
- net.predict_flow = torch.compile(net.predict_flow)
74
- feature_utils.compile()
75
-
76
- dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
77
-
78
- with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
79
- for batch in tqdm(loader):
80
- audios = generate(batch.get('clip_video', None),
81
- batch.get('sync_video', None),
82
- batch.get('caption', None),
83
- feature_utils=feature_utils,
84
- net=net,
85
- fm=fm,
86
- rng=rng,
87
- cfg_strength=cfg.cfg_strength,
88
- clip_batch_size_multiplier=64,
89
- sync_batch_size_multiplier=64)
90
- audios = audios.float().cpu()
91
- names = batch['name']
92
- for audio, name in zip(audios, names):
93
- torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
94
-
95
-
96
- def distributed_setup():
97
- distributed.init_process_group(backend="nccl")
98
- local_rank = distributed.get_rank()
99
- world_size = distributed.get_world_size()
100
- log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
101
- return local_rank, world_size
102
-
103
-
104
- if __name__ == '__main__':
105
- distributed_setup()
106
-
107
- main()
108
-
109
- # clean-up
110
- distributed.destroy_process_group()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/__init__.py DELETED
File without changes
config/base_config.yaml DELETED
@@ -1,62 +0,0 @@
1
- defaults:
2
- - data: base
3
- - eval_data: base
4
- - override hydra/job_logging: custom-simplest
5
- - _self_
6
-
7
- hydra:
8
- run:
9
- dir: ./output/${exp_id}
10
- output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
11
-
12
- enable_email: False
13
-
14
- model: small_16k
15
-
16
- exp_id: default
17
- debug: False
18
- cudnn_benchmark: True
19
- compile: True
20
- amp: True
21
- weights: null
22
- checkpoint: null
23
- seed: 14159265
24
- num_workers: 10 # per-GPU
25
- pin_memory: False # set to True if your system can handle it, i.e., have enough memory
26
-
27
- # NOTE: This DOSE NOT affect the model during inference in any way
28
- # they are just for the dataloader to fill in the missing data in multi-modal loading
29
- # to change the sequence length for the model, see networks.py
30
- data_dim:
31
- text_seq_len: 77
32
- clip_dim: 1024
33
- sync_dim: 768
34
- text_dim: 1024
35
-
36
- # ema configuration
37
- ema:
38
- enable: True
39
- sigma_rels: [0.05, 0.1]
40
- update_every: 1
41
- checkpoint_every: 5_000
42
- checkpoint_folder: ${hydra:run.dir}/ema_ckpts
43
- default_output_sigma: 0.05
44
-
45
-
46
- # sampling
47
- sampling:
48
- mean: 0.0
49
- scale: 1.0
50
- min_sigma: 0.0
51
- method: euler
52
- num_steps: 25
53
-
54
- # classifier-free guidance
55
- null_condition_probability: 0.1
56
- cfg_strength: 4.5
57
-
58
- # checkpoint paths to external modules
59
- vae_16k_ckpt: ./ext_weights/v1-16.pth
60
- vae_44k_ckpt: ./ext_weights/v1-44.pth
61
- bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
62
- synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/data/base.yaml DELETED
@@ -1,70 +0,0 @@
1
- VGGSound:
2
- root: ../data/video
3
- subset_name: sets/vgg3-train.tsv
4
- fps: 8
5
- height: 384
6
- width: 384
7
- sample_duration_sec: 8.0
8
-
9
- VGGSound_test:
10
- root: ../data/video
11
- subset_name: sets/vgg3-test.tsv
12
- fps: 8
13
- height: 384
14
- width: 384
15
- sample_duration_sec: 8.0
16
-
17
- VGGSound_val:
18
- root: ../data/video
19
- subset_name: sets/vgg3-val.tsv
20
- fps: 8
21
- height: 384
22
- width: 384
23
- sample_duration_sec: 8.0
24
-
25
- ExtractedVGG:
26
- tsv: ../data/v1-16-memmap/vgg-train.tsv
27
- memmap_dir: ../data/v1-16-memmap/vgg-train
28
-
29
- ExtractedVGG_test:
30
- tag: test
31
- gt_cache: ../data/eval-cache/vggsound-test
32
- output_subdir: null
33
- tsv: ../data/v1-16-memmap/vgg-test.tsv
34
- memmap_dir: ../data/v1-16-memmap/vgg-test
35
-
36
- ExtractedVGG_val:
37
- tag: val
38
- gt_cache: ../data/eval-cache/vggsound-val
39
- output_subdir: val
40
- tsv: ../data/v1-16-memmap/vgg-val.tsv
41
- memmap_dir: ../data/v1-16-memmap/vgg-val
42
-
43
- AudioCaps:
44
- tsv: ../data/v1-16-memmap/audiocaps.tsv
45
- memmap_dir: ../data/v1-16-memmap/audiocaps
46
-
47
- AudioSetSL:
48
- tsv: ../data/v1-16-memmap/audioset_sl.tsv
49
- memmap_dir: ../data/v1-16-memmap/audioset_sl
50
-
51
- BBCSound:
52
- tsv: ../data/v1-16-memmap/bbcsound.tsv
53
- memmap_dir: ../data/v1-16-memmap/bbcsound
54
-
55
- FreeSound:
56
- tsv: ../data/v1-16-memmap/freesound.tsv
57
- memmap_dir: ../data/v1-16-memmap/freesound
58
-
59
- Clotho:
60
- tsv: ../data/v1-16-memmap/clotho.tsv
61
- memmap_dir: ../data/v1-16-memmap/clotho
62
-
63
- Example_video:
64
- tsv: ./training/example_output/memmap/vgg-example.tsv
65
- memmap_dir: ./training/example_output/memmap/vgg-example
66
-
67
- Example_audio:
68
- tsv: ./training/example_output/memmap/audio-example.tsv
69
- memmap_dir: ./training/example_output/memmap/audio-example
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/eval_config.yaml DELETED
@@ -1,17 +0,0 @@
1
- defaults:
2
- - base_config
3
- - override hydra/job_logging: custom-simplest
4
- - _self_
5
-
6
- hydra:
7
- run:
8
- dir: ./output/${exp_id}
9
- output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
10
-
11
- exp_id: ${model}
12
- dataset: audiocaps
13
- duration_s: 8.0
14
-
15
- # for inference, this is the per-GPU batch size
16
- batch_size: 16
17
- output_name: null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/eval_data/base.yaml DELETED
@@ -1,22 +0,0 @@
1
- AudioCaps:
2
- audio_path: ../data/AudioCaps-test-audioldm-ver
3
- # a csv file, with a header row of 'name' and 'caption'
4
- # name should match the audio file name without extension
5
- # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
6
- csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
7
-
8
- AudioCaps_full:
9
- audio_path: ../data/AudioCaps-test-full-ver
10
- # a csv file, with a header row of 'name' and 'caption'
11
- # name should match the audio file name without extension
12
- # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
13
- csv_path: ../data/AudioCaps-test-full-ver/data.csv
14
-
15
- MovieGen:
16
- video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
17
- jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
18
-
19
- VGGSound:
20
- video_path: ../data/test-videos
21
- # from the officially released csv file
22
- csv_path: ../data/vggsound.csv
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom-eval.yaml DELETED
@@ -1,32 +0,0 @@
1
- # python logging configuration for tasks
2
- version: 1
3
- formatters:
4
- simple:
5
- format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
- datefmt: '%Y-%m-%d %H:%M:%S'
7
- colorlog:
8
- '()': 'colorlog.ColoredFormatter'
9
- format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
- datefmt: '%Y-%m-%d %H:%M:%S'
11
- log_colors:
12
- DEBUG: purple
13
- INFO: green
14
- WARNING: yellow
15
- ERROR: red
16
- CRITICAL: red
17
- handlers:
18
- console:
19
- class: logging.StreamHandler
20
- formatter: colorlog
21
- stream: ext://sys.stdout
22
- file:
23
- class: logging.FileHandler
24
- formatter: simple
25
- # absolute file path
26
- filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
27
- mode: w
28
- root:
29
- level: INFO
30
- handlers: [console, file]
31
-
32
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom-no-rank.yaml DELETED
@@ -1,32 +0,0 @@
1
- # python logging configuration for tasks
2
- version: 1
3
- formatters:
4
- simple:
5
- format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
- datefmt: '%Y-%m-%d %H:%M:%S'
7
- colorlog:
8
- '()': 'colorlog.ColoredFormatter'
9
- format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
- datefmt: '%Y-%m-%d %H:%M:%S'
11
- log_colors:
12
- DEBUG: purple
13
- INFO: green
14
- WARNING: yellow
15
- ERROR: red
16
- CRITICAL: red
17
- handlers:
18
- console:
19
- class: logging.StreamHandler
20
- formatter: colorlog
21
- stream: ext://sys.stdout
22
- file:
23
- class: logging.FileHandler
24
- formatter: simple
25
- # absolute file path
26
- filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
27
- mode: w
28
- root:
29
- level: INFO
30
- handlers: [console, file]
31
-
32
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom-simplest.yaml DELETED
@@ -1,26 +0,0 @@
1
- # python logging configuration for tasks
2
- version: 1
3
- formatters:
4
- simple:
5
- format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
- datefmt: '%Y-%m-%d %H:%M:%S'
7
- colorlog:
8
- '()': 'colorlog.ColoredFormatter'
9
- format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
- datefmt: '%Y-%m-%d %H:%M:%S'
11
- log_colors:
12
- DEBUG: purple
13
- INFO: green
14
- WARNING: yellow
15
- ERROR: red
16
- CRITICAL: red
17
- handlers:
18
- console:
19
- class: logging.StreamHandler
20
- formatter: colorlog
21
- stream: ext://sys.stdout
22
- root:
23
- level: INFO
24
- handlers: [console]
25
-
26
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/hydra/job_logging/custom.yaml DELETED
@@ -1,33 +0,0 @@
1
- # @package hydra.job_logging
2
- # python logging configuration for tasks
3
- version: 1
4
- formatters:
5
- simple:
6
- format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
7
- datefmt: '%Y-%m-%d %H:%M:%S'
8
- colorlog:
9
- '()': 'colorlog.ColoredFormatter'
10
- format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
11
- datefmt: '%Y-%m-%d %H:%M:%S'
12
- log_colors:
13
- DEBUG: purple
14
- INFO: green
15
- WARNING: yellow
16
- ERROR: red
17
- CRITICAL: red
18
- handlers:
19
- console:
20
- class: logging.StreamHandler
21
- formatter: colorlog
22
- stream: ext://sys.stdout
23
- file:
24
- class: logging.FileHandler
25
- formatter: simple
26
- # absolute file path
27
- filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
28
- mode: w
29
- root:
30
- level: INFO
31
- handlers: [console, file]
32
-
33
- disable_existing_loggers: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/train_config.yaml DELETED
@@ -1,41 +0,0 @@
1
- defaults:
2
- - base_config
3
- - override data: base
4
- - override hydra/job_logging: custom
5
- - _self_
6
-
7
- hydra:
8
- run:
9
- dir: ./output/${exp_id}
10
- output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
-
12
- ema:
13
- start: 0
14
-
15
- mini_train: False
16
- example_train: False
17
- enable_grad_scaler: False
18
- vgg_oversample_rate: 5
19
-
20
- log_text_interval: 200
21
- log_extra_interval: 20_000
22
- val_interval: 5_000
23
- eval_interval: 20_000
24
- save_eval_interval: 40_000
25
- save_weights_interval: 10_000
26
- save_checkpoint_interval: 10_000
27
- save_copy_iterations: []
28
-
29
- batch_size: 512
30
- eval_batch_size: 256 # per-GPU
31
-
32
- num_iterations: 300_000
33
- learning_rate: 1.0e-4
34
- linear_warmup_steps: 1_000
35
-
36
- lr_schedule: step
37
- lr_schedule_steps: [240_000, 270_000]
38
- lr_schedule_gamma: 0.1
39
-
40
- clip_grad_norm: 1.0
41
- weight_decay: 1.0e-6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo.py CHANGED
@@ -62,13 +62,7 @@ def main():
62
  skip_video_composite: bool = args.skip_video_composite
63
  mask_away_clip: bool = args.mask_away_clip
64
 
65
- device = 'cpu'
66
- if torch.cuda.is_available():
67
- device = 'cuda'
68
- elif torch.backends.mps.is_available():
69
- device = 'mps'
70
- else:
71
- log.warning('CUDA/MPS are not available, running on CPU')
72
  dtype = torch.float32 if args.full_precision else torch.bfloat16
73
 
74
  output_dir.mkdir(parents=True, exist_ok=True)
 
62
  skip_video_composite: bool = args.skip_video_composite
63
  mask_away_clip: bool = args.mask_away_clip
64
 
65
+ device = 'cuda'
 
 
 
 
 
 
66
  dtype = torch.float32 if args.full_precision else torch.bfloat16
67
 
68
  output_dir.mkdir(parents=True, exist_ok=True)
docs/EVAL.md DELETED
@@ -1,22 +0,0 @@
1
- # Evaluation
2
-
3
- ## Batch Evaluation
4
-
5
- To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
6
-
7
- An example of running this script with four GPUs is as follows:
8
-
9
- ```bash
10
- OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
11
- ```
12
-
13
- You may need to update the data paths in `config/eval_data/base.yaml`.
14
- More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
15
-
16
- ## Precomputed Results
17
-
18
- Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
19
-
20
- ## Obtaining Quantitative Metrics
21
-
22
- Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/MODELS.md DELETED
@@ -1,50 +0,0 @@
1
- # Pretrained models
2
-
3
- The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
4
- The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
5
-
6
- | Model | Download link | File size |
7
- | -------- | ------- | ------- |
8
- | Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
9
- | Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
10
- | Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
11
- | Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
12
- | Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
13
- | 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
14
- | 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
15
- | 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
16
- | Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
17
-
18
- To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
19
- The 44.1kHz vocoder will be downloaded automatically.
20
- The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
21
-
22
- The expected directory structure (full):
23
-
24
- ```bash
25
- MMAudio
26
- ├── ext_weights
27
- │ ├── best_netG.pt
28
- │ ├── synchformer_state_dict.pth
29
- │ ├── v1-16.pth
30
- │ └── v1-44.pth
31
- ├── weights
32
- │ ├── mmaudio_small_16k.pth
33
- │ ├── mmaudio_small_44k.pth
34
- │ ├── mmaudio_medium_44k.pth
35
- │ ├── mmaudio_large_44k.pth
36
- │ └── mmaudio_large_44k_v2.pth
37
- └── ...
38
- ```
39
-
40
- The expected directory structure (minimal, for the recommended model only):
41
-
42
- ```bash
43
- MMAudio
44
- ├── ext_weights
45
- │ ├── synchformer_state_dict.pth
46
- │ └── v1-44.pth
47
- ├── weights
48
- │ └── mmaudio_large_44k_v2.pth
49
- └── ...
50
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/TRAINING.md DELETED
@@ -1,160 +0,0 @@
1
- # Training
2
-
3
- ## Overview
4
-
5
- We have put a large emphasis on making training as fast as possible.
6
- Consequently, some pre-processing steps are required.
7
-
8
- Namely, before starting any training, we
9
-
10
- 1. Obtain training data as videos, audios, and captions.
11
- 2. Encode training audios into spectrograms and then with VAE into mean/std
12
- 3. Extract CLIP and synchronization features from videos
13
- 4. Extract CLIP features from text (captions)
14
- 5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
15
-
16
- **NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
17
-
18
- The current training script does not support `_v2` training.
19
-
20
- ## Recommended Hardware Configuration
21
-
22
- These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
23
-
24
- - Single-node machine. We did not implement multi-node training
25
- - GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
26
- - System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
27
- - Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
28
-
29
- ## Prerequisites
30
-
31
- 1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
32
- 2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
33
-
34
- 3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
35
-
36
- ```bash
37
- conda install -c conda-forge 'ffmpeg<7'
38
- ```
39
-
40
- 4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://arxiv.org/abs/2303.17395). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
41
-
42
- 5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
43
-
44
- ## Preparing Audio-Video-Text Features
45
-
46
- We have prepared some example data in `training/example_videos`.
47
- `training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
48
-
49
- To run this script, use the `torchrun` utility:
50
-
51
- ```bash
52
- torchrun --standalone training/extract_video_training_latents.py
53
- ```
54
-
55
- You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
56
- Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
57
- Change the data path definitions in `data_cfg` if necessary.
58
-
59
- Arguments:
60
-
61
- - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
62
- - `output_dir` -- where TensorDict and the metadata file are saved.
63
-
64
- Outputs produced in `output_dir`:
65
-
66
- 1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
67
- a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
68
- b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
69
- c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
70
- d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
71
- e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
72
- f. `meta.json` that contains the metadata for the above memory mappings
73
- 2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
74
-
75
- ## Preparing Audio-Text Features
76
-
77
- We have prepared some example data in `training/example_audios`.
78
-
79
- 1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
80
- 2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
81
-
82
- ### Partitioning the audio files
83
-
84
- Run
85
-
86
- ```bash
87
- python training/partition_clips.py
88
- ```
89
-
90
- Arguments:
91
-
92
- - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
93
- - `output_dir` -- path to the output `.csv` file
94
- - `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
95
- - `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
96
-
97
- ### Extracting audio and text features
98
-
99
- Run
100
-
101
- ```bash
102
- torchrun --standalone training/extract_audio_training_latents.py
103
- ```
104
-
105
- You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
106
- Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
107
-
108
- Arguments:
109
-
110
- - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
111
- - `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
112
- - `clips_tsv` -- path to the clips file, generated in the last step
113
- - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
114
- - `output_dir` -- where TensorDict and the metadata file are saved.
115
-
116
- **Reference tsv files (with overlaps removed as mentioned in the paper) can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).**
117
- Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
118
- Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
119
- The Clotho data contains both the development set and the validation set.
120
-
121
- Outputs produced in `output_dir`:
122
-
123
- 1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
124
- a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
125
- b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
126
- c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
127
- f. `meta.json` that contains the metadata for the above memory mappings
128
- 2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
129
-
130
- ## Training on Extracted Features
131
-
132
- We use Distributed Data Parallel (DDP) for training.
133
- First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
134
-
135
- To run training on the example data, use the following command:
136
-
137
- ```bash
138
- OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
139
- ```
140
-
141
- This will not train a useful model, but it will check if everything is set up correctly.
142
-
143
- For full training on the base model with two GPUs, use the following command:
144
-
145
- ```bash
146
- OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
147
- ```
148
-
149
- Any outputs from training will be stored in `output/<exp_id>`.
150
-
151
- More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
152
- For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
153
-
154
- ## Checkpoints
155
-
156
- Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
157
-
158
- ---
159
-
160
- Godspeed!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/index.html CHANGED
@@ -40,7 +40,7 @@
40
  <br>
41
  <div class="row text-center" style="font-size:28px">
42
  <div class="col">
43
- CVPR 2025
44
  </div>
45
  </div>
46
  <br>
@@ -83,21 +83,19 @@
83
  <br>
84
 
85
  <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
- <div class="col-sm-2">
87
- <a href="https://arxiv.org/abs/2412.15322">[Paper]</a>
88
- </div>
89
- <div class="col-sm-2">
90
- <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
91
- </div>
92
  <div class="col-sm-3">
93
- <a href="https://huggingface.co/spaces/hkchengrex/MMAudio">[Huggingface Demo]</a>
94
- </div>
95
- <div class="col-sm-2">
96
- <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
97
  </div>
98
  <div class="col-sm-3">
99
- <a href="https://replicate.com/zsxkib/mmaudio">[Replicate Demo]</a>
100
  </div>
 
 
 
 
101
  </div>
102
 
103
  <br>
 
40
  <br>
41
  <div class="row text-center" style="font-size:28px">
42
  <div class="col">
43
+ arXiv 2024
44
  </div>
45
  </div>
46
  <br>
 
83
  <br>
84
 
85
  <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
+ <!-- <div class="col-sm-2">
87
+ <a href="https://arxiv.org/abs/2310.12982">[arXiv]</a>
88
+ </div> -->
 
 
 
89
  <div class="col-sm-3">
90
+ <a href="">[Paper (being prepared)]</a>
 
 
 
91
  </div>
92
  <div class="col-sm-3">
93
+ <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
94
  </div>
95
+ <!-- <div class="col-sm-2">
96
+ <a
97
+ href="https://colab.research.google.com/drive/1yo43XTbjxuWA7XgCUO9qxAi7wBI6HzvP?usp=sharing">[Colab]</a>
98
+ </div> -->
99
  </div>
100
 
101
  <br>
gitattributes DELETED
@@ -1,35 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/av_utils.py CHANGED
@@ -25,32 +25,6 @@ class VideoInfo:
25
  def width(self):
26
  return self.all_frames[0].shape[1]
27
 
28
- @classmethod
29
- def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
30
- fps: Fraction) -> 'VideoInfo':
31
- num_frames = int(duration_sec * fps)
32
- all_frames = [image_info.original_frame] * num_frames
33
- return cls(duration_sec=duration_sec,
34
- fps=fps,
35
- clip_frames=image_info.clip_frames,
36
- sync_frames=image_info.sync_frames,
37
- all_frames=all_frames)
38
-
39
-
40
- @dataclass
41
- class ImageInfo:
42
- clip_frames: torch.Tensor
43
- sync_frames: torch.Tensor
44
- original_frame: Optional[np.ndarray]
45
-
46
- @property
47
- def height(self):
48
- return self.original_frame.shape[0]
49
-
50
- @property
51
- def width(self):
52
- return self.original_frame.shape[1]
53
-
54
 
55
  def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
56
  need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
 
25
  def width(self):
26
  return self.all_frames[0].shape[1]
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
30
  need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
mmaudio/data/data_setup.py DELETED
@@ -1,174 +0,0 @@
1
- import logging
2
- import random
3
-
4
- import numpy as np
5
- import torch
6
- from omegaconf import DictConfig
7
- from torch.utils.data import DataLoader, Dataset
8
- from torch.utils.data.dataloader import default_collate
9
- from torch.utils.data.distributed import DistributedSampler
10
-
11
- from mmaudio.data.eval.audiocaps import AudioCapsData
12
- from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
13
- from mmaudio.data.extracted_audio import ExtractedAudio
14
- from mmaudio.data.extracted_vgg import ExtractedVGG
15
- from mmaudio.data.mm_dataset import MultiModalDataset
16
- from mmaudio.utils.dist_utils import local_rank
17
-
18
- log = logging.getLogger()
19
-
20
-
21
- # Re-seed randomness every time we start a worker
22
- def worker_init_fn(worker_id: int):
23
- worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
24
- np.random.seed(worker_seed)
25
- random.seed(worker_seed)
26
- log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
27
-
28
-
29
- def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
30
- dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
31
- data_dim=cfg.data_dim,
32
- premade_mmap_dir=data_cfg.memmap_dir)
33
-
34
- return dataset
35
-
36
-
37
- def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
38
- dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
39
- data_dim=cfg.data_dim,
40
- premade_mmap_dir=data_cfg.memmap_dir)
41
-
42
- return dataset
43
-
44
-
45
- def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
46
- if cfg.mini_train:
47
- vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
48
- audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
49
- dataset = MultiModalDataset([vgg], [audiocaps])
50
- if cfg.example_train:
51
- video = load_vgg_data(cfg, cfg.data.Example_video)
52
- audio = load_audio_data(cfg, cfg.data.Example_audio)
53
- dataset = MultiModalDataset([video], [audio])
54
- else:
55
- # load the largest one first
56
- freesound = load_audio_data(cfg, cfg.data.FreeSound)
57
- vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
58
- audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
59
- audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
60
- bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
61
- clotho = load_audio_data(cfg, cfg.data.Clotho)
62
- dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
63
- [audiocaps, audioset_sl, bbcsound, freesound, clotho])
64
-
65
- batch_size = cfg.batch_size
66
- num_workers = cfg.num_workers
67
- pin_memory = cfg.pin_memory
68
- sampler, loader = construct_loader(dataset,
69
- batch_size,
70
- num_workers,
71
- shuffle=True,
72
- drop_last=True,
73
- pin_memory=pin_memory)
74
-
75
- return dataset, sampler, loader
76
-
77
-
78
- def setup_test_datasets(cfg):
79
- dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
80
-
81
- batch_size = cfg.batch_size
82
- num_workers = cfg.num_workers
83
- pin_memory = cfg.pin_memory
84
- sampler, loader = construct_loader(dataset,
85
- batch_size,
86
- num_workers,
87
- shuffle=False,
88
- drop_last=False,
89
- pin_memory=pin_memory)
90
-
91
- return dataset, sampler, loader
92
-
93
-
94
- def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
95
- if cfg.example_train:
96
- dataset = load_vgg_data(cfg, cfg.data.Example_video)
97
- else:
98
- dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
99
-
100
- val_batch_size = cfg.batch_size
101
- val_eval_batch_size = cfg.eval_batch_size
102
- num_workers = cfg.num_workers
103
- pin_memory = cfg.pin_memory
104
- _, val_loader = construct_loader(dataset,
105
- val_batch_size,
106
- num_workers,
107
- shuffle=False,
108
- drop_last=False,
109
- pin_memory=pin_memory)
110
- _, eval_loader = construct_loader(dataset,
111
- val_eval_batch_size,
112
- num_workers,
113
- shuffle=False,
114
- drop_last=False,
115
- pin_memory=pin_memory)
116
-
117
- return dataset, val_loader, eval_loader
118
-
119
-
120
- def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
121
- if dataset_name.startswith('audiocaps_full'):
122
- dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
123
- cfg.eval_data.AudioCaps_full.csv_path)
124
- elif dataset_name.startswith('audiocaps'):
125
- dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
126
- cfg.eval_data.AudioCaps.csv_path)
127
- elif dataset_name.startswith('moviegen'):
128
- dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
129
- cfg.eval_data.MovieGen.jsonl_path,
130
- duration_sec=cfg.duration_s)
131
- elif dataset_name.startswith('vggsound'):
132
- dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
133
- cfg.eval_data.VGGSound.csv_path,
134
- duration_sec=cfg.duration_s)
135
- else:
136
- raise ValueError(f'Invalid dataset name: {dataset_name}')
137
-
138
- batch_size = cfg.batch_size
139
- num_workers = cfg.num_workers
140
- pin_memory = cfg.pin_memory
141
- _, loader = construct_loader(dataset,
142
- batch_size,
143
- num_workers,
144
- shuffle=False,
145
- drop_last=False,
146
- pin_memory=pin_memory,
147
- error_avoidance=True)
148
- return dataset, loader
149
-
150
-
151
- def error_avoidance_collate(batch):
152
- batch = list(filter(lambda x: x is not None, batch))
153
- return default_collate(batch)
154
-
155
-
156
- def construct_loader(dataset: Dataset,
157
- batch_size: int,
158
- num_workers: int,
159
- *,
160
- shuffle: bool = True,
161
- drop_last: bool = True,
162
- pin_memory: bool = False,
163
- error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
164
- train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
165
- train_loader = DataLoader(dataset,
166
- batch_size,
167
- sampler=train_sampler,
168
- num_workers=num_workers,
169
- worker_init_fn=worker_init_fn,
170
- drop_last=drop_last,
171
- persistent_workers=num_workers > 0,
172
- pin_memory=pin_memory,
173
- collate_fn=error_avoidance_collate if error_avoidance else None)
174
- return train_sampler, train_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/eval/__init__.py DELETED
File without changes
mmaudio/data/eval/audiocaps.py DELETED
@@ -1,39 +0,0 @@
1
- import logging
2
- import os
3
- from collections import defaultdict
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import pandas as pd
8
- import torch
9
- from torch.utils.data.dataset import Dataset
10
-
11
- log = logging.getLogger()
12
-
13
-
14
- class AudioCapsData(Dataset):
15
-
16
- def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
17
- df = pd.read_csv(csv_path).to_dict(orient='records')
18
-
19
- audio_files = sorted(os.listdir(audio_path))
20
- audio_files = set(
21
- [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
22
-
23
- self.data = []
24
- for row in df:
25
- self.data.append({
26
- 'name': row['name'],
27
- 'caption': row['caption'],
28
- })
29
-
30
- self.audio_path = Path(audio_path)
31
- self.csv_path = Path(csv_path)
32
-
33
- log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
34
-
35
- def __getitem__(self, idx: int) -> torch.Tensor:
36
- return self.data[idx]
37
-
38
- def __len__(self):
39
- return len(self.data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/eval/moviegen.py DELETED
@@ -1,131 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import torch
8
- from torch.utils.data.dataset import Dataset
9
- from torchvision.transforms import v2
10
- from torio.io import StreamingMediaDecoder
11
-
12
- from mmaudio.utils.dist_utils import local_rank
13
-
14
- log = logging.getLogger()
15
-
16
- _CLIP_SIZE = 384
17
- _CLIP_FPS = 8.0
18
-
19
- _SYNC_SIZE = 224
20
- _SYNC_FPS = 25.0
21
-
22
-
23
- class MovieGenData(Dataset):
24
-
25
- def __init__(
26
- self,
27
- video_root: Union[str, Path],
28
- sync_root: Union[str, Path],
29
- jsonl_root: Union[str, Path],
30
- *,
31
- duration_sec: float = 10.0,
32
- read_clip: bool = True,
33
- ):
34
- self.video_root = Path(video_root)
35
- self.sync_root = Path(sync_root)
36
- self.jsonl_root = Path(jsonl_root)
37
- self.read_clip = read_clip
38
-
39
- videos = sorted(os.listdir(self.video_root))
40
- videos = [v[:-4] for v in videos] # remove extensions
41
- self.captions = {}
42
-
43
- for v in videos:
44
- with open(self.jsonl_root / (v + '.jsonl')) as f:
45
- data = json.load(f)
46
- self.captions[v] = data['audio_prompt']
47
-
48
- if local_rank == 0:
49
- log.info(f'{len(videos)} videos found in {video_root}')
50
-
51
- self.duration_sec = duration_sec
52
-
53
- self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
54
- self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
55
-
56
- self.clip_augment = v2.Compose([
57
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
58
- v2.ToImage(),
59
- v2.ToDtype(torch.float32, scale=True),
60
- ])
61
-
62
- self.sync_augment = v2.Compose([
63
- v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
64
- v2.CenterCrop(_SYNC_SIZE),
65
- v2.ToImage(),
66
- v2.ToDtype(torch.float32, scale=True),
67
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
68
- ])
69
-
70
- self.videos = videos
71
-
72
- def sample(self, idx: int) -> dict[str, torch.Tensor]:
73
- video_id = self.videos[idx]
74
- caption = self.captions[video_id]
75
-
76
- reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
77
- reader.add_basic_video_stream(
78
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
79
- frame_rate=_CLIP_FPS,
80
- format='rgb24',
81
- )
82
- reader.add_basic_video_stream(
83
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
84
- frame_rate=_SYNC_FPS,
85
- format='rgb24',
86
- )
87
-
88
- reader.fill_buffer()
89
- data_chunk = reader.pop_chunks()
90
-
91
- clip_chunk = data_chunk[0]
92
- sync_chunk = data_chunk[1]
93
- if clip_chunk is None:
94
- raise RuntimeError(f'CLIP video returned None {video_id}')
95
- if clip_chunk.shape[0] < self.clip_expected_length:
96
- raise RuntimeError(f'CLIP video too short {video_id}')
97
-
98
- if sync_chunk is None:
99
- raise RuntimeError(f'Sync video returned None {video_id}')
100
- if sync_chunk.shape[0] < self.sync_expected_length:
101
- raise RuntimeError(f'Sync video too short {video_id}')
102
-
103
- # truncate the video
104
- clip_chunk = clip_chunk[:self.clip_expected_length]
105
- if clip_chunk.shape[0] != self.clip_expected_length:
106
- raise RuntimeError(f'CLIP video wrong length {video_id}, '
107
- f'expected {self.clip_expected_length}, '
108
- f'got {clip_chunk.shape[0]}')
109
- clip_chunk = self.clip_augment(clip_chunk)
110
-
111
- sync_chunk = sync_chunk[:self.sync_expected_length]
112
- if sync_chunk.shape[0] != self.sync_expected_length:
113
- raise RuntimeError(f'Sync video wrong length {video_id}, '
114
- f'expected {self.sync_expected_length}, '
115
- f'got {sync_chunk.shape[0]}')
116
- sync_chunk = self.sync_augment(sync_chunk)
117
-
118
- data = {
119
- 'name': video_id,
120
- 'caption': caption,
121
- 'clip_video': clip_chunk,
122
- 'sync_video': sync_chunk,
123
- }
124
-
125
- return data
126
-
127
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
128
- return self.sample(idx)
129
-
130
- def __len__(self):
131
- return len(self.captions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/eval/video_dataset.py DELETED
@@ -1,197 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- from pathlib import Path
5
- from typing import Union
6
-
7
- import pandas as pd
8
- import torch
9
- from torch.utils.data.dataset import Dataset
10
- from torchvision.transforms import v2
11
- from torio.io import StreamingMediaDecoder
12
-
13
- from mmaudio.utils.dist_utils import local_rank
14
-
15
- log = logging.getLogger()
16
-
17
- _CLIP_SIZE = 384
18
- _CLIP_FPS = 8.0
19
-
20
- _SYNC_SIZE = 224
21
- _SYNC_FPS = 25.0
22
-
23
-
24
- class VideoDataset(Dataset):
25
-
26
- def __init__(
27
- self,
28
- video_root: Union[str, Path],
29
- *,
30
- duration_sec: float = 8.0,
31
- ):
32
- self.video_root = Path(video_root)
33
-
34
- self.duration_sec = duration_sec
35
-
36
- self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
37
- self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
38
-
39
- self.clip_transform = v2.Compose([
40
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
41
- v2.ToImage(),
42
- v2.ToDtype(torch.float32, scale=True),
43
- ])
44
-
45
- self.sync_transform = v2.Compose([
46
- v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
47
- v2.CenterCrop(_SYNC_SIZE),
48
- v2.ToImage(),
49
- v2.ToDtype(torch.float32, scale=True),
50
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
51
- ])
52
-
53
- # to be implemented by subclasses
54
- self.captions = {}
55
- self.videos = sorted(list(self.captions.keys()))
56
-
57
- def sample(self, idx: int) -> dict[str, torch.Tensor]:
58
- video_id = self.videos[idx]
59
- caption = self.captions[video_id]
60
-
61
- reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
62
- reader.add_basic_video_stream(
63
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
64
- frame_rate=_CLIP_FPS,
65
- format='rgb24',
66
- )
67
- reader.add_basic_video_stream(
68
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
69
- frame_rate=_SYNC_FPS,
70
- format='rgb24',
71
- )
72
-
73
- reader.fill_buffer()
74
- data_chunk = reader.pop_chunks()
75
-
76
- clip_chunk = data_chunk[0]
77
- sync_chunk = data_chunk[1]
78
- if clip_chunk is None:
79
- raise RuntimeError(f'CLIP video returned None {video_id}')
80
- if clip_chunk.shape[0] < self.clip_expected_length:
81
- raise RuntimeError(
82
- f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
83
- )
84
-
85
- if sync_chunk is None:
86
- raise RuntimeError(f'Sync video returned None {video_id}')
87
- if sync_chunk.shape[0] < self.sync_expected_length:
88
- raise RuntimeError(
89
- f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
90
- )
91
-
92
- # truncate the video
93
- clip_chunk = clip_chunk[:self.clip_expected_length]
94
- if clip_chunk.shape[0] != self.clip_expected_length:
95
- raise RuntimeError(f'CLIP video wrong length {video_id}, '
96
- f'expected {self.clip_expected_length}, '
97
- f'got {clip_chunk.shape[0]}')
98
- clip_chunk = self.clip_transform(clip_chunk)
99
-
100
- sync_chunk = sync_chunk[:self.sync_expected_length]
101
- if sync_chunk.shape[0] != self.sync_expected_length:
102
- raise RuntimeError(f'Sync video wrong length {video_id}, '
103
- f'expected {self.sync_expected_length}, '
104
- f'got {sync_chunk.shape[0]}')
105
- sync_chunk = self.sync_transform(sync_chunk)
106
-
107
- data = {
108
- 'name': video_id,
109
- 'caption': caption,
110
- 'clip_video': clip_chunk,
111
- 'sync_video': sync_chunk,
112
- }
113
-
114
- return data
115
-
116
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
117
- try:
118
- return self.sample(idx)
119
- except Exception as e:
120
- log.error(f'Error loading video {self.videos[idx]}: {e}')
121
- return None
122
-
123
- def __len__(self):
124
- return len(self.captions)
125
-
126
-
127
- class VGGSound(VideoDataset):
128
-
129
- def __init__(
130
- self,
131
- video_root: Union[str, Path],
132
- csv_path: Union[str, Path],
133
- *,
134
- duration_sec: float = 8.0,
135
- ):
136
- super().__init__(video_root, duration_sec=duration_sec)
137
- self.video_root = Path(video_root)
138
- self.csv_path = Path(csv_path)
139
-
140
- videos = sorted(os.listdir(self.video_root))
141
- if local_rank == 0:
142
- log.info(f'{len(videos)} videos found in {video_root}')
143
- self.captions = {}
144
-
145
- df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
146
- 'split']).to_dict(orient='records')
147
-
148
- videos_no_found = []
149
- for row in df:
150
- if row['split'] == 'test':
151
- start_sec = int(row['sec'])
152
- video_id = str(row['id'])
153
- # this is how our videos are named
154
- video_name = f'{video_id}_{start_sec:06d}'
155
- if video_name + '.mp4' not in videos:
156
- videos_no_found.append(video_name)
157
- continue
158
-
159
- self.captions[video_name] = row['caption']
160
-
161
- if local_rank == 0:
162
- log.info(f'{len(videos)} videos found in {video_root}')
163
- log.info(f'{len(self.captions)} useable videos found')
164
- if videos_no_found:
165
- log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
166
- log.info(
167
- 'A small amount is expected, as not all videos are still available on YouTube')
168
-
169
- self.videos = sorted(list(self.captions.keys()))
170
-
171
-
172
- class MovieGen(VideoDataset):
173
-
174
- def __init__(
175
- self,
176
- video_root: Union[str, Path],
177
- jsonl_root: Union[str, Path],
178
- *,
179
- duration_sec: float = 10.0,
180
- ):
181
- super().__init__(video_root, duration_sec=duration_sec)
182
- self.video_root = Path(video_root)
183
- self.jsonl_root = Path(jsonl_root)
184
-
185
- videos = sorted(os.listdir(self.video_root))
186
- videos = [v[:-4] for v in videos] # remove extensions
187
- self.captions = {}
188
-
189
- for v in videos:
190
- with open(self.jsonl_root / (v + '.jsonl')) as f:
191
- data = json.load(f)
192
- self.captions[v] = data['audio_prompt']
193
-
194
- if local_rank == 0:
195
- log.info(f'{len(videos)} videos found in {video_root}')
196
-
197
- self.videos = videos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extracted_audio.py DELETED
@@ -1,88 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Union
4
-
5
- import pandas as pd
6
- import torch
7
- from tensordict import TensorDict
8
- from torch.utils.data.dataset import Dataset
9
-
10
- from mmaudio.utils.dist_utils import local_rank
11
-
12
- log = logging.getLogger()
13
-
14
-
15
- class ExtractedAudio(Dataset):
16
-
17
- def __init__(
18
- self,
19
- tsv_path: Union[str, Path],
20
- *,
21
- premade_mmap_dir: Union[str, Path],
22
- data_dim: dict[str, int],
23
- ):
24
- super().__init__()
25
-
26
- self.data_dim = data_dim
27
- self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
- self.ids = [str(d['id']) for d in self.df_list]
29
-
30
- log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
- # load precomputed memory mapped tensors
32
- premade_mmap_dir = Path(premade_mmap_dir)
33
- td = TensorDict.load_memmap(premade_mmap_dir)
34
- log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
- self.mean = td['mean']
36
- self.std = td['std']
37
- self.text_features = td['text_features']
38
-
39
- log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
40
- log.info(f'Loaded mean: {self.mean.shape}.')
41
- log.info(f'Loaded std: {self.std.shape}.')
42
- log.info(f'Loaded text features: {self.text_features.shape}.')
43
-
44
- assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
45
- f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
46
- assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
47
- f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
48
-
49
- assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
50
- f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
51
- assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
52
- f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
53
-
54
- self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
55
- self.data_dim['clip_dim'])
56
- self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
57
- self.data_dim['sync_dim'])
58
- self.video_exist = torch.tensor(0, dtype=torch.bool)
59
- self.text_exist = torch.tensor(1, dtype=torch.bool)
60
-
61
- def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
62
- latents = self.mean
63
- return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
64
-
65
- def get_memory_mapped_tensor(self) -> TensorDict:
66
- td = TensorDict({
67
- 'mean': self.mean,
68
- 'std': self.std,
69
- 'text_features': self.text_features,
70
- })
71
- return td
72
-
73
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
74
- data = {
75
- 'id': str(self.df_list[idx]['id']),
76
- 'a_mean': self.mean[idx],
77
- 'a_std': self.std[idx],
78
- 'clip_features': self.fake_clip_features,
79
- 'sync_features': self.fake_sync_features,
80
- 'text_features': self.text_features[idx],
81
- 'caption': self.df_list[idx]['caption'],
82
- 'video_exist': self.video_exist,
83
- 'text_exist': self.text_exist,
84
- }
85
- return data
86
-
87
- def __len__(self):
88
- return len(self.ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extracted_vgg.py DELETED
@@ -1,101 +0,0 @@
1
- import logging
2
- from pathlib import Path
3
- from typing import Union
4
-
5
- import pandas as pd
6
- import torch
7
- from tensordict import TensorDict
8
- from torch.utils.data.dataset import Dataset
9
-
10
- from mmaudio.utils.dist_utils import local_rank
11
-
12
- log = logging.getLogger()
13
-
14
-
15
- class ExtractedVGG(Dataset):
16
-
17
- def __init__(
18
- self,
19
- tsv_path: Union[str, Path],
20
- *,
21
- premade_mmap_dir: Union[str, Path],
22
- data_dim: dict[str, int],
23
- ):
24
- super().__init__()
25
-
26
- self.data_dim = data_dim
27
- self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
- self.ids = [d['id'] for d in self.df_list]
29
-
30
- log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
- # load precomputed memory mapped tensors
32
- premade_mmap_dir = Path(premade_mmap_dir)
33
- td = TensorDict.load_memmap(premade_mmap_dir)
34
- log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
- self.mean = td['mean']
36
- self.std = td['std']
37
- self.clip_features = td['clip_features']
38
- self.sync_features = td['sync_features']
39
- self.text_features = td['text_features']
40
-
41
- if local_rank == 0:
42
- log.info(f'Loaded {len(self)} samples.')
43
- log.info(f'Loaded mean: {self.mean.shape}.')
44
- log.info(f'Loaded std: {self.std.shape}.')
45
- log.info(f'Loaded clip_features: {self.clip_features.shape}.')
46
- log.info(f'Loaded sync_features: {self.sync_features.shape}.')
47
- log.info(f'Loaded text_features: {self.text_features.shape}.')
48
-
49
- assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
50
- f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
51
- assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
52
- f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
53
-
54
- assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
55
- f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
56
- assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
57
- f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
58
- assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
59
- f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
60
-
61
- assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
62
- f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
63
- assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
64
- f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
65
- assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
66
- f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
67
-
68
- self.video_exist = torch.tensor(1, dtype=torch.bool)
69
- self.text_exist = torch.tensor(1, dtype=torch.bool)
70
-
71
- def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
72
- latents = self.mean
73
- return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
74
-
75
- def get_memory_mapped_tensor(self) -> TensorDict:
76
- td = TensorDict({
77
- 'mean': self.mean,
78
- 'std': self.std,
79
- 'clip_features': self.clip_features,
80
- 'sync_features': self.sync_features,
81
- 'text_features': self.text_features,
82
- })
83
- return td
84
-
85
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
86
- data = {
87
- 'id': self.df_list[idx]['id'],
88
- 'a_mean': self.mean[idx],
89
- 'a_std': self.std[idx],
90
- 'clip_features': self.clip_features[idx],
91
- 'sync_features': self.sync_features[idx],
92
- 'text_features': self.text_features[idx],
93
- 'caption': self.df_list[idx]['label'],
94
- 'video_exist': self.video_exist,
95
- 'text_exist': self.text_exist,
96
- }
97
-
98
- return data
99
-
100
- def __len__(self):
101
- return len(self.ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extraction/__init__.py DELETED
File without changes
mmaudio/data/extraction/vgg_sound.py DELETED
@@ -1,193 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Optional, Union
5
-
6
- import pandas as pd
7
- import torch
8
- import torchaudio
9
- from torch.utils.data.dataset import Dataset
10
- from torchvision.transforms import v2
11
- from torio.io import StreamingMediaDecoder
12
-
13
- from mmaudio.utils.dist_utils import local_rank
14
-
15
- log = logging.getLogger()
16
-
17
- _CLIP_SIZE = 384
18
- _CLIP_FPS = 8.0
19
-
20
- _SYNC_SIZE = 224
21
- _SYNC_FPS = 25.0
22
-
23
-
24
- class VGGSound(Dataset):
25
-
26
- def __init__(
27
- self,
28
- root: Union[str, Path],
29
- *,
30
- tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
31
- sample_rate: int = 16_000,
32
- duration_sec: float = 8.0,
33
- audio_samples: Optional[int] = None,
34
- normalize_audio: bool = False,
35
- ):
36
- self.root = Path(root)
37
- self.normalize_audio = normalize_audio
38
- if audio_samples is None:
39
- self.audio_samples = int(sample_rate * duration_sec)
40
- else:
41
- self.audio_samples = audio_samples
42
- effective_duration = audio_samples / sample_rate
43
- # make sure the duration is close enough, within 15ms
44
- assert abs(effective_duration - duration_sec) < 0.015, \
45
- f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
46
-
47
- videos = sorted(os.listdir(self.root))
48
- videos = set([Path(v).stem for v in videos]) # remove extensions
49
- self.labels = {}
50
- self.videos = []
51
- missing_videos = []
52
-
53
- # read the tsv for subset information
54
- df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
55
- for record in df_list:
56
- id = record['id']
57
- label = record['label']
58
- if id in videos:
59
- self.labels[id] = label
60
- self.videos.append(id)
61
- else:
62
- missing_videos.append(id)
63
-
64
- if local_rank == 0:
65
- log.info(f'{len(videos)} videos found in {root}')
66
- log.info(f'{len(self.videos)} videos found in {tsv_path}')
67
- log.info(f'{len(missing_videos)} videos missing in {root}')
68
-
69
- self.sample_rate = sample_rate
70
- self.duration_sec = duration_sec
71
-
72
- self.expected_audio_length = audio_samples
73
- self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
74
- self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
75
-
76
- self.clip_transform = v2.Compose([
77
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
78
- v2.ToImage(),
79
- v2.ToDtype(torch.float32, scale=True),
80
- ])
81
-
82
- self.sync_transform = v2.Compose([
83
- v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
84
- v2.CenterCrop(_SYNC_SIZE),
85
- v2.ToImage(),
86
- v2.ToDtype(torch.float32, scale=True),
87
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
88
- ])
89
-
90
- self.resampler = {}
91
-
92
- def sample(self, idx: int) -> dict[str, torch.Tensor]:
93
- video_id = self.videos[idx]
94
- label = self.labels[video_id]
95
-
96
- reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
97
- reader.add_basic_video_stream(
98
- frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
99
- frame_rate=_CLIP_FPS,
100
- format='rgb24',
101
- )
102
- reader.add_basic_video_stream(
103
- frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
104
- frame_rate=_SYNC_FPS,
105
- format='rgb24',
106
- )
107
- reader.add_basic_audio_stream(frames_per_chunk=2**30, )
108
-
109
- reader.fill_buffer()
110
- data_chunk = reader.pop_chunks()
111
-
112
- clip_chunk = data_chunk[0]
113
- sync_chunk = data_chunk[1]
114
- audio_chunk = data_chunk[2]
115
-
116
- if clip_chunk is None:
117
- raise RuntimeError(f'CLIP video returned None {video_id}')
118
- if clip_chunk.shape[0] < self.clip_expected_length:
119
- raise RuntimeError(
120
- f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
121
- )
122
-
123
- if sync_chunk is None:
124
- raise RuntimeError(f'Sync video returned None {video_id}')
125
- if sync_chunk.shape[0] < self.sync_expected_length:
126
- raise RuntimeError(
127
- f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
128
- )
129
-
130
- # process audio
131
- sample_rate = int(reader.get_out_stream_info(2).sample_rate)
132
- audio_chunk = audio_chunk.transpose(0, 1)
133
- audio_chunk = audio_chunk.mean(dim=0) # mono
134
- if self.normalize_audio:
135
- abs_max = audio_chunk.abs().max()
136
- audio_chunk = audio_chunk / abs_max * 0.95
137
- if abs_max <= 1e-6:
138
- raise RuntimeError(f'Audio is silent {video_id}')
139
-
140
- # resample
141
- if sample_rate == self.sample_rate:
142
- audio_chunk = audio_chunk
143
- else:
144
- if sample_rate not in self.resampler:
145
- # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
146
- self.resampler[sample_rate] = torchaudio.transforms.Resample(
147
- sample_rate,
148
- self.sample_rate,
149
- lowpass_filter_width=64,
150
- rolloff=0.9475937167399596,
151
- resampling_method='sinc_interp_kaiser',
152
- beta=14.769656459379492,
153
- )
154
- audio_chunk = self.resampler[sample_rate](audio_chunk)
155
-
156
- if audio_chunk.shape[0] < self.expected_audio_length:
157
- raise RuntimeError(f'Audio too short {video_id}')
158
- audio_chunk = audio_chunk[:self.expected_audio_length]
159
-
160
- # truncate the video
161
- clip_chunk = clip_chunk[:self.clip_expected_length]
162
- if clip_chunk.shape[0] != self.clip_expected_length:
163
- raise RuntimeError(f'CLIP video wrong length {video_id}, '
164
- f'expected {self.clip_expected_length}, '
165
- f'got {clip_chunk.shape[0]}')
166
- clip_chunk = self.clip_transform(clip_chunk)
167
-
168
- sync_chunk = sync_chunk[:self.sync_expected_length]
169
- if sync_chunk.shape[0] != self.sync_expected_length:
170
- raise RuntimeError(f'Sync video wrong length {video_id}, '
171
- f'expected {self.sync_expected_length}, '
172
- f'got {sync_chunk.shape[0]}')
173
- sync_chunk = self.sync_transform(sync_chunk)
174
-
175
- data = {
176
- 'id': video_id,
177
- 'caption': label,
178
- 'audio': audio_chunk,
179
- 'clip_video': clip_chunk,
180
- 'sync_video': sync_chunk,
181
- }
182
-
183
- return data
184
-
185
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
186
- try:
187
- return self.sample(idx)
188
- except Exception as e:
189
- log.error(f'Error loading video {self.videos[idx]}: {e}')
190
- return None
191
-
192
- def __len__(self):
193
- return len(self.labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/extraction/wav_dataset.py DELETED
@@ -1,132 +0,0 @@
1
- import logging
2
- import os
3
- from pathlib import Path
4
- from typing import Union
5
-
6
- import open_clip
7
- import pandas as pd
8
- import torch
9
- import torchaudio
10
- from torch.utils.data.dataset import Dataset
11
-
12
- log = logging.getLogger()
13
-
14
-
15
- class WavTextClipsDataset(Dataset):
16
-
17
- def __init__(
18
- self,
19
- root: Union[str, Path],
20
- *,
21
- captions_tsv: Union[str, Path],
22
- clips_tsv: Union[str, Path],
23
- sample_rate: int,
24
- num_samples: int,
25
- normalize_audio: bool = False,
26
- reject_silent: bool = False,
27
- tokenizer_id: str = 'ViT-H-14-378-quickgelu',
28
- ):
29
- self.root = Path(root)
30
- self.sample_rate = sample_rate
31
- self.num_samples = num_samples
32
- self.normalize_audio = normalize_audio
33
- self.reject_silent = reject_silent
34
- self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
35
-
36
- audios = sorted(os.listdir(self.root))
37
- audios = set([
38
- Path(audio).stem for audio in audios
39
- if audio.endswith('.wav') or audio.endswith('.flac')
40
- ])
41
- self.captions = {}
42
-
43
- # read the caption tsv
44
- df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
45
- for record in df_list:
46
- id = record['id']
47
- caption = record['caption']
48
- self.captions[id] = caption
49
-
50
- # read the clip tsv
51
- df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
52
- 'id': str,
53
- 'name': str
54
- }).to_dict('records')
55
- self.clips = []
56
- for record in df_list:
57
- record['id'] = record['id']
58
- record['name'] = record['name']
59
- id = record['id']
60
- name = record['name']
61
- if name not in self.captions:
62
- log.warning(f'Audio {name} not found in {captions_tsv}')
63
- continue
64
- record['caption'] = self.captions[name]
65
- self.clips.append(record)
66
-
67
- log.info(f'Found {len(self.clips)} audio files in {self.root}')
68
-
69
- self.resampler = {}
70
-
71
- def __getitem__(self, idx: int) -> torch.Tensor:
72
- try:
73
- clip = self.clips[idx]
74
- audio_name = clip['name']
75
- audio_id = clip['id']
76
- caption = clip['caption']
77
- start_sample = clip['start_sample']
78
- end_sample = clip['end_sample']
79
-
80
- audio_path = self.root / f'{audio_name}.flac'
81
- if not audio_path.exists():
82
- audio_path = self.root / f'{audio_name}.wav'
83
- assert audio_path.exists()
84
-
85
- audio_chunk, sample_rate = torchaudio.load(audio_path)
86
- audio_chunk = audio_chunk.mean(dim=0) # mono
87
- abs_max = audio_chunk.abs().max()
88
- if self.normalize_audio:
89
- audio_chunk = audio_chunk / abs_max * 0.95
90
-
91
- if self.reject_silent and abs_max < 1e-6:
92
- log.warning(f'Rejecting silent audio')
93
- return None
94
-
95
- audio_chunk = audio_chunk[start_sample:end_sample]
96
-
97
- # resample
98
- if sample_rate == self.sample_rate:
99
- audio_chunk = audio_chunk
100
- else:
101
- if sample_rate not in self.resampler:
102
- # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
103
- self.resampler[sample_rate] = torchaudio.transforms.Resample(
104
- sample_rate,
105
- self.sample_rate,
106
- lowpass_filter_width=64,
107
- rolloff=0.9475937167399596,
108
- resampling_method='sinc_interp_kaiser',
109
- beta=14.769656459379492,
110
- )
111
- audio_chunk = self.resampler[sample_rate](audio_chunk)
112
-
113
- if audio_chunk.shape[0] < self.num_samples:
114
- raise ValueError('Audio is too short')
115
- audio_chunk = audio_chunk[:self.num_samples]
116
-
117
- tokens = self.tokenizer([caption])[0]
118
-
119
- output = {
120
- 'waveform': audio_chunk,
121
- 'id': audio_id,
122
- 'caption': caption,
123
- 'tokens': tokens,
124
- }
125
-
126
- return output
127
- except Exception as e:
128
- log.error(f'Error reading {audio_path}: {e}')
129
- return None
130
-
131
- def __len__(self):
132
- return len(self.clips)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/mm_dataset.py DELETED
@@ -1,45 +0,0 @@
1
- import bisect
2
-
3
- import torch
4
- from torch.utils.data.dataset import Dataset
5
-
6
-
7
- # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
8
- class MultiModalDataset(Dataset):
9
- datasets: list[Dataset]
10
- cumulative_sizes: list[int]
11
-
12
- @staticmethod
13
- def cumsum(sequence):
14
- r, s = [], 0
15
- for e in sequence:
16
- l = len(e)
17
- r.append(l + s)
18
- s += l
19
- return r
20
-
21
- def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
22
- super().__init__()
23
- self.video_datasets = list(video_datasets)
24
- self.audio_datasets = list(audio_datasets)
25
- self.datasets = self.video_datasets + self.audio_datasets
26
-
27
- self.cumulative_sizes = self.cumsum(self.datasets)
28
-
29
- def __len__(self):
30
- return self.cumulative_sizes[-1]
31
-
32
- def __getitem__(self, idx):
33
- if idx < 0:
34
- if -idx > len(self):
35
- raise ValueError("absolute value of index should not exceed dataset length")
36
- idx = len(self) + idx
37
- dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
38
- if dataset_idx == 0:
39
- sample_idx = idx
40
- else:
41
- sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
42
- return self.datasets[dataset_idx][sample_idx]
43
-
44
- def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
45
- return self.video_datasets[0].compute_latent_stats()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/data/utils.py DELETED
@@ -1,148 +0,0 @@
1
- import logging
2
- import os
3
- import random
4
- import tempfile
5
- from pathlib import Path
6
- from typing import Any, Optional, Union
7
-
8
- import torch
9
- import torch.distributed as dist
10
- from tensordict import MemoryMappedTensor
11
- from torch.utils.data import DataLoader
12
- from torch.utils.data.dataset import Dataset
13
- from tqdm import tqdm
14
-
15
- from mmaudio.utils.dist_utils import local_rank, world_size
16
-
17
- scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
18
- shm_path = Path('/dev/shm')
19
-
20
- log = logging.getLogger()
21
-
22
-
23
- def reseed(seed):
24
- random.seed(seed)
25
- torch.manual_seed(seed)
26
-
27
-
28
- def local_scatter_torch(obj: Optional[Any]):
29
- if world_size == 1:
30
- # Just one worker. Do nothing.
31
- return obj
32
-
33
- array = [obj] * world_size
34
- target_array = [None]
35
- if local_rank == 0:
36
- dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
37
- else:
38
- dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
39
- return target_array[0]
40
-
41
-
42
- class ShardDataset(Dataset):
43
-
44
- def __init__(self, root):
45
- self.root = root
46
- self.shards = sorted(os.listdir(root))
47
-
48
- def __len__(self):
49
- return len(self.shards)
50
-
51
- def __getitem__(self, idx):
52
- return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
53
-
54
-
55
- def get_tmp_dir(in_memory: bool) -> Path:
56
- return shm_path if in_memory else scratch_path
57
-
58
-
59
- def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
60
- in_memory: bool) -> MemoryMappedTensor:
61
- if local_rank == 0:
62
- with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
63
- log.info(f'Loading shards from {data_path} into {f.name}...')
64
- data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
65
- data = share_tensor_to_all(data)
66
- torch.distributed.barrier()
67
- f.close() # why does the context manager not close the file for me?
68
- else:
69
- log.info('Waiting for the data to be shared with me...')
70
- data = share_tensor_to_all(None)
71
- torch.distributed.barrier()
72
-
73
- return data
74
-
75
-
76
- def load_shards(
77
- data_path: Union[str, Path],
78
- ids: list[int],
79
- *,
80
- tmp_file_path: str,
81
- ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
82
-
83
- id_set = set(ids)
84
- shards = sorted(os.listdir(data_path))
85
- log.info(f'Found {len(shards)} shards in {data_path}.')
86
- first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
87
-
88
- log.info(f'Rank {local_rank} created file {tmp_file_path}')
89
- first_item = next(iter(first_shard.values()))
90
- log.info(f'First item shape: {first_item.shape}')
91
- mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
92
- dtype=torch.float32,
93
- filename=tmp_file_path,
94
- existsok=True)
95
- total_count = 0
96
- used_index = set()
97
- id_indexing = {i: idx for idx, i in enumerate(ids)}
98
- # faster with no workers; otherwise we need to set_sharing_strategy('file_system')
99
- loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
100
- for data in tqdm(loader, desc='Loading shards'):
101
- for i, v in data.items():
102
- if i not in id_set:
103
- continue
104
-
105
- # tensor_index = ids.index(i)
106
- tensor_index = id_indexing[i]
107
- if tensor_index in used_index:
108
- raise ValueError(f'Duplicate id {i} found in {data_path}.')
109
- used_index.add(tensor_index)
110
- mm_tensor[tensor_index] = v
111
- total_count += 1
112
-
113
- assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
114
- log.info(f'Loaded {total_count} tensors from {data_path}.')
115
-
116
- return mm_tensor
117
-
118
-
119
- def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
120
- """
121
- x: the tensor to be shared; None if local_rank != 0
122
- return: the shared tensor
123
- """
124
-
125
- # there is no need to share your stuff with anyone if you are alone; must be in memory
126
- if world_size == 1:
127
- return x
128
-
129
- if local_rank == 0:
130
- assert x is not None, 'x must not be None if local_rank == 0'
131
- else:
132
- assert x is None, 'x must be None if local_rank != 0'
133
-
134
- if local_rank == 0:
135
- filename = x.filename
136
- meta_information = (filename, x.shape, x.dtype)
137
- else:
138
- meta_information = None
139
-
140
- filename, data_shape, data_type = local_scatter_torch(meta_information)
141
- if local_rank == 0:
142
- data = x
143
- else:
144
- data = MemoryMappedTensor.from_filename(filename=filename,
145
- dtype=data_type,
146
- shape=data_shape)
147
-
148
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/eval_utils.py CHANGED
@@ -3,16 +3,14 @@ import logging
3
  from pathlib import Path
4
  from typing import Optional
5
 
6
- import numpy as np
7
  import torch
8
  from colorlog import ColoredFormatter
9
- from PIL import Image
10
  from torchvision.transforms import v2
11
 
12
- from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
13
  from mmaudio.model.flow_matching import FlowMatching
14
  from mmaudio.model.networks import MMAudio
15
- from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
16
  from mmaudio.model.utils.features_utils import FeaturesUtils
17
  from mmaudio.utils.download_utils import download_model_if_needed
18
 
@@ -90,7 +88,6 @@ def generate(
90
  cfg_strength: float,
91
  clip_batch_size_multiplier: int = 40,
92
  sync_batch_size_multiplier: int = 40,
93
- image_input: bool = False,
94
  ) -> torch.Tensor:
95
  device = feature_utils.device
96
  dtype = feature_utils.dtype
@@ -101,12 +98,10 @@ def generate(
101
  clip_features = feature_utils.encode_video_with_clip(clip_video,
102
  batch_size=bs *
103
  clip_batch_size_multiplier)
104
- if image_input:
105
- clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
106
  else:
107
  clip_features = net.get_empty_clip_sequence(bs)
108
 
109
- if sync_video is not None and not image_input:
110
  sync_video = sync_video.to(device, dtype, non_blocking=True)
111
  sync_features = feature_utils.encode_video_with_sync(sync_video,
112
  batch_size=bs *
@@ -144,7 +139,7 @@ def generate(
144
  return audio
145
 
146
 
147
- LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
148
 
149
 
150
  def setup_eval_logging(log_level: int = logging.INFO):
@@ -158,14 +153,12 @@ def setup_eval_logging(log_level: int = logging.INFO):
158
  log.addHandler(stream)
159
 
160
 
161
- _CLIP_SIZE = 384
162
- _CLIP_FPS = 8.0
163
-
164
- _SYNC_SIZE = 224
165
- _SYNC_FPS = 25.0
166
-
167
-
168
  def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
 
 
 
 
 
169
 
170
  clip_transform = v2.Compose([
171
  v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
@@ -220,36 +213,5 @@ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = Tr
220
  return video_info
221
 
222
 
223
- def load_image(image_path: Path) -> VideoInfo:
224
- clip_transform = v2.Compose([
225
- v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
226
- v2.ToImage(),
227
- v2.ToDtype(torch.float32, scale=True),
228
- ])
229
-
230
- sync_transform = v2.Compose([
231
- v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
232
- v2.CenterCrop(_SYNC_SIZE),
233
- v2.ToImage(),
234
- v2.ToDtype(torch.float32, scale=True),
235
- v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
236
- ])
237
-
238
- frame = np.array(Image.open(image_path))
239
-
240
- clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
241
- sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
242
-
243
- clip_frames = clip_transform(clip_chunk)
244
- sync_frames = sync_transform(sync_chunk)
245
-
246
- video_info = ImageInfo(
247
- clip_frames=clip_frames,
248
- sync_frames=sync_frames,
249
- original_frame=frame,
250
- )
251
- return video_info
252
-
253
-
254
  def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
255
  reencode_with_audio(video_info, output_path, audio, sampling_rate)
 
3
  from pathlib import Path
4
  from typing import Optional
5
 
 
6
  import torch
7
  from colorlog import ColoredFormatter
 
8
  from torchvision.transforms import v2
9
 
10
+ from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
11
  from mmaudio.model.flow_matching import FlowMatching
12
  from mmaudio.model.networks import MMAudio
13
+ from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
14
  from mmaudio.model.utils.features_utils import FeaturesUtils
15
  from mmaudio.utils.download_utils import download_model_if_needed
16
 
 
88
  cfg_strength: float,
89
  clip_batch_size_multiplier: int = 40,
90
  sync_batch_size_multiplier: int = 40,
 
91
  ) -> torch.Tensor:
92
  device = feature_utils.device
93
  dtype = feature_utils.dtype
 
98
  clip_features = feature_utils.encode_video_with_clip(clip_video,
99
  batch_size=bs *
100
  clip_batch_size_multiplier)
 
 
101
  else:
102
  clip_features = net.get_empty_clip_sequence(bs)
103
 
104
+ if sync_video is not None:
105
  sync_video = sync_video.to(device, dtype, non_blocking=True)
106
  sync_features = feature_utils.encode_video_with_sync(sync_video,
107
  batch_size=bs *
 
139
  return audio
140
 
141
 
142
+ LOGFORMAT = " %(log_color)s%(levelname)-8s%(reset)s | %(log_color)s%(message)s%(reset)s"
143
 
144
 
145
  def setup_eval_logging(log_level: int = logging.INFO):
 
153
  log.addHandler(stream)
154
 
155
 
 
 
 
 
 
 
 
156
  def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
157
+ _CLIP_SIZE = 384
158
+ _CLIP_FPS = 8.0
159
+
160
+ _SYNC_SIZE = 224
161
+ _SYNC_FPS = 25.0
162
 
163
  clip_transform = v2.Compose([
164
  v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
 
213
  return video_info
214
 
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
217
  reencode_with_audio(video_info, output_path, audio, sampling_rate)
mmaudio/ext/autoencoder/autoencoder.py CHANGED
@@ -20,7 +20,7 @@ class AutoEncoderModule(nn.Module):
20
  super().__init__()
21
  self.vae: VAE = get_my_vae(mode).eval()
22
  vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
23
- self.vae.load_state_dict(vae_state_dict)
24
  self.vae.remove_weight_norm()
25
 
26
  if mode == '16k':
 
20
  super().__init__()
21
  self.vae: VAE = get_my_vae(mode).eval()
22
  vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
23
+ self.vae.load_state_dict(vae_state_dict, strict=False)
24
  self.vae.remove_weight_norm()
25
 
26
  if mode == '16k':
mmaudio/ext/autoencoder/vae.py CHANGED
@@ -75,9 +75,13 @@ class VAE(nn.Module):
75
  super().__init__()
76
 
77
  if data_dim == 80:
 
 
78
  self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
79
  self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32))
80
  elif data_dim == 128:
 
 
81
  self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
82
  self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32))
83
 
 
75
  super().__init__()
76
 
77
  if data_dim == 80:
78
+ # self.data_mean = torch.tensor(DATA_MEAN_80D, dtype=torch.float32).cuda()
79
+ # self.data_std = torch.tensor(DATA_STD_80D, dtype=torch.float32).cuda()
80
  self.register_buffer('data_mean', torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
81
  self.register_buffer('data_std', torch.tensor(DATA_STD_80D, dtype=torch.float32))
82
  elif data_dim == 128:
83
+ # torch.tensor(DATA_MEAN_128D, dtype=torch.float32).cuda()
84
+ # self.data_std = torch.tensor(DATA_STD_128D, dtype=torch.float32).cuda()
85
  self.register_buffer('data_mean', torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
86
  self.register_buffer('data_std', torch.tensor(DATA_STD_128D, dtype=torch.float32))
87
 
mmaudio/ext/mel_converter.py CHANGED
@@ -1,12 +1,11 @@
1
  # Reference: # https://github.com/bytedance/Make-An-Audio-2
2
- from typing import Literal
3
 
4
  import torch
5
  import torch.nn as nn
6
  from librosa.filters import mel as librosa_mel_fn
7
 
8
 
9
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
10
  return norm_fn(torch.clamp(x, min=clip_val) * C)
11
 
12
 
@@ -20,14 +19,14 @@ class MelConverter(nn.Module):
20
  def __init__(
21
  self,
22
  *,
23
- sampling_rate: float,
24
- n_fft: int,
25
- num_mels: int,
26
- hop_size: int,
27
- win_size: int,
28
- fmin: float,
29
- fmax: float,
30
- norm_fn,
31
  ):
32
  super().__init__()
33
  self.sampling_rate = sampling_rate
@@ -81,26 +80,3 @@ class MelConverter(nn.Module):
81
  spec = spectral_normalize_torch(spec, self.norm_fn)
82
 
83
  return spec
84
-
85
-
86
- def get_mel_converter(mode: Literal['16k', '44k']) -> MelConverter:
87
- if mode == '16k':
88
- return MelConverter(sampling_rate=16_000,
89
- n_fft=1024,
90
- num_mels=80,
91
- hop_size=256,
92
- win_size=1024,
93
- fmin=0,
94
- fmax=8_000,
95
- norm_fn=torch.log10)
96
- elif mode == '44k':
97
- return MelConverter(sampling_rate=44_100,
98
- n_fft=2048,
99
- num_mels=128,
100
- hop_size=512,
101
- win_size=2048,
102
- fmin=0,
103
- fmax=44100 / 2,
104
- norm_fn=torch.log)
105
- else:
106
- raise ValueError(f'Unknown mode: {mode}')
 
1
  # Reference: # https://github.com/bytedance/Make-An-Audio-2
 
2
 
3
  import torch
4
  import torch.nn as nn
5
  from librosa.filters import mel as librosa_mel_fn
6
 
7
 
8
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, norm_fn=torch.log10):
9
  return norm_fn(torch.clamp(x, min=clip_val) * C)
10
 
11
 
 
19
  def __init__(
20
  self,
21
  *,
22
+ sampling_rate: float = 16_000,
23
+ n_fft: int = 1024,
24
+ num_mels: int = 80,
25
+ hop_size: int = 256,
26
+ win_size: int = 1024,
27
+ fmin: float = 0,
28
+ fmax: float = 8_000,
29
+ norm_fn=torch.log10,
30
  ):
31
  super().__init__()
32
  self.sampling_rate = sampling_rate
 
80
  spec = spectral_normalize_torch(spec, self.norm_fn)
81
 
82
  return spec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/model/embeddings.py CHANGED
@@ -1,9 +1,9 @@
1
  import torch
2
  import torch.nn as nn
3
- import math
4
 
5
  # https://github.com/facebookresearch/DiT
6
 
 
7
  class TimestepEmbedder(nn.Module):
8
  """
9
  Embeds scalar timesteps into vector representations.
 
1
  import torch
2
  import torch.nn as nn
 
3
 
4
  # https://github.com/facebookresearch/DiT
5
 
6
+
7
  class TimestepEmbedder(nn.Module):
8
  """
9
  Embeds scalar timesteps into vector representations.
mmaudio/model/flow_matching.py CHANGED
@@ -1,9 +1,11 @@
1
  import logging
2
- from typing import Callable, Optional
3
 
4
  import torch
5
  from torchdiffeq import odeint
6
 
 
 
7
  log = logging.getLogger()
8
 
9
 
@@ -43,8 +45,12 @@ class FlowMatching:
43
  Cs: list[torch.Tensor],
44
  generator: Optional[torch.Generator] = None
45
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 
46
  x0 = torch.empty_like(x1).normal_(generator=generator)
47
 
 
 
 
48
  xt = self.get_conditional_flow(x0, x1, t)
49
  return x0, x1, xt, Cs
50
 
@@ -68,4 +74,15 @@ class FlowMatching:
68
  dt = next_t - t
69
  x = x + dt * flow
70
 
 
 
 
 
 
 
 
 
 
 
 
71
  return x
 
1
  import logging
2
+ from typing import Callable, Iterable, Optional
3
 
4
  import torch
5
  from torchdiffeq import odeint
6
 
7
+ # from torchcfm.conditional_flow_matching import ExactOptimalTransportConditionalFlowMatcher
8
+
9
  log = logging.getLogger()
10
 
11
 
 
45
  Cs: list[torch.Tensor],
46
  generator: Optional[torch.Generator] = None
47
  ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
48
+ # x0 = torch.randn_like(x1, generator=generator)
49
  x0 = torch.empty_like(x1).normal_(generator=generator)
50
 
51
+ # find mini-batch optimal transport
52
+ # x0, x1, _, Cs = self.fm.ot_sampler.sample_plan_with_labels(x0, x1, None, Cs, replace=True)
53
+
54
  xt = self.get_conditional_flow(x0, x1, t)
55
  return x0, x1, xt, Cs
56
 
 
74
  dt = next_t - t
75
  x = x + dt * flow
76
 
77
+ # return odeint(fn,
78
+ # x0,
79
+ # torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype),
80
+ # method='rk4',
81
+ # options=dict(step_size=(t1 - t0) / self.num_steps))[-1]
82
+ # return odeint(fn,
83
+ # x0,
84
+ # torch.tensor([t0, t1], device=x0.device, dtype=x0.dtype),
85
+ # method='euler',
86
+ # options=dict(step_size=(t1 - t0) / self.num_steps))[-1]
87
+
88
  return x
mmaudio/model/networks.py CHANGED
@@ -468,4 +468,4 @@ if __name__ == '__main__':
468
 
469
  # print the number of parameters in terms of millions
470
  num_params = sum(p.numel() for p in network.parameters()) / 1e6
471
- print(f'Number of parameters: {num_params:.2f}M')
 
468
 
469
  # print the number of parameters in terms of millions
470
  num_params = sum(p.numel() for p in network.parameters()) / 1e6
471
+ print(f'Number of parameters: {num_params:.2f}M')
mmaudio/model/transformer_layers.py CHANGED
@@ -5,6 +5,7 @@ import torch.nn as nn
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from einops.layers.torch import Rearrange
 
8
 
9
  from mmaudio.ext.rotary_embeddings import apply_rope
10
  from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP
 
5
  import torch.nn.functional as F
6
  from einops import rearrange
7
  from einops.layers.torch import Rearrange
8
+ from torch.nn.attention import SDPBackend, sdpa_kernel
9
 
10
  from mmaudio.ext.rotary_embeddings import apply_rope
11
  from mmaudio.model.low_level import MLP, ChannelLastConv1d, ConvMLP
mmaudio/model/utils/features_utils.py CHANGED
@@ -9,7 +9,7 @@ from open_clip import create_model_from_pretrained
9
  from torchvision.transforms import Normalize
10
 
11
  from mmaudio.ext.autoencoder import AutoEncoderModule
12
- from mmaudio.ext.mel_converter import get_mel_converter
13
  from mmaudio.ext.synchformer import Synchformer
14
  from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
15
 
@@ -63,13 +63,13 @@ class FeaturesUtils(nn.Module):
63
  self.tokenizer = None
64
 
65
  if tod_vae_ckpt is not None:
66
- self.mel_converter = get_mel_converter(mode)
67
  self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
68
  vocoder_ckpt_path=bigvgan_vocoder_ckpt,
69
  mode=mode,
70
  need_vae_encoder=need_vae_encoder)
71
  else:
72
  self.tod = None
 
73
 
74
  def compile(self):
75
  if self.clip_model is not None:
 
9
  from torchvision.transforms import Normalize
10
 
11
  from mmaudio.ext.autoencoder import AutoEncoderModule
12
+ from mmaudio.ext.mel_converter import MelConverter
13
  from mmaudio.ext.synchformer import Synchformer
14
  from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
15
 
 
63
  self.tokenizer = None
64
 
65
  if tod_vae_ckpt is not None:
 
66
  self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt,
67
  vocoder_ckpt_path=bigvgan_vocoder_ckpt,
68
  mode=mode,
69
  need_vae_encoder=need_vae_encoder)
70
  else:
71
  self.tod = None
72
+ self.mel_converter = MelConverter()
73
 
74
  def compile(self):
75
  if self.clip_model is not None:
mmaudio/runner.py DELETED
@@ -1,609 +0,0 @@
1
- """
2
- trainer.py - wrapper and utility functions for network training
3
- Compute loss, back-prop, update parameters, logging, etc.
4
- """
5
- import os
6
- from pathlib import Path
7
- from typing import Optional, Union
8
-
9
- import torch
10
- import torch.distributed
11
- import torch.optim as optim
12
- from av_bench.evaluate import evaluate
13
- from av_bench.extract import extract
14
- from nitrous_ema import PostHocEMA
15
- from omegaconf import DictConfig
16
- from torch.nn.parallel import DistributedDataParallel as DDP
17
-
18
- from mmaudio.model.flow_matching import FlowMatching
19
- from mmaudio.model.networks import get_my_mmaudio
20
- from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K
21
- from mmaudio.model.utils.features_utils import FeaturesUtils
22
- from mmaudio.model.utils.parameter_groups import get_parameter_groups
23
- from mmaudio.model.utils.sample_utils import log_normal_sample
24
- from mmaudio.utils.dist_utils import (info_if_rank_zero, local_rank, string_if_rank_zero)
25
- from mmaudio.utils.log_integrator import Integrator
26
- from mmaudio.utils.logger import TensorboardLogger
27
- from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator
28
- from mmaudio.utils.video_joiner import VideoJoiner
29
-
30
-
31
- class Runner:
32
-
33
- def __init__(self,
34
- cfg: DictConfig,
35
- log: TensorboardLogger,
36
- run_path: Union[str, Path],
37
- for_training: bool = True,
38
- latent_mean: Optional[torch.Tensor] = None,
39
- latent_std: Optional[torch.Tensor] = None):
40
- self.exp_id = cfg.exp_id
41
- self.use_amp = cfg.amp
42
- self.enable_grad_scaler = cfg.enable_grad_scaler
43
- self.for_training = for_training
44
- self.cfg = cfg
45
-
46
- if cfg.model.endswith('16k'):
47
- self.seq_cfg = CONFIG_16K
48
- mode = '16k'
49
- elif cfg.model.endswith('44k'):
50
- self.seq_cfg = CONFIG_44K
51
- mode = '44k'
52
- else:
53
- raise ValueError(f'Unknown model: {cfg.model}')
54
-
55
- self.sample_rate = self.seq_cfg.sampling_rate
56
- self.duration_sec = self.seq_cfg.duration
57
-
58
- # setting up the model
59
- empty_string_feat = torch.load('./ext_weights/empty_string.pth', weights_only=True)[0]
60
- self.network = DDP(get_my_mmaudio(cfg.model,
61
- latent_mean=latent_mean,
62
- latent_std=latent_std,
63
- empty_string_feat=empty_string_feat).cuda(),
64
- device_ids=[local_rank],
65
- broadcast_buffers=False)
66
- if cfg.compile:
67
- # NOTE: though train_fn and val_fn are very similar
68
- # (early on they are implemented as a single function)
69
- # keeping them separate and compiling them separately are CRUCIAL for high performance
70
- self.train_fn = torch.compile(self.train_fn)
71
- self.val_fn = torch.compile(self.val_fn)
72
-
73
- self.fm = FlowMatching(cfg.sampling.min_sigma,
74
- inference_mode=cfg.sampling.method,
75
- num_steps=cfg.sampling.num_steps)
76
-
77
- # ema profile
78
- if for_training and cfg.ema.enable and local_rank == 0:
79
- self.ema = PostHocEMA(self.network.module,
80
- sigma_rels=cfg.ema.sigma_rels,
81
- update_every=cfg.ema.update_every,
82
- checkpoint_every_num_steps=cfg.ema.checkpoint_every,
83
- checkpoint_folder=cfg.ema.checkpoint_folder,
84
- step_size_correction=True).cuda()
85
- self.ema_start = cfg.ema.start
86
- else:
87
- self.ema = None
88
-
89
- self.rng = torch.Generator(device='cuda')
90
- self.rng.manual_seed(cfg['seed'] + local_rank)
91
-
92
- # setting up feature extractors and VAEs
93
- if mode == '16k':
94
- self.features = FeaturesUtils(
95
- tod_vae_ckpt=cfg['vae_16k_ckpt'],
96
- bigvgan_vocoder_ckpt=cfg['bigvgan_vocoder_ckpt'],
97
- synchformer_ckpt=cfg['synchformer_ckpt'],
98
- enable_conditions=True,
99
- mode=mode,
100
- need_vae_encoder=False,
101
- )
102
- elif mode == '44k':
103
- self.features = FeaturesUtils(
104
- tod_vae_ckpt=cfg['vae_44k_ckpt'],
105
- synchformer_ckpt=cfg['synchformer_ckpt'],
106
- enable_conditions=True,
107
- mode=mode,
108
- need_vae_encoder=False,
109
- )
110
- self.features = self.features.cuda().eval()
111
-
112
- if cfg.compile:
113
- self.features.compile()
114
-
115
- # hyperparameters
116
- self.log_normal_sampling_mean = cfg.sampling.mean
117
- self.log_normal_sampling_scale = cfg.sampling.scale
118
- self.null_condition_probability = cfg.null_condition_probability
119
- self.cfg_strength = cfg.cfg_strength
120
-
121
- # setting up logging
122
- self.log = log
123
- self.run_path = Path(run_path)
124
- vgg_cfg = cfg.data.VGGSound
125
- if for_training:
126
- self.val_video_joiner = VideoJoiner(vgg_cfg.root, self.run_path / 'val-sampled-videos',
127
- self.sample_rate, self.duration_sec)
128
- else:
129
- self.test_video_joiner = VideoJoiner(vgg_cfg.root,
130
- self.run_path / 'test-sampled-videos',
131
- self.sample_rate, self.duration_sec)
132
- string_if_rank_zero(self.log, 'model_size',
133
- f'{sum([param.nelement() for param in self.network.parameters()])}')
134
- string_if_rank_zero(
135
- self.log, 'number_of_parameters_that_require_gradient: ',
136
- str(
137
- sum([
138
- param.nelement()
139
- for param in filter(lambda p: p.requires_grad, self.network.parameters())
140
- ])))
141
- info_if_rank_zero(self.log, 'torch version: ' + torch.__version__)
142
- self.train_integrator = Integrator(self.log, distributed=True)
143
- self.val_integrator = Integrator(self.log, distributed=True)
144
-
145
- # setting up optimizer and loss
146
- if for_training:
147
- self.enter_train()
148
- parameter_groups = get_parameter_groups(self.network, cfg, print_log=(local_rank == 0))
149
- self.optimizer = optim.AdamW(parameter_groups,
150
- lr=cfg['learning_rate'],
151
- weight_decay=cfg['weight_decay'],
152
- betas=[0.9, 0.95],
153
- eps=1e-6 if self.use_amp else 1e-8,
154
- fused=True)
155
- if self.enable_grad_scaler:
156
- self.scaler = torch.amp.GradScaler(init_scale=2048)
157
- self.clip_grad_norm = cfg['clip_grad_norm']
158
-
159
- # linearly warmup learning rate
160
- linear_warmup_steps = cfg['linear_warmup_steps']
161
-
162
- def warmup(currrent_step: int):
163
- return (currrent_step + 1) / (linear_warmup_steps + 1)
164
-
165
- warmup_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup)
166
-
167
- # setting up learning rate scheduler
168
- if cfg['lr_schedule'] == 'constant':
169
- next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1)
170
- elif cfg['lr_schedule'] == 'poly':
171
- total_num_iter = cfg['iterations']
172
- next_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
173
- lr_lambda=lambda x:
174
- (1 - (x / total_num_iter))**0.9)
175
- elif cfg['lr_schedule'] == 'step':
176
- next_scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
177
- cfg['lr_schedule_steps'],
178
- cfg['lr_schedule_gamma'])
179
- else:
180
- raise NotImplementedError
181
-
182
- self.scheduler = optim.lr_scheduler.SequentialLR(self.optimizer,
183
- [warmup_scheduler, next_scheduler],
184
- [linear_warmup_steps])
185
-
186
- # Logging info
187
- self.log_text_interval = cfg['log_text_interval']
188
- self.log_extra_interval = cfg['log_extra_interval']
189
- self.save_weights_interval = cfg['save_weights_interval']
190
- self.save_checkpoint_interval = cfg['save_checkpoint_interval']
191
- self.save_copy_iterations = cfg['save_copy_iterations']
192
- self.num_iterations = cfg['num_iterations']
193
- if cfg['debug']:
194
- self.log_text_interval = self.log_extra_interval = 1
195
-
196
- # update() is called when we log metrics, within the logger
197
- self.log.batch_timer = TimeEstimator(self.num_iterations, self.log_text_interval)
198
- # update() is called every iteration, in this script
199
- self.log.data_timer = PartialTimeEstimator(self.num_iterations, 1, ema_alpha=0.9)
200
- else:
201
- self.enter_val()
202
-
203
- def train_fn(
204
- self,
205
- clip_f: torch.Tensor,
206
- sync_f: torch.Tensor,
207
- text_f: torch.Tensor,
208
- a_mean: torch.Tensor,
209
- a_std: torch.Tensor,
210
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
211
- # sample
212
- a_randn = torch.empty_like(a_mean).normal_(generator=self.rng)
213
- x1 = a_mean + a_std * a_randn
214
- bs = x1.shape[0] # batch_size * seq_len * num_channels
215
-
216
- # normalize the latents
217
- x1 = self.network.module.normalize(x1)
218
-
219
- t = log_normal_sample(x1,
220
- generator=self.rng,
221
- m=self.log_normal_sampling_mean,
222
- s=self.log_normal_sampling_scale)
223
- x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1,
224
- t,
225
- Cs=[clip_f, sync_f, text_f],
226
- generator=self.rng)
227
-
228
- # classifier-free training
229
- samples = torch.rand(bs, device=x1.device, generator=self.rng)
230
- null_video = (samples < self.null_condition_probability)
231
- clip_f[null_video] = self.network.module.empty_clip_feat
232
- sync_f[null_video] = self.network.module.empty_sync_feat
233
-
234
- samples = torch.rand(bs, device=x1.device, generator=self.rng)
235
- null_text = (samples < self.null_condition_probability)
236
- text_f[null_text] = self.network.module.empty_string_feat
237
-
238
- pred_v = self.network(xt, clip_f, sync_f, text_f, t)
239
- loss = self.fm.loss(pred_v, x0, x1)
240
- mean_loss = loss.mean()
241
- return x1, loss, mean_loss, t
242
-
243
- def val_fn(
244
- self,
245
- clip_f: torch.Tensor,
246
- sync_f: torch.Tensor,
247
- text_f: torch.Tensor,
248
- x1: torch.Tensor,
249
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
250
- bs = x1.shape[0] # batch_size * seq_len * num_channels
251
- # normalize the latents
252
- x1 = self.network.module.normalize(x1)
253
- t = log_normal_sample(x1,
254
- generator=self.rng,
255
- m=self.log_normal_sampling_mean,
256
- s=self.log_normal_sampling_scale)
257
- x0, x1, xt, (clip_f, sync_f, text_f) = self.fm.get_x0_xt_c(x1,
258
- t,
259
- Cs=[clip_f, sync_f, text_f],
260
- generator=self.rng)
261
-
262
- # classifier-free training
263
- samples = torch.rand(bs, device=x1.device, generator=self.rng)
264
- # null mask is for when a video is provided but we decided to ignore it
265
- null_video = (samples < self.null_condition_probability)
266
- # complete mask is for when a video is not provided or we decided to ignore it
267
- clip_f[null_video] = self.network.module.empty_clip_feat
268
- sync_f[null_video] = self.network.module.empty_sync_feat
269
-
270
- samples = torch.rand(bs, device=x1.device, generator=self.rng)
271
- null_text = (samples < self.null_condition_probability)
272
- text_f[null_text] = self.network.module.empty_string_feat
273
-
274
- pred_v = self.network(xt, clip_f, sync_f, text_f, t)
275
-
276
- loss = self.fm.loss(pred_v, x0, x1)
277
- mean_loss = loss.mean()
278
- return loss, mean_loss, t
279
-
280
- def train_pass(self, data, it: int = 0):
281
-
282
- if not self.for_training:
283
- raise ValueError('train_pass() should not be called when not training.')
284
-
285
- self.enter_train()
286
- with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
287
- clip_f = data['clip_features'].cuda(non_blocking=True)
288
- sync_f = data['sync_features'].cuda(non_blocking=True)
289
- text_f = data['text_features'].cuda(non_blocking=True)
290
- video_exist = data['video_exist'].cuda(non_blocking=True)
291
- text_exist = data['text_exist'].cuda(non_blocking=True)
292
- a_mean = data['a_mean'].cuda(non_blocking=True)
293
- a_std = data['a_std'].cuda(non_blocking=True)
294
-
295
- # these masks are for non-existent data; masking for CFG training is in train_fn
296
- clip_f[~video_exist] = self.network.module.empty_clip_feat
297
- sync_f[~video_exist] = self.network.module.empty_sync_feat
298
- text_f[~text_exist] = self.network.module.empty_string_feat
299
-
300
- self.log.data_timer.end()
301
- if it % self.log_extra_interval == 0:
302
- unmasked_clip_f = clip_f.clone()
303
- unmasked_sync_f = sync_f.clone()
304
- unmasked_text_f = text_f.clone()
305
- x1, loss, mean_loss, t = self.train_fn(clip_f, sync_f, text_f, a_mean, a_std)
306
-
307
- self.train_integrator.add_dict({'loss': mean_loss})
308
-
309
- if it % self.log_text_interval == 0 and it != 0:
310
- self.train_integrator.add_scalar('lr', self.scheduler.get_last_lr()[0])
311
- self.train_integrator.add_binned_tensor('binned_loss', loss, t)
312
- self.train_integrator.finalize('train', it)
313
- self.train_integrator.reset_except_hooks()
314
-
315
- # Backward pass
316
- self.optimizer.zero_grad(set_to_none=True)
317
- if self.enable_grad_scaler:
318
- self.scaler.scale(mean_loss).backward()
319
- self.scaler.unscale_(self.optimizer)
320
- grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(),
321
- self.clip_grad_norm)
322
- self.scaler.step(self.optimizer)
323
- self.scaler.update()
324
- else:
325
- mean_loss.backward()
326
- grad_norm = torch.nn.utils.clip_grad_norm_(self.network.parameters(),
327
- self.clip_grad_norm)
328
- self.optimizer.step()
329
-
330
- if self.ema is not None and it >= self.ema_start:
331
- self.ema.update()
332
- self.scheduler.step()
333
- self.integrator.add_scalar('grad_norm', grad_norm)
334
-
335
- self.enter_val()
336
- with torch.amp.autocast('cuda', enabled=self.use_amp,
337
- dtype=torch.bfloat16), torch.inference_mode():
338
- try:
339
- if it % self.log_extra_interval == 0:
340
- # save GT audio
341
- # unnormalize the latents
342
- x1 = self.network.module.unnormalize(x1[0:1])
343
- mel = self.features.decode(x1)
344
- audio = self.features.vocode(mel).cpu()[0] # 1 * num_samples
345
- self.log.log_spectrogram('train', f'spec-gt-r{local_rank}', mel.cpu()[0], it)
346
- self.log.log_audio('train',
347
- f'audio-gt-r{local_rank}',
348
- audio,
349
- it,
350
- sample_rate=self.sample_rate)
351
-
352
- # save audio from sampling
353
- x0 = torch.empty_like(x1[0:1]).normal_(generator=self.rng)
354
- clip_f = unmasked_clip_f[0:1]
355
- sync_f = unmasked_sync_f[0:1]
356
- text_f = unmasked_text_f[0:1]
357
- conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f)
358
- empty_conditions = self.network.module.get_empty_conditions(x0.shape[0])
359
- cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper(
360
- t, x, conditions, empty_conditions, self.cfg_strength)
361
- x1_hat = self.fm.to_data(cfg_ode_wrapper, x0)
362
- x1_hat = self.network.module.unnormalize(x1_hat)
363
- mel = self.features.decode(x1_hat)
364
- audio = self.features.vocode(mel).cpu()[0]
365
- self.log.log_spectrogram('train', f'spec-r{local_rank}', mel.cpu()[0], it)
366
- self.log.log_audio('train',
367
- f'audio-r{local_rank}',
368
- audio,
369
- it,
370
- sample_rate=self.sample_rate)
371
- except Exception as e:
372
- self.log.warning(f'Error in extra logging: {e}')
373
- if self.cfg.debug:
374
- raise
375
-
376
- # Save network weights and checkpoint if needed
377
- save_copy = it in self.save_copy_iterations
378
-
379
- if (it % self.save_weights_interval == 0 and it != 0) or save_copy:
380
- self.save_weights(it)
381
-
382
- if it % self.save_checkpoint_interval == 0 and it != 0:
383
- self.save_checkpoint(it, save_copy=save_copy)
384
-
385
- self.log.data_timer.start()
386
-
387
- @torch.inference_mode()
388
- def validation_pass(self, data, it: int = 0):
389
- self.enter_val()
390
- with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
391
- clip_f = data['clip_features'].cuda(non_blocking=True)
392
- sync_f = data['sync_features'].cuda(non_blocking=True)
393
- text_f = data['text_features'].cuda(non_blocking=True)
394
- video_exist = data['video_exist'].cuda(non_blocking=True)
395
- text_exist = data['text_exist'].cuda(non_blocking=True)
396
- a_mean = data['a_mean'].cuda(non_blocking=True)
397
- a_std = data['a_std'].cuda(non_blocking=True)
398
-
399
- clip_f[~video_exist] = self.network.module.empty_clip_feat
400
- sync_f[~video_exist] = self.network.module.empty_sync_feat
401
- text_f[~text_exist] = self.network.module.empty_string_feat
402
- a_randn = torch.empty_like(a_mean).normal_(generator=self.rng)
403
- x1 = a_mean + a_std * a_randn
404
-
405
- self.log.data_timer.end()
406
- loss, mean_loss, t = self.val_fn(clip_f.clone(), sync_f.clone(), text_f.clone(), x1)
407
-
408
- self.val_integrator.add_binned_tensor('binned_loss', loss, t)
409
- self.val_integrator.add_dict({'loss': mean_loss})
410
-
411
- self.log.data_timer.start()
412
-
413
- @torch.inference_mode()
414
- def inference_pass(self,
415
- data,
416
- it: int,
417
- data_cfg: DictConfig,
418
- *,
419
- save_eval: bool = True) -> Path:
420
- self.enter_val()
421
- with torch.amp.autocast('cuda', enabled=self.use_amp, dtype=torch.bfloat16):
422
- clip_f = data['clip_features'].cuda(non_blocking=True)
423
- sync_f = data['sync_features'].cuda(non_blocking=True)
424
- text_f = data['text_features'].cuda(non_blocking=True)
425
- video_exist = data['video_exist'].cuda(non_blocking=True)
426
- text_exist = data['text_exist'].cuda(non_blocking=True)
427
- a_mean = data['a_mean'].cuda(non_blocking=True) # for the shape only
428
-
429
- clip_f[~video_exist] = self.network.module.empty_clip_feat
430
- sync_f[~video_exist] = self.network.module.empty_sync_feat
431
- text_f[~text_exist] = self.network.module.empty_string_feat
432
-
433
- # sample
434
- x0 = torch.empty_like(a_mean).normal_(generator=self.rng)
435
- conditions = self.network.module.preprocess_conditions(clip_f, sync_f, text_f)
436
- empty_conditions = self.network.module.get_empty_conditions(x0.shape[0])
437
- cfg_ode_wrapper = lambda t, x: self.network.module.ode_wrapper(
438
- t, x, conditions, empty_conditions, self.cfg_strength)
439
- x1_hat = self.fm.to_data(cfg_ode_wrapper, x0)
440
- x1_hat = self.network.module.unnormalize(x1_hat)
441
- mel = self.features.decode(x1_hat)
442
- audio = self.features.vocode(mel).cpu()
443
- for i in range(audio.shape[0]):
444
- video_id = data['id'][i]
445
- if (not self.for_training) and i == 0:
446
- # save very few videos
447
- self.test_video_joiner.join(video_id, f'{video_id}', audio[i].transpose(0, 1))
448
-
449
- if data_cfg.output_subdir is not None:
450
- # validation
451
- if save_eval:
452
- iter_naming = f'{it:09d}'
453
- else:
454
- iter_naming = 'val-cache'
455
- audio_dir = self.log.log_audio(iter_naming,
456
- f'{video_id}',
457
- audio[i],
458
- it=None,
459
- sample_rate=self.sample_rate,
460
- subdir=Path(data_cfg.output_subdir))
461
- if save_eval and i == 0:
462
- self.val_video_joiner.join(video_id, f'{iter_naming}-{video_id}',
463
- audio[i].transpose(0, 1))
464
- else:
465
- # full test set, usually
466
- audio_dir = self.log.log_audio(f'{data_cfg.tag}-sampled',
467
- f'{video_id}',
468
- audio[i],
469
- it=None,
470
- sample_rate=self.sample_rate)
471
-
472
- return Path(audio_dir)
473
-
474
- @torch.inference_mode()
475
- def eval(self, audio_dir: Path, it: int, data_cfg: DictConfig) -> dict[str, float]:
476
- with torch.amp.autocast('cuda', enabled=False):
477
- if local_rank == 0:
478
- extract(audio_path=audio_dir,
479
- output_path=audio_dir / 'cache',
480
- device='cuda',
481
- batch_size=32,
482
- audio_length=8)
483
- output_metrics = evaluate(gt_audio_cache=Path(data_cfg.gt_cache),
484
- pred_audio_cache=audio_dir / 'cache')
485
- for k, v in output_metrics.items():
486
- # pad k to 10 characters
487
- # pad v to 10 decimal places
488
- self.log.log_scalar(f'{data_cfg.tag}/{k}', v, it)
489
- self.log.info(f'{data_cfg.tag}/{k:<10}: {v:.10f}')
490
- else:
491
- output_metrics = None
492
-
493
- return output_metrics
494
-
495
- def save_weights(self, it, save_copy=False):
496
- if local_rank != 0:
497
- return
498
-
499
- os.makedirs(self.run_path, exist_ok=True)
500
- if save_copy:
501
- model_path = self.run_path / f'{self.exp_id}_{it}.pth'
502
- torch.save(self.network.module.state_dict(), model_path)
503
- self.log.info(f'Network weights saved to {model_path}.')
504
-
505
- # if last exists, move it to a shadow copy
506
- model_path = self.run_path / f'{self.exp_id}_last.pth'
507
- if model_path.exists():
508
- shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow'))
509
- model_path.replace(shadow_path)
510
- self.log.info(f'Network weights shadowed to {shadow_path}.')
511
-
512
- torch.save(self.network.module.state_dict(), model_path)
513
- self.log.info(f'Network weights saved to {model_path}.')
514
-
515
- def save_checkpoint(self, it, save_copy=False):
516
- if local_rank != 0:
517
- return
518
-
519
- checkpoint = {
520
- 'it': it,
521
- 'weights': self.network.module.state_dict(),
522
- 'optimizer': self.optimizer.state_dict(),
523
- 'scheduler': self.scheduler.state_dict(),
524
- 'ema': self.ema.state_dict() if self.ema is not None else None,
525
- }
526
-
527
- os.makedirs(self.run_path, exist_ok=True)
528
- if save_copy:
529
- model_path = self.run_path / f'{self.exp_id}_ckpt_{it}.pth'
530
- torch.save(checkpoint, model_path)
531
- self.log.info(f'Checkpoint saved to {model_path}.')
532
-
533
- # if ckpt_last exists, move it to a shadow copy
534
- model_path = self.run_path / f'{self.exp_id}_ckpt_last.pth'
535
- if model_path.exists():
536
- shadow_path = model_path.with_name(model_path.name.replace('last', 'shadow'))
537
- model_path.replace(shadow_path) # moves the file
538
- self.log.info(f'Checkpoint shadowed to {shadow_path}.')
539
-
540
- torch.save(checkpoint, model_path)
541
- self.log.info(f'Checkpoint saved to {model_path}.')
542
-
543
- def get_latest_checkpoint_path(self):
544
- ckpt_path = self.run_path / f'{self.exp_id}_ckpt_last.pth'
545
- if not ckpt_path.exists():
546
- info_if_rank_zero(self.log, f'No checkpoint found at {ckpt_path}.')
547
- return None
548
- return ckpt_path
549
-
550
- def get_latest_weight_path(self):
551
- weight_path = self.run_path / f'{self.exp_id}_last.pth'
552
- if not weight_path.exists():
553
- self.log.info(f'No weight found at {weight_path}.')
554
- return None
555
- return weight_path
556
-
557
- def get_final_ema_weight_path(self):
558
- weight_path = self.run_path / f'{self.exp_id}_ema_final.pth'
559
- if not weight_path.exists():
560
- self.log.info(f'No weight found at {weight_path}.')
561
- return None
562
- return weight_path
563
-
564
- def load_checkpoint(self, path):
565
- # This method loads everything and should be used to resume training
566
- map_location = 'cuda:%d' % local_rank
567
- checkpoint = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True)
568
-
569
- it = checkpoint['it']
570
- weights = checkpoint['weights']
571
- optimizer = checkpoint['optimizer']
572
- scheduler = checkpoint['scheduler']
573
- if self.ema is not None:
574
- self.ema.load_state_dict(checkpoint['ema'])
575
- self.log.info(f'EMA states loaded from step {self.ema.step}')
576
-
577
- map_location = 'cuda:%d' % local_rank
578
- self.network.module.load_state_dict(weights)
579
- self.optimizer.load_state_dict(optimizer)
580
- self.scheduler.load_state_dict(scheduler)
581
-
582
- self.log.info(f'Global iteration {it} loaded.')
583
- self.log.info('Network weights, optimizer states, and scheduler states loaded.')
584
-
585
- return it
586
-
587
- def load_weights_in_memory(self, src_dict):
588
- self.network.module.load_weights(src_dict)
589
- self.log.info('Network weights loaded from memory.')
590
-
591
- def load_weights(self, path):
592
- # This method loads only the network weight and should be used to load a pretrained model
593
- map_location = 'cuda:%d' % local_rank
594
- src_dict = torch.load(path, map_location={'cuda:0': map_location}, weights_only=True)
595
-
596
- self.log.info(f'Importing network weights from {path}...')
597
- self.load_weights_in_memory(src_dict)
598
-
599
- def weights(self):
600
- return self.network.module.state_dict()
601
-
602
- def enter_train(self):
603
- self.integrator = self.train_integrator
604
- self.network.train()
605
- return self
606
-
607
- def enter_val(self):
608
- self.network.eval()
609
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/sample.py DELETED
@@ -1,90 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import random
5
-
6
- import numpy as np
7
- import torch
8
- from hydra.core.hydra_config import HydraConfig
9
- from omegaconf import DictConfig, open_dict
10
- from tqdm import tqdm
11
-
12
- from mmaudio.data.data_setup import setup_test_datasets
13
- from mmaudio.runner import Runner
14
- from mmaudio.utils.dist_utils import info_if_rank_zero
15
- from mmaudio.utils.logger import TensorboardLogger
16
-
17
- local_rank = int(os.environ['LOCAL_RANK'])
18
- world_size = int(os.environ['WORLD_SIZE'])
19
-
20
-
21
- def sample(cfg: DictConfig):
22
- # initial setup
23
- num_gpus = world_size
24
- run_dir = HydraConfig.get().run.dir
25
-
26
- # wrap python logger with a tensorboard logger
27
- log = TensorboardLogger(cfg.exp_id,
28
- run_dir,
29
- logging.getLogger(),
30
- is_rank0=(local_rank == 0),
31
- enable_email=cfg.enable_email and not cfg.debug)
32
-
33
- info_if_rank_zero(log, f'All configuration: {cfg}')
34
- info_if_rank_zero(log, f'Number of GPUs detected: {num_gpus}')
35
-
36
- # cuda setup
37
- torch.cuda.set_device(local_rank)
38
- torch.backends.cudnn.benchmark = cfg.cudnn_benchmark
39
-
40
- # number of dataloader workers
41
- info_if_rank_zero(log, f'Number of dataloader workers (per GPU): {cfg.num_workers}')
42
-
43
- # Set seeds to ensure the same initialization
44
- torch.manual_seed(cfg.seed)
45
- np.random.seed(cfg.seed)
46
- random.seed(cfg.seed)
47
-
48
- # setting up configurations
49
- info_if_rank_zero(log, f'Configuration: {cfg}')
50
- info_if_rank_zero(log, f'Batch size (per GPU): {cfg.batch_size}')
51
-
52
- # construct the trainer
53
- runner = Runner(cfg, log=log, run_path=run_dir, for_training=False).enter_val()
54
-
55
- # load the last weights if needed
56
- if cfg['weights'] is not None:
57
- info_if_rank_zero(log, f'Loading weights from the disk: {cfg["weights"]}')
58
- runner.load_weights(cfg['weights'])
59
- cfg['weights'] = None
60
- else:
61
- weights = runner.get_final_ema_weight_path()
62
- if weights is not None:
63
- info_if_rank_zero(log, f'Automatically finding weight: {weights}')
64
- runner.load_weights(weights)
65
-
66
- # setup datasets
67
- dataset, sampler, loader = setup_test_datasets(cfg)
68
- data_cfg = cfg.data.ExtractedVGG_test
69
- with open_dict(data_cfg):
70
- if cfg.output_name is not None:
71
- # append to the tag
72
- data_cfg.tag = f'{data_cfg.tag}-{cfg.output_name}'
73
-
74
- # loop
75
- audio_path = None
76
- for curr_iter, data in enumerate(tqdm(loader)):
77
- new_audio_path = runner.inference_pass(data, curr_iter, data_cfg)
78
- if audio_path is None:
79
- audio_path = new_audio_path
80
- else:
81
- assert audio_path == new_audio_path, 'Different audio path detected'
82
-
83
- info_if_rank_zero(log, f'Inference completed. Audio path: {audio_path}')
84
- output_metrics = runner.eval(audio_path, curr_iter, data_cfg)
85
-
86
- if local_rank == 0:
87
- # write the output metrics to run_dir
88
- output_metrics_path = os.path.join(run_dir, f'{data_cfg.tag}-output_metrics.json')
89
- with open(output_metrics_path, 'w') as f:
90
- json.dump(output_metrics, f, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/email_utils.py DELETED
@@ -1,50 +0,0 @@
1
- import logging
2
- import os
3
- from datetime import datetime
4
-
5
- import requests
6
- from dotenv import load_dotenv
7
- from pytz import timezone
8
-
9
- from mmaudio.utils.timezone import my_timezone
10
-
11
- _source = 'USE YOURS'
12
- _target = 'USE YOURS'
13
-
14
- log = logging.getLogger()
15
-
16
- _fmt = "%Y-%m-%d %H:%M:%S %Z%z"
17
-
18
-
19
- class EmailSender:
20
-
21
- def __init__(self, exp_id: str, enable: bool):
22
- self.exp_id = exp_id
23
- self.enable = enable
24
- if enable:
25
- load_dotenv()
26
- self.MAILGUN_API_KEY = os.getenv('MAILGUN_API_KEY')
27
- if self.MAILGUN_API_KEY is None:
28
- log.warning('MAILGUN_API_KEY is not set')
29
- self.enable = False
30
-
31
- def send(self, subject, content):
32
- if self.enable:
33
- subject = str(subject)
34
- content = str(content)
35
- try:
36
- return requests.post(f'https://api.mailgun.net/v3/{_source}/messages',
37
- auth=('api', self.MAILGUN_API_KEY),
38
- data={
39
- 'from':
40
- f'<agent name>🤖 <mailgun@{_source}>',
41
- 'to': [f'{_target}'],
42
- 'subject':
43
- f'[{self.exp_id}] {subject}',
44
- 'text':
45
- ('\n\n' + content + '\n\n<your sign off>\n' +
46
- datetime.now(timezone(my_timezone)).strftime(_fmt)),
47
- },
48
- timeout=20)
49
- except Exception as e:
50
- log.error(f'Failed to send email: {e}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/log_integrator.py DELETED
@@ -1,112 +0,0 @@
1
- """
2
- Integrate numerical values for some iterations
3
- Typically used for loss computation / logging to tensorboard
4
- Call finalize and create a new Integrator when you want to display/log
5
- """
6
- from typing import Callable, Union
7
-
8
- import torch
9
-
10
- from mmaudio.utils.logger import TensorboardLogger
11
- from mmaudio.utils.tensor_utils import distribute_into_histogram
12
-
13
-
14
- class Integrator:
15
-
16
- def __init__(self, logger: TensorboardLogger, distributed: bool = True):
17
- self.values = {}
18
- self.counts = {}
19
- self.hooks = [] # List is used here to maintain insertion order
20
-
21
- # for binned tensors
22
- self.binned_tensors = {}
23
- self.binned_tensor_indices = {}
24
-
25
- self.logger = logger
26
-
27
- self.distributed = distributed
28
- self.local_rank = torch.distributed.get_rank()
29
- self.world_size = torch.distributed.get_world_size()
30
-
31
- def add_scalar(self, key: str, x: Union[torch.Tensor, int, float]):
32
- if isinstance(x, torch.Tensor):
33
- x = x.detach()
34
- if x.dtype in [torch.long, torch.int, torch.bool]:
35
- x = x.float()
36
-
37
- if key not in self.values:
38
- self.counts[key] = 1
39
- self.values[key] = x
40
- else:
41
- self.counts[key] += 1
42
- self.values[key] += x
43
-
44
- def add_dict(self, tensor_dict: dict[str, torch.Tensor]):
45
- for k, v in tensor_dict.items():
46
- self.add_scalar(k, v)
47
-
48
- def add_binned_tensor(self, key: str, x: torch.Tensor, indices: torch.Tensor):
49
- if key not in self.binned_tensors:
50
- self.binned_tensors[key] = [x.detach().flatten()]
51
- self.binned_tensor_indices[key] = [indices.detach().flatten()]
52
- else:
53
- self.binned_tensors[key].append(x.detach().flatten())
54
- self.binned_tensor_indices[key].append(indices.detach().flatten())
55
-
56
- def add_hook(self, hook: Callable[[torch.Tensor], tuple[str, torch.Tensor]]):
57
- """
58
- Adds a custom hook, i.e. compute new metrics using values in the dict
59
- The hook takes the dict as argument, and returns a (k, v) tuple
60
- e.g. for computing IoU
61
- """
62
- self.hooks.append(hook)
63
-
64
- def reset_except_hooks(self):
65
- self.values = {}
66
- self.counts = {}
67
-
68
- # Average and output the metrics
69
- def finalize(self, prefix: str, it: int, ignore_timer: bool = False) -> None:
70
-
71
- for hook in self.hooks:
72
- k, v = hook(self.values)
73
- self.add_scalar(k, v)
74
-
75
- # for the metrics
76
- outputs = {}
77
- for k, v in self.values.items():
78
- avg = v / self.counts[k]
79
- if self.distributed:
80
- # Inplace operation
81
- if isinstance(avg, torch.Tensor):
82
- avg = avg.cuda()
83
- else:
84
- avg = torch.tensor(avg).cuda()
85
- torch.distributed.reduce(avg, dst=0)
86
-
87
- if self.local_rank == 0:
88
- avg = (avg / self.world_size).cpu().item()
89
- outputs[k] = avg
90
- else:
91
- # Simple does it
92
- outputs[k] = avg
93
-
94
- if (not self.distributed) or (self.local_rank == 0):
95
- self.logger.log_metrics(prefix, outputs, it, ignore_timer=ignore_timer)
96
-
97
- # for the binned tensors
98
- for k, v in self.binned_tensors.items():
99
- x = torch.cat(v, dim=0)
100
- indices = torch.cat(self.binned_tensor_indices[k], dim=0)
101
- hist, count = distribute_into_histogram(x, indices)
102
-
103
- if self.distributed:
104
- torch.distributed.reduce(hist, dst=0)
105
- torch.distributed.reduce(count, dst=0)
106
- if self.local_rank == 0:
107
- hist = hist / count
108
- else:
109
- hist = hist / count
110
-
111
- if (not self.distributed) or (self.local_rank == 0):
112
- self.logger.log_histogram(f'{prefix}/{k}', hist, it)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/logger.py DELETED
@@ -1,231 +0,0 @@
1
- """
2
- Dumps things to tensorboard and console
3
- """
4
-
5
- import datetime
6
- import logging
7
- import math
8
- import os
9
- from collections import defaultdict
10
- from pathlib import Path
11
- from typing import Optional, Union
12
-
13
- import matplotlib.pyplot as plt
14
- import numpy as np
15
- import torch
16
- import torchaudio
17
- from PIL import Image
18
- from pytz import timezone
19
- from torch.utils.tensorboard import SummaryWriter
20
-
21
- from mmaudio.utils.email_utils import EmailSender
22
- from mmaudio.utils.time_estimator import PartialTimeEstimator, TimeEstimator
23
- from mmaudio.utils.timezone import my_timezone
24
-
25
-
26
- def tensor_to_numpy(image: torch.Tensor):
27
- image_np = (image.numpy() * 255).astype('uint8')
28
- return image_np
29
-
30
-
31
- def detach_to_cpu(x: torch.Tensor):
32
- return x.detach().cpu()
33
-
34
-
35
- def fix_width_trunc(x: float):
36
- return ('{:.9s}'.format('{:0.9f}'.format(x)))
37
-
38
-
39
- def plot_spectrogram(spectrogram: np.ndarray, title=None, ylabel="freq_bin", ax=None):
40
- if ax is None:
41
- _, ax = plt.subplots(1, 1)
42
- if title is not None:
43
- ax.set_title(title)
44
- ax.set_ylabel(ylabel)
45
- ax.imshow(spectrogram, origin="lower", aspect="auto", interpolation="nearest")
46
-
47
-
48
- class TensorboardLogger:
49
-
50
- def __init__(self,
51
- exp_id: str,
52
- run_dir: Union[Path, str],
53
- py_logger: logging.Logger,
54
- *,
55
- is_rank0: bool = False,
56
- enable_email: bool = False):
57
- self.exp_id = exp_id
58
- self.run_dir = Path(run_dir)
59
- self.py_log = py_logger
60
- self.email_sender = EmailSender(exp_id, enable=(is_rank0 and enable_email))
61
- if is_rank0:
62
- self.tb_log = SummaryWriter(run_dir)
63
- else:
64
- self.tb_log = None
65
-
66
- # Get current git info for logging
67
- try:
68
- import git
69
- repo = git.Repo(".")
70
- git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
71
- except (ImportError, RuntimeError, TypeError):
72
- print('Failed to fetch git info. Defaulting to None')
73
- git_info = 'None'
74
-
75
- self.log_string('git', git_info)
76
-
77
- # log the SLURM job id if available
78
- job_id = os.environ.get('SLURM_JOB_ID', None)
79
- if job_id is not None:
80
- self.log_string('slurm_job_id', job_id)
81
- self.email_sender.send(f'Job {job_id} started', f'Job started {run_dir}')
82
-
83
- # used when logging metrics
84
- self.batch_timer: TimeEstimator = None
85
- self.data_timer: PartialTimeEstimator = None
86
-
87
- self.nan_count = defaultdict(int)
88
-
89
- def log_scalar(self, tag: str, x: float, it: int):
90
- if self.tb_log is None:
91
- return
92
- if math.isnan(x) and 'grad_norm' not in tag:
93
- self.nan_count[tag] += 1
94
- if self.nan_count[tag] == 10:
95
- self.email_sender.send(
96
- f'Nan detected in {tag} @ {self.run_dir}',
97
- f'Nan detected in {tag} at iteration {it}; run_dir: {self.run_dir}')
98
- else:
99
- self.nan_count[tag] = 0
100
- self.tb_log.add_scalar(tag, x, it)
101
-
102
- def log_metrics(self,
103
- prefix: str,
104
- metrics: dict[str, float],
105
- it: int,
106
- ignore_timer: bool = False):
107
- msg = f'{self.exp_id}-{prefix} - it {it:6d}: '
108
- metrics_msg = ''
109
- for k, v in sorted(metrics.items()):
110
- self.log_scalar(f'{prefix}/{k}', v, it)
111
- metrics_msg += f'{k: >10}:{v:.7f},\t'
112
-
113
- if self.batch_timer is not None and not ignore_timer:
114
- self.batch_timer.update()
115
- avg_time = self.batch_timer.get_and_reset_avg_time()
116
- data_time = self.data_timer.get_and_reset_avg_time()
117
-
118
- # add time to tensorboard
119
- self.log_scalar(f'{prefix}/avg_time', avg_time, it)
120
- self.log_scalar(f'{prefix}/data_time', data_time, it)
121
-
122
- est = self.batch_timer.get_est_remaining(it)
123
- est = datetime.timedelta(seconds=est)
124
- if est.days > 0:
125
- remaining_str = f'{est.days}d {est.seconds // 3600}h'
126
- else:
127
- remaining_str = f'{est.seconds // 3600}h {(est.seconds%3600) // 60}m'
128
- eta = datetime.datetime.now(timezone(my_timezone)) + est
129
- eta_str = eta.strftime('%Y-%m-%d %H:%M:%S %Z%z')
130
- time_msg = f'avg_time:{avg_time:.3f},data:{data_time:.3f},remaining:{remaining_str},eta:{eta_str},\t'
131
- msg = f'{msg} {time_msg}'
132
-
133
- msg = f'{msg} {metrics_msg}'
134
- self.py_log.info(msg)
135
-
136
- def log_histogram(self, tag: str, hist: torch.Tensor, it: int):
137
- if self.tb_log is None:
138
- return
139
- # hist should be a 1D tensor
140
- hist = hist.cpu().numpy()
141
- fig, ax = plt.subplots()
142
- x_range = np.linspace(0, 1, len(hist))
143
- ax.bar(x_range, hist, width=1 / (len(hist) - 1))
144
- ax.set_xticks(x_range)
145
- ax.set_xticklabels(x_range)
146
- plt.tight_layout()
147
- self.tb_log.add_figure(tag, fig, it)
148
- plt.close()
149
-
150
- def log_image(self, prefix: str, tag: str, image: np.ndarray, it: int):
151
- image_dir = self.run_dir / f'{prefix}_images'
152
- image_dir.mkdir(exist_ok=True, parents=True)
153
-
154
- image = Image.fromarray(image)
155
- image.save(image_dir / f'{it:09d}_{tag}.png')
156
-
157
- def log_audio(self,
158
- prefix: str,
159
- tag: str,
160
- waveform: torch.Tensor,
161
- it: Optional[int] = None,
162
- *,
163
- subdir: Optional[Path] = None,
164
- sample_rate: int = 16000) -> Path:
165
- if subdir is None:
166
- audio_dir = self.run_dir / prefix
167
- else:
168
- audio_dir = self.run_dir / subdir / prefix
169
- audio_dir.mkdir(exist_ok=True, parents=True)
170
-
171
- if it is None:
172
- name = f'{tag}.flac'
173
- else:
174
- name = f'{it:09d}_{tag}.flac'
175
-
176
- torchaudio.save(audio_dir / name,
177
- waveform.cpu().float(),
178
- sample_rate=sample_rate,
179
- channels_first=True)
180
- return Path(audio_dir)
181
-
182
- def log_spectrogram(
183
- self,
184
- prefix: str,
185
- tag: str,
186
- spec: torch.Tensor,
187
- it: Optional[int],
188
- *,
189
- subdir: Optional[Path] = None,
190
- ):
191
- if subdir is None:
192
- spec_dir = self.run_dir / prefix
193
- else:
194
- spec_dir = self.run_dir / subdir / prefix
195
- spec_dir.mkdir(exist_ok=True, parents=True)
196
-
197
- if it is None:
198
- name = f'{tag}.png'
199
- else:
200
- name = f'{it:09d}_{tag}.png'
201
-
202
- plot_spectrogram(spec.cpu().float())
203
- plt.tight_layout()
204
- plt.savefig(spec_dir / name)
205
- plt.close()
206
-
207
- def log_string(self, tag: str, x: str):
208
- self.py_log.info(f'{tag} - {x}')
209
- if self.tb_log is None:
210
- return
211
- self.tb_log.add_text(tag, x)
212
-
213
- def debug(self, x):
214
- self.py_log.debug(x)
215
-
216
- def info(self, x):
217
- self.py_log.info(x)
218
-
219
- def warning(self, x):
220
- self.py_log.warning(x)
221
-
222
- def error(self, x):
223
- self.py_log.error(x)
224
-
225
- def critical(self, x):
226
- self.py_log.critical(x)
227
-
228
- self.email_sender.send(f'Error occurred in {self.run_dir}', x)
229
-
230
- def complete(self):
231
- self.email_sender.send(f'Job completed in {self.run_dir}', 'Job completed')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/synthesize_ema.py DELETED
@@ -1,19 +0,0 @@
1
- from typing import Optional
2
-
3
- from nitrous_ema import PostHocEMA
4
- from omegaconf import DictConfig
5
-
6
- from mmaudio.model.networks import get_my_mmaudio
7
-
8
-
9
- def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]):
10
- vae = get_my_mmaudio(cfg.model)
11
- emas = PostHocEMA(vae,
12
- sigma_rels=cfg.ema.sigma_rels,
13
- update_every=cfg.ema.update_every,
14
- checkpoint_every_num_steps=cfg.ema.checkpoint_every,
15
- checkpoint_folder=cfg.ema.checkpoint_folder)
16
-
17
- synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu')
18
- state_dict = synthesized_ema.ema_model.state_dict()
19
- return state_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/tensor_utils.py DELETED
@@ -1,14 +0,0 @@
1
- import torch
2
-
3
-
4
- def distribute_into_histogram(loss: torch.Tensor,
5
- t: torch.Tensor,
6
- num_bins: int = 25) -> tuple[torch.Tensor, torch.Tensor]:
7
- loss = loss.detach().flatten()
8
- t = t.detach().flatten()
9
- t = (t * num_bins).long()
10
- hist = torch.zeros(num_bins, device=loss.device)
11
- count = torch.zeros(num_bins, device=loss.device)
12
- hist.scatter_add_(0, t, loss)
13
- count.scatter_add_(0, t, torch.ones_like(loss))
14
- return hist, count
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/time_estimator.py DELETED
@@ -1,72 +0,0 @@
1
- import time
2
-
3
-
4
- class TimeEstimator:
5
-
6
- def __init__(self, total_iter: int, step_size: int, ema_alpha: float = 0.7):
7
- self.avg_time_window = [] # window-based average
8
- self.exp_avg_time = None # exponential moving average
9
- self.alpha = ema_alpha # for exponential moving average
10
-
11
- self.last_time = time.time() # would not be accurate for the first iteration but well
12
- self.total_iter = total_iter
13
- self.step_size = step_size
14
-
15
- self._buffering_exp = True
16
-
17
- # call this at a fixed interval
18
- # does not have to be every step
19
- def update(self):
20
- curr_time = time.time()
21
- time_per_iter = curr_time - self.last_time
22
- self.last_time = curr_time
23
-
24
- self.avg_time_window.append(time_per_iter)
25
-
26
- if self._buffering_exp:
27
- if self.exp_avg_time is not None:
28
- # discard the first iteration call to not pollute the ema
29
- self._buffering_exp = False
30
- self.exp_avg_time = time_per_iter
31
- else:
32
- self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
33
-
34
- def get_est_remaining(self, it: int):
35
- if self.exp_avg_time is None:
36
- return 0
37
-
38
- remaining_iter = self.total_iter - it
39
- return remaining_iter * self.exp_avg_time / self.step_size
40
-
41
- def get_and_reset_avg_time(self):
42
- avg = sum(self.avg_time_window) / len(self.avg_time_window) / self.step_size
43
- self.avg_time_window = []
44
- return avg
45
-
46
-
47
- class PartialTimeEstimator(TimeEstimator):
48
- """
49
- Used where the start_time and the end_time do not align
50
- """
51
-
52
- def update(self):
53
- raise RuntimeError('Please use start() and end() for PartialTimeEstimator')
54
-
55
- def start(self):
56
- self.last_time = time.time()
57
-
58
- def end(self):
59
- assert self.last_time is not None, 'Please call start() before calling end()'
60
- curr_time = time.time()
61
- time_per_iter = curr_time - self.last_time
62
- self.last_time = None
63
-
64
- self.avg_time_window.append(time_per_iter)
65
-
66
- if self._buffering_exp:
67
- if self.exp_avg_time is not None:
68
- # discard the first iteration call to not pollute the ema
69
- self._buffering_exp = False
70
- self.exp_avg_time = time_per_iter
71
- else:
72
- self.exp_avg_time = self.alpha * self.exp_avg_time + (1 - self.alpha) * time_per_iter
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
mmaudio/utils/timezone.py DELETED
@@ -1 +0,0 @@
1
- my_timezone = 'US/Central'