Phil Sobrepena commited on
Commit
73ed896
·
1 Parent(s): ddb444b

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +146 -0
  2. Dockerfile +39 -0
  3. LICENSE +21 -0
  4. README.md +56 -0
  5. app.py +128 -0
  6. batch_eval.py +110 -0
  7. config/__init__.py +0 -0
  8. config/base_config.yaml +62 -0
  9. config/data/base.yaml +70 -0
  10. config/eval_config.yaml +17 -0
  11. config/eval_data/base.yaml +22 -0
  12. config/hydra/job_logging/custom-eval.yaml +32 -0
  13. config/hydra/job_logging/custom-no-rank.yaml +32 -0
  14. config/hydra/job_logging/custom-simplest.yaml +26 -0
  15. config/hydra/job_logging/custom.yaml +33 -0
  16. config/train_config.yaml +41 -0
  17. demo.py +141 -0
  18. docs/EVAL.md +22 -0
  19. docs/MODELS.md +50 -0
  20. docs/TRAINING.md +160 -0
  21. docs/images/icon.png +0 -0
  22. docs/index.html +149 -0
  23. docs/style.css +78 -0
  24. docs/style_videos.css +52 -0
  25. docs/video_gen.html +254 -0
  26. docs/video_main.html +98 -0
  27. docs/video_vgg.html +452 -0
  28. gitattributes +35 -0
  29. gradio_demo.py +343 -0
  30. mmaudio/__init__.py +0 -0
  31. mmaudio/data/__init__.py +0 -0
  32. mmaudio/data/av_utils.py +162 -0
  33. mmaudio/data/data_setup.py +174 -0
  34. mmaudio/data/eval/__init__.py +0 -0
  35. mmaudio/data/eval/audiocaps.py +39 -0
  36. mmaudio/data/eval/moviegen.py +131 -0
  37. mmaudio/data/eval/video_dataset.py +197 -0
  38. mmaudio/data/extracted_audio.py +88 -0
  39. mmaudio/data/extracted_vgg.py +101 -0
  40. mmaudio/data/extraction/__init__.py +0 -0
  41. mmaudio/data/extraction/vgg_sound.py +193 -0
  42. mmaudio/data/extraction/wav_dataset.py +132 -0
  43. mmaudio/data/mm_dataset.py +45 -0
  44. mmaudio/data/utils.py +148 -0
  45. mmaudio/eval_utils.py +255 -0
  46. mmaudio/ext/__init__.py +1 -0
  47. mmaudio/ext/autoencoder/__init__.py +1 -0
  48. mmaudio/ext/autoencoder/autoencoder.py +52 -0
  49. mmaudio/ext/autoencoder/edm2_utils.py +168 -0
  50. mmaudio/ext/autoencoder/vae.py +369 -0
.gitignore ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ run_*.sh
2
+ log/
3
+ saves
4
+ saves/
5
+ weights/
6
+ weights
7
+ output/
8
+ output
9
+ pretrained/
10
+ workspace
11
+ workspace/
12
+ ext_weights/
13
+ ext_weights
14
+ .checkpoints/
15
+ .vscode/
16
+ training/example_output/
17
+
18
+ # Byte-compiled / optimized / DLL files
19
+ __pycache__/
20
+ *.py[cod]
21
+ *$py.class
22
+
23
+ # C extensions
24
+ *.so
25
+
26
+ # Distribution / packaging
27
+ .Python
28
+ build/
29
+ develop-eggs/
30
+ dist/
31
+ downloads/
32
+ eggs/
33
+ .eggs/
34
+ lib/
35
+ lib64/
36
+ parts/
37
+ sdist/
38
+ var/
39
+ wheels/
40
+ pip-wheel-metadata/
41
+ share/python-wheels/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+ MANIFEST
46
+
47
+ # PyInstaller
48
+ # Usually these files are written by a python script from a template
49
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
50
+ *.manifest
51
+ *.spec
52
+
53
+ # Installer logs
54
+ pip-log.txt
55
+ pip-delete-this-directory.txt
56
+
57
+ # Unit test / coverage reports
58
+ htmlcov/
59
+ .tox/
60
+ .nox/
61
+ .coverage
62
+ .coverage.*
63
+ .cache
64
+ nosetests.xml
65
+ coverage.xml
66
+ *.cover
67
+ *.py,cover
68
+ .hypothesis/
69
+ .pytest_cache/
70
+
71
+ # Translations
72
+ *.mo
73
+ *.pot
74
+
75
+ # Django stuff:
76
+ *.log
77
+ local_settings.py
78
+ db.sqlite3
79
+ db.sqlite3-journal
80
+
81
+ # Flask stuff:
82
+ instance/
83
+ .webassets-cache
84
+
85
+ # Scrapy stuff:
86
+ .scrapy
87
+
88
+ # Sphinx documentation
89
+ docs/_build/
90
+
91
+ # PyBuilder
92
+ target/
93
+
94
+ # Jupyter Notebook
95
+ .ipynb_checkpoints
96
+
97
+ # IPython
98
+ profile_default/
99
+ ipython_config.py
100
+
101
+ # pyenv
102
+ .python-version
103
+
104
+ # pipenv
105
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
107
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
108
+ # install all needed dependencies.
109
+ #Pipfile.lock
110
+
111
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112
+ __pypackages__/
113
+
114
+ # Celery stuff
115
+ celerybeat-schedule
116
+ celerybeat.pid
117
+
118
+ # SageMath parsed files
119
+ *.sage.py
120
+
121
+ # Environments
122
+ .env
123
+ .venv
124
+ env/
125
+ venv/
126
+ ENV/
127
+ env.bak/
128
+ venv.bak/
129
+
130
+ # Spyder project settings
131
+ .spyderproject
132
+ .spyproject
133
+
134
+ # Rope project settings
135
+ .ropeproject
136
+
137
+ # mkdocs documentation
138
+ /site
139
+
140
+ # mypy
141
+ .mypy_cache/
142
+ .dmypy.json
143
+ dmypy.json
144
+
145
+ # Pyre type checker
146
+ .pyre/
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
2
+
3
+ WORKDIR /code
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y \
7
+ python3.9 \
8
+ python3-pip \
9
+ git \
10
+ ffmpeg \
11
+ libsm6 \
12
+ libxext6 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Install Python dependencies
16
+ COPY requirements.txt .
17
+ RUN pip3 install --no-cache-dir -r requirements.txt
18
+
19
+ # Clone and install MMAudio
20
+ RUN git clone https://github.com/hkchengrex/MMAudio.git && \
21
+ cd MMAudio && \
22
+ pip3 install -e .
23
+
24
+ # Copy the application files
25
+ COPY app.py .
26
+
27
+ # Create output directory
28
+ RUN mkdir -p output/gradio && chmod 777 output/gradio
29
+
30
+ # Set environment variables for Hugging Face Spaces
31
+ ENV PYTHONUNBUFFERED=1
32
+ ENV GRADIO_SERVER_NAME=0.0.0.0
33
+ ENV GRADIO_SERVER_PORT=7860
34
+
35
+ # Expose the port
36
+ EXPOSE 7860
37
+
38
+ # Run the Gradio app
39
+ CMD ["python3", "app.py"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sony Research Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sonisphere
3
+ emoji: 🐢
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.20.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
+ # Sonisphere Demo
15
+
16
+ This is a Hugging Face Spaces demo for [MMAudio](https://hkchengrex.com/MMAudio/), a powerful model for generating realistic audio for videos.
17
+
18
+ ## 🎥 Features
19
+
20
+ - Upload any video and generate matching audio
21
+ - Control the generation with text prompts
22
+ - Adjust generation parameters like steps and guidance strength
23
+ - Process videos up to 30 seconds in length
24
+
25
+ ## 🚀 Usage
26
+
27
+ 1. Upload a video or use one of the example videos
28
+ 2. Enter a text prompt describing the desired audio
29
+ 3. (Optional) Add a negative prompt to specify what you don't want
30
+ 4. Adjust the generation parameters if needed
31
+ 5. Click "Submit" and wait for the generation to complete
32
+
33
+ ## ⚙️ Parameters
34
+
35
+ - **Prompt**: Describe the audio you want to generate
36
+ - **Negative prompt**: Specify what you don't want in the audio (default: "music")
37
+ - **Seed**: Control randomness (-1 for random seed)
38
+ - **Number of steps**: More steps = better quality but slower (default: 25)
39
+ - **Guidance Strength**: Controls how closely the generation follows the prompt (default: 4.5)
40
+ - **Duration**: Length of the generated audio in seconds (default: 8)
41
+
42
+ ## 📝 Notes
43
+
44
+ - Processing high-resolution videos (>384px on shorter side) takes longer and doesn't improve results
45
+ - The model works best with videos between 5-30 seconds
46
+ - Generation time depends on video length and number of steps
47
+
48
+ ## 🔗 Links
49
+
50
+ - [Project Page](https://hkchengrex.com/MMAudio/)
51
+ - [GitHub Repository](https://github.com/hkchengrex/MMAudio)
52
+ - [Paper](https://arxiv.org/abs/2401.09774)
53
+
54
+ ## 📜 License
55
+
56
+ This demo uses the MMAudio model which is released under the [MIT license](https://github.com/hkchengrex/MMAudio/blob/main/LICENSE).
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from datetime import datetime
4
+ from fractions import Fraction
5
+ from pathlib import Path
6
+
7
+ import gradio as gr
8
+ import torch
9
+ import torchaudio
10
+
11
+ from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
12
+ load_video, make_video, setup_eval_logging)
13
+ from mmaudio.model.flow_matching import FlowMatching
14
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
15
+ from mmaudio.model.sequence_config import SequenceConfig
16
+ from mmaudio.model.utils.features_utils import FeaturesUtils
17
+
18
+ # Setup logging
19
+ setup_eval_logging()
20
+ log = logging.getLogger()
21
+
22
+ # Configure device and dtype
23
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
24
+ if device == 'cpu':
25
+ log.warning('CUDA is not available, running on CPU')
26
+ dtype = torch.bfloat16
27
+
28
+ # Configure model and paths
29
+ model: ModelConfig = all_model_cfg['large_44k_v2']
30
+ model.download_if_needed()
31
+ output_dir = Path('./output/gradio')
32
+ output_dir.mkdir(exist_ok=True, parents=True)
33
+
34
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
35
+ seq_cfg = model.seq_cfg
36
+
37
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
38
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
39
+ log.info(f'Loaded weights from {model.model_path}')
40
+
41
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
42
+ synchformer_ckpt=model.synchformer_ckpt,
43
+ enable_conditions=True,
44
+ mode=model.mode,
45
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
46
+ need_vae_encoder=False)
47
+ feature_utils = feature_utils.to(device, dtype).eval()
48
+
49
+ return net, feature_utils, seq_cfg
50
+
51
+ # Load model once at startup
52
+ net, feature_utils, seq_cfg = get_model()
53
+
54
+ @torch.inference_mode()
55
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
56
+ cfg_strength: float, duration: float):
57
+ try:
58
+ rng = torch.Generator(device=device)
59
+ if seed >= 0:
60
+ rng.manual_seed(seed)
61
+ else:
62
+ rng.seed()
63
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
64
+
65
+ video_info = load_video(video, duration)
66
+ clip_frames = video_info.clip_frames.unsqueeze(0)
67
+ sync_frames = video_info.sync_frames.unsqueeze(0)
68
+ duration = video_info.duration_sec
69
+
70
+ seq_cfg.duration = duration
71
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
72
+
73
+ audios = generate(clip_frames, sync_frames, [prompt],
74
+ negative_text=[negative_prompt],
75
+ feature_utils=feature_utils,
76
+ net=net,
77
+ fm=fm,
78
+ rng=rng,
79
+ cfg_strength=cfg_strength)
80
+ audio = audios.float().cpu()[0]
81
+
82
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
83
+ video_save_path = output_dir / f'{current_time_string}.mp4'
84
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
85
+
86
+ gc.collect()
87
+ torch.cuda.empty_cache()
88
+
89
+ return video_save_path
90
+ except Exception as e:
91
+ log.error(f"Error in video_to_audio: {str(e)}")
92
+ raise gr.Error(f"An error occurred: {str(e)}")
93
+
94
+ # Create the Gradio interface
95
+ demo = gr.Interface(
96
+ fn=video_to_audio,
97
+ title="MMAudio — Video-to-Audio Synthesis",
98
+ description="""
99
+ Generate realistic audio for your videos using MMAudio!
100
+
101
+ Project page: [MMAudio](https://hkchengrex.com/MMAudio/)
102
+ Code: [GitHub](https://github.com/hkchengrex/MMAudio)
103
+
104
+ Note: Processing high-resolution videos (>384px on shorter side) takes longer and doesn't improve results.
105
+ """,
106
+ inputs=[
107
+ gr.Video(label="Upload Video"),
108
+ gr.Text(label="Prompt", placeholder="Describe the audio you want to generate..."),
109
+ gr.Text(label="Negative prompt", value="music", placeholder="What you don't want in the audio..."),
110
+ gr.Number(label="Seed (-1: random)", value=-1, precision=0, minimum=-1),
111
+ gr.Number(label="Number of steps", value=25, precision=0, minimum=1),
112
+ gr.Slider(label="Guidance Strength", value=4.5, minimum=1, maximum=10, step=0.5),
113
+ gr.Slider(label="Duration (seconds)", value=8, minimum=1, maximum=30, step=1),
114
+ ],
115
+ outputs=gr.Video(label="Generated Result"),
116
+ examples=[
117
+ ["https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4",
118
+ "waves, seagulls", "", 0, 25, 4.5, 10],
119
+ ["https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4",
120
+ "", "music", 0, 25, 4.5, 10],
121
+ ],
122
+ cache_examples=True,
123
+ )
124
+
125
+ # Launch the app
126
+ if __name__ == "__main__":
127
+ demo.launch(server_name="0.0.0.0", server_port=7860)
128
+
batch_eval.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import hydra
6
+ import torch
7
+ import torch.distributed as distributed
8
+ import torchaudio
9
+ from hydra.core.hydra_config import HydraConfig
10
+ from omegaconf import DictConfig
11
+ from tqdm import tqdm
12
+
13
+ from mmaudio.data.data_setup import setup_eval_dataset
14
+ from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
15
+ from mmaudio.model.flow_matching import FlowMatching
16
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
17
+ from mmaudio.model.utils.features_utils import FeaturesUtils
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+ local_rank = int(os.environ['LOCAL_RANK'])
23
+ world_size = int(os.environ['WORLD_SIZE'])
24
+ log = logging.getLogger()
25
+
26
+
27
+ @torch.inference_mode()
28
+ @hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
29
+ def main(cfg: DictConfig):
30
+ device = 'cuda'
31
+ torch.cuda.set_device(local_rank)
32
+
33
+ if cfg.model not in all_model_cfg:
34
+ raise ValueError(f'Unknown model variant: {cfg.model}')
35
+ model: ModelConfig = all_model_cfg[cfg.model]
36
+ model.download_if_needed()
37
+ seq_cfg = model.seq_cfg
38
+
39
+ run_dir = Path(HydraConfig.get().run.dir)
40
+ if cfg.output_name is None:
41
+ output_dir = run_dir / cfg.dataset
42
+ else:
43
+ output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
44
+ output_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+ # load a pretrained model
47
+ seq_cfg.duration = cfg.duration_s
48
+ net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
49
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
50
+ log.info(f'Loaded weights from {model.model_path}')
51
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
52
+ log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
53
+ log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
54
+ log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
55
+
56
+ # misc setup
57
+ rng = torch.Generator(device=device)
58
+ rng.manual_seed(cfg.seed)
59
+ fm = FlowMatching(cfg.sampling.min_sigma,
60
+ inference_mode=cfg.sampling.method,
61
+ num_steps=cfg.sampling.num_steps)
62
+
63
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
64
+ synchformer_ckpt=model.synchformer_ckpt,
65
+ enable_conditions=True,
66
+ mode=model.mode,
67
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
68
+ need_vae_encoder=False)
69
+ feature_utils = feature_utils.to(device).eval()
70
+
71
+ if cfg.compile:
72
+ net.preprocess_conditions = torch.compile(net.preprocess_conditions)
73
+ net.predict_flow = torch.compile(net.predict_flow)
74
+ feature_utils.compile()
75
+
76
+ dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
77
+
78
+ with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
79
+ for batch in tqdm(loader):
80
+ audios = generate(batch.get('clip_video', None),
81
+ batch.get('sync_video', None),
82
+ batch.get('caption', None),
83
+ feature_utils=feature_utils,
84
+ net=net,
85
+ fm=fm,
86
+ rng=rng,
87
+ cfg_strength=cfg.cfg_strength,
88
+ clip_batch_size_multiplier=64,
89
+ sync_batch_size_multiplier=64)
90
+ audios = audios.float().cpu()
91
+ names = batch['name']
92
+ for audio, name in zip(audios, names):
93
+ torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
94
+
95
+
96
+ def distributed_setup():
97
+ distributed.init_process_group(backend="nccl")
98
+ local_rank = distributed.get_rank()
99
+ world_size = distributed.get_world_size()
100
+ log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
101
+ return local_rank, world_size
102
+
103
+
104
+ if __name__ == '__main__':
105
+ distributed_setup()
106
+
107
+ main()
108
+
109
+ # clean-up
110
+ distributed.destroy_process_group()
config/__init__.py ADDED
File without changes
config/base_config.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: base
3
+ - eval_data: base
4
+ - override hydra/job_logging: custom-simplest
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ enable_email: False
13
+
14
+ model: small_16k
15
+
16
+ exp_id: default
17
+ debug: False
18
+ cudnn_benchmark: True
19
+ compile: True
20
+ amp: True
21
+ weights: null
22
+ checkpoint: null
23
+ seed: 14159265
24
+ num_workers: 10 # per-GPU
25
+ pin_memory: False # set to True if your system can handle it, i.e., have enough memory
26
+
27
+ # NOTE: This DOSE NOT affect the model during inference in any way
28
+ # they are just for the dataloader to fill in the missing data in multi-modal loading
29
+ # to change the sequence length for the model, see networks.py
30
+ data_dim:
31
+ text_seq_len: 77
32
+ clip_dim: 1024
33
+ sync_dim: 768
34
+ text_dim: 1024
35
+
36
+ # ema configuration
37
+ ema:
38
+ enable: True
39
+ sigma_rels: [0.05, 0.1]
40
+ update_every: 1
41
+ checkpoint_every: 5_000
42
+ checkpoint_folder: ${hydra:run.dir}/ema_ckpts
43
+ default_output_sigma: 0.05
44
+
45
+
46
+ # sampling
47
+ sampling:
48
+ mean: 0.0
49
+ scale: 1.0
50
+ min_sigma: 0.0
51
+ method: euler
52
+ num_steps: 25
53
+
54
+ # classifier-free guidance
55
+ null_condition_probability: 0.1
56
+ cfg_strength: 4.5
57
+
58
+ # checkpoint paths to external modules
59
+ vae_16k_ckpt: ./ext_weights/v1-16.pth
60
+ vae_44k_ckpt: ./ext_weights/v1-44.pth
61
+ bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
62
+ synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
config/data/base.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VGGSound:
2
+ root: ../data/video
3
+ subset_name: sets/vgg3-train.tsv
4
+ fps: 8
5
+ height: 384
6
+ width: 384
7
+ sample_duration_sec: 8.0
8
+
9
+ VGGSound_test:
10
+ root: ../data/video
11
+ subset_name: sets/vgg3-test.tsv
12
+ fps: 8
13
+ height: 384
14
+ width: 384
15
+ sample_duration_sec: 8.0
16
+
17
+ VGGSound_val:
18
+ root: ../data/video
19
+ subset_name: sets/vgg3-val.tsv
20
+ fps: 8
21
+ height: 384
22
+ width: 384
23
+ sample_duration_sec: 8.0
24
+
25
+ ExtractedVGG:
26
+ tsv: ../data/v1-16-memmap/vgg-train.tsv
27
+ memmap_dir: ../data/v1-16-memmap/vgg-train
28
+
29
+ ExtractedVGG_test:
30
+ tag: test
31
+ gt_cache: ../data/eval-cache/vggsound-test
32
+ output_subdir: null
33
+ tsv: ../data/v1-16-memmap/vgg-test.tsv
34
+ memmap_dir: ../data/v1-16-memmap/vgg-test
35
+
36
+ ExtractedVGG_val:
37
+ tag: val
38
+ gt_cache: ../data/eval-cache/vggsound-val
39
+ output_subdir: val
40
+ tsv: ../data/v1-16-memmap/vgg-val.tsv
41
+ memmap_dir: ../data/v1-16-memmap/vgg-val
42
+
43
+ AudioCaps:
44
+ tsv: ../data/v1-16-memmap/audiocaps.tsv
45
+ memmap_dir: ../data/v1-16-memmap/audiocaps
46
+
47
+ AudioSetSL:
48
+ tsv: ../data/v1-16-memmap/audioset_sl.tsv
49
+ memmap_dir: ../data/v1-16-memmap/audioset_sl
50
+
51
+ BBCSound:
52
+ tsv: ../data/v1-16-memmap/bbcsound.tsv
53
+ memmap_dir: ../data/v1-16-memmap/bbcsound
54
+
55
+ FreeSound:
56
+ tsv: ../data/v1-16-memmap/freesound.tsv
57
+ memmap_dir: ../data/v1-16-memmap/freesound
58
+
59
+ Clotho:
60
+ tsv: ../data/v1-16-memmap/clotho.tsv
61
+ memmap_dir: ../data/v1-16-memmap/clotho
62
+
63
+ Example_video:
64
+ tsv: ./training/example_output/memmap/vgg-example.tsv
65
+ memmap_dir: ./training/example_output/memmap/vgg-example
66
+
67
+ Example_audio:
68
+ tsv: ./training/example_output/memmap/audio-example.tsv
69
+ memmap_dir: ./training/example_output/memmap/audio-example
70
+
config/eval_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override hydra/job_logging: custom-simplest
4
+ - _self_
5
+
6
+ hydra:
7
+ run:
8
+ dir: ./output/${exp_id}
9
+ output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ exp_id: ${model}
12
+ dataset: audiocaps
13
+ duration_s: 8.0
14
+
15
+ # for inference, this is the per-GPU batch size
16
+ batch_size: 16
17
+ output_name: null
config/eval_data/base.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ AudioCaps:
2
+ audio_path: ../data/AudioCaps-test-audioldm-ver
3
+ # a csv file, with a header row of 'name' and 'caption'
4
+ # name should match the audio file name without extension
5
+ # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
6
+ csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
7
+
8
+ AudioCaps_full:
9
+ audio_path: ../data/AudioCaps-test-full-ver
10
+ # a csv file, with a header row of 'name' and 'caption'
11
+ # name should match the audio file name without extension
12
+ # Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
13
+ csv_path: ../data/AudioCaps-test-full-ver/data.csv
14
+
15
+ MovieGen:
16
+ video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
17
+ jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
18
+
19
+ VGGSound:
20
+ video_path: ../data/test-videos
21
+ # from the officially released csv file
22
+ csv_path: ../data/vggsound.csv
config/hydra/job_logging/custom-eval.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
config/hydra/job_logging/custom-simplest.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ root:
23
+ level: INFO
24
+ handlers: [console]
25
+
26
+ disable_existing_loggers: false
config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package hydra.job_logging
2
+ # python logging configuration for tasks
3
+ version: 1
4
+ formatters:
5
+ simple:
6
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
7
+ datefmt: '%Y-%m-%d %H:%M:%S'
8
+ colorlog:
9
+ '()': 'colorlog.ColoredFormatter'
10
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
11
+ datefmt: '%Y-%m-%d %H:%M:%S'
12
+ log_colors:
13
+ DEBUG: purple
14
+ INFO: green
15
+ WARNING: yellow
16
+ ERROR: red
17
+ CRITICAL: red
18
+ handlers:
19
+ console:
20
+ class: logging.StreamHandler
21
+ formatter: colorlog
22
+ stream: ext://sys.stdout
23
+ file:
24
+ class: logging.FileHandler
25
+ formatter: simple
26
+ # absolute file path
27
+ filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
28
+ mode: w
29
+ root:
30
+ level: INFO
31
+ handlers: [console, file]
32
+
33
+ disable_existing_loggers: false
config/train_config.yaml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override data: base
4
+ - override hydra/job_logging: custom
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./output/${exp_id}
10
+ output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ ema:
13
+ start: 0
14
+
15
+ mini_train: False
16
+ example_train: False
17
+ enable_grad_scaler: False
18
+ vgg_oversample_rate: 5
19
+
20
+ log_text_interval: 200
21
+ log_extra_interval: 20_000
22
+ val_interval: 5_000
23
+ eval_interval: 20_000
24
+ save_eval_interval: 40_000
25
+ save_weights_interval: 10_000
26
+ save_checkpoint_interval: 10_000
27
+ save_copy_iterations: []
28
+
29
+ batch_size: 512
30
+ eval_batch_size: 256 # per-GPU
31
+
32
+ num_iterations: 300_000
33
+ learning_rate: 1.0e-4
34
+ linear_warmup_steps: 1_000
35
+
36
+ lr_schedule: step
37
+ lr_schedule_steps: [240_000, 270_000]
38
+ lr_schedule_gamma: 0.1
39
+
40
+ clip_grad_norm: 1.0
41
+ weight_decay: 1.0e-6
demo.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torchaudio
7
+
8
+ from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
9
+ setup_eval_logging)
10
+ from mmaudio.model.flow_matching import FlowMatching
11
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
12
+ from mmaudio.model.utils.features_utils import FeaturesUtils
13
+
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ log = logging.getLogger()
18
+
19
+
20
+ @torch.inference_mode()
21
+ def main():
22
+ setup_eval_logging()
23
+
24
+ parser = ArgumentParser()
25
+ parser.add_argument('--variant',
26
+ type=str,
27
+ default='large_44k_v2',
28
+ help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
29
+ parser.add_argument('--video', type=Path, help='Path to the video file')
30
+ parser.add_argument('--prompt', type=str, help='Input prompt', default='')
31
+ parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
32
+ parser.add_argument('--duration', type=float, default=8.0)
33
+ parser.add_argument('--cfg_strength', type=float, default=4.5)
34
+ parser.add_argument('--num_steps', type=int, default=25)
35
+
36
+ parser.add_argument('--mask_away_clip', action='store_true')
37
+
38
+ parser.add_argument('--output', type=Path, help='Output directory', default='./output')
39
+ parser.add_argument('--seed', type=int, help='Random seed', default=42)
40
+ parser.add_argument('--skip_video_composite', action='store_true')
41
+ parser.add_argument('--full_precision', action='store_true')
42
+
43
+ args = parser.parse_args()
44
+
45
+ if args.variant not in all_model_cfg:
46
+ raise ValueError(f'Unknown model variant: {args.variant}')
47
+ model: ModelConfig = all_model_cfg[args.variant]
48
+ model.download_if_needed()
49
+ seq_cfg = model.seq_cfg
50
+
51
+ if args.video:
52
+ video_path: Path = Path(args.video).expanduser()
53
+ else:
54
+ video_path = None
55
+ prompt: str = args.prompt
56
+ negative_prompt: str = args.negative_prompt
57
+ output_dir: str = args.output.expanduser()
58
+ seed: int = args.seed
59
+ num_steps: int = args.num_steps
60
+ duration: float = args.duration
61
+ cfg_strength: float = args.cfg_strength
62
+ skip_video_composite: bool = args.skip_video_composite
63
+ mask_away_clip: bool = args.mask_away_clip
64
+
65
+ device = 'cpu'
66
+ if torch.cuda.is_available():
67
+ device = 'cuda'
68
+ elif torch.backends.mps.is_available():
69
+ device = 'mps'
70
+ else:
71
+ log.warning('CUDA/MPS are not available, running on CPU')
72
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
73
+
74
+ output_dir.mkdir(parents=True, exist_ok=True)
75
+
76
+ # load a pretrained model
77
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
78
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
79
+ log.info(f'Loaded weights from {model.model_path}')
80
+
81
+ # misc setup
82
+ rng = torch.Generator(device=device)
83
+ rng.manual_seed(seed)
84
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
85
+
86
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
87
+ synchformer_ckpt=model.synchformer_ckpt,
88
+ enable_conditions=True,
89
+ mode=model.mode,
90
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
91
+ need_vae_encoder=False)
92
+ feature_utils = feature_utils.to(device, dtype).eval()
93
+
94
+ if video_path is not None:
95
+ log.info(f'Using video {video_path}')
96
+ video_info = load_video(video_path, duration)
97
+ clip_frames = video_info.clip_frames
98
+ sync_frames = video_info.sync_frames
99
+ duration = video_info.duration_sec
100
+ if mask_away_clip:
101
+ clip_frames = None
102
+ else:
103
+ clip_frames = clip_frames.unsqueeze(0)
104
+ sync_frames = sync_frames.unsqueeze(0)
105
+ else:
106
+ log.info('No video provided -- text-to-audio mode')
107
+ clip_frames = sync_frames = None
108
+
109
+ seq_cfg.duration = duration
110
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
111
+
112
+ log.info(f'Prompt: {prompt}')
113
+ log.info(f'Negative prompt: {negative_prompt}')
114
+
115
+ audios = generate(clip_frames,
116
+ sync_frames, [prompt],
117
+ negative_text=[negative_prompt],
118
+ feature_utils=feature_utils,
119
+ net=net,
120
+ fm=fm,
121
+ rng=rng,
122
+ cfg_strength=cfg_strength)
123
+ audio = audios.float().cpu()[0]
124
+ if video_path is not None:
125
+ save_path = output_dir / f'{video_path.stem}.flac'
126
+ else:
127
+ safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
128
+ save_path = output_dir / f'{safe_filename}.flac'
129
+ torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
130
+
131
+ log.info(f'Audio saved to {save_path}')
132
+ if video_path is not None and not skip_video_composite:
133
+ video_save_path = output_dir / f'{video_path.stem}.mp4'
134
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
135
+ log.info(f'Video saved to {output_dir / video_save_path}')
136
+
137
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
138
+
139
+
140
+ if __name__ == '__main__':
141
+ main()
docs/EVAL.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation
2
+
3
+ ## Batch Evaluation
4
+
5
+ To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
6
+
7
+ An example of running this script with four GPUs is as follows:
8
+
9
+ ```bash
10
+ OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
11
+ ```
12
+
13
+ You may need to update the data paths in `config/eval_data/base.yaml`.
14
+ More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
15
+
16
+ ## Precomputed Results
17
+
18
+ Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
19
+
20
+ ## Obtaining Quantitative Metrics
21
+
22
+ Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
docs/MODELS.md ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretrained models
2
+
3
+ The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
4
+ The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
5
+
6
+ | Model | Download link | File size |
7
+ | -------- | ------- | ------- |
8
+ | Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
9
+ | Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
10
+ | Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
11
+ | Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
12
+ | Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
13
+ | 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
14
+ | 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
15
+ | 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
16
+ | Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
17
+
18
+ To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
19
+ The 44.1kHz vocoder will be downloaded automatically.
20
+ The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
21
+
22
+ The expected directory structure (full):
23
+
24
+ ```bash
25
+ MMAudio
26
+ ├── ext_weights
27
+ │ ├── best_netG.pt
28
+ │ ├── synchformer_state_dict.pth
29
+ │ ├── v1-16.pth
30
+ │ └── v1-44.pth
31
+ ├── weights
32
+ │ ├── mmaudio_small_16k.pth
33
+ │ ├── mmaudio_small_44k.pth
34
+ │ ├── mmaudio_medium_44k.pth
35
+ │ ├── mmaudio_large_44k.pth
36
+ │ └── mmaudio_large_44k_v2.pth
37
+ └── ...
38
+ ```
39
+
40
+ The expected directory structure (minimal, for the recommended model only):
41
+
42
+ ```bash
43
+ MMAudio
44
+ ├── ext_weights
45
+ │ ├── synchformer_state_dict.pth
46
+ │ └── v1-44.pth
47
+ ├── weights
48
+ │ └── mmaudio_large_44k_v2.pth
49
+ └── ...
50
+ ```
docs/TRAINING.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training
2
+
3
+ ## Overview
4
+
5
+ We have put a large emphasis on making training as fast as possible.
6
+ Consequently, some pre-processing steps are required.
7
+
8
+ Namely, before starting any training, we
9
+
10
+ 1. Obtain training data as videos, audios, and captions.
11
+ 2. Encode training audios into spectrograms and then with VAE into mean/std
12
+ 3. Extract CLIP and synchronization features from videos
13
+ 4. Extract CLIP features from text (captions)
14
+ 5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
15
+
16
+ **NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
17
+
18
+ The current training script does not support `_v2` training.
19
+
20
+ ## Recommended Hardware Configuration
21
+
22
+ These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
23
+
24
+ - Single-node machine. We did not implement multi-node training
25
+ - GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
26
+ - System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
27
+ - Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
28
+
29
+ ## Prerequisites
30
+
31
+ 1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
32
+ 2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
33
+
34
+ 3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
35
+
36
+ ```bash
37
+ conda install -c conda-forge 'ffmpeg<7'
38
+ ```
39
+
40
+ 4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://arxiv.org/abs/2303.17395). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
41
+
42
+ 5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
43
+
44
+ ## Preparing Audio-Video-Text Features
45
+
46
+ We have prepared some example data in `training/example_videos`.
47
+ `training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
48
+
49
+ To run this script, use the `torchrun` utility:
50
+
51
+ ```bash
52
+ torchrun --standalone training/extract_video_training_latents.py
53
+ ```
54
+
55
+ You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
56
+ Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
57
+ Change the data path definitions in `data_cfg` if necessary.
58
+
59
+ Arguments:
60
+
61
+ - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
62
+ - `output_dir` -- where TensorDict and the metadata file are saved.
63
+
64
+ Outputs produced in `output_dir`:
65
+
66
+ 1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
67
+ a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
68
+ b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
69
+ c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
70
+ d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
71
+ e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
72
+ f. `meta.json` that contains the metadata for the above memory mappings
73
+ 2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
74
+
75
+ ## Preparing Audio-Text Features
76
+
77
+ We have prepared some example data in `training/example_audios`.
78
+
79
+ 1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
80
+ 2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
81
+
82
+ ### Partitioning the audio files
83
+
84
+ Run
85
+
86
+ ```bash
87
+ python training/partition_clips.py
88
+ ```
89
+
90
+ Arguments:
91
+
92
+ - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
93
+ - `output_dir` -- path to the output `.csv` file
94
+ - `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
95
+ - `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
96
+
97
+ ### Extracting audio and text features
98
+
99
+ Run
100
+
101
+ ```bash
102
+ torchrun --standalone training/extract_audio_training_latents.py
103
+ ```
104
+
105
+ You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
106
+ Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
107
+
108
+ Arguments:
109
+
110
+ - `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
111
+ - `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
112
+ - `clips_tsv` -- path to the clips file, generated in the last step
113
+ - `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
114
+ - `output_dir` -- where TensorDict and the metadata file are saved.
115
+
116
+ **Reference tsv files (with overlaps removed as mentioned in the paper) can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).**
117
+ Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
118
+ Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
119
+ The Clotho data contains both the development set and the validation set.
120
+
121
+ Outputs produced in `output_dir`:
122
+
123
+ 1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
124
+ a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
125
+ b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
126
+ c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
127
+ f. `meta.json` that contains the metadata for the above memory mappings
128
+ 2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
129
+
130
+ ## Training on Extracted Features
131
+
132
+ We use Distributed Data Parallel (DDP) for training.
133
+ First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
134
+
135
+ To run training on the example data, use the following command:
136
+
137
+ ```bash
138
+ OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
139
+ ```
140
+
141
+ This will not train a useful model, but it will check if everything is set up correctly.
142
+
143
+ For full training on the base model with two GPUs, use the following command:
144
+
145
+ ```bash
146
+ OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
147
+ ```
148
+
149
+ Any outputs from training will be stored in `output/<exp_id>`.
150
+
151
+ More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
152
+ For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
153
+
154
+ ## Checkpoints
155
+
156
+ Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
157
+
158
+ ---
159
+
160
+ Godspeed!
docs/images/icon.png ADDED
docs/index.html ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link rel="preconnect" href="https://fonts.googleapis.com">
14
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
15
+ <link href="https://fonts.googleapis.com/css2?family=Source+Sans+3&display=swap" rel="stylesheet">
16
+ <meta charset="UTF-8">
17
+ <title>MMAudio</title>
18
+
19
+ <link rel="icon" type="image/png" href="images/icon.png">
20
+
21
+ <meta name="viewport" content="width=device-width, initial-scale=1">
22
+ <!-- CSS only -->
23
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
24
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
25
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
26
+
27
+ <link rel="stylesheet" href="style.css">
28
+ </head>
29
+ <body>
30
+
31
+ <body>
32
+ <br><br><br><br>
33
+ <div class="container">
34
+ <div class="row text-center" style="font-size:38px">
35
+ <div class="col strong">
36
+ Taming Multimodal Joint Training for High-Quality <br>Video-to-Audio Synthesis
37
+ </div>
38
+ </div>
39
+
40
+ <br>
41
+ <div class="row text-center" style="font-size:28px">
42
+ <div class="col">
43
+ CVPR 2025
44
+ </div>
45
+ </div>
46
+ <br>
47
+
48
+ <div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
49
+ <div class="col-sm-auto px-lg-2">
50
+ <a href="https://hkchengrex.github.io/">Ho Kei Cheng<sup>1</sup></a>
51
+ </div>
52
+ <div class="col-sm-auto px-lg-2">
53
+ <nobr><a href="https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ">Masato Ishii<sup>2</sup></a></nobr>
54
+ </div>
55
+ <div class="col-sm-auto px-lg-2">
56
+ <nobr><a href="https://scholar.google.com/citations?user=sXAjHFIAAAAJ">Akio Hayakawa<sup>2</sup></a></nobr>
57
+ </div>
58
+ <div class="col-sm-auto px-lg-2">
59
+ <nobr><a href="https://scholar.google.com/citations?user=XCRO260AAAAJ">Takashi Shibuya<sup>2</sup></a></nobr>
60
+ </div>
61
+ <div class="col-sm-auto px-lg-2">
62
+ <nobr><a href="https://www.alexander-schwing.de/">Alexander Schwing<sup>1</sup></a></nobr>
63
+ </div>
64
+ <div class="col-sm-auto px-lg-2" >
65
+ <nobr><a href="https://www.yukimitsufuji.com/">Yuki Mitsufuji<sup>2,3</sup></a></nobr>
66
+ </div>
67
+ </div>
68
+
69
+ <div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
70
+ <div class="col-sm-auto px-lg-2">
71
+ <sup>1</sup>University of Illinois Urbana-Champaign
72
+ </div>
73
+ <div class="col-sm-auto px-lg-2">
74
+ <sup>2</sup>Sony AI
75
+ </div>
76
+ <div class="col-sm-auto px-lg-2">
77
+ <sup>3</sup>Sony Group Corporation
78
+ </div>
79
+ </div>
80
+
81
+ <br>
82
+
83
+ <br>
84
+
85
+ <div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
86
+ <div class="col-sm-2">
87
+ <a href="https://arxiv.org/abs/2412.15322">[Paper]</a>
88
+ </div>
89
+ <div class="col-sm-2">
90
+ <a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
91
+ </div>
92
+ <div class="col-sm-3">
93
+ <a href="https://huggingface.co/spaces/hkchengrex/MMAudio">[Huggingface Demo]</a>
94
+ </div>
95
+ <div class="col-sm-2">
96
+ <a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
97
+ </div>
98
+ <div class="col-sm-3">
99
+ <a href="https://replicate.com/zsxkib/mmaudio">[Replicate Demo]</a>
100
+ </div>
101
+ </div>
102
+
103
+ <br>
104
+
105
+ <hr>
106
+
107
+ <div class="row" style="font-size:32px">
108
+ <div class="col strong">
109
+ TL;DR
110
+ </div>
111
+ </div>
112
+ <br>
113
+ <div class="row">
114
+ <div class="col">
115
+ <p class="light" style="text-align: left;">
116
+ MMAudio generates synchronized audio given video and/or text inputs.
117
+ </p>
118
+ </div>
119
+ </div>
120
+
121
+ <br>
122
+ <hr>
123
+ <br>
124
+
125
+ <div class="row" style="font-size:32px">
126
+ <div class="col strong">
127
+ Demo
128
+ </div>
129
+ </div>
130
+ <br>
131
+ <div class="row" style="font-size:48px">
132
+ <div class="col strong text-center">
133
+ <a href="video_main.html" style="text-decoration: underline;">&lt;More results&gt;</a>
134
+ </div>
135
+ </div>
136
+ <br>
137
+ <div class="video-container" style="text-align: center;">
138
+ <iframe src="https://youtube.com/embed/YElewUT2M4M"></iframe>
139
+ </div>
140
+
141
+ <br>
142
+
143
+ <br><br>
144
+ <br><br>
145
+
146
+ </div>
147
+
148
+ </body>
149
+ </html>
docs/style.css ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Source Sans 3', sans-serif;
3
+ font-size: 18px;
4
+ margin-left: auto;
5
+ margin-right: auto;
6
+ font-weight: 400;
7
+ height: 100%;
8
+ max-width: 1000px;
9
+ }
10
+
11
+ table {
12
+ width: 100%;
13
+ border-collapse: collapse;
14
+ }
15
+ th, td {
16
+ border: 1px solid #ddd;
17
+ padding: 8px;
18
+ text-align: center;
19
+ }
20
+ th {
21
+ background-color: #f2f2f2;
22
+ }
23
+ video {
24
+ width: 100%;
25
+ height: auto;
26
+ }
27
+ p {
28
+ font-size: 28px;
29
+ }
30
+ h2 {
31
+ font-size: 36px;
32
+ }
33
+
34
+ .strong {
35
+ font-weight: 700;
36
+ }
37
+
38
+ .light {
39
+ font-weight: 100;
40
+ }
41
+
42
+ .heavy {
43
+ font-weight: 900;
44
+ }
45
+
46
+ .column {
47
+ float: left;
48
+ }
49
+
50
+ a:link,
51
+ a:visited {
52
+ color: #05538f;
53
+ text-decoration: none;
54
+ }
55
+
56
+ a:hover {
57
+ color: #63cbdd;
58
+ }
59
+
60
+ hr {
61
+ border: 0;
62
+ height: 1px;
63
+ background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
64
+ }
65
+
66
+ .video-container {
67
+ position: relative;
68
+ padding-bottom: 56.25%; /* 16:9 */
69
+ height: 0;
70
+ }
71
+
72
+ .video-container iframe {
73
+ position: absolute;
74
+ top: 0;
75
+ left: 0;
76
+ width: 100%;
77
+ height: 100%;
78
+ }
docs/style_videos.css ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: 'Source Sans 3', sans-serif;
3
+ font-size: 1.5vh;
4
+ font-weight: 400;
5
+ }
6
+
7
+ table {
8
+ width: 100%;
9
+ border-collapse: collapse;
10
+ }
11
+ th, td {
12
+ border: 1px solid #ddd;
13
+ padding: 8px;
14
+ text-align: center;
15
+ }
16
+ th {
17
+ background-color: #f2f2f2;
18
+ }
19
+ video {
20
+ width: 100%;
21
+ height: auto;
22
+ }
23
+ p {
24
+ font-size: 1.5vh;
25
+ font-weight: bold;
26
+ }
27
+ h2 {
28
+ font-size: 2vh;
29
+ font-weight: bold;
30
+ }
31
+
32
+ .video-container {
33
+ position: relative;
34
+ padding-bottom: 56.25%; /* 16:9 */
35
+ height: 0;
36
+ }
37
+
38
+ .video-container iframe {
39
+ position: absolute;
40
+ top: 0;
41
+ left: 0;
42
+ width: 100%;
43
+ height: 100%;
44
+ }
45
+
46
+ .video-header {
47
+ background-color: #f2f2f2;
48
+ text-align: center;
49
+ font-size: 1.5vh;
50
+ font-weight: bold;
51
+ padding: 8px;
52
+ }
docs/video_gen.html ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <link rel="icon" type="image/png" href="images/icon.png">
18
+
19
+ <meta name="viewport" content="width=device-width, initial-scale=1">
20
+ <!-- CSS only -->
21
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
22
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
23
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
24
+
25
+ <link rel="stylesheet" href="style_videos.css">
26
+ </head>
27
+ <body>
28
+
29
+ <div id="moviegen_all">
30
+ <h2 id="moviegen" style="text-align: center;">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</h2>
31
+ <p id="moviegen1" style="overflow: hidden;">
32
+ Example 1: Ice cracking with sharp snapping sound, and metal tool scraping against the ice surface.
33
+ <span style="float: right;"><a href="#index">Back to index</a></span>
34
+ </p>
35
+
36
+ <div class="row g-1">
37
+ <div class="col-sm-6">
38
+ <div class="video-header">Movie Gen Audio</div>
39
+ <div class="video-container">
40
+ <iframe src="https://youtube.com/embed/d7Lb0ihtGcE"></iframe>
41
+ </div>
42
+ </div>
43
+ <div class="col-sm-6">
44
+ <div class="video-header">Ours</div>
45
+ <div class="video-container">
46
+ <iframe src="https://youtube.com/embed/F4JoJ2r2m8U"></iframe>
47
+ </div>
48
+ </div>
49
+ </div>
50
+ <br>
51
+
52
+ <!-- <p id="moviegen2">Example 2: Rhythmic splashing and lapping of water. <span style="float:right;"><a href="#index">Back to index</a></span> </p>
53
+
54
+ <table>
55
+ <thead>
56
+ <tr>
57
+ <th>Movie Gen Audio</th>
58
+ <th>Ours</th>
59
+ </tr>
60
+ </thead>
61
+ <tbody>
62
+ <tr>
63
+ <td width="50%">
64
+ <div class="video-container">
65
+ <iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
66
+ </div>
67
+ </td>
68
+ <td width="50%">
69
+ <div class="video-container">
70
+ <iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
71
+ </div>
72
+ </td>
73
+ </tr>
74
+ </tbody>
75
+ </table> -->
76
+
77
+ <p id="moviegen2" style="overflow: hidden;">
78
+ Example 2: Rhythmic splashing and lapping of water.
79
+ <span style="float:right;"><a href="#index">Back to index</a></span>
80
+ </p>
81
+ <div class="row g-1">
82
+ <div class="col-sm-6">
83
+ <div class="video-header">Movie Gen Audio</div>
84
+ <div class="video-container">
85
+ <iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
86
+ </div>
87
+ </div>
88
+ <div class="col-sm-6">
89
+ <div class="video-header">Ours</div>
90
+ <div class="video-container">
91
+ <iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
92
+ </div>
93
+ </div>
94
+ </div>
95
+ <br>
96
+
97
+ <p id="moviegen3" style="overflow: hidden;">
98
+ Example 3: Shovel scrapes against dry earth.
99
+ <span style="float:right;"><a href="#index">Back to index</a></span>
100
+ </p>
101
+ <div class="row g-1">
102
+ <div class="col-sm-6">
103
+ <div class="video-header">Movie Gen Audio</div>
104
+ <div class="video-container">
105
+ <iframe src="https://youtube.com/embed/PUKGyEve7XQ"></iframe>
106
+ </div>
107
+ </div>
108
+ <div class="col-sm-6">
109
+ <div class="video-header">Ours</div>
110
+ <div class="video-container">
111
+ <iframe src="https://youtube.com/embed/CNn7i8VNkdc"></iframe>
112
+ </div>
113
+ </div>
114
+ </div>
115
+ <br>
116
+
117
+
118
+ <p id="moviegen4" style="overflow: hidden;">
119
+ (Failure case) Example 4: Creamy sound of mashed potatoes being scooped.
120
+ <span style="float:right;"><a href="#index">Back to index</a></span>
121
+ </p>
122
+ <div class="row g-1">
123
+ <div class="col-sm-6">
124
+ <div class="video-header">Movie Gen Audio</div>
125
+ <div class="video-container">
126
+ <iframe src="https://youtube.com/embed/PJv1zxR9JjQ"></iframe>
127
+ </div>
128
+ </div>
129
+ <div class="col-sm-6">
130
+ <div class="video-header">Ours</div>
131
+ <div class="video-container">
132
+ <iframe src="https://youtube.com/embed/c3-LJ1lNsPQ"></iframe>
133
+ </div>
134
+ </div>
135
+ </div>
136
+ <br>
137
+
138
+ </div>
139
+
140
+ <div id="hunyuan_sora_all">
141
+
142
+ <h2 id="hunyuan" style="text-align: center;">Results on Videos Generated by Hunyuan</h2>
143
+ <p style="overflow: hidden;">
144
+ <span style="float:right;"><a href="#index">Back to index</a></span>
145
+ </p>
146
+ <div class="row g-1">
147
+ <div class="col-sm-6">
148
+ <div class="video-header">Typing</div>
149
+ <div class="video-container">
150
+ <iframe src="https://youtube.com/embed/8ln_9hhH_nk"></iframe>
151
+ </div>
152
+ </div>
153
+ <div class="col-sm-6">
154
+ <div class="video-header">Water is rushing down a stream and pouring</div>
155
+ <div class="video-container">
156
+ <iframe src="https://youtube.com/embed/5df1FZFQj30"></iframe>
157
+ </div>
158
+ </div>
159
+ </div>
160
+ <div class="row g-1">
161
+ <div class="col-sm-6">
162
+ <div class="video-header">Waves on beach</div>
163
+ <div class="video-container">
164
+ <iframe src="https://youtube.com/embed/7wQ9D5WgpFc"></iframe>
165
+ </div>
166
+ </div>
167
+ <div class="col-sm-6">
168
+ <div class="video-header">Water droplet</div>
169
+ <div class="video-container">
170
+ <iframe src="https://youtube.com/embed/q7M2nsalGjM"></iframe>
171
+ </div>
172
+ </div>
173
+ </div>
174
+ <br>
175
+
176
+ <h2 id="sora" style="text-align: center;">Results on Videos Generated by Sora</h2>
177
+ <p style="overflow: hidden;">
178
+ <span style="float:right;"><a href="#index">Back to index</a></span>
179
+ </p>
180
+ <div class="row g-1">
181
+ <div class="col-sm-6">
182
+ <div class="video-header">Ships riding waves</div>
183
+ <div class="video-container">
184
+ <iframe src="https://youtube.com/embed/JbgQzHHytk8"></iframe>
185
+ </div>
186
+ </div>
187
+ <div class="col-sm-6">
188
+ <div class="video-header">Train (no text prompt given)</div>
189
+ <div class="video-container">
190
+ <iframe src="https://youtube.com/embed/xOW7zrjpWC8"></iframe>
191
+ </div>
192
+ </div>
193
+ </div>
194
+ <div class="row g-1">
195
+ <div class="col-sm-6">
196
+ <div class="video-header">Seashore (no text prompt given)</div>
197
+ <div class="video-container">
198
+ <iframe src="https://youtube.com/embed/fIuw5Y8ZZ9E"></iframe>
199
+ </div>
200
+ </div>
201
+ <div class="col-sm-6">
202
+ <div class="video-header">Surfing (failure: unprompted music)</div>
203
+ <div class="video-container">
204
+ <iframe src="https://youtube.com/embed/UcSTk-v0M_s"></iframe>
205
+ </div>
206
+ </div>
207
+ </div>
208
+ <br>
209
+
210
+ <div id="mochi_ltx_all">
211
+ <h2 id="mochi" style="text-align: center;">Results on Videos Generated by Mochi 1</h2>
212
+ <p style="overflow: hidden;">
213
+ <span style="float:right;"><a href="#index">Back to index</a></span>
214
+ </p>
215
+ <div class="row g-1">
216
+ <div class="col-sm-6">
217
+ <div class="video-header">Magical fire and lightning (no text prompt given)</div>
218
+ <div class="video-container">
219
+ <iframe src="https://youtube.com/embed/tTlRZaSMNwY"></iframe>
220
+ </div>
221
+ </div>
222
+ <div class="col-sm-6">
223
+ <div class="video-header">Storm (no text prompt given)</div>
224
+ <div class="video-container">
225
+ <iframe src="https://youtube.com/embed/4hrZTMJUy3w"></iframe>
226
+ </div>
227
+ </div>
228
+ </div>
229
+ <br>
230
+
231
+ <h2 id="ltx" style="text-align: center;">Results on Videos Generated by LTX-Video</h2>
232
+ <p style="overflow: hidden;">
233
+ <span style="float:right;"><a href="#index">Back to index</a></span>
234
+ </p>
235
+ <div class="row g-1">
236
+ <div class="col-sm-6">
237
+ <div class="video-header">Firewood burning and cracking</div>
238
+ <div class="video-container">
239
+ <iframe src="https://youtube.com/embed/P7_DDpgev0g"></iframe>
240
+ </div>
241
+ </div>
242
+ <div class="col-sm-6">
243
+ <div class="video-header">Waterfall, water splashing</div>
244
+ <div class="video-container">
245
+ <iframe src="https://youtube.com/embed/4MvjceYnIO0"></iframe>
246
+ </div>
247
+ </div>
248
+ </div>
249
+ <br>
250
+
251
+ </div>
252
+
253
+ </body>
254
+ </html>
docs/video_main.html ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <link rel="icon" type="image/png" href="images/icon.png">
18
+
19
+ <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
20
+ <!-- CSS only -->
21
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
22
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
23
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
24
+
25
+ <link rel="stylesheet" href="style_videos.css">
26
+
27
+ <script type="text/javascript">
28
+ $(document).ready(function(){
29
+ $("#content").load("video_gen.html #moviegen_all");
30
+ $("#load_moveigen").click(function(){
31
+ $("#content").load("video_gen.html #moviegen_all");
32
+ });
33
+ $("#load_hunyuan_sora").click(function(){
34
+ $("#content").load("video_gen.html #hunyuan_sora_all");
35
+ });
36
+ $("#load_mochi_ltx").click(function(){
37
+ $("#content").load("video_gen.html #mochi_ltx_all");
38
+ });
39
+ $("#load_vgg1").click(function(){
40
+ $("#content").load("video_vgg.html #vgg1");
41
+ });
42
+ $("#load_vgg2").click(function(){
43
+ $("#content").load("video_vgg.html #vgg2");
44
+ });
45
+ $("#load_vgg3").click(function(){
46
+ $("#content").load("video_vgg.html #vgg3");
47
+ });
48
+ $("#load_vgg4").click(function(){
49
+ $("#content").load("video_vgg.html #vgg4");
50
+ });
51
+ $("#load_vgg5").click(function(){
52
+ $("#content").load("video_vgg.html #vgg5");
53
+ });
54
+ $("#load_vgg6").click(function(){
55
+ $("#content").load("video_vgg.html #vgg6");
56
+ });
57
+ $("#load_vgg_extra").click(function(){
58
+ $("#content").load("video_vgg.html #vgg_extra");
59
+ });
60
+ });
61
+ </script>
62
+ </head>
63
+ <body>
64
+ <h1 id="index" style="text-align: center;">Index</h1>
65
+ <p><b>(Click on the links to load the corresponding videos)</b> <span style="float:right;"><a href="index.html">Back to project page</a></span></p>
66
+
67
+ <ol>
68
+ <li>
69
+ <a href="#" id="load_moveigen">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</a>
70
+ </li>
71
+ <li>
72
+ <a href="#" id="load_hunyuan_sora">Results on Videos Generated by Hunyuan and Sora</a>
73
+ </li>
74
+ <li>
75
+ <a href="#" id="load_mochi_ltx">Results on Videos Generated by Mochi 1 and LTX-Video</a>
76
+ </li>
77
+ <li>
78
+ On VGGSound
79
+ <ol>
80
+ <li><a id='load_vgg1' href="#">Example 1: Wolf howling</a></li>
81
+ <li><a id='load_vgg2' href="#">Example 2: Striking a golf ball</a></li>
82
+ <li><a id='load_vgg3' href="#">Example 3: Hitting a drum</a></li>
83
+ <li><a id='load_vgg4' href="#">Example 4: Dog barking</a></li>
84
+ <li><a id='load_vgg5' href="#">Example 5: Playing a string instrument</a></li>
85
+ <li><a id='load_vgg6' href="#">Example 6: A group of people playing tambourines</a></li>
86
+ <li><a id='load_vgg_extra' href="#">Extra results & failure cases</a></li>
87
+ </ol>
88
+ </li>
89
+ </ol>
90
+
91
+ <div id="content" class="container-fluid">
92
+
93
+ </div>
94
+ <br>
95
+ <br>
96
+
97
+ </body>
98
+ </html>
docs/video_vgg.html ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <!-- Google tag (gtag.js) -->
5
+ <script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
6
+ <script>
7
+ window.dataLayer = window.dataLayer || [];
8
+ function gtag(){dataLayer.push(arguments);}
9
+ gtag('js', new Date());
10
+ gtag('config', 'G-0JKBJ3WRJZ');
11
+ </script>
12
+
13
+ <link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
14
+ <meta charset="UTF-8">
15
+ <title>MMAudio</title>
16
+
17
+ <meta name="viewport" content="width=device-width, initial-scale=1">
18
+ <!-- CSS only -->
19
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
20
+ integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
21
+ <script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
22
+
23
+ <link rel="stylesheet" href="style_videos.css">
24
+ </head>
25
+ <body>
26
+
27
+ <div id="vgg1">
28
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
29
+ <p style="overflow: hidden;">
30
+ Example 1: Wolf howling.
31
+ <span style="float:right;"><a href="#index">Back to index</a></span>
32
+ </p>
33
+ <div class="row g-1">
34
+ <div class="col-sm-3">
35
+ <div class="video-header">Ground-truth</div>
36
+ <div class="video-container">
37
+ <iframe src="https://youtube.com/embed/9J_V74gqMUA"></iframe>
38
+ </div>
39
+ </div>
40
+ <div class="col-sm-3">
41
+ <div class="video-header">Ours</div>
42
+ <div class="video-container">
43
+ <iframe src="https://youtube.com/embed/P6O8IpjErPc"></iframe>
44
+ </div>
45
+ </div>
46
+ <div class="col-sm-3">
47
+ <div class="video-header">V2A-Mapper</div>
48
+ <div class="video-container">
49
+ <iframe src="https://youtube.com/embed/w-5eyqepvTk"></iframe>
50
+ </div>
51
+ </div>
52
+ <div class="col-sm-3">
53
+ <div class="video-header">FoleyCrafter</div>
54
+ <div class="video-container">
55
+ <iframe src="https://youtube.com/embed/VOLfoZlRkzo"></iframe>
56
+ </div>
57
+ </div>
58
+ </div>
59
+ <div class="row g-1">
60
+ <div class="col-sm-3">
61
+ <div class="video-header">Frieren</div>
62
+ <div class="video-container">
63
+ <iframe src="https://youtube.com/embed/49owKyA5Pa8"></iframe>
64
+ </div>
65
+ </div>
66
+ <div class="col-sm-3">
67
+ <div class="video-header">VATT</div>
68
+ <div class="video-container">
69
+ <iframe src="https://youtube.com/embed/QVtrFgbeGDM"></iframe>
70
+ </div>
71
+ </div>
72
+ <div class="col-sm-3">
73
+ <div class="video-header">V-AURA</div>
74
+ <div class="video-container">
75
+ <iframe src="https://youtube.com/embed/8r0uEfSNjvI"></iframe>
76
+ </div>
77
+ </div>
78
+ <div class="col-sm-3">
79
+ <div class="video-header">Seeing and Hearing</div>
80
+ <div class="video-container">
81
+ <iframe src="https://youtube.com/embed/bn-sLg2qulk"></iframe>
82
+ </div>
83
+ </div>
84
+ </div>
85
+ </div>
86
+
87
+ <div id="vgg2">
88
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
89
+ <p style="overflow: hidden;">
90
+ Example 2: Striking a golf ball.
91
+ <span style="float:right;"><a href="#index">Back to index</a></span>
92
+ </p>
93
+
94
+ <div class="row g-1">
95
+ <div class="col-sm-3">
96
+ <div class="video-header">Ground-truth</div>
97
+ <div class="video-container">
98
+ <iframe src="https://youtube.com/embed/1hwSu42kkho"></iframe>
99
+ </div>
100
+ </div>
101
+ <div class="col-sm-3">
102
+ <div class="video-header">Ours</div>
103
+ <div class="video-container">
104
+ <iframe src="https://youtube.com/embed/kZibDoDCNxI"></iframe>
105
+ </div>
106
+ </div>
107
+ <div class="col-sm-3">
108
+ <div class="video-header">V2A-Mapper</div>
109
+ <div class="video-container">
110
+ <iframe src="https://youtube.com/embed/jgKfLBLhh7Y"></iframe>
111
+ </div>
112
+ </div>
113
+ <div class="col-sm-3">
114
+ <div class="video-header">FoleyCrafter</div>
115
+ <div class="video-container">
116
+ <iframe src="https://youtube.com/embed/Lfsx8mOPcJo"></iframe>
117
+ </div>
118
+ </div>
119
+ </div>
120
+ <div class="row g-1">
121
+ <div class="col-sm-3">
122
+ <div class="video-header">Frieren</div>
123
+ <div class="video-container">
124
+ <iframe src="https://youtube.com/embed/tz-LpbB0MBc"></iframe>
125
+ </div>
126
+ </div>
127
+ <div class="col-sm-3">
128
+ <div class="video-header">VATT</div>
129
+ <div class="video-container">
130
+ <iframe src="https://youtube.com/embed/RTDUHMi08n4"></iframe>
131
+ </div>
132
+ </div>
133
+ <div class="col-sm-3">
134
+ <div class="video-header">V-AURA</div>
135
+ <div class="video-container">
136
+ <iframe src="https://youtube.com/embed/N-3TDOsPnZQ"></iframe>
137
+ </div>
138
+ </div>
139
+ <div class="col-sm-3">
140
+ <div class="video-header">Seeing and Hearing</div>
141
+ <div class="video-container">
142
+ <iframe src="https://youtube.com/embed/QnsHnLn4gB0"></iframe>
143
+ </div>
144
+ </div>
145
+ </div>
146
+ </div>
147
+
148
+ <div id="vgg3">
149
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
150
+ <p style="overflow: hidden;">
151
+ Example 3: Hitting a drum.
152
+ <span style="float:right;"><a href="#index">Back to index</a></span>
153
+ </p>
154
+
155
+ <div class="row g-1">
156
+ <div class="col-sm-3">
157
+ <div class="video-header">Ground-truth</div>
158
+ <div class="video-container">
159
+ <iframe src="https://youtube.com/embed/0oeIwq77w0Q"></iframe>
160
+ </div>
161
+ </div>
162
+ <div class="col-sm-3">
163
+ <div class="video-header">Ours</div>
164
+ <div class="video-container">
165
+ <iframe src="https://youtube.com/embed/-UtPV9ohuIM"></iframe>
166
+ </div>
167
+ </div>
168
+ <div class="col-sm-3">
169
+ <div class="video-header">V2A-Mapper</div>
170
+ <div class="video-container">
171
+ <iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
172
+ </div>
173
+ </div>
174
+ <div class="col-sm-3">
175
+ <div class="video-header">FoleyCrafter</div>
176
+ <div class="video-container">
177
+ <iframe src="https://youtube.com/embed/kkCsXPOlBvY"></iframe>
178
+ </div>
179
+ </div>
180
+ </div>
181
+ <div class="row g-1">
182
+ <div class="col-sm-3">
183
+ <div class="video-header">Frieren</div>
184
+ <div class="video-container">
185
+ <iframe src="https://youtube.com/embed/MbNKsVsuvig"></iframe>
186
+ </div>
187
+ </div>
188
+ <div class="col-sm-3">
189
+ <div class="video-header">VATT</div>
190
+ <div class="video-container">
191
+ <iframe src="https://youtube.com/embed/2yYviBjrpBw"></iframe>
192
+ </div>
193
+ </div>
194
+ <div class="col-sm-3">
195
+ <div class="video-header">V-AURA</div>
196
+ <div class="video-container">
197
+ <iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
198
+ </div>
199
+ </div>
200
+ <div class="col-sm-3">
201
+ <div class="video-header">Seeing and Hearing</div>
202
+ <div class="video-container">
203
+ <iframe src="https://youtube.com/embed/6dnyQt4Fuhs"></iframe>
204
+ </div>
205
+ </div>
206
+ </div>
207
+ </div>
208
+ </div>
209
+
210
+ <div id="vgg4">
211
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
212
+ <p style="overflow: hidden;">
213
+ Example 4: Dog barking.
214
+ <span style="float:right;"><a href="#index">Back to index</a></span>
215
+ </p>
216
+
217
+ <div class="row g-1">
218
+ <div class="col-sm-3">
219
+ <div class="video-header">Ground-truth</div>
220
+ <div class="video-container">
221
+ <iframe src="https://youtube.com/embed/ckaqvTyMYAw"></iframe>
222
+ </div>
223
+ </div>
224
+ <div class="col-sm-3">
225
+ <div class="video-header">Ours</div>
226
+ <div class="video-container">
227
+ <iframe src="https://youtube.com/embed/_aRndFZzZ-I"></iframe>
228
+ </div>
229
+ </div>
230
+ <div class="col-sm-3">
231
+ <div class="video-header">V2A-Mapper</div>
232
+ <div class="video-container">
233
+ <iframe src="https://youtube.com/embed/mNCISP3LBl0"></iframe>
234
+ </div>
235
+ </div>
236
+ <div class="col-sm-3">
237
+ <div class="video-header">FoleyCrafter</div>
238
+ <div class="video-container">
239
+ <iframe src="https://youtube.com/embed/phZBQ3L7foE"></iframe>
240
+ </div>
241
+ </div>
242
+ </div>
243
+ <div class="row g-1">
244
+ <div class="col-sm-3">
245
+ <div class="video-header">Frieren</div>
246
+ <div class="video-container">
247
+ <iframe src="https://youtube.com/embed/Sb5Mg1-ORao"></iframe>
248
+ </div>
249
+ </div>
250
+ <div class="col-sm-3">
251
+ <div class="video-header">VATT</div>
252
+ <div class="video-container">
253
+ <iframe src="https://youtube.com/embed/eHmAGOmtDDg"></iframe>
254
+ </div>
255
+ </div>
256
+ <div class="col-sm-3">
257
+ <div class="video-header">V-AURA</div>
258
+ <div class="video-container">
259
+ <iframe src="https://youtube.com/embed/NEGa3krBrm0"></iframe>
260
+ </div>
261
+ </div>
262
+ <div class="col-sm-3">
263
+ <div class="video-header">Seeing and Hearing</div>
264
+ <div class="video-container">
265
+ <iframe src="https://youtube.com/embed/aO0EAXlwE7A"></iframe>
266
+ </div>
267
+ </div>
268
+ </div>
269
+ </div>
270
+
271
+ <div id="vgg5">
272
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
273
+ <p style="overflow: hidden;">
274
+ Example 5: Playing a string instrument.
275
+ <span style="float:right;"><a href="#index">Back to index</a></span>
276
+ </p>
277
+
278
+ <div class="row g-1">
279
+ <div class="col-sm-3">
280
+ <div class="video-header">Ground-truth</div>
281
+ <div class="video-container">
282
+ <iframe src="https://youtube.com/embed/KP1QhWauIOc"></iframe>
283
+ </div>
284
+ </div>
285
+ <div class="col-sm-3">
286
+ <div class="video-header">Ours</div>
287
+ <div class="video-container">
288
+ <iframe src="https://youtube.com/embed/ovaJhWSquYE"></iframe>
289
+ </div>
290
+ </div>
291
+ <div class="col-sm-3">
292
+ <div class="video-header">V2A-Mapper</div>
293
+ <div class="video-container">
294
+ <iframe src="https://youtube.com/embed/N723FS9lcy8"></iframe>
295
+ </div>
296
+ </div>
297
+ <div class="col-sm-3">
298
+ <div class="video-header">FoleyCrafter</div>
299
+ <div class="video-container">
300
+ <iframe src="https://youtube.com/embed/t0N4ZAAXo58"></iframe>
301
+ </div>
302
+ </div>
303
+ </div>
304
+ <div class="row g-1">
305
+ <div class="col-sm-3">
306
+ <div class="video-header">Frieren</div>
307
+ <div class="video-container">
308
+ <iframe src="https://youtube.com/embed/8YSRs03QNNA"></iframe>
309
+ </div>
310
+ </div>
311
+ <div class="col-sm-3">
312
+ <div class="video-header">VATT</div>
313
+ <div class="video-container">
314
+ <iframe src="https://youtube.com/embed/vOpMz55J1kY"></iframe>
315
+ </div>
316
+ </div>
317
+ <div class="col-sm-3">
318
+ <div class="video-header">V-AURA</div>
319
+ <div class="video-container">
320
+ <iframe src="https://youtube.com/embed/9JHC75vr9h0"></iframe>
321
+ </div>
322
+ </div>
323
+ <div class="col-sm-3">
324
+ <div class="video-header">Seeing and Hearing</div>
325
+ <div class="video-container">
326
+ <iframe src="https://youtube.com/embed/9w0JckNzXmY"></iframe>
327
+ </div>
328
+ </div>
329
+ </div>
330
+ </div>
331
+
332
+ <div id="vgg6">
333
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
334
+ <p style="overflow: hidden;">
335
+ Example 6: A group of people playing tambourines.
336
+ <span style="float:right;"><a href="#index">Back to index</a></span>
337
+ </p>
338
+
339
+ <div class="row g-1">
340
+ <div class="col-sm-3">
341
+ <div class="video-header">Ground-truth</div>
342
+ <div class="video-container">
343
+ <iframe src="https://youtube.com/embed/mx6JLxzUkRc"></iframe>
344
+ </div>
345
+ </div>
346
+ <div class="col-sm-3">
347
+ <div class="video-header">Ours</div>
348
+ <div class="video-container">
349
+ <iframe src="https://youtube.com/embed/oLirHhP9Su8"></iframe>
350
+ </div>
351
+ </div>
352
+ <div class="col-sm-3">
353
+ <div class="video-header">V2A-Mapper</div>
354
+ <div class="video-container">
355
+ <iframe src="https://youtube.com/embed/HkLkHMqptv0"></iframe>
356
+ </div>
357
+ </div>
358
+ <div class="col-sm-3">
359
+ <div class="video-header">FoleyCrafter</div>
360
+ <div class="video-container">
361
+ <iframe src="https://youtube.com/embed/rpHiiODjmNU"></iframe>
362
+ </div>
363
+ </div>
364
+ </div>
365
+ <div class="row g-1">
366
+ <div class="col-sm-3">
367
+ <div class="video-header">Frieren</div>
368
+ <div class="video-container">
369
+ <iframe src="https://youtube.com/embed/1mVD3fJ0LpM"></iframe>
370
+ </div>
371
+ </div>
372
+ <div class="col-sm-3">
373
+ <div class="video-header">VATT</div>
374
+ <div class="video-container">
375
+ <iframe src="https://youtube.com/embed/yjVFnJiEJlw"></iframe>
376
+ </div>
377
+ </div>
378
+ <div class="col-sm-3">
379
+ <div class="video-header">V-AURA</div>
380
+ <div class="video-container">
381
+ <iframe src="https://youtube.com/embed/neVeMSWtRkU"></iframe>
382
+ </div>
383
+ </div>
384
+ <div class="col-sm-3">
385
+ <div class="video-header">Seeing and Hearing</div>
386
+ <div class="video-container">
387
+ <iframe src="https://youtube.com/embed/EUE7YwyVWz8"></iframe>
388
+ </div>
389
+ </div>
390
+ </div>
391
+ </div>
392
+
393
+ <div id="vgg_extra">
394
+ <h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
395
+ <p style="overflow: hidden;">
396
+ <span style="float:right;"><a href="#index">Back to index</a></span>
397
+ </p>
398
+
399
+ <div class="row g-1">
400
+ <div class="col-sm-3">
401
+ <div class="video-header">Moving train</div>
402
+ <div class="video-container">
403
+ <iframe src="https://youtube.com/embed/Ta6H45rBzJc"></iframe>
404
+ </div>
405
+ </div>
406
+ <div class="col-sm-3">
407
+ <div class="video-header">Water splashing</div>
408
+ <div class="video-container">
409
+ <iframe src="https://youtube.com/embed/hl6AtgHXpb4"></iframe>
410
+ </div>
411
+ </div>
412
+ <div class="col-sm-3">
413
+ <div class="video-header">Skateboarding</div>
414
+ <div class="video-container">
415
+ <iframe src="https://youtube.com/embed/n4sCNi_9buI"></iframe>
416
+ </div>
417
+ </div>
418
+ <div class="col-sm-3">
419
+ <div class="video-header">Synchronized clapping</div>
420
+ <div class="video-container">
421
+ <iframe src="https://youtube.com/embed/oxexfpLn7FE"></iframe>
422
+ </div>
423
+ </div>
424
+ </div>
425
+
426
+ <br><br>
427
+
428
+ <div id="extra-failure">
429
+ <h2 style="text-align: center;">Failure cases</h2>
430
+ <p style="overflow: hidden;">
431
+ <span style="float:right;"><a href="#index">Back to index</a></span>
432
+ </p>
433
+
434
+ <div class="row g-1">
435
+ <div class="col-sm-6">
436
+ <div class="video-header">Human speech</div>
437
+ <div class="video-container">
438
+ <iframe src="https://youtube.com/embed/nx0CyrDu70Y"></iframe>
439
+ </div>
440
+ </div>
441
+ <div class="col-sm-6">
442
+ <div class="video-header">Unfamiliar vision input</div>
443
+ <div class="video-container">
444
+ <iframe src="https://youtube.com/embed/hfnAqmK3X7w"></iframe>
445
+ </div>
446
+ </div>
447
+ </div>
448
+ </div>
449
+ </div>
450
+
451
+ </body>
452
+ </html>
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
gradio_demo.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from argparse import ArgumentParser
4
+ from datetime import datetime
5
+ from fractions import Fraction
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torchaudio
11
+
12
+ from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
13
+ load_video, make_video, setup_eval_logging)
14
+ from mmaudio.model.flow_matching import FlowMatching
15
+ from mmaudio.model.networks import MMAudio, get_my_mmaudio
16
+ from mmaudio.model.sequence_config import SequenceConfig
17
+ from mmaudio.model.utils.features_utils import FeaturesUtils
18
+
19
+ torch.backends.cuda.matmul.allow_tf32 = True
20
+ torch.backends.cudnn.allow_tf32 = True
21
+
22
+ log = logging.getLogger()
23
+
24
+ device = 'cpu'
25
+ if torch.cuda.is_available():
26
+ device = 'cuda'
27
+ elif torch.backends.mps.is_available():
28
+ device = 'mps'
29
+ else:
30
+ log.warning('CUDA/MPS are not available, running on CPU')
31
+ dtype = torch.bfloat16
32
+
33
+ model: ModelConfig = all_model_cfg['large_44k_v2']
34
+ model.download_if_needed()
35
+ output_dir = Path('./output/gradio')
36
+
37
+ setup_eval_logging()
38
+
39
+
40
+ def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
41
+ seq_cfg = model.seq_cfg
42
+
43
+ net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
44
+ net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
45
+ log.info(f'Loaded weights from {model.model_path}')
46
+
47
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
48
+ synchformer_ckpt=model.synchformer_ckpt,
49
+ enable_conditions=True,
50
+ mode=model.mode,
51
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
52
+ need_vae_encoder=False)
53
+ feature_utils = feature_utils.to(device, dtype).eval()
54
+
55
+ return net, feature_utils, seq_cfg
56
+
57
+
58
+ net, feature_utils, seq_cfg = get_model()
59
+
60
+
61
+ @torch.inference_mode()
62
+ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
63
+ cfg_strength: float, duration: float):
64
+
65
+ rng = torch.Generator(device=device)
66
+ if seed >= 0:
67
+ rng.manual_seed(seed)
68
+ else:
69
+ rng.seed()
70
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
71
+
72
+ video_info = load_video(video, duration)
73
+ clip_frames = video_info.clip_frames
74
+ sync_frames = video_info.sync_frames
75
+ duration = video_info.duration_sec
76
+ clip_frames = clip_frames.unsqueeze(0)
77
+ sync_frames = sync_frames.unsqueeze(0)
78
+ seq_cfg.duration = duration
79
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
80
+
81
+ audios = generate(clip_frames,
82
+ sync_frames, [prompt],
83
+ negative_text=[negative_prompt],
84
+ feature_utils=feature_utils,
85
+ net=net,
86
+ fm=fm,
87
+ rng=rng,
88
+ cfg_strength=cfg_strength)
89
+ audio = audios.float().cpu()[0]
90
+
91
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
92
+ output_dir.mkdir(exist_ok=True, parents=True)
93
+ video_save_path = output_dir / f'{current_time_string}.mp4'
94
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
95
+ gc.collect()
96
+ return video_save_path
97
+
98
+
99
+ @torch.inference_mode()
100
+ def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
101
+ cfg_strength: float, duration: float):
102
+
103
+ rng = torch.Generator(device=device)
104
+ if seed >= 0:
105
+ rng.manual_seed(seed)
106
+ else:
107
+ rng.seed()
108
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
109
+
110
+ image_info = load_image(image)
111
+ clip_frames = image_info.clip_frames
112
+ sync_frames = image_info.sync_frames
113
+ clip_frames = clip_frames.unsqueeze(0)
114
+ sync_frames = sync_frames.unsqueeze(0)
115
+ seq_cfg.duration = duration
116
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
117
+
118
+ audios = generate(clip_frames,
119
+ sync_frames, [prompt],
120
+ negative_text=[negative_prompt],
121
+ feature_utils=feature_utils,
122
+ net=net,
123
+ fm=fm,
124
+ rng=rng,
125
+ cfg_strength=cfg_strength,
126
+ image_input=True)
127
+ audio = audios.float().cpu()[0]
128
+
129
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
130
+ output_dir.mkdir(exist_ok=True, parents=True)
131
+ video_save_path = output_dir / f'{current_time_string}.mp4'
132
+ video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
133
+ make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
134
+ gc.collect()
135
+ return video_save_path
136
+
137
+
138
+ @torch.inference_mode()
139
+ def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
140
+ duration: float):
141
+
142
+ rng = torch.Generator(device=device)
143
+ if seed >= 0:
144
+ rng.manual_seed(seed)
145
+ else:
146
+ rng.seed()
147
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
148
+
149
+ clip_frames = sync_frames = None
150
+ seq_cfg.duration = duration
151
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
152
+
153
+ audios = generate(clip_frames,
154
+ sync_frames, [prompt],
155
+ negative_text=[negative_prompt],
156
+ feature_utils=feature_utils,
157
+ net=net,
158
+ fm=fm,
159
+ rng=rng,
160
+ cfg_strength=cfg_strength)
161
+ audio = audios.float().cpu()[0]
162
+
163
+ current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
164
+ output_dir.mkdir(exist_ok=True, parents=True)
165
+ audio_save_path = output_dir / f'{current_time_string}.flac'
166
+ torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
167
+ gc.collect()
168
+ return audio_save_path
169
+
170
+
171
+ video_to_audio_tab = gr.Interface(
172
+ fn=video_to_audio,
173
+ description="""
174
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
175
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
176
+
177
+ NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
178
+ Doing so does not improve results.
179
+ """,
180
+ inputs=[
181
+ gr.Video(),
182
+ gr.Text(label='Prompt'),
183
+ gr.Text(label='Negative prompt', value='music'),
184
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
185
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
186
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
187
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
188
+ ],
189
+ outputs='playable_video',
190
+ cache_examples=False,
191
+ title='MMAudio — Video-to-Audio Synthesis',
192
+ examples=[
193
+ [
194
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
195
+ 'waves, seagulls',
196
+ '',
197
+ 0,
198
+ 25,
199
+ 4.5,
200
+ 10,
201
+ ],
202
+ [
203
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
204
+ '',
205
+ 'music',
206
+ 0,
207
+ 25,
208
+ 4.5,
209
+ 10,
210
+ ],
211
+ [
212
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
213
+ 'bubbles',
214
+ '',
215
+ 0,
216
+ 25,
217
+ 4.5,
218
+ 10,
219
+ ],
220
+ [
221
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
222
+ 'Indian holy music',
223
+ '',
224
+ 0,
225
+ 25,
226
+ 4.5,
227
+ 10,
228
+ ],
229
+ [
230
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
231
+ 'galloping',
232
+ '',
233
+ 0,
234
+ 25,
235
+ 4.5,
236
+ 10,
237
+ ],
238
+ [
239
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
240
+ 'waves, storm',
241
+ '',
242
+ 0,
243
+ 25,
244
+ 4.5,
245
+ 10,
246
+ ],
247
+ [
248
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
249
+ 'storm',
250
+ '',
251
+ 0,
252
+ 25,
253
+ 4.5,
254
+ 10,
255
+ ],
256
+ [
257
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
258
+ '',
259
+ '',
260
+ 0,
261
+ 25,
262
+ 4.5,
263
+ 10,
264
+ ],
265
+ [
266
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
267
+ 'typing',
268
+ '',
269
+ 0,
270
+ 25,
271
+ 4.5,
272
+ 10,
273
+ ],
274
+ [
275
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
276
+ '',
277
+ '',
278
+ 0,
279
+ 25,
280
+ 4.5,
281
+ 10,
282
+ ],
283
+ [
284
+ 'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
285
+ '',
286
+ '',
287
+ 0,
288
+ 25,
289
+ 4.5,
290
+ 10,
291
+ ],
292
+ ])
293
+
294
+ text_to_audio_tab = gr.Interface(
295
+ fn=text_to_audio,
296
+ description="""
297
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
298
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
299
+ """,
300
+ inputs=[
301
+ gr.Text(label='Prompt'),
302
+ gr.Text(label='Negative prompt'),
303
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
304
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
305
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
306
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
307
+ ],
308
+ outputs='audio',
309
+ cache_examples=False,
310
+ title='MMAudio — Text-to-Audio Synthesis',
311
+ )
312
+
313
+ image_to_audio_tab = gr.Interface(
314
+ fn=image_to_audio,
315
+ description="""
316
+ Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
317
+ Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
318
+
319
+ NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
320
+ Doing so does not improve results.
321
+ """,
322
+ inputs=[
323
+ gr.Image(type='filepath'),
324
+ gr.Text(label='Prompt'),
325
+ gr.Text(label='Negative prompt'),
326
+ gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
327
+ gr.Number(label='Num steps', value=25, precision=0, minimum=1),
328
+ gr.Number(label='Guidance Strength', value=4.5, minimum=1),
329
+ gr.Number(label='Duration (sec)', value=8, minimum=1),
330
+ ],
331
+ outputs='playable_video',
332
+ cache_examples=False,
333
+ title='MMAudio — Image-to-Audio Synthesis (experimental)',
334
+ )
335
+
336
+ if __name__ == "__main__":
337
+ parser = ArgumentParser()
338
+ parser.add_argument('--port', type=int, default=7860)
339
+ args = parser.parse_args()
340
+
341
+ gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
342
+ ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
343
+ server_port=args.port, allowed_paths=[output_dir])
mmaudio/__init__.py ADDED
File without changes
mmaudio/data/__init__.py ADDED
File without changes
mmaudio/data/av_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from fractions import Fraction
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import av
7
+ import numpy as np
8
+ import torch
9
+ from av import AudioFrame
10
+
11
+
12
+ @dataclass
13
+ class VideoInfo:
14
+ duration_sec: float
15
+ fps: Fraction
16
+ clip_frames: torch.Tensor
17
+ sync_frames: torch.Tensor
18
+ all_frames: Optional[list[np.ndarray]]
19
+
20
+ @property
21
+ def height(self):
22
+ return self.all_frames[0].shape[0]
23
+
24
+ @property
25
+ def width(self):
26
+ return self.all_frames[0].shape[1]
27
+
28
+ @classmethod
29
+ def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
30
+ fps: Fraction) -> 'VideoInfo':
31
+ num_frames = int(duration_sec * fps)
32
+ all_frames = [image_info.original_frame] * num_frames
33
+ return cls(duration_sec=duration_sec,
34
+ fps=fps,
35
+ clip_frames=image_info.clip_frames,
36
+ sync_frames=image_info.sync_frames,
37
+ all_frames=all_frames)
38
+
39
+
40
+ @dataclass
41
+ class ImageInfo:
42
+ clip_frames: torch.Tensor
43
+ sync_frames: torch.Tensor
44
+ original_frame: Optional[np.ndarray]
45
+
46
+ @property
47
+ def height(self):
48
+ return self.original_frame.shape[0]
49
+
50
+ @property
51
+ def width(self):
52
+ return self.original_frame.shape[1]
53
+
54
+
55
+ def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
56
+ need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
57
+ output_frames = [[] for _ in list_of_fps]
58
+ next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
59
+ time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
60
+ all_frames = []
61
+
62
+ # container = av.open(video_path)
63
+ with av.open(video_path) as container:
64
+ stream = container.streams.video[0]
65
+ fps = stream.guessed_rate
66
+ stream.thread_type = 'AUTO'
67
+ for packet in container.demux(stream):
68
+ for frame in packet.decode():
69
+ frame_time = frame.time
70
+ if frame_time < start_sec:
71
+ continue
72
+ if frame_time > end_sec:
73
+ break
74
+
75
+ frame_np = None
76
+ if need_all_frames:
77
+ frame_np = frame.to_ndarray(format='rgb24')
78
+ all_frames.append(frame_np)
79
+
80
+ for i, _ in enumerate(list_of_fps):
81
+ this_time = frame_time
82
+ while this_time >= next_frame_time_for_each_fps[i]:
83
+ if frame_np is None:
84
+ frame_np = frame.to_ndarray(format='rgb24')
85
+
86
+ output_frames[i].append(frame_np)
87
+ next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
88
+
89
+ output_frames = [np.stack(frames) for frames in output_frames]
90
+ return output_frames, all_frames, fps
91
+
92
+
93
+ def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
94
+ sampling_rate: int):
95
+ container = av.open(output_path, 'w')
96
+ output_video_stream = container.add_stream('h264', video_info.fps)
97
+ output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
98
+ output_video_stream.width = video_info.width
99
+ output_video_stream.height = video_info.height
100
+ output_video_stream.pix_fmt = 'yuv420p'
101
+
102
+ output_audio_stream = container.add_stream('aac', sampling_rate)
103
+
104
+ # encode video
105
+ for image in video_info.all_frames:
106
+ image = av.VideoFrame.from_ndarray(image)
107
+ packet = output_video_stream.encode(image)
108
+ container.mux(packet)
109
+
110
+ for packet in output_video_stream.encode():
111
+ container.mux(packet)
112
+
113
+ # convert float tensor audio to numpy array
114
+ audio_np = audio.numpy().astype(np.float32)
115
+ audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
116
+ audio_frame.sample_rate = sampling_rate
117
+
118
+ for packet in output_audio_stream.encode(audio_frame):
119
+ container.mux(packet)
120
+
121
+ for packet in output_audio_stream.encode():
122
+ container.mux(packet)
123
+
124
+ container.close()
125
+
126
+
127
+ def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
128
+ """
129
+ NOTE: I don't think we can get the exact video duration right without re-encoding
130
+ so we are not using this but keeping it here for reference
131
+ """
132
+ video = av.open(video_path)
133
+ output = av.open(output_path, 'w')
134
+ input_video_stream = video.streams.video[0]
135
+ output_video_stream = output.add_stream(template=input_video_stream)
136
+ output_audio_stream = output.add_stream('aac', sampling_rate)
137
+
138
+ duration_sec = audio.shape[-1] / sampling_rate
139
+
140
+ for packet in video.demux(input_video_stream):
141
+ # We need to skip the "flushing" packets that `demux` generates.
142
+ if packet.dts is None:
143
+ continue
144
+ # We need to assign the packet to the new stream.
145
+ packet.stream = output_video_stream
146
+ output.mux(packet)
147
+
148
+ # convert float tensor audio to numpy array
149
+ audio_np = audio.numpy().astype(np.float32)
150
+ audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
151
+ audio_frame.sample_rate = sampling_rate
152
+
153
+ for packet in output_audio_stream.encode(audio_frame):
154
+ output.mux(packet)
155
+
156
+ for packet in output_audio_stream.encode():
157
+ output.mux(packet)
158
+
159
+ video.close()
160
+ output.close()
161
+
162
+ output.close()
mmaudio/data/data_setup.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torch.utils.data.dataloader import default_collate
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from mmaudio.data.eval.audiocaps import AudioCapsData
12
+ from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
13
+ from mmaudio.data.extracted_audio import ExtractedAudio
14
+ from mmaudio.data.extracted_vgg import ExtractedVGG
15
+ from mmaudio.data.mm_dataset import MultiModalDataset
16
+ from mmaudio.utils.dist_utils import local_rank
17
+
18
+ log = logging.getLogger()
19
+
20
+
21
+ # Re-seed randomness every time we start a worker
22
+ def worker_init_fn(worker_id: int):
23
+ worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
24
+ np.random.seed(worker_seed)
25
+ random.seed(worker_seed)
26
+ log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
27
+
28
+
29
+ def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
30
+ dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
31
+ data_dim=cfg.data_dim,
32
+ premade_mmap_dir=data_cfg.memmap_dir)
33
+
34
+ return dataset
35
+
36
+
37
+ def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
38
+ dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
39
+ data_dim=cfg.data_dim,
40
+ premade_mmap_dir=data_cfg.memmap_dir)
41
+
42
+ return dataset
43
+
44
+
45
+ def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
46
+ if cfg.mini_train:
47
+ vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
48
+ audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
49
+ dataset = MultiModalDataset([vgg], [audiocaps])
50
+ if cfg.example_train:
51
+ video = load_vgg_data(cfg, cfg.data.Example_video)
52
+ audio = load_audio_data(cfg, cfg.data.Example_audio)
53
+ dataset = MultiModalDataset([video], [audio])
54
+ else:
55
+ # load the largest one first
56
+ freesound = load_audio_data(cfg, cfg.data.FreeSound)
57
+ vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
58
+ audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
59
+ audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
60
+ bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
61
+ clotho = load_audio_data(cfg, cfg.data.Clotho)
62
+ dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
63
+ [audiocaps, audioset_sl, bbcsound, freesound, clotho])
64
+
65
+ batch_size = cfg.batch_size
66
+ num_workers = cfg.num_workers
67
+ pin_memory = cfg.pin_memory
68
+ sampler, loader = construct_loader(dataset,
69
+ batch_size,
70
+ num_workers,
71
+ shuffle=True,
72
+ drop_last=True,
73
+ pin_memory=pin_memory)
74
+
75
+ return dataset, sampler, loader
76
+
77
+
78
+ def setup_test_datasets(cfg):
79
+ dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
80
+
81
+ batch_size = cfg.batch_size
82
+ num_workers = cfg.num_workers
83
+ pin_memory = cfg.pin_memory
84
+ sampler, loader = construct_loader(dataset,
85
+ batch_size,
86
+ num_workers,
87
+ shuffle=False,
88
+ drop_last=False,
89
+ pin_memory=pin_memory)
90
+
91
+ return dataset, sampler, loader
92
+
93
+
94
+ def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
95
+ if cfg.example_train:
96
+ dataset = load_vgg_data(cfg, cfg.data.Example_video)
97
+ else:
98
+ dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
99
+
100
+ val_batch_size = cfg.batch_size
101
+ val_eval_batch_size = cfg.eval_batch_size
102
+ num_workers = cfg.num_workers
103
+ pin_memory = cfg.pin_memory
104
+ _, val_loader = construct_loader(dataset,
105
+ val_batch_size,
106
+ num_workers,
107
+ shuffle=False,
108
+ drop_last=False,
109
+ pin_memory=pin_memory)
110
+ _, eval_loader = construct_loader(dataset,
111
+ val_eval_batch_size,
112
+ num_workers,
113
+ shuffle=False,
114
+ drop_last=False,
115
+ pin_memory=pin_memory)
116
+
117
+ return dataset, val_loader, eval_loader
118
+
119
+
120
+ def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
121
+ if dataset_name.startswith('audiocaps_full'):
122
+ dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
123
+ cfg.eval_data.AudioCaps_full.csv_path)
124
+ elif dataset_name.startswith('audiocaps'):
125
+ dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
126
+ cfg.eval_data.AudioCaps.csv_path)
127
+ elif dataset_name.startswith('moviegen'):
128
+ dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
129
+ cfg.eval_data.MovieGen.jsonl_path,
130
+ duration_sec=cfg.duration_s)
131
+ elif dataset_name.startswith('vggsound'):
132
+ dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
133
+ cfg.eval_data.VGGSound.csv_path,
134
+ duration_sec=cfg.duration_s)
135
+ else:
136
+ raise ValueError(f'Invalid dataset name: {dataset_name}')
137
+
138
+ batch_size = cfg.batch_size
139
+ num_workers = cfg.num_workers
140
+ pin_memory = cfg.pin_memory
141
+ _, loader = construct_loader(dataset,
142
+ batch_size,
143
+ num_workers,
144
+ shuffle=False,
145
+ drop_last=False,
146
+ pin_memory=pin_memory,
147
+ error_avoidance=True)
148
+ return dataset, loader
149
+
150
+
151
+ def error_avoidance_collate(batch):
152
+ batch = list(filter(lambda x: x is not None, batch))
153
+ return default_collate(batch)
154
+
155
+
156
+ def construct_loader(dataset: Dataset,
157
+ batch_size: int,
158
+ num_workers: int,
159
+ *,
160
+ shuffle: bool = True,
161
+ drop_last: bool = True,
162
+ pin_memory: bool = False,
163
+ error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
164
+ train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
165
+ train_loader = DataLoader(dataset,
166
+ batch_size,
167
+ sampler=train_sampler,
168
+ num_workers=num_workers,
169
+ worker_init_fn=worker_init_fn,
170
+ drop_last=drop_last,
171
+ persistent_workers=num_workers > 0,
172
+ pin_memory=pin_memory,
173
+ collate_fn=error_avoidance_collate if error_avoidance else None)
174
+ return train_sampler, train_loader
mmaudio/data/eval/__init__.py ADDED
File without changes
mmaudio/data/eval/audiocaps.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ log = logging.getLogger()
12
+
13
+
14
+ class AudioCapsData(Dataset):
15
+
16
+ def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
17
+ df = pd.read_csv(csv_path).to_dict(orient='records')
18
+
19
+ audio_files = sorted(os.listdir(audio_path))
20
+ audio_files = set(
21
+ [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
22
+
23
+ self.data = []
24
+ for row in df:
25
+ self.data.append({
26
+ 'name': row['name'],
27
+ 'caption': row['caption'],
28
+ })
29
+
30
+ self.audio_path = Path(audio_path)
31
+ self.csv_path = Path(csv_path)
32
+
33
+ log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
34
+
35
+ def __getitem__(self, idx: int) -> torch.Tensor:
36
+ return self.data[idx]
37
+
38
+ def __len__(self):
39
+ return len(self.data)
mmaudio/data/eval/moviegen.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import torch
8
+ from torch.utils.data.dataset import Dataset
9
+ from torchvision.transforms import v2
10
+ from torio.io import StreamingMediaDecoder
11
+
12
+ from mmaudio.utils.dist_utils import local_rank
13
+
14
+ log = logging.getLogger()
15
+
16
+ _CLIP_SIZE = 384
17
+ _CLIP_FPS = 8.0
18
+
19
+ _SYNC_SIZE = 224
20
+ _SYNC_FPS = 25.0
21
+
22
+
23
+ class MovieGenData(Dataset):
24
+
25
+ def __init__(
26
+ self,
27
+ video_root: Union[str, Path],
28
+ sync_root: Union[str, Path],
29
+ jsonl_root: Union[str, Path],
30
+ *,
31
+ duration_sec: float = 10.0,
32
+ read_clip: bool = True,
33
+ ):
34
+ self.video_root = Path(video_root)
35
+ self.sync_root = Path(sync_root)
36
+ self.jsonl_root = Path(jsonl_root)
37
+ self.read_clip = read_clip
38
+
39
+ videos = sorted(os.listdir(self.video_root))
40
+ videos = [v[:-4] for v in videos] # remove extensions
41
+ self.captions = {}
42
+
43
+ for v in videos:
44
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
45
+ data = json.load(f)
46
+ self.captions[v] = data['audio_prompt']
47
+
48
+ if local_rank == 0:
49
+ log.info(f'{len(videos)} videos found in {video_root}')
50
+
51
+ self.duration_sec = duration_sec
52
+
53
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
54
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
55
+
56
+ self.clip_augment = v2.Compose([
57
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
58
+ v2.ToImage(),
59
+ v2.ToDtype(torch.float32, scale=True),
60
+ ])
61
+
62
+ self.sync_augment = v2.Compose([
63
+ v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
64
+ v2.CenterCrop(_SYNC_SIZE),
65
+ v2.ToImage(),
66
+ v2.ToDtype(torch.float32, scale=True),
67
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
68
+ ])
69
+
70
+ self.videos = videos
71
+
72
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
73
+ video_id = self.videos[idx]
74
+ caption = self.captions[video_id]
75
+
76
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
77
+ reader.add_basic_video_stream(
78
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
79
+ frame_rate=_CLIP_FPS,
80
+ format='rgb24',
81
+ )
82
+ reader.add_basic_video_stream(
83
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
84
+ frame_rate=_SYNC_FPS,
85
+ format='rgb24',
86
+ )
87
+
88
+ reader.fill_buffer()
89
+ data_chunk = reader.pop_chunks()
90
+
91
+ clip_chunk = data_chunk[0]
92
+ sync_chunk = data_chunk[1]
93
+ if clip_chunk is None:
94
+ raise RuntimeError(f'CLIP video returned None {video_id}')
95
+ if clip_chunk.shape[0] < self.clip_expected_length:
96
+ raise RuntimeError(f'CLIP video too short {video_id}')
97
+
98
+ if sync_chunk is None:
99
+ raise RuntimeError(f'Sync video returned None {video_id}')
100
+ if sync_chunk.shape[0] < self.sync_expected_length:
101
+ raise RuntimeError(f'Sync video too short {video_id}')
102
+
103
+ # truncate the video
104
+ clip_chunk = clip_chunk[:self.clip_expected_length]
105
+ if clip_chunk.shape[0] != self.clip_expected_length:
106
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
107
+ f'expected {self.clip_expected_length}, '
108
+ f'got {clip_chunk.shape[0]}')
109
+ clip_chunk = self.clip_augment(clip_chunk)
110
+
111
+ sync_chunk = sync_chunk[:self.sync_expected_length]
112
+ if sync_chunk.shape[0] != self.sync_expected_length:
113
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
114
+ f'expected {self.sync_expected_length}, '
115
+ f'got {sync_chunk.shape[0]}')
116
+ sync_chunk = self.sync_augment(sync_chunk)
117
+
118
+ data = {
119
+ 'name': video_id,
120
+ 'caption': caption,
121
+ 'clip_video': clip_chunk,
122
+ 'sync_video': sync_chunk,
123
+ }
124
+
125
+ return data
126
+
127
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
128
+ return self.sample(idx)
129
+
130
+ def __len__(self):
131
+ return len(self.captions)
mmaudio/data/eval/video_dataset.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VideoDataset(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ video_root: Union[str, Path],
29
+ *,
30
+ duration_sec: float = 8.0,
31
+ ):
32
+ self.video_root = Path(video_root)
33
+
34
+ self.duration_sec = duration_sec
35
+
36
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
37
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
38
+
39
+ self.clip_transform = v2.Compose([
40
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
41
+ v2.ToImage(),
42
+ v2.ToDtype(torch.float32, scale=True),
43
+ ])
44
+
45
+ self.sync_transform = v2.Compose([
46
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
47
+ v2.CenterCrop(_SYNC_SIZE),
48
+ v2.ToImage(),
49
+ v2.ToDtype(torch.float32, scale=True),
50
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
51
+ ])
52
+
53
+ # to be implemented by subclasses
54
+ self.captions = {}
55
+ self.videos = sorted(list(self.captions.keys()))
56
+
57
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
58
+ video_id = self.videos[idx]
59
+ caption = self.captions[video_id]
60
+
61
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
62
+ reader.add_basic_video_stream(
63
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
64
+ frame_rate=_CLIP_FPS,
65
+ format='rgb24',
66
+ )
67
+ reader.add_basic_video_stream(
68
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
69
+ frame_rate=_SYNC_FPS,
70
+ format='rgb24',
71
+ )
72
+
73
+ reader.fill_buffer()
74
+ data_chunk = reader.pop_chunks()
75
+
76
+ clip_chunk = data_chunk[0]
77
+ sync_chunk = data_chunk[1]
78
+ if clip_chunk is None:
79
+ raise RuntimeError(f'CLIP video returned None {video_id}')
80
+ if clip_chunk.shape[0] < self.clip_expected_length:
81
+ raise RuntimeError(
82
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
83
+ )
84
+
85
+ if sync_chunk is None:
86
+ raise RuntimeError(f'Sync video returned None {video_id}')
87
+ if sync_chunk.shape[0] < self.sync_expected_length:
88
+ raise RuntimeError(
89
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
90
+ )
91
+
92
+ # truncate the video
93
+ clip_chunk = clip_chunk[:self.clip_expected_length]
94
+ if clip_chunk.shape[0] != self.clip_expected_length:
95
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
96
+ f'expected {self.clip_expected_length}, '
97
+ f'got {clip_chunk.shape[0]}')
98
+ clip_chunk = self.clip_transform(clip_chunk)
99
+
100
+ sync_chunk = sync_chunk[:self.sync_expected_length]
101
+ if sync_chunk.shape[0] != self.sync_expected_length:
102
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
103
+ f'expected {self.sync_expected_length}, '
104
+ f'got {sync_chunk.shape[0]}')
105
+ sync_chunk = self.sync_transform(sync_chunk)
106
+
107
+ data = {
108
+ 'name': video_id,
109
+ 'caption': caption,
110
+ 'clip_video': clip_chunk,
111
+ 'sync_video': sync_chunk,
112
+ }
113
+
114
+ return data
115
+
116
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
117
+ try:
118
+ return self.sample(idx)
119
+ except Exception as e:
120
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
121
+ return None
122
+
123
+ def __len__(self):
124
+ return len(self.captions)
125
+
126
+
127
+ class VGGSound(VideoDataset):
128
+
129
+ def __init__(
130
+ self,
131
+ video_root: Union[str, Path],
132
+ csv_path: Union[str, Path],
133
+ *,
134
+ duration_sec: float = 8.0,
135
+ ):
136
+ super().__init__(video_root, duration_sec=duration_sec)
137
+ self.video_root = Path(video_root)
138
+ self.csv_path = Path(csv_path)
139
+
140
+ videos = sorted(os.listdir(self.video_root))
141
+ if local_rank == 0:
142
+ log.info(f'{len(videos)} videos found in {video_root}')
143
+ self.captions = {}
144
+
145
+ df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
146
+ 'split']).to_dict(orient='records')
147
+
148
+ videos_no_found = []
149
+ for row in df:
150
+ if row['split'] == 'test':
151
+ start_sec = int(row['sec'])
152
+ video_id = str(row['id'])
153
+ # this is how our videos are named
154
+ video_name = f'{video_id}_{start_sec:06d}'
155
+ if video_name + '.mp4' not in videos:
156
+ videos_no_found.append(video_name)
157
+ continue
158
+
159
+ self.captions[video_name] = row['caption']
160
+
161
+ if local_rank == 0:
162
+ log.info(f'{len(videos)} videos found in {video_root}')
163
+ log.info(f'{len(self.captions)} useable videos found')
164
+ if videos_no_found:
165
+ log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
166
+ log.info(
167
+ 'A small amount is expected, as not all videos are still available on YouTube')
168
+
169
+ self.videos = sorted(list(self.captions.keys()))
170
+
171
+
172
+ class MovieGen(VideoDataset):
173
+
174
+ def __init__(
175
+ self,
176
+ video_root: Union[str, Path],
177
+ jsonl_root: Union[str, Path],
178
+ *,
179
+ duration_sec: float = 10.0,
180
+ ):
181
+ super().__init__(video_root, duration_sec=duration_sec)
182
+ self.video_root = Path(video_root)
183
+ self.jsonl_root = Path(jsonl_root)
184
+
185
+ videos = sorted(os.listdir(self.video_root))
186
+ videos = [v[:-4] for v in videos] # remove extensions
187
+ self.captions = {}
188
+
189
+ for v in videos:
190
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
191
+ data = json.load(f)
192
+ self.captions[v] = data['audio_prompt']
193
+
194
+ if local_rank == 0:
195
+ log.info(f'{len(videos)} videos found in {video_root}')
196
+
197
+ self.videos = videos
mmaudio/data/extracted_audio.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ from mmaudio.utils.dist_utils import local_rank
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class ExtractedAudio(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ tsv_path: Union[str, Path],
20
+ *,
21
+ premade_mmap_dir: Union[str, Path],
22
+ data_dim: dict[str, int],
23
+ ):
24
+ super().__init__()
25
+
26
+ self.data_dim = data_dim
27
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
+ self.ids = [str(d['id']) for d in self.df_list]
29
+
30
+ log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
+ # load precomputed memory mapped tensors
32
+ premade_mmap_dir = Path(premade_mmap_dir)
33
+ td = TensorDict.load_memmap(premade_mmap_dir)
34
+ log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
+ self.mean = td['mean']
36
+ self.std = td['std']
37
+ self.text_features = td['text_features']
38
+
39
+ log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
40
+ log.info(f'Loaded mean: {self.mean.shape}.')
41
+ log.info(f'Loaded std: {self.std.shape}.')
42
+ log.info(f'Loaded text features: {self.text_features.shape}.')
43
+
44
+ assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
45
+ f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
46
+ assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
47
+ f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
48
+
49
+ assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
50
+ f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
51
+ assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
52
+ f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
53
+
54
+ self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
55
+ self.data_dim['clip_dim'])
56
+ self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
57
+ self.data_dim['sync_dim'])
58
+ self.video_exist = torch.tensor(0, dtype=torch.bool)
59
+ self.text_exist = torch.tensor(1, dtype=torch.bool)
60
+
61
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
62
+ latents = self.mean
63
+ return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
64
+
65
+ def get_memory_mapped_tensor(self) -> TensorDict:
66
+ td = TensorDict({
67
+ 'mean': self.mean,
68
+ 'std': self.std,
69
+ 'text_features': self.text_features,
70
+ })
71
+ return td
72
+
73
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
74
+ data = {
75
+ 'id': str(self.df_list[idx]['id']),
76
+ 'a_mean': self.mean[idx],
77
+ 'a_std': self.std[idx],
78
+ 'clip_features': self.fake_clip_features,
79
+ 'sync_features': self.fake_sync_features,
80
+ 'text_features': self.text_features[idx],
81
+ 'caption': self.df_list[idx]['caption'],
82
+ 'video_exist': self.video_exist,
83
+ 'text_exist': self.text_exist,
84
+ }
85
+ return data
86
+
87
+ def __len__(self):
88
+ return len(self.ids)
mmaudio/data/extracted_vgg.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+
10
+ from mmaudio.utils.dist_utils import local_rank
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class ExtractedVGG(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ tsv_path: Union[str, Path],
20
+ *,
21
+ premade_mmap_dir: Union[str, Path],
22
+ data_dim: dict[str, int],
23
+ ):
24
+ super().__init__()
25
+
26
+ self.data_dim = data_dim
27
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
28
+ self.ids = [d['id'] for d in self.df_list]
29
+
30
+ log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
31
+ # load precomputed memory mapped tensors
32
+ premade_mmap_dir = Path(premade_mmap_dir)
33
+ td = TensorDict.load_memmap(premade_mmap_dir)
34
+ log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
35
+ self.mean = td['mean']
36
+ self.std = td['std']
37
+ self.clip_features = td['clip_features']
38
+ self.sync_features = td['sync_features']
39
+ self.text_features = td['text_features']
40
+
41
+ if local_rank == 0:
42
+ log.info(f'Loaded {len(self)} samples.')
43
+ log.info(f'Loaded mean: {self.mean.shape}.')
44
+ log.info(f'Loaded std: {self.std.shape}.')
45
+ log.info(f'Loaded clip_features: {self.clip_features.shape}.')
46
+ log.info(f'Loaded sync_features: {self.sync_features.shape}.')
47
+ log.info(f'Loaded text_features: {self.text_features.shape}.')
48
+
49
+ assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
50
+ f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
51
+ assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
52
+ f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
53
+
54
+ assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
55
+ f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
56
+ assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
57
+ f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
58
+ assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
59
+ f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
60
+
61
+ assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
62
+ f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
63
+ assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
64
+ f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
65
+ assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
66
+ f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
67
+
68
+ self.video_exist = torch.tensor(1, dtype=torch.bool)
69
+ self.text_exist = torch.tensor(1, dtype=torch.bool)
70
+
71
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
72
+ latents = self.mean
73
+ return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
74
+
75
+ def get_memory_mapped_tensor(self) -> TensorDict:
76
+ td = TensorDict({
77
+ 'mean': self.mean,
78
+ 'std': self.std,
79
+ 'clip_features': self.clip_features,
80
+ 'sync_features': self.sync_features,
81
+ 'text_features': self.text_features,
82
+ })
83
+ return td
84
+
85
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
86
+ data = {
87
+ 'id': self.df_list[idx]['id'],
88
+ 'a_mean': self.mean[idx],
89
+ 'a_std': self.std[idx],
90
+ 'clip_features': self.clip_features[idx],
91
+ 'sync_features': self.sync_features[idx],
92
+ 'text_features': self.text_features[idx],
93
+ 'caption': self.df_list[idx]['label'],
94
+ 'video_exist': self.video_exist,
95
+ 'text_exist': self.text_exist,
96
+ }
97
+
98
+ return data
99
+
100
+ def __len__(self):
101
+ return len(self.ids)
mmaudio/data/extraction/__init__.py ADDED
File without changes
mmaudio/data/extraction/vgg_sound.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VGGSound(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, Path],
29
+ *,
30
+ tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
31
+ sample_rate: int = 16_000,
32
+ duration_sec: float = 8.0,
33
+ audio_samples: Optional[int] = None,
34
+ normalize_audio: bool = False,
35
+ ):
36
+ self.root = Path(root)
37
+ self.normalize_audio = normalize_audio
38
+ if audio_samples is None:
39
+ self.audio_samples = int(sample_rate * duration_sec)
40
+ else:
41
+ self.audio_samples = audio_samples
42
+ effective_duration = audio_samples / sample_rate
43
+ # make sure the duration is close enough, within 15ms
44
+ assert abs(effective_duration - duration_sec) < 0.015, \
45
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
46
+
47
+ videos = sorted(os.listdir(self.root))
48
+ videos = set([Path(v).stem for v in videos]) # remove extensions
49
+ self.labels = {}
50
+ self.videos = []
51
+ missing_videos = []
52
+
53
+ # read the tsv for subset information
54
+ df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
55
+ for record in df_list:
56
+ id = record['id']
57
+ label = record['label']
58
+ if id in videos:
59
+ self.labels[id] = label
60
+ self.videos.append(id)
61
+ else:
62
+ missing_videos.append(id)
63
+
64
+ if local_rank == 0:
65
+ log.info(f'{len(videos)} videos found in {root}')
66
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
67
+ log.info(f'{len(missing_videos)} videos missing in {root}')
68
+
69
+ self.sample_rate = sample_rate
70
+ self.duration_sec = duration_sec
71
+
72
+ self.expected_audio_length = audio_samples
73
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
74
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
75
+
76
+ self.clip_transform = v2.Compose([
77
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
78
+ v2.ToImage(),
79
+ v2.ToDtype(torch.float32, scale=True),
80
+ ])
81
+
82
+ self.sync_transform = v2.Compose([
83
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
84
+ v2.CenterCrop(_SYNC_SIZE),
85
+ v2.ToImage(),
86
+ v2.ToDtype(torch.float32, scale=True),
87
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
88
+ ])
89
+
90
+ self.resampler = {}
91
+
92
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
93
+ video_id = self.videos[idx]
94
+ label = self.labels[video_id]
95
+
96
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
97
+ reader.add_basic_video_stream(
98
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
99
+ frame_rate=_CLIP_FPS,
100
+ format='rgb24',
101
+ )
102
+ reader.add_basic_video_stream(
103
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
104
+ frame_rate=_SYNC_FPS,
105
+ format='rgb24',
106
+ )
107
+ reader.add_basic_audio_stream(frames_per_chunk=2**30, )
108
+
109
+ reader.fill_buffer()
110
+ data_chunk = reader.pop_chunks()
111
+
112
+ clip_chunk = data_chunk[0]
113
+ sync_chunk = data_chunk[1]
114
+ audio_chunk = data_chunk[2]
115
+
116
+ if clip_chunk is None:
117
+ raise RuntimeError(f'CLIP video returned None {video_id}')
118
+ if clip_chunk.shape[0] < self.clip_expected_length:
119
+ raise RuntimeError(
120
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
121
+ )
122
+
123
+ if sync_chunk is None:
124
+ raise RuntimeError(f'Sync video returned None {video_id}')
125
+ if sync_chunk.shape[0] < self.sync_expected_length:
126
+ raise RuntimeError(
127
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
128
+ )
129
+
130
+ # process audio
131
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
132
+ audio_chunk = audio_chunk.transpose(0, 1)
133
+ audio_chunk = audio_chunk.mean(dim=0) # mono
134
+ if self.normalize_audio:
135
+ abs_max = audio_chunk.abs().max()
136
+ audio_chunk = audio_chunk / abs_max * 0.95
137
+ if abs_max <= 1e-6:
138
+ raise RuntimeError(f'Audio is silent {video_id}')
139
+
140
+ # resample
141
+ if sample_rate == self.sample_rate:
142
+ audio_chunk = audio_chunk
143
+ else:
144
+ if sample_rate not in self.resampler:
145
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
146
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
147
+ sample_rate,
148
+ self.sample_rate,
149
+ lowpass_filter_width=64,
150
+ rolloff=0.9475937167399596,
151
+ resampling_method='sinc_interp_kaiser',
152
+ beta=14.769656459379492,
153
+ )
154
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
155
+
156
+ if audio_chunk.shape[0] < self.expected_audio_length:
157
+ raise RuntimeError(f'Audio too short {video_id}')
158
+ audio_chunk = audio_chunk[:self.expected_audio_length]
159
+
160
+ # truncate the video
161
+ clip_chunk = clip_chunk[:self.clip_expected_length]
162
+ if clip_chunk.shape[0] != self.clip_expected_length:
163
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
164
+ f'expected {self.clip_expected_length}, '
165
+ f'got {clip_chunk.shape[0]}')
166
+ clip_chunk = self.clip_transform(clip_chunk)
167
+
168
+ sync_chunk = sync_chunk[:self.sync_expected_length]
169
+ if sync_chunk.shape[0] != self.sync_expected_length:
170
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
171
+ f'expected {self.sync_expected_length}, '
172
+ f'got {sync_chunk.shape[0]}')
173
+ sync_chunk = self.sync_transform(sync_chunk)
174
+
175
+ data = {
176
+ 'id': video_id,
177
+ 'caption': label,
178
+ 'audio': audio_chunk,
179
+ 'clip_video': clip_chunk,
180
+ 'sync_video': sync_chunk,
181
+ }
182
+
183
+ return data
184
+
185
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
186
+ try:
187
+ return self.sample(idx)
188
+ except Exception as e:
189
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
190
+ return None
191
+
192
+ def __len__(self):
193
+ return len(self.labels)
mmaudio/data/extraction/wav_dataset.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import open_clip
7
+ import pandas as pd
8
+ import torch
9
+ import torchaudio
10
+ from torch.utils.data.dataset import Dataset
11
+
12
+ log = logging.getLogger()
13
+
14
+
15
+ class WavTextClipsDataset(Dataset):
16
+
17
+ def __init__(
18
+ self,
19
+ root: Union[str, Path],
20
+ *,
21
+ captions_tsv: Union[str, Path],
22
+ clips_tsv: Union[str, Path],
23
+ sample_rate: int,
24
+ num_samples: int,
25
+ normalize_audio: bool = False,
26
+ reject_silent: bool = False,
27
+ tokenizer_id: str = 'ViT-H-14-378-quickgelu',
28
+ ):
29
+ self.root = Path(root)
30
+ self.sample_rate = sample_rate
31
+ self.num_samples = num_samples
32
+ self.normalize_audio = normalize_audio
33
+ self.reject_silent = reject_silent
34
+ self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
35
+
36
+ audios = sorted(os.listdir(self.root))
37
+ audios = set([
38
+ Path(audio).stem for audio in audios
39
+ if audio.endswith('.wav') or audio.endswith('.flac')
40
+ ])
41
+ self.captions = {}
42
+
43
+ # read the caption tsv
44
+ df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
45
+ for record in df_list:
46
+ id = record['id']
47
+ caption = record['caption']
48
+ self.captions[id] = caption
49
+
50
+ # read the clip tsv
51
+ df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
52
+ 'id': str,
53
+ 'name': str
54
+ }).to_dict('records')
55
+ self.clips = []
56
+ for record in df_list:
57
+ record['id'] = record['id']
58
+ record['name'] = record['name']
59
+ id = record['id']
60
+ name = record['name']
61
+ if name not in self.captions:
62
+ log.warning(f'Audio {name} not found in {captions_tsv}')
63
+ continue
64
+ record['caption'] = self.captions[name]
65
+ self.clips.append(record)
66
+
67
+ log.info(f'Found {len(self.clips)} audio files in {self.root}')
68
+
69
+ self.resampler = {}
70
+
71
+ def __getitem__(self, idx: int) -> torch.Tensor:
72
+ try:
73
+ clip = self.clips[idx]
74
+ audio_name = clip['name']
75
+ audio_id = clip['id']
76
+ caption = clip['caption']
77
+ start_sample = clip['start_sample']
78
+ end_sample = clip['end_sample']
79
+
80
+ audio_path = self.root / f'{audio_name}.flac'
81
+ if not audio_path.exists():
82
+ audio_path = self.root / f'{audio_name}.wav'
83
+ assert audio_path.exists()
84
+
85
+ audio_chunk, sample_rate = torchaudio.load(audio_path)
86
+ audio_chunk = audio_chunk.mean(dim=0) # mono
87
+ abs_max = audio_chunk.abs().max()
88
+ if self.normalize_audio:
89
+ audio_chunk = audio_chunk / abs_max * 0.95
90
+
91
+ if self.reject_silent and abs_max < 1e-6:
92
+ log.warning(f'Rejecting silent audio')
93
+ return None
94
+
95
+ audio_chunk = audio_chunk[start_sample:end_sample]
96
+
97
+ # resample
98
+ if sample_rate == self.sample_rate:
99
+ audio_chunk = audio_chunk
100
+ else:
101
+ if sample_rate not in self.resampler:
102
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
103
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
104
+ sample_rate,
105
+ self.sample_rate,
106
+ lowpass_filter_width=64,
107
+ rolloff=0.9475937167399596,
108
+ resampling_method='sinc_interp_kaiser',
109
+ beta=14.769656459379492,
110
+ )
111
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
112
+
113
+ if audio_chunk.shape[0] < self.num_samples:
114
+ raise ValueError('Audio is too short')
115
+ audio_chunk = audio_chunk[:self.num_samples]
116
+
117
+ tokens = self.tokenizer([caption])[0]
118
+
119
+ output = {
120
+ 'waveform': audio_chunk,
121
+ 'id': audio_id,
122
+ 'caption': caption,
123
+ 'tokens': tokens,
124
+ }
125
+
126
+ return output
127
+ except Exception as e:
128
+ log.error(f'Error reading {audio_path}: {e}')
129
+ return None
130
+
131
+ def __len__(self):
132
+ return len(self.clips)
mmaudio/data/mm_dataset.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+
3
+ import torch
4
+ from torch.utils.data.dataset import Dataset
5
+
6
+
7
+ # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
8
+ class MultiModalDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+
12
+ @staticmethod
13
+ def cumsum(sequence):
14
+ r, s = [], 0
15
+ for e in sequence:
16
+ l = len(e)
17
+ r.append(l + s)
18
+ s += l
19
+ return r
20
+
21
+ def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
22
+ super().__init__()
23
+ self.video_datasets = list(video_datasets)
24
+ self.audio_datasets = list(audio_datasets)
25
+ self.datasets = self.video_datasets + self.audio_datasets
26
+
27
+ self.cumulative_sizes = self.cumsum(self.datasets)
28
+
29
+ def __len__(self):
30
+ return self.cumulative_sizes[-1]
31
+
32
+ def __getitem__(self, idx):
33
+ if idx < 0:
34
+ if -idx > len(self):
35
+ raise ValueError("absolute value of index should not exceed dataset length")
36
+ idx = len(self) + idx
37
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
38
+ if dataset_idx == 0:
39
+ sample_idx = idx
40
+ else:
41
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
42
+ return self.datasets[dataset_idx][sample_idx]
43
+
44
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
45
+ return self.video_datasets[0].compute_latent_stats()
mmaudio/data/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tensordict import MemoryMappedTensor
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.dataset import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from mmaudio.utils.dist_utils import local_rank, world_size
16
+
17
+ scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
18
+ shm_path = Path('/dev/shm')
19
+
20
+ log = logging.getLogger()
21
+
22
+
23
+ def reseed(seed):
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+
28
+ def local_scatter_torch(obj: Optional[Any]):
29
+ if world_size == 1:
30
+ # Just one worker. Do nothing.
31
+ return obj
32
+
33
+ array = [obj] * world_size
34
+ target_array = [None]
35
+ if local_rank == 0:
36
+ dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
37
+ else:
38
+ dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
39
+ return target_array[0]
40
+
41
+
42
+ class ShardDataset(Dataset):
43
+
44
+ def __init__(self, root):
45
+ self.root = root
46
+ self.shards = sorted(os.listdir(root))
47
+
48
+ def __len__(self):
49
+ return len(self.shards)
50
+
51
+ def __getitem__(self, idx):
52
+ return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
53
+
54
+
55
+ def get_tmp_dir(in_memory: bool) -> Path:
56
+ return shm_path if in_memory else scratch_path
57
+
58
+
59
+ def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
60
+ in_memory: bool) -> MemoryMappedTensor:
61
+ if local_rank == 0:
62
+ with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
63
+ log.info(f'Loading shards from {data_path} into {f.name}...')
64
+ data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
65
+ data = share_tensor_to_all(data)
66
+ torch.distributed.barrier()
67
+ f.close() # why does the context manager not close the file for me?
68
+ else:
69
+ log.info('Waiting for the data to be shared with me...')
70
+ data = share_tensor_to_all(None)
71
+ torch.distributed.barrier()
72
+
73
+ return data
74
+
75
+
76
+ def load_shards(
77
+ data_path: Union[str, Path],
78
+ ids: list[int],
79
+ *,
80
+ tmp_file_path: str,
81
+ ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
82
+
83
+ id_set = set(ids)
84
+ shards = sorted(os.listdir(data_path))
85
+ log.info(f'Found {len(shards)} shards in {data_path}.')
86
+ first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
87
+
88
+ log.info(f'Rank {local_rank} created file {tmp_file_path}')
89
+ first_item = next(iter(first_shard.values()))
90
+ log.info(f'First item shape: {first_item.shape}')
91
+ mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
92
+ dtype=torch.float32,
93
+ filename=tmp_file_path,
94
+ existsok=True)
95
+ total_count = 0
96
+ used_index = set()
97
+ id_indexing = {i: idx for idx, i in enumerate(ids)}
98
+ # faster with no workers; otherwise we need to set_sharing_strategy('file_system')
99
+ loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
100
+ for data in tqdm(loader, desc='Loading shards'):
101
+ for i, v in data.items():
102
+ if i not in id_set:
103
+ continue
104
+
105
+ # tensor_index = ids.index(i)
106
+ tensor_index = id_indexing[i]
107
+ if tensor_index in used_index:
108
+ raise ValueError(f'Duplicate id {i} found in {data_path}.')
109
+ used_index.add(tensor_index)
110
+ mm_tensor[tensor_index] = v
111
+ total_count += 1
112
+
113
+ assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
114
+ log.info(f'Loaded {total_count} tensors from {data_path}.')
115
+
116
+ return mm_tensor
117
+
118
+
119
+ def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
120
+ """
121
+ x: the tensor to be shared; None if local_rank != 0
122
+ return: the shared tensor
123
+ """
124
+
125
+ # there is no need to share your stuff with anyone if you are alone; must be in memory
126
+ if world_size == 1:
127
+ return x
128
+
129
+ if local_rank == 0:
130
+ assert x is not None, 'x must not be None if local_rank == 0'
131
+ else:
132
+ assert x is None, 'x must be None if local_rank != 0'
133
+
134
+ if local_rank == 0:
135
+ filename = x.filename
136
+ meta_information = (filename, x.shape, x.dtype)
137
+ else:
138
+ meta_information = None
139
+
140
+ filename, data_shape, data_type = local_scatter_torch(meta_information)
141
+ if local_rank == 0:
142
+ data = x
143
+ else:
144
+ data = MemoryMappedTensor.from_filename(filename=filename,
145
+ dtype=data_type,
146
+ shape=data_shape)
147
+
148
+ return data
mmaudio/eval_utils.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from colorlog import ColoredFormatter
9
+ from PIL import Image
10
+ from torchvision.transforms import v2
11
+
12
+ from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
13
+ from mmaudio.model.flow_matching import FlowMatching
14
+ from mmaudio.model.networks import MMAudio
15
+ from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
16
+ from mmaudio.model.utils.features_utils import FeaturesUtils
17
+ from mmaudio.utils.download_utils import download_model_if_needed
18
+
19
+ log = logging.getLogger()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class ModelConfig:
24
+ model_name: str
25
+ model_path: Path
26
+ vae_path: Path
27
+ bigvgan_16k_path: Optional[Path]
28
+ mode: str
29
+ synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
30
+
31
+ @property
32
+ def seq_cfg(self) -> SequenceConfig:
33
+ if self.mode == '16k':
34
+ return CONFIG_16K
35
+ elif self.mode == '44k':
36
+ return CONFIG_44K
37
+
38
+ def download_if_needed(self):
39
+ download_model_if_needed(self.model_path)
40
+ download_model_if_needed(self.vae_path)
41
+ if self.bigvgan_16k_path is not None:
42
+ download_model_if_needed(self.bigvgan_16k_path)
43
+ download_model_if_needed(self.synchformer_ckpt)
44
+
45
+
46
+ small_16k = ModelConfig(model_name='small_16k',
47
+ model_path=Path('./weights/mmaudio_small_16k.pth'),
48
+ vae_path=Path('./ext_weights/v1-16.pth'),
49
+ bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
50
+ mode='16k')
51
+ small_44k = ModelConfig(model_name='small_44k',
52
+ model_path=Path('./weights/mmaudio_small_44k.pth'),
53
+ vae_path=Path('./ext_weights/v1-44.pth'),
54
+ bigvgan_16k_path=None,
55
+ mode='44k')
56
+ medium_44k = ModelConfig(model_name='medium_44k',
57
+ model_path=Path('./weights/mmaudio_medium_44k.pth'),
58
+ vae_path=Path('./ext_weights/v1-44.pth'),
59
+ bigvgan_16k_path=None,
60
+ mode='44k')
61
+ large_44k = ModelConfig(model_name='large_44k',
62
+ model_path=Path('./weights/mmaudio_large_44k.pth'),
63
+ vae_path=Path('./ext_weights/v1-44.pth'),
64
+ bigvgan_16k_path=None,
65
+ mode='44k')
66
+ large_44k_v2 = ModelConfig(model_name='large_44k_v2',
67
+ model_path=Path('./weights/mmaudio_large_44k_v2.pth'),
68
+ vae_path=Path('./ext_weights/v1-44.pth'),
69
+ bigvgan_16k_path=None,
70
+ mode='44k')
71
+ all_model_cfg: dict[str, ModelConfig] = {
72
+ 'small_16k': small_16k,
73
+ 'small_44k': small_44k,
74
+ 'medium_44k': medium_44k,
75
+ 'large_44k': large_44k,
76
+ 'large_44k_v2': large_44k_v2,
77
+ }
78
+
79
+
80
+ def generate(
81
+ clip_video: Optional[torch.Tensor],
82
+ sync_video: Optional[torch.Tensor],
83
+ text: Optional[list[str]],
84
+ *,
85
+ negative_text: Optional[list[str]] = None,
86
+ feature_utils: FeaturesUtils,
87
+ net: MMAudio,
88
+ fm: FlowMatching,
89
+ rng: torch.Generator,
90
+ cfg_strength: float,
91
+ clip_batch_size_multiplier: int = 40,
92
+ sync_batch_size_multiplier: int = 40,
93
+ image_input: bool = False,
94
+ ) -> torch.Tensor:
95
+ device = feature_utils.device
96
+ dtype = feature_utils.dtype
97
+
98
+ bs = len(text)
99
+ if clip_video is not None:
100
+ clip_video = clip_video.to(device, dtype, non_blocking=True)
101
+ clip_features = feature_utils.encode_video_with_clip(clip_video,
102
+ batch_size=bs *
103
+ clip_batch_size_multiplier)
104
+ if image_input:
105
+ clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
106
+ else:
107
+ clip_features = net.get_empty_clip_sequence(bs)
108
+
109
+ if sync_video is not None and not image_input:
110
+ sync_video = sync_video.to(device, dtype, non_blocking=True)
111
+ sync_features = feature_utils.encode_video_with_sync(sync_video,
112
+ batch_size=bs *
113
+ sync_batch_size_multiplier)
114
+ else:
115
+ sync_features = net.get_empty_sync_sequence(bs)
116
+
117
+ if text is not None:
118
+ text_features = feature_utils.encode_text(text)
119
+ else:
120
+ text_features = net.get_empty_string_sequence(bs)
121
+
122
+ if negative_text is not None:
123
+ assert len(negative_text) == bs
124
+ negative_text_features = feature_utils.encode_text(negative_text)
125
+ else:
126
+ negative_text_features = net.get_empty_string_sequence(bs)
127
+
128
+ x0 = torch.randn(bs,
129
+ net.latent_seq_len,
130
+ net.latent_dim,
131
+ device=device,
132
+ dtype=dtype,
133
+ generator=rng)
134
+ preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features)
135
+ empty_conditions = net.get_empty_conditions(
136
+ bs, negative_text_features=negative_text_features if negative_text is not None else None)
137
+
138
+ cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
139
+ cfg_strength)
140
+ x1 = fm.to_data(cfg_ode_wrapper, x0)
141
+ x1 = net.unnormalize(x1)
142
+ spec = feature_utils.decode(x1)
143
+ audio = feature_utils.vocode(spec)
144
+ return audio
145
+
146
+
147
+ LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
148
+
149
+
150
+ def setup_eval_logging(log_level: int = logging.INFO):
151
+ logging.root.setLevel(log_level)
152
+ formatter = ColoredFormatter(LOGFORMAT)
153
+ stream = logging.StreamHandler()
154
+ stream.setLevel(log_level)
155
+ stream.setFormatter(formatter)
156
+ log = logging.getLogger()
157
+ log.setLevel(log_level)
158
+ log.addHandler(stream)
159
+
160
+
161
+ _CLIP_SIZE = 384
162
+ _CLIP_FPS = 8.0
163
+
164
+ _SYNC_SIZE = 224
165
+ _SYNC_FPS = 25.0
166
+
167
+
168
+ def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
169
+
170
+ clip_transform = v2.Compose([
171
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
172
+ v2.ToImage(),
173
+ v2.ToDtype(torch.float32, scale=True),
174
+ ])
175
+
176
+ sync_transform = v2.Compose([
177
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
178
+ v2.CenterCrop(_SYNC_SIZE),
179
+ v2.ToImage(),
180
+ v2.ToDtype(torch.float32, scale=True),
181
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
182
+ ])
183
+
184
+ output_frames, all_frames, orig_fps = read_frames(video_path,
185
+ list_of_fps=[_CLIP_FPS, _SYNC_FPS],
186
+ start_sec=0,
187
+ end_sec=duration_sec,
188
+ need_all_frames=load_all_frames)
189
+
190
+ clip_chunk, sync_chunk = output_frames
191
+ clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
192
+ sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
193
+
194
+ clip_frames = clip_transform(clip_chunk)
195
+ sync_frames = sync_transform(sync_chunk)
196
+
197
+ clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
198
+ sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
199
+
200
+ if clip_length_sec < duration_sec:
201
+ log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
202
+ log.warning(f'Truncating to {clip_length_sec:.2f} sec')
203
+ duration_sec = clip_length_sec
204
+
205
+ if sync_length_sec < duration_sec:
206
+ log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
207
+ log.warning(f'Truncating to {sync_length_sec:.2f} sec')
208
+ duration_sec = sync_length_sec
209
+
210
+ clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
211
+ sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
212
+
213
+ video_info = VideoInfo(
214
+ duration_sec=duration_sec,
215
+ fps=orig_fps,
216
+ clip_frames=clip_frames,
217
+ sync_frames=sync_frames,
218
+ all_frames=all_frames if load_all_frames else None,
219
+ )
220
+ return video_info
221
+
222
+
223
+ def load_image(image_path: Path) -> VideoInfo:
224
+ clip_transform = v2.Compose([
225
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
226
+ v2.ToImage(),
227
+ v2.ToDtype(torch.float32, scale=True),
228
+ ])
229
+
230
+ sync_transform = v2.Compose([
231
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
232
+ v2.CenterCrop(_SYNC_SIZE),
233
+ v2.ToImage(),
234
+ v2.ToDtype(torch.float32, scale=True),
235
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
236
+ ])
237
+
238
+ frame = np.array(Image.open(image_path))
239
+
240
+ clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
241
+ sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
242
+
243
+ clip_frames = clip_transform(clip_chunk)
244
+ sync_frames = sync_transform(sync_chunk)
245
+
246
+ video_info = ImageInfo(
247
+ clip_frames=clip_frames,
248
+ sync_frames=sync_frames,
249
+ original_frame=frame,
250
+ )
251
+ return video_info
252
+
253
+
254
+ def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
255
+ reencode_with_audio(video_info, output_path, audio, sampling_rate)
mmaudio/ext/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mmaudio/ext/autoencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoEncoderModule
mmaudio/ext/autoencoder/autoencoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from mmaudio.ext.autoencoder.vae import VAE, get_my_vae
7
+ from mmaudio.ext.bigvgan import BigVGAN
8
+ from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
9
+ from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
10
+
11
+
12
+ class AutoEncoderModule(nn.Module):
13
+
14
+ def __init__(self,
15
+ *,
16
+ vae_ckpt_path,
17
+ vocoder_ckpt_path: Optional[str] = None,
18
+ mode: Literal['16k', '44k'],
19
+ need_vae_encoder: bool = True):
20
+ super().__init__()
21
+ self.vae: VAE = get_my_vae(mode).eval()
22
+ vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
23
+ self.vae.load_state_dict(vae_state_dict)
24
+ self.vae.remove_weight_norm()
25
+
26
+ if mode == '16k':
27
+ assert vocoder_ckpt_path is not None
28
+ self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
29
+ elif mode == '44k':
30
+ self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
31
+ use_cuda_kernel=False)
32
+ self.vocoder.remove_weight_norm()
33
+ else:
34
+ raise ValueError(f'Unknown mode: {mode}')
35
+
36
+ for param in self.parameters():
37
+ param.requires_grad = False
38
+
39
+ if not need_vae_encoder:
40
+ del self.vae.encoder
41
+
42
+ @torch.inference_mode()
43
+ def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
44
+ return self.vae.encode(x)
45
+
46
+ @torch.inference_mode()
47
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
48
+ return self.vae.decode(z)
49
+
50
+ @torch.inference_mode()
51
+ def vocode(self, spec: torch.Tensor) -> torch.Tensor:
52
+ return self.vocoder(spec)
mmaudio/ext/autoencoder/edm2_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+ """Improved diffusion model architecture proposed in the paper
8
+ "Analyzing and Improving the Training Dynamics of Diffusion Models"."""
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+ # Variant of constant() that inherits dtype and device from the given
15
+ # reference tensor by default.
16
+
17
+ _constant_cache = dict()
18
+
19
+
20
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
21
+ value = np.asarray(value)
22
+ if shape is not None:
23
+ shape = tuple(shape)
24
+ if dtype is None:
25
+ dtype = torch.get_default_dtype()
26
+ if device is None:
27
+ device = torch.device('cpu')
28
+ if memory_format is None:
29
+ memory_format = torch.contiguous_format
30
+
31
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
32
+ tensor = _constant_cache.get(key, None)
33
+ if tensor is None:
34
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
35
+ if shape is not None:
36
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
37
+ tensor = tensor.contiguous(memory_format=memory_format)
38
+ _constant_cache[key] = tensor
39
+ return tensor
40
+
41
+
42
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
43
+ if dtype is None:
44
+ dtype = ref.dtype
45
+ if device is None:
46
+ device = ref.device
47
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
48
+
49
+
50
+ #----------------------------------------------------------------------------
51
+ # Normalize given tensor to unit magnitude with respect to the given
52
+ # dimensions. Default = all dimensions except the first.
53
+
54
+
55
+ def normalize(x, dim=None, eps=1e-4):
56
+ if dim is None:
57
+ dim = list(range(1, x.ndim))
58
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
59
+ norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
60
+ return x / norm.to(x.dtype)
61
+
62
+
63
+ class Normalize(torch.nn.Module):
64
+
65
+ def __init__(self, dim=None, eps=1e-4):
66
+ super().__init__()
67
+ self.dim = dim
68
+ self.eps = eps
69
+
70
+ def forward(self, x):
71
+ return normalize(x, dim=self.dim, eps=self.eps)
72
+
73
+
74
+ #----------------------------------------------------------------------------
75
+ # Upsample or downsample the given tensor with the given filter,
76
+ # or keep it as is.
77
+
78
+
79
+ def resample(x, f=[1, 1], mode='keep'):
80
+ if mode == 'keep':
81
+ return x
82
+ f = np.float32(f)
83
+ assert f.ndim == 1 and len(f) % 2 == 0
84
+ pad = (len(f) - 1) // 2
85
+ f = f / f.sum()
86
+ f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
87
+ f = const_like(x, f)
88
+ c = x.shape[1]
89
+ if mode == 'down':
90
+ return torch.nn.functional.conv2d(x,
91
+ f.tile([c, 1, 1, 1]),
92
+ groups=c,
93
+ stride=2,
94
+ padding=(pad, ))
95
+ assert mode == 'up'
96
+ return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
97
+ groups=c,
98
+ stride=2,
99
+ padding=(pad, ))
100
+
101
+
102
+ #----------------------------------------------------------------------------
103
+ # Magnitude-preserving SiLU (Equation 81).
104
+
105
+
106
+ def mp_silu(x):
107
+ return torch.nn.functional.silu(x) / 0.596
108
+
109
+
110
+ class MPSiLU(torch.nn.Module):
111
+
112
+ def forward(self, x):
113
+ return mp_silu(x)
114
+
115
+
116
+ #----------------------------------------------------------------------------
117
+ # Magnitude-preserving sum (Equation 88).
118
+
119
+
120
+ def mp_sum(a, b, t=0.5):
121
+ return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
122
+
123
+
124
+ #----------------------------------------------------------------------------
125
+ # Magnitude-preserving concatenation (Equation 103).
126
+
127
+
128
+ def mp_cat(a, b, dim=1, t=0.5):
129
+ Na = a.shape[dim]
130
+ Nb = b.shape[dim]
131
+ C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
132
+ wa = C / np.sqrt(Na) * (1 - t)
133
+ wb = C / np.sqrt(Nb) * t
134
+ return torch.cat([wa * a, wb * b], dim=dim)
135
+
136
+
137
+ #----------------------------------------------------------------------------
138
+ # Magnitude-preserving convolution or fully-connected layer (Equation 47)
139
+ # with force weight normalization (Equation 66).
140
+
141
+
142
+ class MPConv1D(torch.nn.Module):
143
+
144
+ def __init__(self, in_channels, out_channels, kernel_size):
145
+ super().__init__()
146
+ self.out_channels = out_channels
147
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
148
+
149
+ self.weight_norm_removed = False
150
+
151
+ def forward(self, x, gain=1):
152
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
153
+
154
+ w = self.weight * gain
155
+ if w.ndim == 2:
156
+ return x @ w.t()
157
+ assert w.ndim == 3
158
+ return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
159
+
160
+ def remove_weight_norm(self):
161
+ w = self.weight.to(torch.float32)
162
+ w = normalize(w) # traditional weight normalization
163
+ w = w / np.sqrt(w[0].numel())
164
+ w = w.to(self.weight.dtype)
165
+ self.weight.data.copy_(w)
166
+
167
+ self.weight_norm_removed = True
168
+ return self
mmaudio/ext/autoencoder/vae.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from mmaudio.ext.autoencoder.edm2_utils import MPConv1D
8
+ from mmaudio.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
9
+ Upsample1D, nonlinearity)
10
+ from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
11
+
12
+ log = logging.getLogger()
13
+
14
+ DATA_MEAN_80D = [
15
+ -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
16
+ -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
17
+ -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
18
+ -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
19
+ -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
20
+ -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
21
+ -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
22
+ -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
23
+ ]
24
+
25
+ DATA_STD_80D = [
26
+ 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
27
+ 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
28
+ 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
29
+ 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
30
+ 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
31
+ 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
32
+ 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
33
+ ]
34
+
35
+ DATA_MEAN_128D = [
36
+ -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
37
+ -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
38
+ -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
39
+ -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
40
+ -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
41
+ -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
42
+ -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
43
+ -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
44
+ -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
45
+ -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
46
+ -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
47
+ -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
48
+ -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
49
+ ]
50
+
51
+ DATA_STD_128D = [
52
+ 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
53
+ 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
54
+ 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
55
+ 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
56
+ 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
57
+ 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
58
+ 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
59
+ 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
60
+ 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
61
+ 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
62
+ 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
63
+ ]
64
+
65
+
66
+ class VAE(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ *,
71
+ data_dim: int,
72
+ embed_dim: int,
73
+ hidden_dim: int,
74
+ ):
75
+ super().__init__()
76
+
77
+ if data_dim == 80:
78
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
79
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
80
+ elif data_dim == 128:
81
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
82
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
83
+
84
+ self.data_mean = self.data_mean.view(1, -1, 1)
85
+ self.data_std = self.data_std.view(1, -1, 1)
86
+
87
+ self.encoder = Encoder1D(
88
+ dim=hidden_dim,
89
+ ch_mult=(1, 2, 4),
90
+ num_res_blocks=2,
91
+ attn_layers=[3],
92
+ down_layers=[0],
93
+ in_dim=data_dim,
94
+ embed_dim=embed_dim,
95
+ )
96
+ self.decoder = Decoder1D(
97
+ dim=hidden_dim,
98
+ ch_mult=(1, 2, 4),
99
+ num_res_blocks=2,
100
+ attn_layers=[3],
101
+ down_layers=[0],
102
+ in_dim=data_dim,
103
+ out_dim=data_dim,
104
+ embed_dim=embed_dim,
105
+ )
106
+
107
+ self.embed_dim = embed_dim
108
+ # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
109
+ # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
110
+
111
+ self.initialize_weights()
112
+
113
+ def initialize_weights(self):
114
+ pass
115
+
116
+ def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
117
+ if normalize:
118
+ x = self.normalize(x)
119
+ moments = self.encoder(x)
120
+ posterior = DiagonalGaussianDistribution(moments)
121
+ return posterior
122
+
123
+ def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
124
+ dec = self.decoder(z)
125
+ if unnormalize:
126
+ dec = self.unnormalize(dec)
127
+ return dec
128
+
129
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
130
+ return (x - self.data_mean) / self.data_std
131
+
132
+ def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
133
+ return x * self.data_std + self.data_mean
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ sample_posterior: bool = True,
139
+ rng: Optional[torch.Generator] = None,
140
+ normalize: bool = True,
141
+ unnormalize: bool = True,
142
+ ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
143
+
144
+ posterior = self.encode(x, normalize=normalize)
145
+ if sample_posterior:
146
+ z = posterior.sample(rng)
147
+ else:
148
+ z = posterior.mode()
149
+ dec = self.decode(z, unnormalize=unnormalize)
150
+ return dec, posterior
151
+
152
+ def load_weights(self, src_dict) -> None:
153
+ self.load_state_dict(src_dict, strict=True)
154
+
155
+ @property
156
+ def device(self) -> torch.device:
157
+ return next(self.parameters()).device
158
+
159
+ def get_last_layer(self):
160
+ return self.decoder.conv_out.weight
161
+
162
+ def remove_weight_norm(self):
163
+ for name, m in self.named_modules():
164
+ if isinstance(m, MPConv1D):
165
+ m.remove_weight_norm()
166
+ log.debug(f"Removed weight norm from {name}")
167
+ return self
168
+
169
+
170
+ class Encoder1D(nn.Module):
171
+
172
+ def __init__(self,
173
+ *,
174
+ dim: int,
175
+ ch_mult: tuple[int] = (1, 2, 4, 8),
176
+ num_res_blocks: int,
177
+ attn_layers: list[int] = [],
178
+ down_layers: list[int] = [],
179
+ resamp_with_conv: bool = True,
180
+ in_dim: int,
181
+ embed_dim: int,
182
+ double_z: bool = True,
183
+ kernel_size: int = 3,
184
+ clip_act: float = 256.0):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.num_layers = len(ch_mult)
188
+ self.num_res_blocks = num_res_blocks
189
+ self.in_channels = in_dim
190
+ self.clip_act = clip_act
191
+ self.down_layers = down_layers
192
+ self.attn_layers = attn_layers
193
+ self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
194
+
195
+ in_ch_mult = (1, ) + tuple(ch_mult)
196
+ self.in_ch_mult = in_ch_mult
197
+ # downsampling
198
+ self.down = nn.ModuleList()
199
+ for i_level in range(self.num_layers):
200
+ block = nn.ModuleList()
201
+ attn = nn.ModuleList()
202
+ block_in = dim * in_ch_mult[i_level]
203
+ block_out = dim * ch_mult[i_level]
204
+ for i_block in range(self.num_res_blocks):
205
+ block.append(
206
+ ResnetBlock1D(in_dim=block_in,
207
+ out_dim=block_out,
208
+ kernel_size=kernel_size,
209
+ use_norm=True))
210
+ block_in = block_out
211
+ if i_level in attn_layers:
212
+ attn.append(AttnBlock1D(block_in))
213
+ down = nn.Module()
214
+ down.block = block
215
+ down.attn = attn
216
+ if i_level in down_layers:
217
+ down.downsample = Downsample1D(block_in, resamp_with_conv)
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
223
+ out_dim=block_in,
224
+ kernel_size=kernel_size,
225
+ use_norm=True)
226
+ self.mid.attn_1 = AttnBlock1D(block_in)
227
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
228
+ out_dim=block_in,
229
+ kernel_size=kernel_size,
230
+ use_norm=True)
231
+
232
+ # end
233
+ self.conv_out = MPConv1D(block_in,
234
+ 2 * embed_dim if double_z else embed_dim,
235
+ kernel_size=kernel_size)
236
+
237
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
238
+
239
+ def forward(self, x):
240
+
241
+ # downsampling
242
+ hs = [self.conv_in(x)]
243
+ for i_level in range(self.num_layers):
244
+ for i_block in range(self.num_res_blocks):
245
+ h = self.down[i_level].block[i_block](hs[-1])
246
+ if len(self.down[i_level].attn) > 0:
247
+ h = self.down[i_level].attn[i_block](h)
248
+ h = h.clamp(-self.clip_act, self.clip_act)
249
+ hs.append(h)
250
+ if i_level in self.down_layers:
251
+ hs.append(self.down[i_level].downsample(hs[-1]))
252
+
253
+ # middle
254
+ h = hs[-1]
255
+ h = self.mid.block_1(h)
256
+ h = self.mid.attn_1(h)
257
+ h = self.mid.block_2(h)
258
+ h = h.clamp(-self.clip_act, self.clip_act)
259
+
260
+ # end
261
+ h = nonlinearity(h)
262
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
263
+ return h
264
+
265
+
266
+ class Decoder1D(nn.Module):
267
+
268
+ def __init__(self,
269
+ *,
270
+ dim: int,
271
+ out_dim: int,
272
+ ch_mult: tuple[int] = (1, 2, 4, 8),
273
+ num_res_blocks: int,
274
+ attn_layers: list[int] = [],
275
+ down_layers: list[int] = [],
276
+ kernel_size: int = 3,
277
+ resamp_with_conv: bool = True,
278
+ in_dim: int,
279
+ embed_dim: int,
280
+ clip_act: float = 256.0):
281
+ super().__init__()
282
+ self.ch = dim
283
+ self.num_layers = len(ch_mult)
284
+ self.num_res_blocks = num_res_blocks
285
+ self.in_channels = in_dim
286
+ self.clip_act = clip_act
287
+ self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
288
+
289
+ # compute in_ch_mult, block_in and curr_res at lowest res
290
+ block_in = dim * ch_mult[self.num_layers - 1]
291
+
292
+ # z to block_in
293
+ self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
294
+
295
+ # middle
296
+ self.mid = nn.Module()
297
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
298
+ self.mid.attn_1 = AttnBlock1D(block_in)
299
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
300
+
301
+ # upsampling
302
+ self.up = nn.ModuleList()
303
+ for i_level in reversed(range(self.num_layers)):
304
+ block = nn.ModuleList()
305
+ attn = nn.ModuleList()
306
+ block_out = dim * ch_mult[i_level]
307
+ for i_block in range(self.num_res_blocks + 1):
308
+ block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
309
+ block_in = block_out
310
+ if i_level in attn_layers:
311
+ attn.append(AttnBlock1D(block_in))
312
+ up = nn.Module()
313
+ up.block = block
314
+ up.attn = attn
315
+ if i_level in self.down_layers:
316
+ up.upsample = Upsample1D(block_in, resamp_with_conv)
317
+ self.up.insert(0, up) # prepend to get consistent order
318
+
319
+ # end
320
+ self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
321
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
322
+
323
+ def forward(self, z):
324
+ # z to block_in
325
+ h = self.conv_in(z)
326
+
327
+ # middle
328
+ h = self.mid.block_1(h)
329
+ h = self.mid.attn_1(h)
330
+ h = self.mid.block_2(h)
331
+ h = h.clamp(-self.clip_act, self.clip_act)
332
+
333
+ # upsampling
334
+ for i_level in reversed(range(self.num_layers)):
335
+ for i_block in range(self.num_res_blocks + 1):
336
+ h = self.up[i_level].block[i_block](h)
337
+ if len(self.up[i_level].attn) > 0:
338
+ h = self.up[i_level].attn[i_block](h)
339
+ h = h.clamp(-self.clip_act, self.clip_act)
340
+ if i_level in self.down_layers:
341
+ h = self.up[i_level].upsample(h)
342
+
343
+ h = nonlinearity(h)
344
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
345
+ return h
346
+
347
+
348
+ def VAE_16k(**kwargs) -> VAE:
349
+ return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
350
+
351
+
352
+ def VAE_44k(**kwargs) -> VAE:
353
+ return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
354
+
355
+
356
+ def get_my_vae(name: str, **kwargs) -> VAE:
357
+ if name == '16k':
358
+ return VAE_16k(**kwargs)
359
+ if name == '44k':
360
+ return VAE_44k(**kwargs)
361
+ raise ValueError(f'Unknown model: {name}')
362
+
363
+
364
+ if __name__ == '__main__':
365
+ network = get_my_vae('standard')
366
+
367
+ # print the number of parameters in terms of millions
368
+ num_params = sum(p.numel() for p in network.parameters()) / 1e6
369
+ print(f'Number of parameters: {num_params:.2f}M')