Spaces:
Sleeping
Sleeping
Phil Sobrepena
commited on
Commit
·
73ed896
1
Parent(s):
ddb444b
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +146 -0
- Dockerfile +39 -0
- LICENSE +21 -0
- README.md +56 -0
- app.py +128 -0
- batch_eval.py +110 -0
- config/__init__.py +0 -0
- config/base_config.yaml +62 -0
- config/data/base.yaml +70 -0
- config/eval_config.yaml +17 -0
- config/eval_data/base.yaml +22 -0
- config/hydra/job_logging/custom-eval.yaml +32 -0
- config/hydra/job_logging/custom-no-rank.yaml +32 -0
- config/hydra/job_logging/custom-simplest.yaml +26 -0
- config/hydra/job_logging/custom.yaml +33 -0
- config/train_config.yaml +41 -0
- demo.py +141 -0
- docs/EVAL.md +22 -0
- docs/MODELS.md +50 -0
- docs/TRAINING.md +160 -0
- docs/images/icon.png +0 -0
- docs/index.html +149 -0
- docs/style.css +78 -0
- docs/style_videos.css +52 -0
- docs/video_gen.html +254 -0
- docs/video_main.html +98 -0
- docs/video_vgg.html +452 -0
- gitattributes +35 -0
- gradio_demo.py +343 -0
- mmaudio/__init__.py +0 -0
- mmaudio/data/__init__.py +0 -0
- mmaudio/data/av_utils.py +162 -0
- mmaudio/data/data_setup.py +174 -0
- mmaudio/data/eval/__init__.py +0 -0
- mmaudio/data/eval/audiocaps.py +39 -0
- mmaudio/data/eval/moviegen.py +131 -0
- mmaudio/data/eval/video_dataset.py +197 -0
- mmaudio/data/extracted_audio.py +88 -0
- mmaudio/data/extracted_vgg.py +101 -0
- mmaudio/data/extraction/__init__.py +0 -0
- mmaudio/data/extraction/vgg_sound.py +193 -0
- mmaudio/data/extraction/wav_dataset.py +132 -0
- mmaudio/data/mm_dataset.py +45 -0
- mmaudio/data/utils.py +148 -0
- mmaudio/eval_utils.py +255 -0
- mmaudio/ext/__init__.py +1 -0
- mmaudio/ext/autoencoder/__init__.py +1 -0
- mmaudio/ext/autoencoder/autoencoder.py +52 -0
- mmaudio/ext/autoencoder/edm2_utils.py +168 -0
- mmaudio/ext/autoencoder/vae.py +369 -0
.gitignore
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
run_*.sh
|
2 |
+
log/
|
3 |
+
saves
|
4 |
+
saves/
|
5 |
+
weights/
|
6 |
+
weights
|
7 |
+
output/
|
8 |
+
output
|
9 |
+
pretrained/
|
10 |
+
workspace
|
11 |
+
workspace/
|
12 |
+
ext_weights/
|
13 |
+
ext_weights
|
14 |
+
.checkpoints/
|
15 |
+
.vscode/
|
16 |
+
training/example_output/
|
17 |
+
|
18 |
+
# Byte-compiled / optimized / DLL files
|
19 |
+
__pycache__/
|
20 |
+
*.py[cod]
|
21 |
+
*$py.class
|
22 |
+
|
23 |
+
# C extensions
|
24 |
+
*.so
|
25 |
+
|
26 |
+
# Distribution / packaging
|
27 |
+
.Python
|
28 |
+
build/
|
29 |
+
develop-eggs/
|
30 |
+
dist/
|
31 |
+
downloads/
|
32 |
+
eggs/
|
33 |
+
.eggs/
|
34 |
+
lib/
|
35 |
+
lib64/
|
36 |
+
parts/
|
37 |
+
sdist/
|
38 |
+
var/
|
39 |
+
wheels/
|
40 |
+
pip-wheel-metadata/
|
41 |
+
share/python-wheels/
|
42 |
+
*.egg-info/
|
43 |
+
.installed.cfg
|
44 |
+
*.egg
|
45 |
+
MANIFEST
|
46 |
+
|
47 |
+
# PyInstaller
|
48 |
+
# Usually these files are written by a python script from a template
|
49 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
50 |
+
*.manifest
|
51 |
+
*.spec
|
52 |
+
|
53 |
+
# Installer logs
|
54 |
+
pip-log.txt
|
55 |
+
pip-delete-this-directory.txt
|
56 |
+
|
57 |
+
# Unit test / coverage reports
|
58 |
+
htmlcov/
|
59 |
+
.tox/
|
60 |
+
.nox/
|
61 |
+
.coverage
|
62 |
+
.coverage.*
|
63 |
+
.cache
|
64 |
+
nosetests.xml
|
65 |
+
coverage.xml
|
66 |
+
*.cover
|
67 |
+
*.py,cover
|
68 |
+
.hypothesis/
|
69 |
+
.pytest_cache/
|
70 |
+
|
71 |
+
# Translations
|
72 |
+
*.mo
|
73 |
+
*.pot
|
74 |
+
|
75 |
+
# Django stuff:
|
76 |
+
*.log
|
77 |
+
local_settings.py
|
78 |
+
db.sqlite3
|
79 |
+
db.sqlite3-journal
|
80 |
+
|
81 |
+
# Flask stuff:
|
82 |
+
instance/
|
83 |
+
.webassets-cache
|
84 |
+
|
85 |
+
# Scrapy stuff:
|
86 |
+
.scrapy
|
87 |
+
|
88 |
+
# Sphinx documentation
|
89 |
+
docs/_build/
|
90 |
+
|
91 |
+
# PyBuilder
|
92 |
+
target/
|
93 |
+
|
94 |
+
# Jupyter Notebook
|
95 |
+
.ipynb_checkpoints
|
96 |
+
|
97 |
+
# IPython
|
98 |
+
profile_default/
|
99 |
+
ipython_config.py
|
100 |
+
|
101 |
+
# pyenv
|
102 |
+
.python-version
|
103 |
+
|
104 |
+
# pipenv
|
105 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
106 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
107 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
108 |
+
# install all needed dependencies.
|
109 |
+
#Pipfile.lock
|
110 |
+
|
111 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
112 |
+
__pypackages__/
|
113 |
+
|
114 |
+
# Celery stuff
|
115 |
+
celerybeat-schedule
|
116 |
+
celerybeat.pid
|
117 |
+
|
118 |
+
# SageMath parsed files
|
119 |
+
*.sage.py
|
120 |
+
|
121 |
+
# Environments
|
122 |
+
.env
|
123 |
+
.venv
|
124 |
+
env/
|
125 |
+
venv/
|
126 |
+
ENV/
|
127 |
+
env.bak/
|
128 |
+
venv.bak/
|
129 |
+
|
130 |
+
# Spyder project settings
|
131 |
+
.spyderproject
|
132 |
+
.spyproject
|
133 |
+
|
134 |
+
# Rope project settings
|
135 |
+
.ropeproject
|
136 |
+
|
137 |
+
# mkdocs documentation
|
138 |
+
/site
|
139 |
+
|
140 |
+
# mypy
|
141 |
+
.mypy_cache/
|
142 |
+
.dmypy.json
|
143 |
+
dmypy.json
|
144 |
+
|
145 |
+
# Pyre type checker
|
146 |
+
.pyre/
|
Dockerfile
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.8.0-runtime-ubuntu22.04
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
python3.9 \
|
8 |
+
python3-pip \
|
9 |
+
git \
|
10 |
+
ffmpeg \
|
11 |
+
libsm6 \
|
12 |
+
libxext6 \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
# Install Python dependencies
|
16 |
+
COPY requirements.txt .
|
17 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
18 |
+
|
19 |
+
# Clone and install MMAudio
|
20 |
+
RUN git clone https://github.com/hkchengrex/MMAudio.git && \
|
21 |
+
cd MMAudio && \
|
22 |
+
pip3 install -e .
|
23 |
+
|
24 |
+
# Copy the application files
|
25 |
+
COPY app.py .
|
26 |
+
|
27 |
+
# Create output directory
|
28 |
+
RUN mkdir -p output/gradio && chmod 777 output/gradio
|
29 |
+
|
30 |
+
# Set environment variables for Hugging Face Spaces
|
31 |
+
ENV PYTHONUNBUFFERED=1
|
32 |
+
ENV GRADIO_SERVER_NAME=0.0.0.0
|
33 |
+
ENV GRADIO_SERVER_PORT=7860
|
34 |
+
|
35 |
+
# Expose the port
|
36 |
+
EXPOSE 7860
|
37 |
+
|
38 |
+
# Run the Gradio app
|
39 |
+
CMD ["python3", "app.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Sony Research Inc.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Sonisphere
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.20.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
+
|
14 |
+
# Sonisphere Demo
|
15 |
+
|
16 |
+
This is a Hugging Face Spaces demo for [MMAudio](https://hkchengrex.com/MMAudio/), a powerful model for generating realistic audio for videos.
|
17 |
+
|
18 |
+
## 🎥 Features
|
19 |
+
|
20 |
+
- Upload any video and generate matching audio
|
21 |
+
- Control the generation with text prompts
|
22 |
+
- Adjust generation parameters like steps and guidance strength
|
23 |
+
- Process videos up to 30 seconds in length
|
24 |
+
|
25 |
+
## 🚀 Usage
|
26 |
+
|
27 |
+
1. Upload a video or use one of the example videos
|
28 |
+
2. Enter a text prompt describing the desired audio
|
29 |
+
3. (Optional) Add a negative prompt to specify what you don't want
|
30 |
+
4. Adjust the generation parameters if needed
|
31 |
+
5. Click "Submit" and wait for the generation to complete
|
32 |
+
|
33 |
+
## ⚙️ Parameters
|
34 |
+
|
35 |
+
- **Prompt**: Describe the audio you want to generate
|
36 |
+
- **Negative prompt**: Specify what you don't want in the audio (default: "music")
|
37 |
+
- **Seed**: Control randomness (-1 for random seed)
|
38 |
+
- **Number of steps**: More steps = better quality but slower (default: 25)
|
39 |
+
- **Guidance Strength**: Controls how closely the generation follows the prompt (default: 4.5)
|
40 |
+
- **Duration**: Length of the generated audio in seconds (default: 8)
|
41 |
+
|
42 |
+
## 📝 Notes
|
43 |
+
|
44 |
+
- Processing high-resolution videos (>384px on shorter side) takes longer and doesn't improve results
|
45 |
+
- The model works best with videos between 5-30 seconds
|
46 |
+
- Generation time depends on video length and number of steps
|
47 |
+
|
48 |
+
## 🔗 Links
|
49 |
+
|
50 |
+
- [Project Page](https://hkchengrex.com/MMAudio/)
|
51 |
+
- [GitHub Repository](https://github.com/hkchengrex/MMAudio)
|
52 |
+
- [Paper](https://arxiv.org/abs/2401.09774)
|
53 |
+
|
54 |
+
## 📜 License
|
55 |
+
|
56 |
+
This demo uses the MMAudio model which is released under the [MIT license](https://github.com/hkchengrex/MMAudio/blob/main/LICENSE).
|
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
from datetime import datetime
|
4 |
+
from fractions import Fraction
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
|
11 |
+
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
|
12 |
+
load_video, make_video, setup_eval_logging)
|
13 |
+
from mmaudio.model.flow_matching import FlowMatching
|
14 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
15 |
+
from mmaudio.model.sequence_config import SequenceConfig
|
16 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
17 |
+
|
18 |
+
# Setup logging
|
19 |
+
setup_eval_logging()
|
20 |
+
log = logging.getLogger()
|
21 |
+
|
22 |
+
# Configure device and dtype
|
23 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
24 |
+
if device == 'cpu':
|
25 |
+
log.warning('CUDA is not available, running on CPU')
|
26 |
+
dtype = torch.bfloat16
|
27 |
+
|
28 |
+
# Configure model and paths
|
29 |
+
model: ModelConfig = all_model_cfg['large_44k_v2']
|
30 |
+
model.download_if_needed()
|
31 |
+
output_dir = Path('./output/gradio')
|
32 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
33 |
+
|
34 |
+
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
35 |
+
seq_cfg = model.seq_cfg
|
36 |
+
|
37 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
38 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
39 |
+
log.info(f'Loaded weights from {model.model_path}')
|
40 |
+
|
41 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
42 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
43 |
+
enable_conditions=True,
|
44 |
+
mode=model.mode,
|
45 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
46 |
+
need_vae_encoder=False)
|
47 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
48 |
+
|
49 |
+
return net, feature_utils, seq_cfg
|
50 |
+
|
51 |
+
# Load model once at startup
|
52 |
+
net, feature_utils, seq_cfg = get_model()
|
53 |
+
|
54 |
+
@torch.inference_mode()
|
55 |
+
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
56 |
+
cfg_strength: float, duration: float):
|
57 |
+
try:
|
58 |
+
rng = torch.Generator(device=device)
|
59 |
+
if seed >= 0:
|
60 |
+
rng.manual_seed(seed)
|
61 |
+
else:
|
62 |
+
rng.seed()
|
63 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
64 |
+
|
65 |
+
video_info = load_video(video, duration)
|
66 |
+
clip_frames = video_info.clip_frames.unsqueeze(0)
|
67 |
+
sync_frames = video_info.sync_frames.unsqueeze(0)
|
68 |
+
duration = video_info.duration_sec
|
69 |
+
|
70 |
+
seq_cfg.duration = duration
|
71 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
72 |
+
|
73 |
+
audios = generate(clip_frames, sync_frames, [prompt],
|
74 |
+
negative_text=[negative_prompt],
|
75 |
+
feature_utils=feature_utils,
|
76 |
+
net=net,
|
77 |
+
fm=fm,
|
78 |
+
rng=rng,
|
79 |
+
cfg_strength=cfg_strength)
|
80 |
+
audio = audios.float().cpu()[0]
|
81 |
+
|
82 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
83 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
84 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
85 |
+
|
86 |
+
gc.collect()
|
87 |
+
torch.cuda.empty_cache()
|
88 |
+
|
89 |
+
return video_save_path
|
90 |
+
except Exception as e:
|
91 |
+
log.error(f"Error in video_to_audio: {str(e)}")
|
92 |
+
raise gr.Error(f"An error occurred: {str(e)}")
|
93 |
+
|
94 |
+
# Create the Gradio interface
|
95 |
+
demo = gr.Interface(
|
96 |
+
fn=video_to_audio,
|
97 |
+
title="MMAudio — Video-to-Audio Synthesis",
|
98 |
+
description="""
|
99 |
+
Generate realistic audio for your videos using MMAudio!
|
100 |
+
|
101 |
+
Project page: [MMAudio](https://hkchengrex.com/MMAudio/)
|
102 |
+
Code: [GitHub](https://github.com/hkchengrex/MMAudio)
|
103 |
+
|
104 |
+
Note: Processing high-resolution videos (>384px on shorter side) takes longer and doesn't improve results.
|
105 |
+
""",
|
106 |
+
inputs=[
|
107 |
+
gr.Video(label="Upload Video"),
|
108 |
+
gr.Text(label="Prompt", placeholder="Describe the audio you want to generate..."),
|
109 |
+
gr.Text(label="Negative prompt", value="music", placeholder="What you don't want in the audio..."),
|
110 |
+
gr.Number(label="Seed (-1: random)", value=-1, precision=0, minimum=-1),
|
111 |
+
gr.Number(label="Number of steps", value=25, precision=0, minimum=1),
|
112 |
+
gr.Slider(label="Guidance Strength", value=4.5, minimum=1, maximum=10, step=0.5),
|
113 |
+
gr.Slider(label="Duration (seconds)", value=8, minimum=1, maximum=30, step=1),
|
114 |
+
],
|
115 |
+
outputs=gr.Video(label="Generated Result"),
|
116 |
+
examples=[
|
117 |
+
["https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4",
|
118 |
+
"waves, seagulls", "", 0, 25, 4.5, 10],
|
119 |
+
["https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4",
|
120 |
+
"", "music", 0, 25, 4.5, 10],
|
121 |
+
],
|
122 |
+
cache_examples=True,
|
123 |
+
)
|
124 |
+
|
125 |
+
# Launch the app
|
126 |
+
if __name__ == "__main__":
|
127 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
128 |
+
|
batch_eval.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import hydra
|
6 |
+
import torch
|
7 |
+
import torch.distributed as distributed
|
8 |
+
import torchaudio
|
9 |
+
from hydra.core.hydra_config import HydraConfig
|
10 |
+
from omegaconf import DictConfig
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from mmaudio.data.data_setup import setup_eval_dataset
|
14 |
+
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
|
15 |
+
from mmaudio.model.flow_matching import FlowMatching
|
16 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
17 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
+
|
19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
+
torch.backends.cudnn.allow_tf32 = True
|
21 |
+
|
22 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
23 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
24 |
+
log = logging.getLogger()
|
25 |
+
|
26 |
+
|
27 |
+
@torch.inference_mode()
|
28 |
+
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
|
29 |
+
def main(cfg: DictConfig):
|
30 |
+
device = 'cuda'
|
31 |
+
torch.cuda.set_device(local_rank)
|
32 |
+
|
33 |
+
if cfg.model not in all_model_cfg:
|
34 |
+
raise ValueError(f'Unknown model variant: {cfg.model}')
|
35 |
+
model: ModelConfig = all_model_cfg[cfg.model]
|
36 |
+
model.download_if_needed()
|
37 |
+
seq_cfg = model.seq_cfg
|
38 |
+
|
39 |
+
run_dir = Path(HydraConfig.get().run.dir)
|
40 |
+
if cfg.output_name is None:
|
41 |
+
output_dir = run_dir / cfg.dataset
|
42 |
+
else:
|
43 |
+
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
|
44 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
45 |
+
|
46 |
+
# load a pretrained model
|
47 |
+
seq_cfg.duration = cfg.duration_s
|
48 |
+
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
|
49 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
50 |
+
log.info(f'Loaded weights from {model.model_path}')
|
51 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
52 |
+
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
|
53 |
+
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
|
54 |
+
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')
|
55 |
+
|
56 |
+
# misc setup
|
57 |
+
rng = torch.Generator(device=device)
|
58 |
+
rng.manual_seed(cfg.seed)
|
59 |
+
fm = FlowMatching(cfg.sampling.min_sigma,
|
60 |
+
inference_mode=cfg.sampling.method,
|
61 |
+
num_steps=cfg.sampling.num_steps)
|
62 |
+
|
63 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
64 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
65 |
+
enable_conditions=True,
|
66 |
+
mode=model.mode,
|
67 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
68 |
+
need_vae_encoder=False)
|
69 |
+
feature_utils = feature_utils.to(device).eval()
|
70 |
+
|
71 |
+
if cfg.compile:
|
72 |
+
net.preprocess_conditions = torch.compile(net.preprocess_conditions)
|
73 |
+
net.predict_flow = torch.compile(net.predict_flow)
|
74 |
+
feature_utils.compile()
|
75 |
+
|
76 |
+
dataset, loader = setup_eval_dataset(cfg.dataset, cfg)
|
77 |
+
|
78 |
+
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
|
79 |
+
for batch in tqdm(loader):
|
80 |
+
audios = generate(batch.get('clip_video', None),
|
81 |
+
batch.get('sync_video', None),
|
82 |
+
batch.get('caption', None),
|
83 |
+
feature_utils=feature_utils,
|
84 |
+
net=net,
|
85 |
+
fm=fm,
|
86 |
+
rng=rng,
|
87 |
+
cfg_strength=cfg.cfg_strength,
|
88 |
+
clip_batch_size_multiplier=64,
|
89 |
+
sync_batch_size_multiplier=64)
|
90 |
+
audios = audios.float().cpu()
|
91 |
+
names = batch['name']
|
92 |
+
for audio, name in zip(audios, names):
|
93 |
+
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)
|
94 |
+
|
95 |
+
|
96 |
+
def distributed_setup():
|
97 |
+
distributed.init_process_group(backend="nccl")
|
98 |
+
local_rank = distributed.get_rank()
|
99 |
+
world_size = distributed.get_world_size()
|
100 |
+
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
|
101 |
+
return local_rank, world_size
|
102 |
+
|
103 |
+
|
104 |
+
if __name__ == '__main__':
|
105 |
+
distributed_setup()
|
106 |
+
|
107 |
+
main()
|
108 |
+
|
109 |
+
# clean-up
|
110 |
+
distributed.destroy_process_group()
|
config/__init__.py
ADDED
File without changes
|
config/base_config.yaml
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- data: base
|
3 |
+
- eval_data: base
|
4 |
+
- override hydra/job_logging: custom-simplest
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
hydra:
|
8 |
+
run:
|
9 |
+
dir: ./output/${exp_id}
|
10 |
+
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
+
|
12 |
+
enable_email: False
|
13 |
+
|
14 |
+
model: small_16k
|
15 |
+
|
16 |
+
exp_id: default
|
17 |
+
debug: False
|
18 |
+
cudnn_benchmark: True
|
19 |
+
compile: True
|
20 |
+
amp: True
|
21 |
+
weights: null
|
22 |
+
checkpoint: null
|
23 |
+
seed: 14159265
|
24 |
+
num_workers: 10 # per-GPU
|
25 |
+
pin_memory: False # set to True if your system can handle it, i.e., have enough memory
|
26 |
+
|
27 |
+
# NOTE: This DOSE NOT affect the model during inference in any way
|
28 |
+
# they are just for the dataloader to fill in the missing data in multi-modal loading
|
29 |
+
# to change the sequence length for the model, see networks.py
|
30 |
+
data_dim:
|
31 |
+
text_seq_len: 77
|
32 |
+
clip_dim: 1024
|
33 |
+
sync_dim: 768
|
34 |
+
text_dim: 1024
|
35 |
+
|
36 |
+
# ema configuration
|
37 |
+
ema:
|
38 |
+
enable: True
|
39 |
+
sigma_rels: [0.05, 0.1]
|
40 |
+
update_every: 1
|
41 |
+
checkpoint_every: 5_000
|
42 |
+
checkpoint_folder: ${hydra:run.dir}/ema_ckpts
|
43 |
+
default_output_sigma: 0.05
|
44 |
+
|
45 |
+
|
46 |
+
# sampling
|
47 |
+
sampling:
|
48 |
+
mean: 0.0
|
49 |
+
scale: 1.0
|
50 |
+
min_sigma: 0.0
|
51 |
+
method: euler
|
52 |
+
num_steps: 25
|
53 |
+
|
54 |
+
# classifier-free guidance
|
55 |
+
null_condition_probability: 0.1
|
56 |
+
cfg_strength: 4.5
|
57 |
+
|
58 |
+
# checkpoint paths to external modules
|
59 |
+
vae_16k_ckpt: ./ext_weights/v1-16.pth
|
60 |
+
vae_44k_ckpt: ./ext_weights/v1-44.pth
|
61 |
+
bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
|
62 |
+
synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
|
config/data/base.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
VGGSound:
|
2 |
+
root: ../data/video
|
3 |
+
subset_name: sets/vgg3-train.tsv
|
4 |
+
fps: 8
|
5 |
+
height: 384
|
6 |
+
width: 384
|
7 |
+
sample_duration_sec: 8.0
|
8 |
+
|
9 |
+
VGGSound_test:
|
10 |
+
root: ../data/video
|
11 |
+
subset_name: sets/vgg3-test.tsv
|
12 |
+
fps: 8
|
13 |
+
height: 384
|
14 |
+
width: 384
|
15 |
+
sample_duration_sec: 8.0
|
16 |
+
|
17 |
+
VGGSound_val:
|
18 |
+
root: ../data/video
|
19 |
+
subset_name: sets/vgg3-val.tsv
|
20 |
+
fps: 8
|
21 |
+
height: 384
|
22 |
+
width: 384
|
23 |
+
sample_duration_sec: 8.0
|
24 |
+
|
25 |
+
ExtractedVGG:
|
26 |
+
tsv: ../data/v1-16-memmap/vgg-train.tsv
|
27 |
+
memmap_dir: ../data/v1-16-memmap/vgg-train
|
28 |
+
|
29 |
+
ExtractedVGG_test:
|
30 |
+
tag: test
|
31 |
+
gt_cache: ../data/eval-cache/vggsound-test
|
32 |
+
output_subdir: null
|
33 |
+
tsv: ../data/v1-16-memmap/vgg-test.tsv
|
34 |
+
memmap_dir: ../data/v1-16-memmap/vgg-test
|
35 |
+
|
36 |
+
ExtractedVGG_val:
|
37 |
+
tag: val
|
38 |
+
gt_cache: ../data/eval-cache/vggsound-val
|
39 |
+
output_subdir: val
|
40 |
+
tsv: ../data/v1-16-memmap/vgg-val.tsv
|
41 |
+
memmap_dir: ../data/v1-16-memmap/vgg-val
|
42 |
+
|
43 |
+
AudioCaps:
|
44 |
+
tsv: ../data/v1-16-memmap/audiocaps.tsv
|
45 |
+
memmap_dir: ../data/v1-16-memmap/audiocaps
|
46 |
+
|
47 |
+
AudioSetSL:
|
48 |
+
tsv: ../data/v1-16-memmap/audioset_sl.tsv
|
49 |
+
memmap_dir: ../data/v1-16-memmap/audioset_sl
|
50 |
+
|
51 |
+
BBCSound:
|
52 |
+
tsv: ../data/v1-16-memmap/bbcsound.tsv
|
53 |
+
memmap_dir: ../data/v1-16-memmap/bbcsound
|
54 |
+
|
55 |
+
FreeSound:
|
56 |
+
tsv: ../data/v1-16-memmap/freesound.tsv
|
57 |
+
memmap_dir: ../data/v1-16-memmap/freesound
|
58 |
+
|
59 |
+
Clotho:
|
60 |
+
tsv: ../data/v1-16-memmap/clotho.tsv
|
61 |
+
memmap_dir: ../data/v1-16-memmap/clotho
|
62 |
+
|
63 |
+
Example_video:
|
64 |
+
tsv: ./training/example_output/memmap/vgg-example.tsv
|
65 |
+
memmap_dir: ./training/example_output/memmap/vgg-example
|
66 |
+
|
67 |
+
Example_audio:
|
68 |
+
tsv: ./training/example_output/memmap/audio-example.tsv
|
69 |
+
memmap_dir: ./training/example_output/memmap/audio-example
|
70 |
+
|
config/eval_config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- base_config
|
3 |
+
- override hydra/job_logging: custom-simplest
|
4 |
+
- _self_
|
5 |
+
|
6 |
+
hydra:
|
7 |
+
run:
|
8 |
+
dir: ./output/${exp_id}
|
9 |
+
output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
10 |
+
|
11 |
+
exp_id: ${model}
|
12 |
+
dataset: audiocaps
|
13 |
+
duration_s: 8.0
|
14 |
+
|
15 |
+
# for inference, this is the per-GPU batch size
|
16 |
+
batch_size: 16
|
17 |
+
output_name: null
|
config/eval_data/base.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioCaps:
|
2 |
+
audio_path: ../data/AudioCaps-test-audioldm-ver
|
3 |
+
# a csv file, with a header row of 'name' and 'caption'
|
4 |
+
# name should match the audio file name without extension
|
5 |
+
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_audioldm_data.csv
|
6 |
+
csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv
|
7 |
+
|
8 |
+
AudioCaps_full:
|
9 |
+
audio_path: ../data/AudioCaps-test-full-ver
|
10 |
+
# a csv file, with a header row of 'name' and 'caption'
|
11 |
+
# name should match the audio file name without extension
|
12 |
+
# Can be downloaded here: https://github.com/hkchengrex/MMAudio/releases/download/v0.1/AudioCaps_full_data.csv
|
13 |
+
csv_path: ../data/AudioCaps-test-full-ver/data.csv
|
14 |
+
|
15 |
+
MovieGen:
|
16 |
+
video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
|
17 |
+
jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata
|
18 |
+
|
19 |
+
VGGSound:
|
20 |
+
video_path: ../data/test-videos
|
21 |
+
# from the officially released csv file
|
22 |
+
csv_path: ../data/vggsound.csv
|
config/hydra/job_logging/custom-eval.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
colorlog:
|
8 |
+
'()': 'colorlog.ColoredFormatter'
|
9 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
+
log_colors:
|
12 |
+
DEBUG: purple
|
13 |
+
INFO: green
|
14 |
+
WARNING: yellow
|
15 |
+
ERROR: red
|
16 |
+
CRITICAL: red
|
17 |
+
handlers:
|
18 |
+
console:
|
19 |
+
class: logging.StreamHandler
|
20 |
+
formatter: colorlog
|
21 |
+
stream: ext://sys.stdout
|
22 |
+
file:
|
23 |
+
class: logging.FileHandler
|
24 |
+
formatter: simple
|
25 |
+
# absolute file path
|
26 |
+
filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
27 |
+
mode: w
|
28 |
+
root:
|
29 |
+
level: INFO
|
30 |
+
handlers: [console, file]
|
31 |
+
|
32 |
+
disable_existing_loggers: false
|
config/hydra/job_logging/custom-no-rank.yaml
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
colorlog:
|
8 |
+
'()': 'colorlog.ColoredFormatter'
|
9 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
+
log_colors:
|
12 |
+
DEBUG: purple
|
13 |
+
INFO: green
|
14 |
+
WARNING: yellow
|
15 |
+
ERROR: red
|
16 |
+
CRITICAL: red
|
17 |
+
handlers:
|
18 |
+
console:
|
19 |
+
class: logging.StreamHandler
|
20 |
+
formatter: colorlog
|
21 |
+
stream: ext://sys.stdout
|
22 |
+
file:
|
23 |
+
class: logging.FileHandler
|
24 |
+
formatter: simple
|
25 |
+
# absolute file path
|
26 |
+
filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
|
27 |
+
mode: w
|
28 |
+
root:
|
29 |
+
level: INFO
|
30 |
+
handlers: [console, file]
|
31 |
+
|
32 |
+
disable_existing_loggers: false
|
config/hydra/job_logging/custom-simplest.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python logging configuration for tasks
|
2 |
+
version: 1
|
3 |
+
formatters:
|
4 |
+
simple:
|
5 |
+
format: '[%(asctime)s][%(levelname)s] - %(message)s'
|
6 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
7 |
+
colorlog:
|
8 |
+
'()': 'colorlog.ColoredFormatter'
|
9 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
10 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
11 |
+
log_colors:
|
12 |
+
DEBUG: purple
|
13 |
+
INFO: green
|
14 |
+
WARNING: yellow
|
15 |
+
ERROR: red
|
16 |
+
CRITICAL: red
|
17 |
+
handlers:
|
18 |
+
console:
|
19 |
+
class: logging.StreamHandler
|
20 |
+
formatter: colorlog
|
21 |
+
stream: ext://sys.stdout
|
22 |
+
root:
|
23 |
+
level: INFO
|
24 |
+
handlers: [console]
|
25 |
+
|
26 |
+
disable_existing_loggers: false
|
config/hydra/job_logging/custom.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package hydra.job_logging
|
2 |
+
# python logging configuration for tasks
|
3 |
+
version: 1
|
4 |
+
formatters:
|
5 |
+
simple:
|
6 |
+
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
|
7 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
8 |
+
colorlog:
|
9 |
+
'()': 'colorlog.ColoredFormatter'
|
10 |
+
format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
|
11 |
+
datefmt: '%Y-%m-%d %H:%M:%S'
|
12 |
+
log_colors:
|
13 |
+
DEBUG: purple
|
14 |
+
INFO: green
|
15 |
+
WARNING: yellow
|
16 |
+
ERROR: red
|
17 |
+
CRITICAL: red
|
18 |
+
handlers:
|
19 |
+
console:
|
20 |
+
class: logging.StreamHandler
|
21 |
+
formatter: colorlog
|
22 |
+
stream: ext://sys.stdout
|
23 |
+
file:
|
24 |
+
class: logging.FileHandler
|
25 |
+
formatter: simple
|
26 |
+
# absolute file path
|
27 |
+
filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
|
28 |
+
mode: w
|
29 |
+
root:
|
30 |
+
level: INFO
|
31 |
+
handlers: [console, file]
|
32 |
+
|
33 |
+
disable_existing_loggers: false
|
config/train_config.yaml
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- base_config
|
3 |
+
- override data: base
|
4 |
+
- override hydra/job_logging: custom
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
hydra:
|
8 |
+
run:
|
9 |
+
dir: ./output/${exp_id}
|
10 |
+
output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
|
11 |
+
|
12 |
+
ema:
|
13 |
+
start: 0
|
14 |
+
|
15 |
+
mini_train: False
|
16 |
+
example_train: False
|
17 |
+
enable_grad_scaler: False
|
18 |
+
vgg_oversample_rate: 5
|
19 |
+
|
20 |
+
log_text_interval: 200
|
21 |
+
log_extra_interval: 20_000
|
22 |
+
val_interval: 5_000
|
23 |
+
eval_interval: 20_000
|
24 |
+
save_eval_interval: 40_000
|
25 |
+
save_weights_interval: 10_000
|
26 |
+
save_checkpoint_interval: 10_000
|
27 |
+
save_copy_iterations: []
|
28 |
+
|
29 |
+
batch_size: 512
|
30 |
+
eval_batch_size: 256 # per-GPU
|
31 |
+
|
32 |
+
num_iterations: 300_000
|
33 |
+
learning_rate: 1.0e-4
|
34 |
+
linear_warmup_steps: 1_000
|
35 |
+
|
36 |
+
lr_schedule: step
|
37 |
+
lr_schedule_steps: [240_000, 270_000]
|
38 |
+
lr_schedule_gamma: 0.1
|
39 |
+
|
40 |
+
clip_grad_norm: 1.0
|
41 |
+
weight_decay: 1.0e-6
|
demo.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torchaudio
|
7 |
+
|
8 |
+
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
9 |
+
setup_eval_logging)
|
10 |
+
from mmaudio.model.flow_matching import FlowMatching
|
11 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
12 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
13 |
+
|
14 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
15 |
+
torch.backends.cudnn.allow_tf32 = True
|
16 |
+
|
17 |
+
log = logging.getLogger()
|
18 |
+
|
19 |
+
|
20 |
+
@torch.inference_mode()
|
21 |
+
def main():
|
22 |
+
setup_eval_logging()
|
23 |
+
|
24 |
+
parser = ArgumentParser()
|
25 |
+
parser.add_argument('--variant',
|
26 |
+
type=str,
|
27 |
+
default='large_44k_v2',
|
28 |
+
help='small_16k, small_44k, medium_44k, large_44k, large_44k_v2')
|
29 |
+
parser.add_argument('--video', type=Path, help='Path to the video file')
|
30 |
+
parser.add_argument('--prompt', type=str, help='Input prompt', default='')
|
31 |
+
parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
|
32 |
+
parser.add_argument('--duration', type=float, default=8.0)
|
33 |
+
parser.add_argument('--cfg_strength', type=float, default=4.5)
|
34 |
+
parser.add_argument('--num_steps', type=int, default=25)
|
35 |
+
|
36 |
+
parser.add_argument('--mask_away_clip', action='store_true')
|
37 |
+
|
38 |
+
parser.add_argument('--output', type=Path, help='Output directory', default='./output')
|
39 |
+
parser.add_argument('--seed', type=int, help='Random seed', default=42)
|
40 |
+
parser.add_argument('--skip_video_composite', action='store_true')
|
41 |
+
parser.add_argument('--full_precision', action='store_true')
|
42 |
+
|
43 |
+
args = parser.parse_args()
|
44 |
+
|
45 |
+
if args.variant not in all_model_cfg:
|
46 |
+
raise ValueError(f'Unknown model variant: {args.variant}')
|
47 |
+
model: ModelConfig = all_model_cfg[args.variant]
|
48 |
+
model.download_if_needed()
|
49 |
+
seq_cfg = model.seq_cfg
|
50 |
+
|
51 |
+
if args.video:
|
52 |
+
video_path: Path = Path(args.video).expanduser()
|
53 |
+
else:
|
54 |
+
video_path = None
|
55 |
+
prompt: str = args.prompt
|
56 |
+
negative_prompt: str = args.negative_prompt
|
57 |
+
output_dir: str = args.output.expanduser()
|
58 |
+
seed: int = args.seed
|
59 |
+
num_steps: int = args.num_steps
|
60 |
+
duration: float = args.duration
|
61 |
+
cfg_strength: float = args.cfg_strength
|
62 |
+
skip_video_composite: bool = args.skip_video_composite
|
63 |
+
mask_away_clip: bool = args.mask_away_clip
|
64 |
+
|
65 |
+
device = 'cpu'
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
device = 'cuda'
|
68 |
+
elif torch.backends.mps.is_available():
|
69 |
+
device = 'mps'
|
70 |
+
else:
|
71 |
+
log.warning('CUDA/MPS are not available, running on CPU')
|
72 |
+
dtype = torch.float32 if args.full_precision else torch.bfloat16
|
73 |
+
|
74 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
75 |
+
|
76 |
+
# load a pretrained model
|
77 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
78 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
79 |
+
log.info(f'Loaded weights from {model.model_path}')
|
80 |
+
|
81 |
+
# misc setup
|
82 |
+
rng = torch.Generator(device=device)
|
83 |
+
rng.manual_seed(seed)
|
84 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
85 |
+
|
86 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
87 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
88 |
+
enable_conditions=True,
|
89 |
+
mode=model.mode,
|
90 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
91 |
+
need_vae_encoder=False)
|
92 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
93 |
+
|
94 |
+
if video_path is not None:
|
95 |
+
log.info(f'Using video {video_path}')
|
96 |
+
video_info = load_video(video_path, duration)
|
97 |
+
clip_frames = video_info.clip_frames
|
98 |
+
sync_frames = video_info.sync_frames
|
99 |
+
duration = video_info.duration_sec
|
100 |
+
if mask_away_clip:
|
101 |
+
clip_frames = None
|
102 |
+
else:
|
103 |
+
clip_frames = clip_frames.unsqueeze(0)
|
104 |
+
sync_frames = sync_frames.unsqueeze(0)
|
105 |
+
else:
|
106 |
+
log.info('No video provided -- text-to-audio mode')
|
107 |
+
clip_frames = sync_frames = None
|
108 |
+
|
109 |
+
seq_cfg.duration = duration
|
110 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
111 |
+
|
112 |
+
log.info(f'Prompt: {prompt}')
|
113 |
+
log.info(f'Negative prompt: {negative_prompt}')
|
114 |
+
|
115 |
+
audios = generate(clip_frames,
|
116 |
+
sync_frames, [prompt],
|
117 |
+
negative_text=[negative_prompt],
|
118 |
+
feature_utils=feature_utils,
|
119 |
+
net=net,
|
120 |
+
fm=fm,
|
121 |
+
rng=rng,
|
122 |
+
cfg_strength=cfg_strength)
|
123 |
+
audio = audios.float().cpu()[0]
|
124 |
+
if video_path is not None:
|
125 |
+
save_path = output_dir / f'{video_path.stem}.flac'
|
126 |
+
else:
|
127 |
+
safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
|
128 |
+
save_path = output_dir / f'{safe_filename}.flac'
|
129 |
+
torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
|
130 |
+
|
131 |
+
log.info(f'Audio saved to {save_path}')
|
132 |
+
if video_path is not None and not skip_video_composite:
|
133 |
+
video_save_path = output_dir / f'{video_path.stem}.mp4'
|
134 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
135 |
+
log.info(f'Video saved to {output_dir / video_save_path}')
|
136 |
+
|
137 |
+
log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
|
138 |
+
|
139 |
+
|
140 |
+
if __name__ == '__main__':
|
141 |
+
main()
|
docs/EVAL.md
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluation
|
2 |
+
|
3 |
+
## Batch Evaluation
|
4 |
+
|
5 |
+
To evaluate the model on a dataset, use the `batch_eval.py` script. It is significantly more efficient in large-scale evaluation compared to `demo.py`, supporting batched inference, multi-GPU inference, torch compilation, and skipping video compositions.
|
6 |
+
|
7 |
+
An example of running this script with four GPUs is as follows:
|
8 |
+
|
9 |
+
```bash
|
10 |
+
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=4 batch_eval.py duration_s=8 dataset=vggsound model=small_16k num_workers=8
|
11 |
+
```
|
12 |
+
|
13 |
+
You may need to update the data paths in `config/eval_data/base.yaml`.
|
14 |
+
More configuration options can be found in `config/base_config.yaml` and `config/eval_config.yaml`.
|
15 |
+
|
16 |
+
## Precomputed Results
|
17 |
+
|
18 |
+
Precomputed results for VGGSound, AudioCaps, and MovieGen are available here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results
|
19 |
+
|
20 |
+
## Obtaining Quantitative Metrics
|
21 |
+
|
22 |
+
Our evaluation code is available here: https://github.com/hkchengrex/av-benchmark
|
docs/MODELS.md
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pretrained models
|
2 |
+
|
3 |
+
The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
|
4 |
+
The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main
|
5 |
+
|
6 |
+
| Model | Download link | File size |
|
7 |
+
| -------- | ------- | ------- |
|
8 |
+
| Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
|
9 |
+
| Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
|
10 |
+
| Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
|
11 |
+
| Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
|
12 |
+
| Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
|
13 |
+
| 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
|
14 |
+
| 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
|
15 |
+
| 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
|
16 |
+
| Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |
|
17 |
+
|
18 |
+
To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
|
19 |
+
The 44.1kHz vocoder will be downloaded automatically.
|
20 |
+
The `_v2` model performs worse in benchmarking (e.g., in Fréchet distance), but, in my experience, generalizes better to new data.
|
21 |
+
|
22 |
+
The expected directory structure (full):
|
23 |
+
|
24 |
+
```bash
|
25 |
+
MMAudio
|
26 |
+
├── ext_weights
|
27 |
+
│ ├── best_netG.pt
|
28 |
+
│ ├── synchformer_state_dict.pth
|
29 |
+
│ ├── v1-16.pth
|
30 |
+
│ └── v1-44.pth
|
31 |
+
├── weights
|
32 |
+
│ ├── mmaudio_small_16k.pth
|
33 |
+
│ ├── mmaudio_small_44k.pth
|
34 |
+
│ ├── mmaudio_medium_44k.pth
|
35 |
+
│ ├── mmaudio_large_44k.pth
|
36 |
+
│ └── mmaudio_large_44k_v2.pth
|
37 |
+
└── ...
|
38 |
+
```
|
39 |
+
|
40 |
+
The expected directory structure (minimal, for the recommended model only):
|
41 |
+
|
42 |
+
```bash
|
43 |
+
MMAudio
|
44 |
+
├── ext_weights
|
45 |
+
│ ├── synchformer_state_dict.pth
|
46 |
+
│ └── v1-44.pth
|
47 |
+
├── weights
|
48 |
+
│ └── mmaudio_large_44k_v2.pth
|
49 |
+
└── ...
|
50 |
+
```
|
docs/TRAINING.md
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
We have put a large emphasis on making training as fast as possible.
|
6 |
+
Consequently, some pre-processing steps are required.
|
7 |
+
|
8 |
+
Namely, before starting any training, we
|
9 |
+
|
10 |
+
1. Obtain training data as videos, audios, and captions.
|
11 |
+
2. Encode training audios into spectrograms and then with VAE into mean/std
|
12 |
+
3. Extract CLIP and synchronization features from videos
|
13 |
+
4. Extract CLIP features from text (captions)
|
14 |
+
5. Encode all extracted features into [MemoryMappedTensors](https://pytorch.org/tensordict/main/reference/generated/tensordict.MemoryMappedTensor.html) with [TensorDict](https://pytorch.org/tensordict/main/reference/tensordict.html)
|
15 |
+
|
16 |
+
**NOTE:** for maximum training speed (e.g., when training the base model with 2*H100s), you would need around 3~5 GB/s of random read speed. Spinning disks would not be able to catch up and most consumer-grade SSDs would struggle. In my experience, the best bet is to have a large enough system memory such that the OS can cache the data. This way, the data is read from RAM instead of disk.
|
17 |
+
|
18 |
+
The current training script does not support `_v2` training.
|
19 |
+
|
20 |
+
## Recommended Hardware Configuration
|
21 |
+
|
22 |
+
These are what I recommend for a smooth and efficient training experience. These are not minimum requirements.
|
23 |
+
|
24 |
+
- Single-node machine. We did not implement multi-node training
|
25 |
+
- GPUs: for the small model, two 80G-H100s or above; for the large model, eight 80G-H100s or above
|
26 |
+
- System memory: for 16kHz training, 600GB+; for 44kHz training, 700GB+
|
27 |
+
- Storage: >2TB of fast NVMe storage. If you have enough system memory, OS caching will help and the storage does not need to be as fast.
|
28 |
+
|
29 |
+
## Prerequisites
|
30 |
+
|
31 |
+
1. Install [av-benchmark](https://github.com/hkchengrex/av-benchmark). We use this library to automatically evaluate on the validation set during training, and on the test set after training.
|
32 |
+
2. Extract features for evaluation using [av-benchmark](https://github.com/hkchengrex/av-benchmark) for the validation and test set as a [validation cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L38) and a [test cache](https://github.com/hkchengrex/MMAudio/blob/34bf089fdd2e457cd5ef33be96c0e1c8a0412476/config/data/base.yaml#L31). You can also download the precomputed evaluation cache [here](https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results/tree/main).
|
33 |
+
|
34 |
+
3. You will need ffmpeg to extract frames from videos. Note that `torchaudio` imposes a maximum version limit (`ffmpeg<7`). You can install it as follows:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
conda install -c conda-forge 'ffmpeg<7'
|
38 |
+
```
|
39 |
+
|
40 |
+
4. Download the training datasets. We used [VGGSound](https://arxiv.org/abs/2004.14368), [AudioCaps](https://audiocaps.github.io/), and [WavCaps](https://arxiv.org/abs/2303.17395). Note that the audio files in the huggingface release of WavCaps have been downsampled to 32kHz. To the best of our ability, we located the original (high-sampling rate) audio files and used them instead to prevent artifacts during 44.1kHz training. We did not use the "SoundBible" portion of WavCaps, since it is a small set with many short audio unsuitable for our training.
|
41 |
+
|
42 |
+
5. Download the corresponding VAE (`v1-16.pth` for 16kHz training, and `v1-44.pth` for 44.1kHz training), vocoder models (`best_netG.pt` for 16kHz training; the vocoder for 44.1kHz training will be downloaded automatically), the [empty string encoding](https://github.com/hkchengrex/MMAudio/releases/download/v0.1/empty_string.pth), and Synchformer weights from [MODELS.md](https://github.com/hkchengrex/MMAudio/blob/main/docs/MODELS.md) place them in `ext_weights/`.
|
43 |
+
|
44 |
+
## Preparing Audio-Video-Text Features
|
45 |
+
|
46 |
+
We have prepared some example data in `training/example_videos`.
|
47 |
+
`training/extract_video_training_latents.py` extracts audio, video, and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
48 |
+
|
49 |
+
To run this script, use the `torchrun` utility:
|
50 |
+
|
51 |
+
```bash
|
52 |
+
torchrun --standalone training/extract_video_training_latents.py
|
53 |
+
```
|
54 |
+
|
55 |
+
You can run this script with multiple GPUs (with `--nproc_per_node=<n>` after `--standalone` and before the script name) to speed up extraction.
|
56 |
+
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
57 |
+
Change the data path definitions in `data_cfg` if necessary.
|
58 |
+
|
59 |
+
Arguments:
|
60 |
+
|
61 |
+
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
62 |
+
- `output_dir` -- where TensorDict and the metadata file are saved.
|
63 |
+
|
64 |
+
Outputs produced in `output_dir`:
|
65 |
+
|
66 |
+
1. A directory named `vgg-{split}` (i.e., in the TensorDict format), containing
|
67 |
+
a. `mean.memmap` mean values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
68 |
+
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of videos X sequence length X channel size)
|
69 |
+
c. `text_features.memmap` text features extracted from CLIP (number of videos X 77 (sequence length) X 1024)
|
70 |
+
d. `clip_features.memmap` clip features extracted from CLIP (number of videos X 64 (8 fps) X 1024)
|
71 |
+
e. `sync_features.memmap` synchronization features extracted from Synchformer (number of videos X 192 (24 fps) X 768)
|
72 |
+
f. `meta.json` that contains the metadata for the above memory mappings
|
73 |
+
2. A tab-separated values file named `vgg-{split}.tsv` that contains two columns: `id` containing video file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
74 |
+
|
75 |
+
## Preparing Audio-Text Features
|
76 |
+
|
77 |
+
We have prepared some example data in `training/example_audios`.
|
78 |
+
|
79 |
+
1. Run `training/partition_clips` to partition each audio file into clips (by finding start and end points; we do not save the partitioned audio onto the disk to save disk space)
|
80 |
+
2. Run `training/extract_audio_training_latents.py` to extract each clip's audio and text features and save them as a `TensorDict` with a `.tsv` file containing metadata to `output_dir`.
|
81 |
+
|
82 |
+
### Partitioning the audio files
|
83 |
+
|
84 |
+
Run
|
85 |
+
|
86 |
+
```bash
|
87 |
+
python training/partition_clips.py
|
88 |
+
```
|
89 |
+
|
90 |
+
Arguments:
|
91 |
+
|
92 |
+
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`)
|
93 |
+
- `output_dir` -- path to the output `.csv` file
|
94 |
+
- `start` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the beginning of the chunk to be processed
|
95 |
+
- `end` -- optional; useful when you need to run multiple processes to speed up processing -- this defines the end of the chunk to be processed
|
96 |
+
|
97 |
+
### Extracting audio and text features
|
98 |
+
|
99 |
+
Run
|
100 |
+
|
101 |
+
```bash
|
102 |
+
torchrun --standalone training/extract_audio_training_latents.py
|
103 |
+
```
|
104 |
+
|
105 |
+
You can run this with multiple GPUs (with `--nproc_per_node=<n>`) to speed up extraction.
|
106 |
+
Modify the definitions near the top of the script to switch between 16kHz/44.1kHz extraction.
|
107 |
+
|
108 |
+
Arguments:
|
109 |
+
|
110 |
+
- `data_dir` -- path to a directory containing the audio files (`.flac` or `.wav`), same as the previous step
|
111 |
+
- `captions_tsv` -- path to the captions file, a tab-separated values (tsv) file at least with columns `id` and `caption`
|
112 |
+
- `clips_tsv` -- path to the clips file, generated in the last step
|
113 |
+
- `latent_dir` -- where intermediate latent outputs are saved. It is safe to delete this directory afterwards.
|
114 |
+
- `output_dir` -- where TensorDict and the metadata file are saved.
|
115 |
+
|
116 |
+
**Reference tsv files (with overlaps removed as mentioned in the paper) can be found [here](https://github.com/hkchengrex/MMAudio/releases/tag/v0.1).**
|
117 |
+
Note that these reference tsv files are the **outputs** of `extract_audio_training_latents.py`, which means the `id` column might contain duplicate entries (one per clip). You can still use it as the `captions_tsv` input though -- the script will handle duplicates gracefully.
|
118 |
+
Among these reference tsv files, `audioset_sl.tsv`, `bbcsound.tsv`, and `freesound.tsv` are subsets that are parts of WavCaps. These subsets might be smaller than the original datasets.
|
119 |
+
The Clotho data contains both the development set and the validation set.
|
120 |
+
|
121 |
+
Outputs produced in `output_dir`:
|
122 |
+
|
123 |
+
1. A directory named `{basename(output_dir)}` (i.e., in the TensorDict format), containing
|
124 |
+
a. `mean.memmap` mean values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
125 |
+
b. `std.memmap` standard deviation values predicted by the VAE encoder (number of audios X sequence length X channel size)
|
126 |
+
c. `text_features.memmap` text features extracted from CLIP (number of audios X 77 (sequence length) X 1024)
|
127 |
+
f. `meta.json` that contains the metadata for the above memory mappings
|
128 |
+
2. A tab-separated values file named `{basename(output_dir)}.tsv` that contains two columns: `id` containing audio file names without extension, and `label` containing corresponding text labels (i.e., captions)
|
129 |
+
|
130 |
+
## Training on Extracted Features
|
131 |
+
|
132 |
+
We use Distributed Data Parallel (DDP) for training.
|
133 |
+
First, specify the data path in `config/data/base.yaml`. If you used the default parameters in the scripts above to extract features for the example data, the `Example_video` and `Example_audio` items should already be correct.
|
134 |
+
|
135 |
+
To run training on the example data, use the following command:
|
136 |
+
|
137 |
+
```bash
|
138 |
+
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=1 train.py exp_id=debug compile=False debug=True example_train=True batch_size=1
|
139 |
+
```
|
140 |
+
|
141 |
+
This will not train a useful model, but it will check if everything is set up correctly.
|
142 |
+
|
143 |
+
For full training on the base model with two GPUs, use the following command:
|
144 |
+
|
145 |
+
```bash
|
146 |
+
OMP_NUM_THREADS=4 torchrun --standalone --nproc_per_node=2 train.py exp_id=exp_1 model=small_16k
|
147 |
+
```
|
148 |
+
|
149 |
+
Any outputs from training will be stored in `output/<exp_id>`.
|
150 |
+
|
151 |
+
More configuration options can be found in `config/base_config.yaml` and `config/train_config.yaml`.
|
152 |
+
For the medium and large models, specify `vgg_oversample_rate` to be `3` to reduce overfitting.
|
153 |
+
|
154 |
+
## Checkpoints
|
155 |
+
|
156 |
+
Model checkpoints, including optimizer states and the latest EMA weights, are available here: https://huggingface.co/hkchengrex/MMAudio
|
157 |
+
|
158 |
+
---
|
159 |
+
|
160 |
+
Godspeed!
|
docs/images/icon.png
ADDED
![]() |
docs/index.html
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
14 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
15 |
+
<link href="https://fonts.googleapis.com/css2?family=Source+Sans+3&display=swap" rel="stylesheet">
|
16 |
+
<meta charset="UTF-8">
|
17 |
+
<title>MMAudio</title>
|
18 |
+
|
19 |
+
<link rel="icon" type="image/png" href="images/icon.png">
|
20 |
+
|
21 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
22 |
+
<!-- CSS only -->
|
23 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
24 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
25 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
|
26 |
+
|
27 |
+
<link rel="stylesheet" href="style.css">
|
28 |
+
</head>
|
29 |
+
<body>
|
30 |
+
|
31 |
+
<body>
|
32 |
+
<br><br><br><br>
|
33 |
+
<div class="container">
|
34 |
+
<div class="row text-center" style="font-size:38px">
|
35 |
+
<div class="col strong">
|
36 |
+
Taming Multimodal Joint Training for High-Quality <br>Video-to-Audio Synthesis
|
37 |
+
</div>
|
38 |
+
</div>
|
39 |
+
|
40 |
+
<br>
|
41 |
+
<div class="row text-center" style="font-size:28px">
|
42 |
+
<div class="col">
|
43 |
+
CVPR 2025
|
44 |
+
</div>
|
45 |
+
</div>
|
46 |
+
<br>
|
47 |
+
|
48 |
+
<div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
|
49 |
+
<div class="col-sm-auto px-lg-2">
|
50 |
+
<a href="https://hkchengrex.github.io/">Ho Kei Cheng<sup>1</sup></a>
|
51 |
+
</div>
|
52 |
+
<div class="col-sm-auto px-lg-2">
|
53 |
+
<nobr><a href="https://scholar.google.co.jp/citations?user=RRIO1CcAAAAJ">Masato Ishii<sup>2</sup></a></nobr>
|
54 |
+
</div>
|
55 |
+
<div class="col-sm-auto px-lg-2">
|
56 |
+
<nobr><a href="https://scholar.google.com/citations?user=sXAjHFIAAAAJ">Akio Hayakawa<sup>2</sup></a></nobr>
|
57 |
+
</div>
|
58 |
+
<div class="col-sm-auto px-lg-2">
|
59 |
+
<nobr><a href="https://scholar.google.com/citations?user=XCRO260AAAAJ">Takashi Shibuya<sup>2</sup></a></nobr>
|
60 |
+
</div>
|
61 |
+
<div class="col-sm-auto px-lg-2">
|
62 |
+
<nobr><a href="https://www.alexander-schwing.de/">Alexander Schwing<sup>1</sup></a></nobr>
|
63 |
+
</div>
|
64 |
+
<div class="col-sm-auto px-lg-2" >
|
65 |
+
<nobr><a href="https://www.yukimitsufuji.com/">Yuki Mitsufuji<sup>2,3</sup></a></nobr>
|
66 |
+
</div>
|
67 |
+
</div>
|
68 |
+
|
69 |
+
<div class="h-100 row text-center heavy justify-content-md-center" style="font-size:22px;">
|
70 |
+
<div class="col-sm-auto px-lg-2">
|
71 |
+
<sup>1</sup>University of Illinois Urbana-Champaign
|
72 |
+
</div>
|
73 |
+
<div class="col-sm-auto px-lg-2">
|
74 |
+
<sup>2</sup>Sony AI
|
75 |
+
</div>
|
76 |
+
<div class="col-sm-auto px-lg-2">
|
77 |
+
<sup>3</sup>Sony Group Corporation
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
|
81 |
+
<br>
|
82 |
+
|
83 |
+
<br>
|
84 |
+
|
85 |
+
<div class="h-100 row text-center justify-content-md-center" style="font-size:20px;">
|
86 |
+
<div class="col-sm-2">
|
87 |
+
<a href="https://arxiv.org/abs/2412.15322">[Paper]</a>
|
88 |
+
</div>
|
89 |
+
<div class="col-sm-2">
|
90 |
+
<a href="https://github.com/hkchengrex/MMAudio">[Code]</a>
|
91 |
+
</div>
|
92 |
+
<div class="col-sm-3">
|
93 |
+
<a href="https://huggingface.co/spaces/hkchengrex/MMAudio">[Huggingface Demo]</a>
|
94 |
+
</div>
|
95 |
+
<div class="col-sm-2">
|
96 |
+
<a href="https://colab.research.google.com/drive/1TAaXCY2-kPk4xE4PwKB3EqFbSnkUuzZ8?usp=sharing">[Colab Demo]</a>
|
97 |
+
</div>
|
98 |
+
<div class="col-sm-3">
|
99 |
+
<a href="https://replicate.com/zsxkib/mmaudio">[Replicate Demo]</a>
|
100 |
+
</div>
|
101 |
+
</div>
|
102 |
+
|
103 |
+
<br>
|
104 |
+
|
105 |
+
<hr>
|
106 |
+
|
107 |
+
<div class="row" style="font-size:32px">
|
108 |
+
<div class="col strong">
|
109 |
+
TL;DR
|
110 |
+
</div>
|
111 |
+
</div>
|
112 |
+
<br>
|
113 |
+
<div class="row">
|
114 |
+
<div class="col">
|
115 |
+
<p class="light" style="text-align: left;">
|
116 |
+
MMAudio generates synchronized audio given video and/or text inputs.
|
117 |
+
</p>
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
|
121 |
+
<br>
|
122 |
+
<hr>
|
123 |
+
<br>
|
124 |
+
|
125 |
+
<div class="row" style="font-size:32px">
|
126 |
+
<div class="col strong">
|
127 |
+
Demo
|
128 |
+
</div>
|
129 |
+
</div>
|
130 |
+
<br>
|
131 |
+
<div class="row" style="font-size:48px">
|
132 |
+
<div class="col strong text-center">
|
133 |
+
<a href="video_main.html" style="text-decoration: underline;"><More results></a>
|
134 |
+
</div>
|
135 |
+
</div>
|
136 |
+
<br>
|
137 |
+
<div class="video-container" style="text-align: center;">
|
138 |
+
<iframe src="https://youtube.com/embed/YElewUT2M4M"></iframe>
|
139 |
+
</div>
|
140 |
+
|
141 |
+
<br>
|
142 |
+
|
143 |
+
<br><br>
|
144 |
+
<br><br>
|
145 |
+
|
146 |
+
</div>
|
147 |
+
|
148 |
+
</body>
|
149 |
+
</html>
|
docs/style.css
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: 'Source Sans 3', sans-serif;
|
3 |
+
font-size: 18px;
|
4 |
+
margin-left: auto;
|
5 |
+
margin-right: auto;
|
6 |
+
font-weight: 400;
|
7 |
+
height: 100%;
|
8 |
+
max-width: 1000px;
|
9 |
+
}
|
10 |
+
|
11 |
+
table {
|
12 |
+
width: 100%;
|
13 |
+
border-collapse: collapse;
|
14 |
+
}
|
15 |
+
th, td {
|
16 |
+
border: 1px solid #ddd;
|
17 |
+
padding: 8px;
|
18 |
+
text-align: center;
|
19 |
+
}
|
20 |
+
th {
|
21 |
+
background-color: #f2f2f2;
|
22 |
+
}
|
23 |
+
video {
|
24 |
+
width: 100%;
|
25 |
+
height: auto;
|
26 |
+
}
|
27 |
+
p {
|
28 |
+
font-size: 28px;
|
29 |
+
}
|
30 |
+
h2 {
|
31 |
+
font-size: 36px;
|
32 |
+
}
|
33 |
+
|
34 |
+
.strong {
|
35 |
+
font-weight: 700;
|
36 |
+
}
|
37 |
+
|
38 |
+
.light {
|
39 |
+
font-weight: 100;
|
40 |
+
}
|
41 |
+
|
42 |
+
.heavy {
|
43 |
+
font-weight: 900;
|
44 |
+
}
|
45 |
+
|
46 |
+
.column {
|
47 |
+
float: left;
|
48 |
+
}
|
49 |
+
|
50 |
+
a:link,
|
51 |
+
a:visited {
|
52 |
+
color: #05538f;
|
53 |
+
text-decoration: none;
|
54 |
+
}
|
55 |
+
|
56 |
+
a:hover {
|
57 |
+
color: #63cbdd;
|
58 |
+
}
|
59 |
+
|
60 |
+
hr {
|
61 |
+
border: 0;
|
62 |
+
height: 1px;
|
63 |
+
background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
|
64 |
+
}
|
65 |
+
|
66 |
+
.video-container {
|
67 |
+
position: relative;
|
68 |
+
padding-bottom: 56.25%; /* 16:9 */
|
69 |
+
height: 0;
|
70 |
+
}
|
71 |
+
|
72 |
+
.video-container iframe {
|
73 |
+
position: absolute;
|
74 |
+
top: 0;
|
75 |
+
left: 0;
|
76 |
+
width: 100%;
|
77 |
+
height: 100%;
|
78 |
+
}
|
docs/style_videos.css
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
font-family: 'Source Sans 3', sans-serif;
|
3 |
+
font-size: 1.5vh;
|
4 |
+
font-weight: 400;
|
5 |
+
}
|
6 |
+
|
7 |
+
table {
|
8 |
+
width: 100%;
|
9 |
+
border-collapse: collapse;
|
10 |
+
}
|
11 |
+
th, td {
|
12 |
+
border: 1px solid #ddd;
|
13 |
+
padding: 8px;
|
14 |
+
text-align: center;
|
15 |
+
}
|
16 |
+
th {
|
17 |
+
background-color: #f2f2f2;
|
18 |
+
}
|
19 |
+
video {
|
20 |
+
width: 100%;
|
21 |
+
height: auto;
|
22 |
+
}
|
23 |
+
p {
|
24 |
+
font-size: 1.5vh;
|
25 |
+
font-weight: bold;
|
26 |
+
}
|
27 |
+
h2 {
|
28 |
+
font-size: 2vh;
|
29 |
+
font-weight: bold;
|
30 |
+
}
|
31 |
+
|
32 |
+
.video-container {
|
33 |
+
position: relative;
|
34 |
+
padding-bottom: 56.25%; /* 16:9 */
|
35 |
+
height: 0;
|
36 |
+
}
|
37 |
+
|
38 |
+
.video-container iframe {
|
39 |
+
position: absolute;
|
40 |
+
top: 0;
|
41 |
+
left: 0;
|
42 |
+
width: 100%;
|
43 |
+
height: 100%;
|
44 |
+
}
|
45 |
+
|
46 |
+
.video-header {
|
47 |
+
background-color: #f2f2f2;
|
48 |
+
text-align: center;
|
49 |
+
font-size: 1.5vh;
|
50 |
+
font-weight: bold;
|
51 |
+
padding: 8px;
|
52 |
+
}
|
docs/video_gen.html
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
|
14 |
+
<meta charset="UTF-8">
|
15 |
+
<title>MMAudio</title>
|
16 |
+
|
17 |
+
<link rel="icon" type="image/png" href="images/icon.png">
|
18 |
+
|
19 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
20 |
+
<!-- CSS only -->
|
21 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
22 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
23 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
|
24 |
+
|
25 |
+
<link rel="stylesheet" href="style_videos.css">
|
26 |
+
</head>
|
27 |
+
<body>
|
28 |
+
|
29 |
+
<div id="moviegen_all">
|
30 |
+
<h2 id="moviegen" style="text-align: center;">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</h2>
|
31 |
+
<p id="moviegen1" style="overflow: hidden;">
|
32 |
+
Example 1: Ice cracking with sharp snapping sound, and metal tool scraping against the ice surface.
|
33 |
+
<span style="float: right;"><a href="#index">Back to index</a></span>
|
34 |
+
</p>
|
35 |
+
|
36 |
+
<div class="row g-1">
|
37 |
+
<div class="col-sm-6">
|
38 |
+
<div class="video-header">Movie Gen Audio</div>
|
39 |
+
<div class="video-container">
|
40 |
+
<iframe src="https://youtube.com/embed/d7Lb0ihtGcE"></iframe>
|
41 |
+
</div>
|
42 |
+
</div>
|
43 |
+
<div class="col-sm-6">
|
44 |
+
<div class="video-header">Ours</div>
|
45 |
+
<div class="video-container">
|
46 |
+
<iframe src="https://youtube.com/embed/F4JoJ2r2m8U"></iframe>
|
47 |
+
</div>
|
48 |
+
</div>
|
49 |
+
</div>
|
50 |
+
<br>
|
51 |
+
|
52 |
+
<!-- <p id="moviegen2">Example 2: Rhythmic splashing and lapping of water. <span style="float:right;"><a href="#index">Back to index</a></span> </p>
|
53 |
+
|
54 |
+
<table>
|
55 |
+
<thead>
|
56 |
+
<tr>
|
57 |
+
<th>Movie Gen Audio</th>
|
58 |
+
<th>Ours</th>
|
59 |
+
</tr>
|
60 |
+
</thead>
|
61 |
+
<tbody>
|
62 |
+
<tr>
|
63 |
+
<td width="50%">
|
64 |
+
<div class="video-container">
|
65 |
+
<iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
|
66 |
+
</div>
|
67 |
+
</td>
|
68 |
+
<td width="50%">
|
69 |
+
<div class="video-container">
|
70 |
+
<iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
|
71 |
+
</div>
|
72 |
+
</td>
|
73 |
+
</tr>
|
74 |
+
</tbody>
|
75 |
+
</table> -->
|
76 |
+
|
77 |
+
<p id="moviegen2" style="overflow: hidden;">
|
78 |
+
Example 2: Rhythmic splashing and lapping of water.
|
79 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
80 |
+
</p>
|
81 |
+
<div class="row g-1">
|
82 |
+
<div class="col-sm-6">
|
83 |
+
<div class="video-header">Movie Gen Audio</div>
|
84 |
+
<div class="video-container">
|
85 |
+
<iframe src="https://youtube.com/embed/5gQNPK99CIk"></iframe>
|
86 |
+
</div>
|
87 |
+
</div>
|
88 |
+
<div class="col-sm-6">
|
89 |
+
<div class="video-header">Ours</div>
|
90 |
+
<div class="video-container">
|
91 |
+
<iframe src="https://youtube.com/embed/AbwnTzG-BpA"></iframe>
|
92 |
+
</div>
|
93 |
+
</div>
|
94 |
+
</div>
|
95 |
+
<br>
|
96 |
+
|
97 |
+
<p id="moviegen3" style="overflow: hidden;">
|
98 |
+
Example 3: Shovel scrapes against dry earth.
|
99 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
100 |
+
</p>
|
101 |
+
<div class="row g-1">
|
102 |
+
<div class="col-sm-6">
|
103 |
+
<div class="video-header">Movie Gen Audio</div>
|
104 |
+
<div class="video-container">
|
105 |
+
<iframe src="https://youtube.com/embed/PUKGyEve7XQ"></iframe>
|
106 |
+
</div>
|
107 |
+
</div>
|
108 |
+
<div class="col-sm-6">
|
109 |
+
<div class="video-header">Ours</div>
|
110 |
+
<div class="video-container">
|
111 |
+
<iframe src="https://youtube.com/embed/CNn7i8VNkdc"></iframe>
|
112 |
+
</div>
|
113 |
+
</div>
|
114 |
+
</div>
|
115 |
+
<br>
|
116 |
+
|
117 |
+
|
118 |
+
<p id="moviegen4" style="overflow: hidden;">
|
119 |
+
(Failure case) Example 4: Creamy sound of mashed potatoes being scooped.
|
120 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
121 |
+
</p>
|
122 |
+
<div class="row g-1">
|
123 |
+
<div class="col-sm-6">
|
124 |
+
<div class="video-header">Movie Gen Audio</div>
|
125 |
+
<div class="video-container">
|
126 |
+
<iframe src="https://youtube.com/embed/PJv1zxR9JjQ"></iframe>
|
127 |
+
</div>
|
128 |
+
</div>
|
129 |
+
<div class="col-sm-6">
|
130 |
+
<div class="video-header">Ours</div>
|
131 |
+
<div class="video-container">
|
132 |
+
<iframe src="https://youtube.com/embed/c3-LJ1lNsPQ"></iframe>
|
133 |
+
</div>
|
134 |
+
</div>
|
135 |
+
</div>
|
136 |
+
<br>
|
137 |
+
|
138 |
+
</div>
|
139 |
+
|
140 |
+
<div id="hunyuan_sora_all">
|
141 |
+
|
142 |
+
<h2 id="hunyuan" style="text-align: center;">Results on Videos Generated by Hunyuan</h2>
|
143 |
+
<p style="overflow: hidden;">
|
144 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
145 |
+
</p>
|
146 |
+
<div class="row g-1">
|
147 |
+
<div class="col-sm-6">
|
148 |
+
<div class="video-header">Typing</div>
|
149 |
+
<div class="video-container">
|
150 |
+
<iframe src="https://youtube.com/embed/8ln_9hhH_nk"></iframe>
|
151 |
+
</div>
|
152 |
+
</div>
|
153 |
+
<div class="col-sm-6">
|
154 |
+
<div class="video-header">Water is rushing down a stream and pouring</div>
|
155 |
+
<div class="video-container">
|
156 |
+
<iframe src="https://youtube.com/embed/5df1FZFQj30"></iframe>
|
157 |
+
</div>
|
158 |
+
</div>
|
159 |
+
</div>
|
160 |
+
<div class="row g-1">
|
161 |
+
<div class="col-sm-6">
|
162 |
+
<div class="video-header">Waves on beach</div>
|
163 |
+
<div class="video-container">
|
164 |
+
<iframe src="https://youtube.com/embed/7wQ9D5WgpFc"></iframe>
|
165 |
+
</div>
|
166 |
+
</div>
|
167 |
+
<div class="col-sm-6">
|
168 |
+
<div class="video-header">Water droplet</div>
|
169 |
+
<div class="video-container">
|
170 |
+
<iframe src="https://youtube.com/embed/q7M2nsalGjM"></iframe>
|
171 |
+
</div>
|
172 |
+
</div>
|
173 |
+
</div>
|
174 |
+
<br>
|
175 |
+
|
176 |
+
<h2 id="sora" style="text-align: center;">Results on Videos Generated by Sora</h2>
|
177 |
+
<p style="overflow: hidden;">
|
178 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
179 |
+
</p>
|
180 |
+
<div class="row g-1">
|
181 |
+
<div class="col-sm-6">
|
182 |
+
<div class="video-header">Ships riding waves</div>
|
183 |
+
<div class="video-container">
|
184 |
+
<iframe src="https://youtube.com/embed/JbgQzHHytk8"></iframe>
|
185 |
+
</div>
|
186 |
+
</div>
|
187 |
+
<div class="col-sm-6">
|
188 |
+
<div class="video-header">Train (no text prompt given)</div>
|
189 |
+
<div class="video-container">
|
190 |
+
<iframe src="https://youtube.com/embed/xOW7zrjpWC8"></iframe>
|
191 |
+
</div>
|
192 |
+
</div>
|
193 |
+
</div>
|
194 |
+
<div class="row g-1">
|
195 |
+
<div class="col-sm-6">
|
196 |
+
<div class="video-header">Seashore (no text prompt given)</div>
|
197 |
+
<div class="video-container">
|
198 |
+
<iframe src="https://youtube.com/embed/fIuw5Y8ZZ9E"></iframe>
|
199 |
+
</div>
|
200 |
+
</div>
|
201 |
+
<div class="col-sm-6">
|
202 |
+
<div class="video-header">Surfing (failure: unprompted music)</div>
|
203 |
+
<div class="video-container">
|
204 |
+
<iframe src="https://youtube.com/embed/UcSTk-v0M_s"></iframe>
|
205 |
+
</div>
|
206 |
+
</div>
|
207 |
+
</div>
|
208 |
+
<br>
|
209 |
+
|
210 |
+
<div id="mochi_ltx_all">
|
211 |
+
<h2 id="mochi" style="text-align: center;">Results on Videos Generated by Mochi 1</h2>
|
212 |
+
<p style="overflow: hidden;">
|
213 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
214 |
+
</p>
|
215 |
+
<div class="row g-1">
|
216 |
+
<div class="col-sm-6">
|
217 |
+
<div class="video-header">Magical fire and lightning (no text prompt given)</div>
|
218 |
+
<div class="video-container">
|
219 |
+
<iframe src="https://youtube.com/embed/tTlRZaSMNwY"></iframe>
|
220 |
+
</div>
|
221 |
+
</div>
|
222 |
+
<div class="col-sm-6">
|
223 |
+
<div class="video-header">Storm (no text prompt given)</div>
|
224 |
+
<div class="video-container">
|
225 |
+
<iframe src="https://youtube.com/embed/4hrZTMJUy3w"></iframe>
|
226 |
+
</div>
|
227 |
+
</div>
|
228 |
+
</div>
|
229 |
+
<br>
|
230 |
+
|
231 |
+
<h2 id="ltx" style="text-align: center;">Results on Videos Generated by LTX-Video</h2>
|
232 |
+
<p style="overflow: hidden;">
|
233 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
234 |
+
</p>
|
235 |
+
<div class="row g-1">
|
236 |
+
<div class="col-sm-6">
|
237 |
+
<div class="video-header">Firewood burning and cracking</div>
|
238 |
+
<div class="video-container">
|
239 |
+
<iframe src="https://youtube.com/embed/P7_DDpgev0g"></iframe>
|
240 |
+
</div>
|
241 |
+
</div>
|
242 |
+
<div class="col-sm-6">
|
243 |
+
<div class="video-header">Waterfall, water splashing</div>
|
244 |
+
<div class="video-container">
|
245 |
+
<iframe src="https://youtube.com/embed/4MvjceYnIO0"></iframe>
|
246 |
+
</div>
|
247 |
+
</div>
|
248 |
+
</div>
|
249 |
+
<br>
|
250 |
+
|
251 |
+
</div>
|
252 |
+
|
253 |
+
</body>
|
254 |
+
</html>
|
docs/video_main.html
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
|
14 |
+
<meta charset="UTF-8">
|
15 |
+
<title>MMAudio</title>
|
16 |
+
|
17 |
+
<link rel="icon" type="image/png" href="images/icon.png">
|
18 |
+
|
19 |
+
<meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, user-scalable=no">
|
20 |
+
<!-- CSS only -->
|
21 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
22 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
23 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.7.1/jquery.min.js"></script>
|
24 |
+
|
25 |
+
<link rel="stylesheet" href="style_videos.css">
|
26 |
+
|
27 |
+
<script type="text/javascript">
|
28 |
+
$(document).ready(function(){
|
29 |
+
$("#content").load("video_gen.html #moviegen_all");
|
30 |
+
$("#load_moveigen").click(function(){
|
31 |
+
$("#content").load("video_gen.html #moviegen_all");
|
32 |
+
});
|
33 |
+
$("#load_hunyuan_sora").click(function(){
|
34 |
+
$("#content").load("video_gen.html #hunyuan_sora_all");
|
35 |
+
});
|
36 |
+
$("#load_mochi_ltx").click(function(){
|
37 |
+
$("#content").load("video_gen.html #mochi_ltx_all");
|
38 |
+
});
|
39 |
+
$("#load_vgg1").click(function(){
|
40 |
+
$("#content").load("video_vgg.html #vgg1");
|
41 |
+
});
|
42 |
+
$("#load_vgg2").click(function(){
|
43 |
+
$("#content").load("video_vgg.html #vgg2");
|
44 |
+
});
|
45 |
+
$("#load_vgg3").click(function(){
|
46 |
+
$("#content").load("video_vgg.html #vgg3");
|
47 |
+
});
|
48 |
+
$("#load_vgg4").click(function(){
|
49 |
+
$("#content").load("video_vgg.html #vgg4");
|
50 |
+
});
|
51 |
+
$("#load_vgg5").click(function(){
|
52 |
+
$("#content").load("video_vgg.html #vgg5");
|
53 |
+
});
|
54 |
+
$("#load_vgg6").click(function(){
|
55 |
+
$("#content").load("video_vgg.html #vgg6");
|
56 |
+
});
|
57 |
+
$("#load_vgg_extra").click(function(){
|
58 |
+
$("#content").load("video_vgg.html #vgg_extra");
|
59 |
+
});
|
60 |
+
});
|
61 |
+
</script>
|
62 |
+
</head>
|
63 |
+
<body>
|
64 |
+
<h1 id="index" style="text-align: center;">Index</h1>
|
65 |
+
<p><b>(Click on the links to load the corresponding videos)</b> <span style="float:right;"><a href="index.html">Back to project page</a></span></p>
|
66 |
+
|
67 |
+
<ol>
|
68 |
+
<li>
|
69 |
+
<a href="#" id="load_moveigen">Comparisons with Movie Gen Audio on Videos Generated by MovieGen</a>
|
70 |
+
</li>
|
71 |
+
<li>
|
72 |
+
<a href="#" id="load_hunyuan_sora">Results on Videos Generated by Hunyuan and Sora</a>
|
73 |
+
</li>
|
74 |
+
<li>
|
75 |
+
<a href="#" id="load_mochi_ltx">Results on Videos Generated by Mochi 1 and LTX-Video</a>
|
76 |
+
</li>
|
77 |
+
<li>
|
78 |
+
On VGGSound
|
79 |
+
<ol>
|
80 |
+
<li><a id='load_vgg1' href="#">Example 1: Wolf howling</a></li>
|
81 |
+
<li><a id='load_vgg2' href="#">Example 2: Striking a golf ball</a></li>
|
82 |
+
<li><a id='load_vgg3' href="#">Example 3: Hitting a drum</a></li>
|
83 |
+
<li><a id='load_vgg4' href="#">Example 4: Dog barking</a></li>
|
84 |
+
<li><a id='load_vgg5' href="#">Example 5: Playing a string instrument</a></li>
|
85 |
+
<li><a id='load_vgg6' href="#">Example 6: A group of people playing tambourines</a></li>
|
86 |
+
<li><a id='load_vgg_extra' href="#">Extra results & failure cases</a></li>
|
87 |
+
</ol>
|
88 |
+
</li>
|
89 |
+
</ol>
|
90 |
+
|
91 |
+
<div id="content" class="container-fluid">
|
92 |
+
|
93 |
+
</div>
|
94 |
+
<br>
|
95 |
+
<br>
|
96 |
+
|
97 |
+
</body>
|
98 |
+
</html>
|
docs/video_vgg.html
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<!-- Google tag (gtag.js) -->
|
5 |
+
<script async src="https://www.googletagmanager.com/gtag/js?id=G-0JKBJ3WRJZ"></script>
|
6 |
+
<script>
|
7 |
+
window.dataLayer = window.dataLayer || [];
|
8 |
+
function gtag(){dataLayer.push(arguments);}
|
9 |
+
gtag('js', new Date());
|
10 |
+
gtag('config', 'G-0JKBJ3WRJZ');
|
11 |
+
</script>
|
12 |
+
|
13 |
+
<link href='https://fonts.googleapis.com/css?family=Source+Sans+Pro' rel='stylesheet' type='text/css'>
|
14 |
+
<meta charset="UTF-8">
|
15 |
+
<title>MMAudio</title>
|
16 |
+
|
17 |
+
<meta name="viewport" content="width=device-width, initial-scale=1">
|
18 |
+
<!-- CSS only -->
|
19 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet"
|
20 |
+
integrity="sha384-+0n0xVW2eSR5OomGNYDnhzAbDsOXxcvSN1TPprVMTNDbiYZCxYbOOl7+AMvyTG2x" crossorigin="anonymous">
|
21 |
+
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
|
22 |
+
|
23 |
+
<link rel="stylesheet" href="style_videos.css">
|
24 |
+
</head>
|
25 |
+
<body>
|
26 |
+
|
27 |
+
<div id="vgg1">
|
28 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
29 |
+
<p style="overflow: hidden;">
|
30 |
+
Example 1: Wolf howling.
|
31 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
32 |
+
</p>
|
33 |
+
<div class="row g-1">
|
34 |
+
<div class="col-sm-3">
|
35 |
+
<div class="video-header">Ground-truth</div>
|
36 |
+
<div class="video-container">
|
37 |
+
<iframe src="https://youtube.com/embed/9J_V74gqMUA"></iframe>
|
38 |
+
</div>
|
39 |
+
</div>
|
40 |
+
<div class="col-sm-3">
|
41 |
+
<div class="video-header">Ours</div>
|
42 |
+
<div class="video-container">
|
43 |
+
<iframe src="https://youtube.com/embed/P6O8IpjErPc"></iframe>
|
44 |
+
</div>
|
45 |
+
</div>
|
46 |
+
<div class="col-sm-3">
|
47 |
+
<div class="video-header">V2A-Mapper</div>
|
48 |
+
<div class="video-container">
|
49 |
+
<iframe src="https://youtube.com/embed/w-5eyqepvTk"></iframe>
|
50 |
+
</div>
|
51 |
+
</div>
|
52 |
+
<div class="col-sm-3">
|
53 |
+
<div class="video-header">FoleyCrafter</div>
|
54 |
+
<div class="video-container">
|
55 |
+
<iframe src="https://youtube.com/embed/VOLfoZlRkzo"></iframe>
|
56 |
+
</div>
|
57 |
+
</div>
|
58 |
+
</div>
|
59 |
+
<div class="row g-1">
|
60 |
+
<div class="col-sm-3">
|
61 |
+
<div class="video-header">Frieren</div>
|
62 |
+
<div class="video-container">
|
63 |
+
<iframe src="https://youtube.com/embed/49owKyA5Pa8"></iframe>
|
64 |
+
</div>
|
65 |
+
</div>
|
66 |
+
<div class="col-sm-3">
|
67 |
+
<div class="video-header">VATT</div>
|
68 |
+
<div class="video-container">
|
69 |
+
<iframe src="https://youtube.com/embed/QVtrFgbeGDM"></iframe>
|
70 |
+
</div>
|
71 |
+
</div>
|
72 |
+
<div class="col-sm-3">
|
73 |
+
<div class="video-header">V-AURA</div>
|
74 |
+
<div class="video-container">
|
75 |
+
<iframe src="https://youtube.com/embed/8r0uEfSNjvI"></iframe>
|
76 |
+
</div>
|
77 |
+
</div>
|
78 |
+
<div class="col-sm-3">
|
79 |
+
<div class="video-header">Seeing and Hearing</div>
|
80 |
+
<div class="video-container">
|
81 |
+
<iframe src="https://youtube.com/embed/bn-sLg2qulk"></iframe>
|
82 |
+
</div>
|
83 |
+
</div>
|
84 |
+
</div>
|
85 |
+
</div>
|
86 |
+
|
87 |
+
<div id="vgg2">
|
88 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
89 |
+
<p style="overflow: hidden;">
|
90 |
+
Example 2: Striking a golf ball.
|
91 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
92 |
+
</p>
|
93 |
+
|
94 |
+
<div class="row g-1">
|
95 |
+
<div class="col-sm-3">
|
96 |
+
<div class="video-header">Ground-truth</div>
|
97 |
+
<div class="video-container">
|
98 |
+
<iframe src="https://youtube.com/embed/1hwSu42kkho"></iframe>
|
99 |
+
</div>
|
100 |
+
</div>
|
101 |
+
<div class="col-sm-3">
|
102 |
+
<div class="video-header">Ours</div>
|
103 |
+
<div class="video-container">
|
104 |
+
<iframe src="https://youtube.com/embed/kZibDoDCNxI"></iframe>
|
105 |
+
</div>
|
106 |
+
</div>
|
107 |
+
<div class="col-sm-3">
|
108 |
+
<div class="video-header">V2A-Mapper</div>
|
109 |
+
<div class="video-container">
|
110 |
+
<iframe src="https://youtube.com/embed/jgKfLBLhh7Y"></iframe>
|
111 |
+
</div>
|
112 |
+
</div>
|
113 |
+
<div class="col-sm-3">
|
114 |
+
<div class="video-header">FoleyCrafter</div>
|
115 |
+
<div class="video-container">
|
116 |
+
<iframe src="https://youtube.com/embed/Lfsx8mOPcJo"></iframe>
|
117 |
+
</div>
|
118 |
+
</div>
|
119 |
+
</div>
|
120 |
+
<div class="row g-1">
|
121 |
+
<div class="col-sm-3">
|
122 |
+
<div class="video-header">Frieren</div>
|
123 |
+
<div class="video-container">
|
124 |
+
<iframe src="https://youtube.com/embed/tz-LpbB0MBc"></iframe>
|
125 |
+
</div>
|
126 |
+
</div>
|
127 |
+
<div class="col-sm-3">
|
128 |
+
<div class="video-header">VATT</div>
|
129 |
+
<div class="video-container">
|
130 |
+
<iframe src="https://youtube.com/embed/RTDUHMi08n4"></iframe>
|
131 |
+
</div>
|
132 |
+
</div>
|
133 |
+
<div class="col-sm-3">
|
134 |
+
<div class="video-header">V-AURA</div>
|
135 |
+
<div class="video-container">
|
136 |
+
<iframe src="https://youtube.com/embed/N-3TDOsPnZQ"></iframe>
|
137 |
+
</div>
|
138 |
+
</div>
|
139 |
+
<div class="col-sm-3">
|
140 |
+
<div class="video-header">Seeing and Hearing</div>
|
141 |
+
<div class="video-container">
|
142 |
+
<iframe src="https://youtube.com/embed/QnsHnLn4gB0"></iframe>
|
143 |
+
</div>
|
144 |
+
</div>
|
145 |
+
</div>
|
146 |
+
</div>
|
147 |
+
|
148 |
+
<div id="vgg3">
|
149 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
150 |
+
<p style="overflow: hidden;">
|
151 |
+
Example 3: Hitting a drum.
|
152 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
153 |
+
</p>
|
154 |
+
|
155 |
+
<div class="row g-1">
|
156 |
+
<div class="col-sm-3">
|
157 |
+
<div class="video-header">Ground-truth</div>
|
158 |
+
<div class="video-container">
|
159 |
+
<iframe src="https://youtube.com/embed/0oeIwq77w0Q"></iframe>
|
160 |
+
</div>
|
161 |
+
</div>
|
162 |
+
<div class="col-sm-3">
|
163 |
+
<div class="video-header">Ours</div>
|
164 |
+
<div class="video-container">
|
165 |
+
<iframe src="https://youtube.com/embed/-UtPV9ohuIM"></iframe>
|
166 |
+
</div>
|
167 |
+
</div>
|
168 |
+
<div class="col-sm-3">
|
169 |
+
<div class="video-header">V2A-Mapper</div>
|
170 |
+
<div class="video-container">
|
171 |
+
<iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
|
172 |
+
</div>
|
173 |
+
</div>
|
174 |
+
<div class="col-sm-3">
|
175 |
+
<div class="video-header">FoleyCrafter</div>
|
176 |
+
<div class="video-container">
|
177 |
+
<iframe src="https://youtube.com/embed/kkCsXPOlBvY"></iframe>
|
178 |
+
</div>
|
179 |
+
</div>
|
180 |
+
</div>
|
181 |
+
<div class="row g-1">
|
182 |
+
<div class="col-sm-3">
|
183 |
+
<div class="video-header">Frieren</div>
|
184 |
+
<div class="video-container">
|
185 |
+
<iframe src="https://youtube.com/embed/MbNKsVsuvig"></iframe>
|
186 |
+
</div>
|
187 |
+
</div>
|
188 |
+
<div class="col-sm-3">
|
189 |
+
<div class="video-header">VATT</div>
|
190 |
+
<div class="video-container">
|
191 |
+
<iframe src="https://youtube.com/embed/2yYviBjrpBw"></iframe>
|
192 |
+
</div>
|
193 |
+
</div>
|
194 |
+
<div class="col-sm-3">
|
195 |
+
<div class="video-header">V-AURA</div>
|
196 |
+
<div class="video-container">
|
197 |
+
<iframe src="https://youtube.com/embed/9yivkgN-zwc"></iframe>
|
198 |
+
</div>
|
199 |
+
</div>
|
200 |
+
<div class="col-sm-3">
|
201 |
+
<div class="video-header">Seeing and Hearing</div>
|
202 |
+
<div class="video-container">
|
203 |
+
<iframe src="https://youtube.com/embed/6dnyQt4Fuhs"></iframe>
|
204 |
+
</div>
|
205 |
+
</div>
|
206 |
+
</div>
|
207 |
+
</div>
|
208 |
+
</div>
|
209 |
+
|
210 |
+
<div id="vgg4">
|
211 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
212 |
+
<p style="overflow: hidden;">
|
213 |
+
Example 4: Dog barking.
|
214 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
215 |
+
</p>
|
216 |
+
|
217 |
+
<div class="row g-1">
|
218 |
+
<div class="col-sm-3">
|
219 |
+
<div class="video-header">Ground-truth</div>
|
220 |
+
<div class="video-container">
|
221 |
+
<iframe src="https://youtube.com/embed/ckaqvTyMYAw"></iframe>
|
222 |
+
</div>
|
223 |
+
</div>
|
224 |
+
<div class="col-sm-3">
|
225 |
+
<div class="video-header">Ours</div>
|
226 |
+
<div class="video-container">
|
227 |
+
<iframe src="https://youtube.com/embed/_aRndFZzZ-I"></iframe>
|
228 |
+
</div>
|
229 |
+
</div>
|
230 |
+
<div class="col-sm-3">
|
231 |
+
<div class="video-header">V2A-Mapper</div>
|
232 |
+
<div class="video-container">
|
233 |
+
<iframe src="https://youtube.com/embed/mNCISP3LBl0"></iframe>
|
234 |
+
</div>
|
235 |
+
</div>
|
236 |
+
<div class="col-sm-3">
|
237 |
+
<div class="video-header">FoleyCrafter</div>
|
238 |
+
<div class="video-container">
|
239 |
+
<iframe src="https://youtube.com/embed/phZBQ3L7foE"></iframe>
|
240 |
+
</div>
|
241 |
+
</div>
|
242 |
+
</div>
|
243 |
+
<div class="row g-1">
|
244 |
+
<div class="col-sm-3">
|
245 |
+
<div class="video-header">Frieren</div>
|
246 |
+
<div class="video-container">
|
247 |
+
<iframe src="https://youtube.com/embed/Sb5Mg1-ORao"></iframe>
|
248 |
+
</div>
|
249 |
+
</div>
|
250 |
+
<div class="col-sm-3">
|
251 |
+
<div class="video-header">VATT</div>
|
252 |
+
<div class="video-container">
|
253 |
+
<iframe src="https://youtube.com/embed/eHmAGOmtDDg"></iframe>
|
254 |
+
</div>
|
255 |
+
</div>
|
256 |
+
<div class="col-sm-3">
|
257 |
+
<div class="video-header">V-AURA</div>
|
258 |
+
<div class="video-container">
|
259 |
+
<iframe src="https://youtube.com/embed/NEGa3krBrm0"></iframe>
|
260 |
+
</div>
|
261 |
+
</div>
|
262 |
+
<div class="col-sm-3">
|
263 |
+
<div class="video-header">Seeing and Hearing</div>
|
264 |
+
<div class="video-container">
|
265 |
+
<iframe src="https://youtube.com/embed/aO0EAXlwE7A"></iframe>
|
266 |
+
</div>
|
267 |
+
</div>
|
268 |
+
</div>
|
269 |
+
</div>
|
270 |
+
|
271 |
+
<div id="vgg5">
|
272 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
273 |
+
<p style="overflow: hidden;">
|
274 |
+
Example 5: Playing a string instrument.
|
275 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
276 |
+
</p>
|
277 |
+
|
278 |
+
<div class="row g-1">
|
279 |
+
<div class="col-sm-3">
|
280 |
+
<div class="video-header">Ground-truth</div>
|
281 |
+
<div class="video-container">
|
282 |
+
<iframe src="https://youtube.com/embed/KP1QhWauIOc"></iframe>
|
283 |
+
</div>
|
284 |
+
</div>
|
285 |
+
<div class="col-sm-3">
|
286 |
+
<div class="video-header">Ours</div>
|
287 |
+
<div class="video-container">
|
288 |
+
<iframe src="https://youtube.com/embed/ovaJhWSquYE"></iframe>
|
289 |
+
</div>
|
290 |
+
</div>
|
291 |
+
<div class="col-sm-3">
|
292 |
+
<div class="video-header">V2A-Mapper</div>
|
293 |
+
<div class="video-container">
|
294 |
+
<iframe src="https://youtube.com/embed/N723FS9lcy8"></iframe>
|
295 |
+
</div>
|
296 |
+
</div>
|
297 |
+
<div class="col-sm-3">
|
298 |
+
<div class="video-header">FoleyCrafter</div>
|
299 |
+
<div class="video-container">
|
300 |
+
<iframe src="https://youtube.com/embed/t0N4ZAAXo58"></iframe>
|
301 |
+
</div>
|
302 |
+
</div>
|
303 |
+
</div>
|
304 |
+
<div class="row g-1">
|
305 |
+
<div class="col-sm-3">
|
306 |
+
<div class="video-header">Frieren</div>
|
307 |
+
<div class="video-container">
|
308 |
+
<iframe src="https://youtube.com/embed/8YSRs03QNNA"></iframe>
|
309 |
+
</div>
|
310 |
+
</div>
|
311 |
+
<div class="col-sm-3">
|
312 |
+
<div class="video-header">VATT</div>
|
313 |
+
<div class="video-container">
|
314 |
+
<iframe src="https://youtube.com/embed/vOpMz55J1kY"></iframe>
|
315 |
+
</div>
|
316 |
+
</div>
|
317 |
+
<div class="col-sm-3">
|
318 |
+
<div class="video-header">V-AURA</div>
|
319 |
+
<div class="video-container">
|
320 |
+
<iframe src="https://youtube.com/embed/9JHC75vr9h0"></iframe>
|
321 |
+
</div>
|
322 |
+
</div>
|
323 |
+
<div class="col-sm-3">
|
324 |
+
<div class="video-header">Seeing and Hearing</div>
|
325 |
+
<div class="video-container">
|
326 |
+
<iframe src="https://youtube.com/embed/9w0JckNzXmY"></iframe>
|
327 |
+
</div>
|
328 |
+
</div>
|
329 |
+
</div>
|
330 |
+
</div>
|
331 |
+
|
332 |
+
<div id="vgg6">
|
333 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
334 |
+
<p style="overflow: hidden;">
|
335 |
+
Example 6: A group of people playing tambourines.
|
336 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
337 |
+
</p>
|
338 |
+
|
339 |
+
<div class="row g-1">
|
340 |
+
<div class="col-sm-3">
|
341 |
+
<div class="video-header">Ground-truth</div>
|
342 |
+
<div class="video-container">
|
343 |
+
<iframe src="https://youtube.com/embed/mx6JLxzUkRc"></iframe>
|
344 |
+
</div>
|
345 |
+
</div>
|
346 |
+
<div class="col-sm-3">
|
347 |
+
<div class="video-header">Ours</div>
|
348 |
+
<div class="video-container">
|
349 |
+
<iframe src="https://youtube.com/embed/oLirHhP9Su8"></iframe>
|
350 |
+
</div>
|
351 |
+
</div>
|
352 |
+
<div class="col-sm-3">
|
353 |
+
<div class="video-header">V2A-Mapper</div>
|
354 |
+
<div class="video-container">
|
355 |
+
<iframe src="https://youtube.com/embed/HkLkHMqptv0"></iframe>
|
356 |
+
</div>
|
357 |
+
</div>
|
358 |
+
<div class="col-sm-3">
|
359 |
+
<div class="video-header">FoleyCrafter</div>
|
360 |
+
<div class="video-container">
|
361 |
+
<iframe src="https://youtube.com/embed/rpHiiODjmNU"></iframe>
|
362 |
+
</div>
|
363 |
+
</div>
|
364 |
+
</div>
|
365 |
+
<div class="row g-1">
|
366 |
+
<div class="col-sm-3">
|
367 |
+
<div class="video-header">Frieren</div>
|
368 |
+
<div class="video-container">
|
369 |
+
<iframe src="https://youtube.com/embed/1mVD3fJ0LpM"></iframe>
|
370 |
+
</div>
|
371 |
+
</div>
|
372 |
+
<div class="col-sm-3">
|
373 |
+
<div class="video-header">VATT</div>
|
374 |
+
<div class="video-container">
|
375 |
+
<iframe src="https://youtube.com/embed/yjVFnJiEJlw"></iframe>
|
376 |
+
</div>
|
377 |
+
</div>
|
378 |
+
<div class="col-sm-3">
|
379 |
+
<div class="video-header">V-AURA</div>
|
380 |
+
<div class="video-container">
|
381 |
+
<iframe src="https://youtube.com/embed/neVeMSWtRkU"></iframe>
|
382 |
+
</div>
|
383 |
+
</div>
|
384 |
+
<div class="col-sm-3">
|
385 |
+
<div class="video-header">Seeing and Hearing</div>
|
386 |
+
<div class="video-container">
|
387 |
+
<iframe src="https://youtube.com/embed/EUE7YwyVWz8"></iframe>
|
388 |
+
</div>
|
389 |
+
</div>
|
390 |
+
</div>
|
391 |
+
</div>
|
392 |
+
|
393 |
+
<div id="vgg_extra">
|
394 |
+
<h2 style="text-align: center;">Comparisons with state-of-the-art methods in VGGSound</h2>
|
395 |
+
<p style="overflow: hidden;">
|
396 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
397 |
+
</p>
|
398 |
+
|
399 |
+
<div class="row g-1">
|
400 |
+
<div class="col-sm-3">
|
401 |
+
<div class="video-header">Moving train</div>
|
402 |
+
<div class="video-container">
|
403 |
+
<iframe src="https://youtube.com/embed/Ta6H45rBzJc"></iframe>
|
404 |
+
</div>
|
405 |
+
</div>
|
406 |
+
<div class="col-sm-3">
|
407 |
+
<div class="video-header">Water splashing</div>
|
408 |
+
<div class="video-container">
|
409 |
+
<iframe src="https://youtube.com/embed/hl6AtgHXpb4"></iframe>
|
410 |
+
</div>
|
411 |
+
</div>
|
412 |
+
<div class="col-sm-3">
|
413 |
+
<div class="video-header">Skateboarding</div>
|
414 |
+
<div class="video-container">
|
415 |
+
<iframe src="https://youtube.com/embed/n4sCNi_9buI"></iframe>
|
416 |
+
</div>
|
417 |
+
</div>
|
418 |
+
<div class="col-sm-3">
|
419 |
+
<div class="video-header">Synchronized clapping</div>
|
420 |
+
<div class="video-container">
|
421 |
+
<iframe src="https://youtube.com/embed/oxexfpLn7FE"></iframe>
|
422 |
+
</div>
|
423 |
+
</div>
|
424 |
+
</div>
|
425 |
+
|
426 |
+
<br><br>
|
427 |
+
|
428 |
+
<div id="extra-failure">
|
429 |
+
<h2 style="text-align: center;">Failure cases</h2>
|
430 |
+
<p style="overflow: hidden;">
|
431 |
+
<span style="float:right;"><a href="#index">Back to index</a></span>
|
432 |
+
</p>
|
433 |
+
|
434 |
+
<div class="row g-1">
|
435 |
+
<div class="col-sm-6">
|
436 |
+
<div class="video-header">Human speech</div>
|
437 |
+
<div class="video-container">
|
438 |
+
<iframe src="https://youtube.com/embed/nx0CyrDu70Y"></iframe>
|
439 |
+
</div>
|
440 |
+
</div>
|
441 |
+
<div class="col-sm-6">
|
442 |
+
<div class="video-header">Unfamiliar vision input</div>
|
443 |
+
<div class="video-container">
|
444 |
+
<iframe src="https://youtube.com/embed/hfnAqmK3X7w"></iframe>
|
445 |
+
</div>
|
446 |
+
</div>
|
447 |
+
</div>
|
448 |
+
</div>
|
449 |
+
</div>
|
450 |
+
|
451 |
+
</body>
|
452 |
+
</html>
|
gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
gradio_demo.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
from argparse import ArgumentParser
|
4 |
+
from datetime import datetime
|
5 |
+
from fractions import Fraction
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import torch
|
10 |
+
import torchaudio
|
11 |
+
|
12 |
+
from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image,
|
13 |
+
load_video, make_video, setup_eval_logging)
|
14 |
+
from mmaudio.model.flow_matching import FlowMatching
|
15 |
+
from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
16 |
+
from mmaudio.model.sequence_config import SequenceConfig
|
17 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
18 |
+
|
19 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
20 |
+
torch.backends.cudnn.allow_tf32 = True
|
21 |
+
|
22 |
+
log = logging.getLogger()
|
23 |
+
|
24 |
+
device = 'cpu'
|
25 |
+
if torch.cuda.is_available():
|
26 |
+
device = 'cuda'
|
27 |
+
elif torch.backends.mps.is_available():
|
28 |
+
device = 'mps'
|
29 |
+
else:
|
30 |
+
log.warning('CUDA/MPS are not available, running on CPU')
|
31 |
+
dtype = torch.bfloat16
|
32 |
+
|
33 |
+
model: ModelConfig = all_model_cfg['large_44k_v2']
|
34 |
+
model.download_if_needed()
|
35 |
+
output_dir = Path('./output/gradio')
|
36 |
+
|
37 |
+
setup_eval_logging()
|
38 |
+
|
39 |
+
|
40 |
+
def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
|
41 |
+
seq_cfg = model.seq_cfg
|
42 |
+
|
43 |
+
net: MMAudio = get_my_mmaudio(model.model_name).to(device, dtype).eval()
|
44 |
+
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
|
45 |
+
log.info(f'Loaded weights from {model.model_path}')
|
46 |
+
|
47 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
|
48 |
+
synchformer_ckpt=model.synchformer_ckpt,
|
49 |
+
enable_conditions=True,
|
50 |
+
mode=model.mode,
|
51 |
+
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
|
52 |
+
need_vae_encoder=False)
|
53 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
54 |
+
|
55 |
+
return net, feature_utils, seq_cfg
|
56 |
+
|
57 |
+
|
58 |
+
net, feature_utils, seq_cfg = get_model()
|
59 |
+
|
60 |
+
|
61 |
+
@torch.inference_mode()
|
62 |
+
def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
63 |
+
cfg_strength: float, duration: float):
|
64 |
+
|
65 |
+
rng = torch.Generator(device=device)
|
66 |
+
if seed >= 0:
|
67 |
+
rng.manual_seed(seed)
|
68 |
+
else:
|
69 |
+
rng.seed()
|
70 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
71 |
+
|
72 |
+
video_info = load_video(video, duration)
|
73 |
+
clip_frames = video_info.clip_frames
|
74 |
+
sync_frames = video_info.sync_frames
|
75 |
+
duration = video_info.duration_sec
|
76 |
+
clip_frames = clip_frames.unsqueeze(0)
|
77 |
+
sync_frames = sync_frames.unsqueeze(0)
|
78 |
+
seq_cfg.duration = duration
|
79 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
80 |
+
|
81 |
+
audios = generate(clip_frames,
|
82 |
+
sync_frames, [prompt],
|
83 |
+
negative_text=[negative_prompt],
|
84 |
+
feature_utils=feature_utils,
|
85 |
+
net=net,
|
86 |
+
fm=fm,
|
87 |
+
rng=rng,
|
88 |
+
cfg_strength=cfg_strength)
|
89 |
+
audio = audios.float().cpu()[0]
|
90 |
+
|
91 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
92 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
93 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
94 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
95 |
+
gc.collect()
|
96 |
+
return video_save_path
|
97 |
+
|
98 |
+
|
99 |
+
@torch.inference_mode()
|
100 |
+
def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int,
|
101 |
+
cfg_strength: float, duration: float):
|
102 |
+
|
103 |
+
rng = torch.Generator(device=device)
|
104 |
+
if seed >= 0:
|
105 |
+
rng.manual_seed(seed)
|
106 |
+
else:
|
107 |
+
rng.seed()
|
108 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
109 |
+
|
110 |
+
image_info = load_image(image)
|
111 |
+
clip_frames = image_info.clip_frames
|
112 |
+
sync_frames = image_info.sync_frames
|
113 |
+
clip_frames = clip_frames.unsqueeze(0)
|
114 |
+
sync_frames = sync_frames.unsqueeze(0)
|
115 |
+
seq_cfg.duration = duration
|
116 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
117 |
+
|
118 |
+
audios = generate(clip_frames,
|
119 |
+
sync_frames, [prompt],
|
120 |
+
negative_text=[negative_prompt],
|
121 |
+
feature_utils=feature_utils,
|
122 |
+
net=net,
|
123 |
+
fm=fm,
|
124 |
+
rng=rng,
|
125 |
+
cfg_strength=cfg_strength,
|
126 |
+
image_input=True)
|
127 |
+
audio = audios.float().cpu()[0]
|
128 |
+
|
129 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
130 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
131 |
+
video_save_path = output_dir / f'{current_time_string}.mp4'
|
132 |
+
video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1))
|
133 |
+
make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
134 |
+
gc.collect()
|
135 |
+
return video_save_path
|
136 |
+
|
137 |
+
|
138 |
+
@torch.inference_mode()
|
139 |
+
def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float,
|
140 |
+
duration: float):
|
141 |
+
|
142 |
+
rng = torch.Generator(device=device)
|
143 |
+
if seed >= 0:
|
144 |
+
rng.manual_seed(seed)
|
145 |
+
else:
|
146 |
+
rng.seed()
|
147 |
+
fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
|
148 |
+
|
149 |
+
clip_frames = sync_frames = None
|
150 |
+
seq_cfg.duration = duration
|
151 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
152 |
+
|
153 |
+
audios = generate(clip_frames,
|
154 |
+
sync_frames, [prompt],
|
155 |
+
negative_text=[negative_prompt],
|
156 |
+
feature_utils=feature_utils,
|
157 |
+
net=net,
|
158 |
+
fm=fm,
|
159 |
+
rng=rng,
|
160 |
+
cfg_strength=cfg_strength)
|
161 |
+
audio = audios.float().cpu()[0]
|
162 |
+
|
163 |
+
current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S')
|
164 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
165 |
+
audio_save_path = output_dir / f'{current_time_string}.flac'
|
166 |
+
torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate)
|
167 |
+
gc.collect()
|
168 |
+
return audio_save_path
|
169 |
+
|
170 |
+
|
171 |
+
video_to_audio_tab = gr.Interface(
|
172 |
+
fn=video_to_audio,
|
173 |
+
description="""
|
174 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
175 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
176 |
+
|
177 |
+
NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side).
|
178 |
+
Doing so does not improve results.
|
179 |
+
""",
|
180 |
+
inputs=[
|
181 |
+
gr.Video(),
|
182 |
+
gr.Text(label='Prompt'),
|
183 |
+
gr.Text(label='Negative prompt', value='music'),
|
184 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
185 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
186 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
187 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
188 |
+
],
|
189 |
+
outputs='playable_video',
|
190 |
+
cache_examples=False,
|
191 |
+
title='MMAudio — Video-to-Audio Synthesis',
|
192 |
+
examples=[
|
193 |
+
[
|
194 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_beach.mp4',
|
195 |
+
'waves, seagulls',
|
196 |
+
'',
|
197 |
+
0,
|
198 |
+
25,
|
199 |
+
4.5,
|
200 |
+
10,
|
201 |
+
],
|
202 |
+
[
|
203 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_serpent.mp4',
|
204 |
+
'',
|
205 |
+
'music',
|
206 |
+
0,
|
207 |
+
25,
|
208 |
+
4.5,
|
209 |
+
10,
|
210 |
+
],
|
211 |
+
[
|
212 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_seahorse.mp4',
|
213 |
+
'bubbles',
|
214 |
+
'',
|
215 |
+
0,
|
216 |
+
25,
|
217 |
+
4.5,
|
218 |
+
10,
|
219 |
+
],
|
220 |
+
[
|
221 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_india.mp4',
|
222 |
+
'Indian holy music',
|
223 |
+
'',
|
224 |
+
0,
|
225 |
+
25,
|
226 |
+
4.5,
|
227 |
+
10,
|
228 |
+
],
|
229 |
+
[
|
230 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_galloping.mp4',
|
231 |
+
'galloping',
|
232 |
+
'',
|
233 |
+
0,
|
234 |
+
25,
|
235 |
+
4.5,
|
236 |
+
10,
|
237 |
+
],
|
238 |
+
[
|
239 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_kraken.mp4',
|
240 |
+
'waves, storm',
|
241 |
+
'',
|
242 |
+
0,
|
243 |
+
25,
|
244 |
+
4.5,
|
245 |
+
10,
|
246 |
+
],
|
247 |
+
[
|
248 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/mochi_storm.mp4',
|
249 |
+
'storm',
|
250 |
+
'',
|
251 |
+
0,
|
252 |
+
25,
|
253 |
+
4.5,
|
254 |
+
10,
|
255 |
+
],
|
256 |
+
[
|
257 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_spring.mp4',
|
258 |
+
'',
|
259 |
+
'',
|
260 |
+
0,
|
261 |
+
25,
|
262 |
+
4.5,
|
263 |
+
10,
|
264 |
+
],
|
265 |
+
[
|
266 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_typing.mp4',
|
267 |
+
'typing',
|
268 |
+
'',
|
269 |
+
0,
|
270 |
+
25,
|
271 |
+
4.5,
|
272 |
+
10,
|
273 |
+
],
|
274 |
+
[
|
275 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/hunyuan_wake_up.mp4',
|
276 |
+
'',
|
277 |
+
'',
|
278 |
+
0,
|
279 |
+
25,
|
280 |
+
4.5,
|
281 |
+
10,
|
282 |
+
],
|
283 |
+
[
|
284 |
+
'https://huggingface.co/hkchengrex/MMAudio/resolve/main/examples/sora_nyc.mp4',
|
285 |
+
'',
|
286 |
+
'',
|
287 |
+
0,
|
288 |
+
25,
|
289 |
+
4.5,
|
290 |
+
10,
|
291 |
+
],
|
292 |
+
])
|
293 |
+
|
294 |
+
text_to_audio_tab = gr.Interface(
|
295 |
+
fn=text_to_audio,
|
296 |
+
description="""
|
297 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
298 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
299 |
+
""",
|
300 |
+
inputs=[
|
301 |
+
gr.Text(label='Prompt'),
|
302 |
+
gr.Text(label='Negative prompt'),
|
303 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
304 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
305 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
306 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
307 |
+
],
|
308 |
+
outputs='audio',
|
309 |
+
cache_examples=False,
|
310 |
+
title='MMAudio — Text-to-Audio Synthesis',
|
311 |
+
)
|
312 |
+
|
313 |
+
image_to_audio_tab = gr.Interface(
|
314 |
+
fn=image_to_audio,
|
315 |
+
description="""
|
316 |
+
Project page: <a href="https://hkchengrex.com/MMAudio/">https://hkchengrex.com/MMAudio/</a><br>
|
317 |
+
Code: <a href="https://github.com/hkchengrex/MMAudio">https://github.com/hkchengrex/MMAudio</a><br>
|
318 |
+
|
319 |
+
NOTE: It takes longer to process high-resolution images (>384 px on the shorter side).
|
320 |
+
Doing so does not improve results.
|
321 |
+
""",
|
322 |
+
inputs=[
|
323 |
+
gr.Image(type='filepath'),
|
324 |
+
gr.Text(label='Prompt'),
|
325 |
+
gr.Text(label='Negative prompt'),
|
326 |
+
gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1),
|
327 |
+
gr.Number(label='Num steps', value=25, precision=0, minimum=1),
|
328 |
+
gr.Number(label='Guidance Strength', value=4.5, minimum=1),
|
329 |
+
gr.Number(label='Duration (sec)', value=8, minimum=1),
|
330 |
+
],
|
331 |
+
outputs='playable_video',
|
332 |
+
cache_examples=False,
|
333 |
+
title='MMAudio — Image-to-Audio Synthesis (experimental)',
|
334 |
+
)
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
parser = ArgumentParser()
|
338 |
+
parser.add_argument('--port', type=int, default=7860)
|
339 |
+
args = parser.parse_args()
|
340 |
+
|
341 |
+
gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab],
|
342 |
+
['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']).launch(
|
343 |
+
server_port=args.port, allowed_paths=[output_dir])
|
mmaudio/__init__.py
ADDED
File without changes
|
mmaudio/data/__init__.py
ADDED
File without changes
|
mmaudio/data/av_utils.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from fractions import Fraction
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import av
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from av import AudioFrame
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class VideoInfo:
|
14 |
+
duration_sec: float
|
15 |
+
fps: Fraction
|
16 |
+
clip_frames: torch.Tensor
|
17 |
+
sync_frames: torch.Tensor
|
18 |
+
all_frames: Optional[list[np.ndarray]]
|
19 |
+
|
20 |
+
@property
|
21 |
+
def height(self):
|
22 |
+
return self.all_frames[0].shape[0]
|
23 |
+
|
24 |
+
@property
|
25 |
+
def width(self):
|
26 |
+
return self.all_frames[0].shape[1]
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
|
30 |
+
fps: Fraction) -> 'VideoInfo':
|
31 |
+
num_frames = int(duration_sec * fps)
|
32 |
+
all_frames = [image_info.original_frame] * num_frames
|
33 |
+
return cls(duration_sec=duration_sec,
|
34 |
+
fps=fps,
|
35 |
+
clip_frames=image_info.clip_frames,
|
36 |
+
sync_frames=image_info.sync_frames,
|
37 |
+
all_frames=all_frames)
|
38 |
+
|
39 |
+
|
40 |
+
@dataclass
|
41 |
+
class ImageInfo:
|
42 |
+
clip_frames: torch.Tensor
|
43 |
+
sync_frames: torch.Tensor
|
44 |
+
original_frame: Optional[np.ndarray]
|
45 |
+
|
46 |
+
@property
|
47 |
+
def height(self):
|
48 |
+
return self.original_frame.shape[0]
|
49 |
+
|
50 |
+
@property
|
51 |
+
def width(self):
|
52 |
+
return self.original_frame.shape[1]
|
53 |
+
|
54 |
+
|
55 |
+
def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
|
56 |
+
need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
|
57 |
+
output_frames = [[] for _ in list_of_fps]
|
58 |
+
next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
|
59 |
+
time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
|
60 |
+
all_frames = []
|
61 |
+
|
62 |
+
# container = av.open(video_path)
|
63 |
+
with av.open(video_path) as container:
|
64 |
+
stream = container.streams.video[0]
|
65 |
+
fps = stream.guessed_rate
|
66 |
+
stream.thread_type = 'AUTO'
|
67 |
+
for packet in container.demux(stream):
|
68 |
+
for frame in packet.decode():
|
69 |
+
frame_time = frame.time
|
70 |
+
if frame_time < start_sec:
|
71 |
+
continue
|
72 |
+
if frame_time > end_sec:
|
73 |
+
break
|
74 |
+
|
75 |
+
frame_np = None
|
76 |
+
if need_all_frames:
|
77 |
+
frame_np = frame.to_ndarray(format='rgb24')
|
78 |
+
all_frames.append(frame_np)
|
79 |
+
|
80 |
+
for i, _ in enumerate(list_of_fps):
|
81 |
+
this_time = frame_time
|
82 |
+
while this_time >= next_frame_time_for_each_fps[i]:
|
83 |
+
if frame_np is None:
|
84 |
+
frame_np = frame.to_ndarray(format='rgb24')
|
85 |
+
|
86 |
+
output_frames[i].append(frame_np)
|
87 |
+
next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
|
88 |
+
|
89 |
+
output_frames = [np.stack(frames) for frames in output_frames]
|
90 |
+
return output_frames, all_frames, fps
|
91 |
+
|
92 |
+
|
93 |
+
def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
|
94 |
+
sampling_rate: int):
|
95 |
+
container = av.open(output_path, 'w')
|
96 |
+
output_video_stream = container.add_stream('h264', video_info.fps)
|
97 |
+
output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
|
98 |
+
output_video_stream.width = video_info.width
|
99 |
+
output_video_stream.height = video_info.height
|
100 |
+
output_video_stream.pix_fmt = 'yuv420p'
|
101 |
+
|
102 |
+
output_audio_stream = container.add_stream('aac', sampling_rate)
|
103 |
+
|
104 |
+
# encode video
|
105 |
+
for image in video_info.all_frames:
|
106 |
+
image = av.VideoFrame.from_ndarray(image)
|
107 |
+
packet = output_video_stream.encode(image)
|
108 |
+
container.mux(packet)
|
109 |
+
|
110 |
+
for packet in output_video_stream.encode():
|
111 |
+
container.mux(packet)
|
112 |
+
|
113 |
+
# convert float tensor audio to numpy array
|
114 |
+
audio_np = audio.numpy().astype(np.float32)
|
115 |
+
audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
116 |
+
audio_frame.sample_rate = sampling_rate
|
117 |
+
|
118 |
+
for packet in output_audio_stream.encode(audio_frame):
|
119 |
+
container.mux(packet)
|
120 |
+
|
121 |
+
for packet in output_audio_stream.encode():
|
122 |
+
container.mux(packet)
|
123 |
+
|
124 |
+
container.close()
|
125 |
+
|
126 |
+
|
127 |
+
def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
|
128 |
+
"""
|
129 |
+
NOTE: I don't think we can get the exact video duration right without re-encoding
|
130 |
+
so we are not using this but keeping it here for reference
|
131 |
+
"""
|
132 |
+
video = av.open(video_path)
|
133 |
+
output = av.open(output_path, 'w')
|
134 |
+
input_video_stream = video.streams.video[0]
|
135 |
+
output_video_stream = output.add_stream(template=input_video_stream)
|
136 |
+
output_audio_stream = output.add_stream('aac', sampling_rate)
|
137 |
+
|
138 |
+
duration_sec = audio.shape[-1] / sampling_rate
|
139 |
+
|
140 |
+
for packet in video.demux(input_video_stream):
|
141 |
+
# We need to skip the "flushing" packets that `demux` generates.
|
142 |
+
if packet.dts is None:
|
143 |
+
continue
|
144 |
+
# We need to assign the packet to the new stream.
|
145 |
+
packet.stream = output_video_stream
|
146 |
+
output.mux(packet)
|
147 |
+
|
148 |
+
# convert float tensor audio to numpy array
|
149 |
+
audio_np = audio.numpy().astype(np.float32)
|
150 |
+
audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
|
151 |
+
audio_frame.sample_rate = sampling_rate
|
152 |
+
|
153 |
+
for packet in output_audio_stream.encode(audio_frame):
|
154 |
+
output.mux(packet)
|
155 |
+
|
156 |
+
for packet in output_audio_stream.encode():
|
157 |
+
output.mux(packet)
|
158 |
+
|
159 |
+
video.close()
|
160 |
+
output.close()
|
161 |
+
|
162 |
+
output.close()
|
mmaudio/data/data_setup.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from omegaconf import DictConfig
|
7 |
+
from torch.utils.data import DataLoader, Dataset
|
8 |
+
from torch.utils.data.dataloader import default_collate
|
9 |
+
from torch.utils.data.distributed import DistributedSampler
|
10 |
+
|
11 |
+
from mmaudio.data.eval.audiocaps import AudioCapsData
|
12 |
+
from mmaudio.data.eval.video_dataset import MovieGen, VGGSound
|
13 |
+
from mmaudio.data.extracted_audio import ExtractedAudio
|
14 |
+
from mmaudio.data.extracted_vgg import ExtractedVGG
|
15 |
+
from mmaudio.data.mm_dataset import MultiModalDataset
|
16 |
+
from mmaudio.utils.dist_utils import local_rank
|
17 |
+
|
18 |
+
log = logging.getLogger()
|
19 |
+
|
20 |
+
|
21 |
+
# Re-seed randomness every time we start a worker
|
22 |
+
def worker_init_fn(worker_id: int):
|
23 |
+
worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
|
24 |
+
np.random.seed(worker_seed)
|
25 |
+
random.seed(worker_seed)
|
26 |
+
log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
|
27 |
+
|
28 |
+
|
29 |
+
def load_vgg_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
30 |
+
dataset = ExtractedVGG(tsv_path=data_cfg.tsv,
|
31 |
+
data_dim=cfg.data_dim,
|
32 |
+
premade_mmap_dir=data_cfg.memmap_dir)
|
33 |
+
|
34 |
+
return dataset
|
35 |
+
|
36 |
+
|
37 |
+
def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
|
38 |
+
dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
|
39 |
+
data_dim=cfg.data_dim,
|
40 |
+
premade_mmap_dir=data_cfg.memmap_dir)
|
41 |
+
|
42 |
+
return dataset
|
43 |
+
|
44 |
+
|
45 |
+
def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
|
46 |
+
if cfg.mini_train:
|
47 |
+
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
48 |
+
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
49 |
+
dataset = MultiModalDataset([vgg], [audiocaps])
|
50 |
+
if cfg.example_train:
|
51 |
+
video = load_vgg_data(cfg, cfg.data.Example_video)
|
52 |
+
audio = load_audio_data(cfg, cfg.data.Example_audio)
|
53 |
+
dataset = MultiModalDataset([video], [audio])
|
54 |
+
else:
|
55 |
+
# load the largest one first
|
56 |
+
freesound = load_audio_data(cfg, cfg.data.FreeSound)
|
57 |
+
vgg = load_vgg_data(cfg, cfg.data.ExtractedVGG)
|
58 |
+
audiocaps = load_audio_data(cfg, cfg.data.AudioCaps)
|
59 |
+
audioset_sl = load_audio_data(cfg, cfg.data.AudioSetSL)
|
60 |
+
bbcsound = load_audio_data(cfg, cfg.data.BBCSound)
|
61 |
+
clotho = load_audio_data(cfg, cfg.data.Clotho)
|
62 |
+
dataset = MultiModalDataset([vgg] * cfg.vgg_oversample_rate,
|
63 |
+
[audiocaps, audioset_sl, bbcsound, freesound, clotho])
|
64 |
+
|
65 |
+
batch_size = cfg.batch_size
|
66 |
+
num_workers = cfg.num_workers
|
67 |
+
pin_memory = cfg.pin_memory
|
68 |
+
sampler, loader = construct_loader(dataset,
|
69 |
+
batch_size,
|
70 |
+
num_workers,
|
71 |
+
shuffle=True,
|
72 |
+
drop_last=True,
|
73 |
+
pin_memory=pin_memory)
|
74 |
+
|
75 |
+
return dataset, sampler, loader
|
76 |
+
|
77 |
+
|
78 |
+
def setup_test_datasets(cfg):
|
79 |
+
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_test)
|
80 |
+
|
81 |
+
batch_size = cfg.batch_size
|
82 |
+
num_workers = cfg.num_workers
|
83 |
+
pin_memory = cfg.pin_memory
|
84 |
+
sampler, loader = construct_loader(dataset,
|
85 |
+
batch_size,
|
86 |
+
num_workers,
|
87 |
+
shuffle=False,
|
88 |
+
drop_last=False,
|
89 |
+
pin_memory=pin_memory)
|
90 |
+
|
91 |
+
return dataset, sampler, loader
|
92 |
+
|
93 |
+
|
94 |
+
def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
|
95 |
+
if cfg.example_train:
|
96 |
+
dataset = load_vgg_data(cfg, cfg.data.Example_video)
|
97 |
+
else:
|
98 |
+
dataset = load_vgg_data(cfg, cfg.data.ExtractedVGG_val)
|
99 |
+
|
100 |
+
val_batch_size = cfg.batch_size
|
101 |
+
val_eval_batch_size = cfg.eval_batch_size
|
102 |
+
num_workers = cfg.num_workers
|
103 |
+
pin_memory = cfg.pin_memory
|
104 |
+
_, val_loader = construct_loader(dataset,
|
105 |
+
val_batch_size,
|
106 |
+
num_workers,
|
107 |
+
shuffle=False,
|
108 |
+
drop_last=False,
|
109 |
+
pin_memory=pin_memory)
|
110 |
+
_, eval_loader = construct_loader(dataset,
|
111 |
+
val_eval_batch_size,
|
112 |
+
num_workers,
|
113 |
+
shuffle=False,
|
114 |
+
drop_last=False,
|
115 |
+
pin_memory=pin_memory)
|
116 |
+
|
117 |
+
return dataset, val_loader, eval_loader
|
118 |
+
|
119 |
+
|
120 |
+
def setup_eval_dataset(dataset_name: str, cfg: DictConfig) -> tuple[Dataset, DataLoader]:
|
121 |
+
if dataset_name.startswith('audiocaps_full'):
|
122 |
+
dataset = AudioCapsData(cfg.eval_data.AudioCaps_full.audio_path,
|
123 |
+
cfg.eval_data.AudioCaps_full.csv_path)
|
124 |
+
elif dataset_name.startswith('audiocaps'):
|
125 |
+
dataset = AudioCapsData(cfg.eval_data.AudioCaps.audio_path,
|
126 |
+
cfg.eval_data.AudioCaps.csv_path)
|
127 |
+
elif dataset_name.startswith('moviegen'):
|
128 |
+
dataset = MovieGen(cfg.eval_data.MovieGen.video_path,
|
129 |
+
cfg.eval_data.MovieGen.jsonl_path,
|
130 |
+
duration_sec=cfg.duration_s)
|
131 |
+
elif dataset_name.startswith('vggsound'):
|
132 |
+
dataset = VGGSound(cfg.eval_data.VGGSound.video_path,
|
133 |
+
cfg.eval_data.VGGSound.csv_path,
|
134 |
+
duration_sec=cfg.duration_s)
|
135 |
+
else:
|
136 |
+
raise ValueError(f'Invalid dataset name: {dataset_name}')
|
137 |
+
|
138 |
+
batch_size = cfg.batch_size
|
139 |
+
num_workers = cfg.num_workers
|
140 |
+
pin_memory = cfg.pin_memory
|
141 |
+
_, loader = construct_loader(dataset,
|
142 |
+
batch_size,
|
143 |
+
num_workers,
|
144 |
+
shuffle=False,
|
145 |
+
drop_last=False,
|
146 |
+
pin_memory=pin_memory,
|
147 |
+
error_avoidance=True)
|
148 |
+
return dataset, loader
|
149 |
+
|
150 |
+
|
151 |
+
def error_avoidance_collate(batch):
|
152 |
+
batch = list(filter(lambda x: x is not None, batch))
|
153 |
+
return default_collate(batch)
|
154 |
+
|
155 |
+
|
156 |
+
def construct_loader(dataset: Dataset,
|
157 |
+
batch_size: int,
|
158 |
+
num_workers: int,
|
159 |
+
*,
|
160 |
+
shuffle: bool = True,
|
161 |
+
drop_last: bool = True,
|
162 |
+
pin_memory: bool = False,
|
163 |
+
error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
|
164 |
+
train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
|
165 |
+
train_loader = DataLoader(dataset,
|
166 |
+
batch_size,
|
167 |
+
sampler=train_sampler,
|
168 |
+
num_workers=num_workers,
|
169 |
+
worker_init_fn=worker_init_fn,
|
170 |
+
drop_last=drop_last,
|
171 |
+
persistent_workers=num_workers > 0,
|
172 |
+
pin_memory=pin_memory,
|
173 |
+
collate_fn=error_avoidance_collate if error_avoidance else None)
|
174 |
+
return train_sampler, train_loader
|
mmaudio/data/eval/__init__.py
ADDED
File without changes
|
mmaudio/data/eval/audiocaps.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from collections import defaultdict
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
|
11 |
+
log = logging.getLogger()
|
12 |
+
|
13 |
+
|
14 |
+
class AudioCapsData(Dataset):
|
15 |
+
|
16 |
+
def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
|
17 |
+
df = pd.read_csv(csv_path).to_dict(orient='records')
|
18 |
+
|
19 |
+
audio_files = sorted(os.listdir(audio_path))
|
20 |
+
audio_files = set(
|
21 |
+
[Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
|
22 |
+
|
23 |
+
self.data = []
|
24 |
+
for row in df:
|
25 |
+
self.data.append({
|
26 |
+
'name': row['name'],
|
27 |
+
'caption': row['caption'],
|
28 |
+
})
|
29 |
+
|
30 |
+
self.audio_path = Path(audio_path)
|
31 |
+
self.csv_path = Path(csv_path)
|
32 |
+
|
33 |
+
log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
|
34 |
+
|
35 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
36 |
+
return self.data[idx]
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.data)
|
mmaudio/data/eval/moviegen.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
from torchvision.transforms import v2
|
10 |
+
from torio.io import StreamingMediaDecoder
|
11 |
+
|
12 |
+
from mmaudio.utils.dist_utils import local_rank
|
13 |
+
|
14 |
+
log = logging.getLogger()
|
15 |
+
|
16 |
+
_CLIP_SIZE = 384
|
17 |
+
_CLIP_FPS = 8.0
|
18 |
+
|
19 |
+
_SYNC_SIZE = 224
|
20 |
+
_SYNC_FPS = 25.0
|
21 |
+
|
22 |
+
|
23 |
+
class MovieGenData(Dataset):
|
24 |
+
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
video_root: Union[str, Path],
|
28 |
+
sync_root: Union[str, Path],
|
29 |
+
jsonl_root: Union[str, Path],
|
30 |
+
*,
|
31 |
+
duration_sec: float = 10.0,
|
32 |
+
read_clip: bool = True,
|
33 |
+
):
|
34 |
+
self.video_root = Path(video_root)
|
35 |
+
self.sync_root = Path(sync_root)
|
36 |
+
self.jsonl_root = Path(jsonl_root)
|
37 |
+
self.read_clip = read_clip
|
38 |
+
|
39 |
+
videos = sorted(os.listdir(self.video_root))
|
40 |
+
videos = [v[:-4] for v in videos] # remove extensions
|
41 |
+
self.captions = {}
|
42 |
+
|
43 |
+
for v in videos:
|
44 |
+
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
45 |
+
data = json.load(f)
|
46 |
+
self.captions[v] = data['audio_prompt']
|
47 |
+
|
48 |
+
if local_rank == 0:
|
49 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
50 |
+
|
51 |
+
self.duration_sec = duration_sec
|
52 |
+
|
53 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
54 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
55 |
+
|
56 |
+
self.clip_augment = v2.Compose([
|
57 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
58 |
+
v2.ToImage(),
|
59 |
+
v2.ToDtype(torch.float32, scale=True),
|
60 |
+
])
|
61 |
+
|
62 |
+
self.sync_augment = v2.Compose([
|
63 |
+
v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
64 |
+
v2.CenterCrop(_SYNC_SIZE),
|
65 |
+
v2.ToImage(),
|
66 |
+
v2.ToDtype(torch.float32, scale=True),
|
67 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
68 |
+
])
|
69 |
+
|
70 |
+
self.videos = videos
|
71 |
+
|
72 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
73 |
+
video_id = self.videos[idx]
|
74 |
+
caption = self.captions[video_id]
|
75 |
+
|
76 |
+
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
77 |
+
reader.add_basic_video_stream(
|
78 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
79 |
+
frame_rate=_CLIP_FPS,
|
80 |
+
format='rgb24',
|
81 |
+
)
|
82 |
+
reader.add_basic_video_stream(
|
83 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
84 |
+
frame_rate=_SYNC_FPS,
|
85 |
+
format='rgb24',
|
86 |
+
)
|
87 |
+
|
88 |
+
reader.fill_buffer()
|
89 |
+
data_chunk = reader.pop_chunks()
|
90 |
+
|
91 |
+
clip_chunk = data_chunk[0]
|
92 |
+
sync_chunk = data_chunk[1]
|
93 |
+
if clip_chunk is None:
|
94 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
95 |
+
if clip_chunk.shape[0] < self.clip_expected_length:
|
96 |
+
raise RuntimeError(f'CLIP video too short {video_id}')
|
97 |
+
|
98 |
+
if sync_chunk is None:
|
99 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
100 |
+
if sync_chunk.shape[0] < self.sync_expected_length:
|
101 |
+
raise RuntimeError(f'Sync video too short {video_id}')
|
102 |
+
|
103 |
+
# truncate the video
|
104 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
105 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
106 |
+
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
107 |
+
f'expected {self.clip_expected_length}, '
|
108 |
+
f'got {clip_chunk.shape[0]}')
|
109 |
+
clip_chunk = self.clip_augment(clip_chunk)
|
110 |
+
|
111 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
112 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
113 |
+
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
114 |
+
f'expected {self.sync_expected_length}, '
|
115 |
+
f'got {sync_chunk.shape[0]}')
|
116 |
+
sync_chunk = self.sync_augment(sync_chunk)
|
117 |
+
|
118 |
+
data = {
|
119 |
+
'name': video_id,
|
120 |
+
'caption': caption,
|
121 |
+
'clip_video': clip_chunk,
|
122 |
+
'sync_video': sync_chunk,
|
123 |
+
}
|
124 |
+
|
125 |
+
return data
|
126 |
+
|
127 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
128 |
+
return self.sample(idx)
|
129 |
+
|
130 |
+
def __len__(self):
|
131 |
+
return len(self.captions)
|
mmaudio/data/eval/video_dataset.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Union
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
|
13 |
+
from mmaudio.utils.dist_utils import local_rank
|
14 |
+
|
15 |
+
log = logging.getLogger()
|
16 |
+
|
17 |
+
_CLIP_SIZE = 384
|
18 |
+
_CLIP_FPS = 8.0
|
19 |
+
|
20 |
+
_SYNC_SIZE = 224
|
21 |
+
_SYNC_FPS = 25.0
|
22 |
+
|
23 |
+
|
24 |
+
class VideoDataset(Dataset):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
video_root: Union[str, Path],
|
29 |
+
*,
|
30 |
+
duration_sec: float = 8.0,
|
31 |
+
):
|
32 |
+
self.video_root = Path(video_root)
|
33 |
+
|
34 |
+
self.duration_sec = duration_sec
|
35 |
+
|
36 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
37 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
38 |
+
|
39 |
+
self.clip_transform = v2.Compose([
|
40 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
41 |
+
v2.ToImage(),
|
42 |
+
v2.ToDtype(torch.float32, scale=True),
|
43 |
+
])
|
44 |
+
|
45 |
+
self.sync_transform = v2.Compose([
|
46 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
47 |
+
v2.CenterCrop(_SYNC_SIZE),
|
48 |
+
v2.ToImage(),
|
49 |
+
v2.ToDtype(torch.float32, scale=True),
|
50 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
51 |
+
])
|
52 |
+
|
53 |
+
# to be implemented by subclasses
|
54 |
+
self.captions = {}
|
55 |
+
self.videos = sorted(list(self.captions.keys()))
|
56 |
+
|
57 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
58 |
+
video_id = self.videos[idx]
|
59 |
+
caption = self.captions[video_id]
|
60 |
+
|
61 |
+
reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
|
62 |
+
reader.add_basic_video_stream(
|
63 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
64 |
+
frame_rate=_CLIP_FPS,
|
65 |
+
format='rgb24',
|
66 |
+
)
|
67 |
+
reader.add_basic_video_stream(
|
68 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
69 |
+
frame_rate=_SYNC_FPS,
|
70 |
+
format='rgb24',
|
71 |
+
)
|
72 |
+
|
73 |
+
reader.fill_buffer()
|
74 |
+
data_chunk = reader.pop_chunks()
|
75 |
+
|
76 |
+
clip_chunk = data_chunk[0]
|
77 |
+
sync_chunk = data_chunk[1]
|
78 |
+
if clip_chunk is None:
|
79 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
80 |
+
if clip_chunk.shape[0] < self.clip_expected_length:
|
81 |
+
raise RuntimeError(
|
82 |
+
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
83 |
+
)
|
84 |
+
|
85 |
+
if sync_chunk is None:
|
86 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
87 |
+
if sync_chunk.shape[0] < self.sync_expected_length:
|
88 |
+
raise RuntimeError(
|
89 |
+
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
90 |
+
)
|
91 |
+
|
92 |
+
# truncate the video
|
93 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
94 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
95 |
+
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
96 |
+
f'expected {self.clip_expected_length}, '
|
97 |
+
f'got {clip_chunk.shape[0]}')
|
98 |
+
clip_chunk = self.clip_transform(clip_chunk)
|
99 |
+
|
100 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
101 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
102 |
+
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
103 |
+
f'expected {self.sync_expected_length}, '
|
104 |
+
f'got {sync_chunk.shape[0]}')
|
105 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
106 |
+
|
107 |
+
data = {
|
108 |
+
'name': video_id,
|
109 |
+
'caption': caption,
|
110 |
+
'clip_video': clip_chunk,
|
111 |
+
'sync_video': sync_chunk,
|
112 |
+
}
|
113 |
+
|
114 |
+
return data
|
115 |
+
|
116 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
117 |
+
try:
|
118 |
+
return self.sample(idx)
|
119 |
+
except Exception as e:
|
120 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
121 |
+
return None
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.captions)
|
125 |
+
|
126 |
+
|
127 |
+
class VGGSound(VideoDataset):
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
video_root: Union[str, Path],
|
132 |
+
csv_path: Union[str, Path],
|
133 |
+
*,
|
134 |
+
duration_sec: float = 8.0,
|
135 |
+
):
|
136 |
+
super().__init__(video_root, duration_sec=duration_sec)
|
137 |
+
self.video_root = Path(video_root)
|
138 |
+
self.csv_path = Path(csv_path)
|
139 |
+
|
140 |
+
videos = sorted(os.listdir(self.video_root))
|
141 |
+
if local_rank == 0:
|
142 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
143 |
+
self.captions = {}
|
144 |
+
|
145 |
+
df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
|
146 |
+
'split']).to_dict(orient='records')
|
147 |
+
|
148 |
+
videos_no_found = []
|
149 |
+
for row in df:
|
150 |
+
if row['split'] == 'test':
|
151 |
+
start_sec = int(row['sec'])
|
152 |
+
video_id = str(row['id'])
|
153 |
+
# this is how our videos are named
|
154 |
+
video_name = f'{video_id}_{start_sec:06d}'
|
155 |
+
if video_name + '.mp4' not in videos:
|
156 |
+
videos_no_found.append(video_name)
|
157 |
+
continue
|
158 |
+
|
159 |
+
self.captions[video_name] = row['caption']
|
160 |
+
|
161 |
+
if local_rank == 0:
|
162 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
163 |
+
log.info(f'{len(self.captions)} useable videos found')
|
164 |
+
if videos_no_found:
|
165 |
+
log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
|
166 |
+
log.info(
|
167 |
+
'A small amount is expected, as not all videos are still available on YouTube')
|
168 |
+
|
169 |
+
self.videos = sorted(list(self.captions.keys()))
|
170 |
+
|
171 |
+
|
172 |
+
class MovieGen(VideoDataset):
|
173 |
+
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
video_root: Union[str, Path],
|
177 |
+
jsonl_root: Union[str, Path],
|
178 |
+
*,
|
179 |
+
duration_sec: float = 10.0,
|
180 |
+
):
|
181 |
+
super().__init__(video_root, duration_sec=duration_sec)
|
182 |
+
self.video_root = Path(video_root)
|
183 |
+
self.jsonl_root = Path(jsonl_root)
|
184 |
+
|
185 |
+
videos = sorted(os.listdir(self.video_root))
|
186 |
+
videos = [v[:-4] for v in videos] # remove extensions
|
187 |
+
self.captions = {}
|
188 |
+
|
189 |
+
for v in videos:
|
190 |
+
with open(self.jsonl_root / (v + '.jsonl')) as f:
|
191 |
+
data = json.load(f)
|
192 |
+
self.captions[v] = data['audio_prompt']
|
193 |
+
|
194 |
+
if local_rank == 0:
|
195 |
+
log.info(f'{len(videos)} videos found in {video_root}')
|
196 |
+
|
197 |
+
self.videos = videos
|
mmaudio/data/extracted_audio.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from tensordict import TensorDict
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
|
10 |
+
from mmaudio.utils.dist_utils import local_rank
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class ExtractedAudio(Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
tsv_path: Union[str, Path],
|
20 |
+
*,
|
21 |
+
premade_mmap_dir: Union[str, Path],
|
22 |
+
data_dim: dict[str, int],
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.data_dim = data_dim
|
27 |
+
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
+
self.ids = [str(d['id']) for d in self.df_list]
|
29 |
+
|
30 |
+
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
+
# load precomputed memory mapped tensors
|
32 |
+
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
+
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
+
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
+
self.mean = td['mean']
|
36 |
+
self.std = td['std']
|
37 |
+
self.text_features = td['text_features']
|
38 |
+
|
39 |
+
log.info(f'Loaded {len(self)} samples from {premade_mmap_dir}.')
|
40 |
+
log.info(f'Loaded mean: {self.mean.shape}.')
|
41 |
+
log.info(f'Loaded std: {self.std.shape}.')
|
42 |
+
log.info(f'Loaded text features: {self.text_features.shape}.')
|
43 |
+
|
44 |
+
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
45 |
+
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
46 |
+
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
47 |
+
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
48 |
+
|
49 |
+
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
50 |
+
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
51 |
+
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
52 |
+
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
53 |
+
|
54 |
+
self.fake_clip_features = torch.zeros(self.data_dim['clip_seq_len'],
|
55 |
+
self.data_dim['clip_dim'])
|
56 |
+
self.fake_sync_features = torch.zeros(self.data_dim['sync_seq_len'],
|
57 |
+
self.data_dim['sync_dim'])
|
58 |
+
self.video_exist = torch.tensor(0, dtype=torch.bool)
|
59 |
+
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
60 |
+
|
61 |
+
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
62 |
+
latents = self.mean
|
63 |
+
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
64 |
+
|
65 |
+
def get_memory_mapped_tensor(self) -> TensorDict:
|
66 |
+
td = TensorDict({
|
67 |
+
'mean': self.mean,
|
68 |
+
'std': self.std,
|
69 |
+
'text_features': self.text_features,
|
70 |
+
})
|
71 |
+
return td
|
72 |
+
|
73 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
74 |
+
data = {
|
75 |
+
'id': str(self.df_list[idx]['id']),
|
76 |
+
'a_mean': self.mean[idx],
|
77 |
+
'a_std': self.std[idx],
|
78 |
+
'clip_features': self.fake_clip_features,
|
79 |
+
'sync_features': self.fake_sync_features,
|
80 |
+
'text_features': self.text_features[idx],
|
81 |
+
'caption': self.df_list[idx]['caption'],
|
82 |
+
'video_exist': self.video_exist,
|
83 |
+
'text_exist': self.text_exist,
|
84 |
+
}
|
85 |
+
return data
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.ids)
|
mmaudio/data/extracted_vgg.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Union
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from tensordict import TensorDict
|
8 |
+
from torch.utils.data.dataset import Dataset
|
9 |
+
|
10 |
+
from mmaudio.utils.dist_utils import local_rank
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class ExtractedVGG(Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
tsv_path: Union[str, Path],
|
20 |
+
*,
|
21 |
+
premade_mmap_dir: Union[str, Path],
|
22 |
+
data_dim: dict[str, int],
|
23 |
+
):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.data_dim = data_dim
|
27 |
+
self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records')
|
28 |
+
self.ids = [d['id'] for d in self.df_list]
|
29 |
+
|
30 |
+
log.info(f'Loading precomputed mmap from {premade_mmap_dir}')
|
31 |
+
# load precomputed memory mapped tensors
|
32 |
+
premade_mmap_dir = Path(premade_mmap_dir)
|
33 |
+
td = TensorDict.load_memmap(premade_mmap_dir)
|
34 |
+
log.info(f'Loaded precomputed mmap from {premade_mmap_dir}')
|
35 |
+
self.mean = td['mean']
|
36 |
+
self.std = td['std']
|
37 |
+
self.clip_features = td['clip_features']
|
38 |
+
self.sync_features = td['sync_features']
|
39 |
+
self.text_features = td['text_features']
|
40 |
+
|
41 |
+
if local_rank == 0:
|
42 |
+
log.info(f'Loaded {len(self)} samples.')
|
43 |
+
log.info(f'Loaded mean: {self.mean.shape}.')
|
44 |
+
log.info(f'Loaded std: {self.std.shape}.')
|
45 |
+
log.info(f'Loaded clip_features: {self.clip_features.shape}.')
|
46 |
+
log.info(f'Loaded sync_features: {self.sync_features.shape}.')
|
47 |
+
log.info(f'Loaded text_features: {self.text_features.shape}.')
|
48 |
+
|
49 |
+
assert self.mean.shape[1] == self.data_dim['latent_seq_len'], \
|
50 |
+
f'{self.mean.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
51 |
+
assert self.std.shape[1] == self.data_dim['latent_seq_len'], \
|
52 |
+
f'{self.std.shape[1]} != {self.data_dim["latent_seq_len"]}'
|
53 |
+
|
54 |
+
assert self.clip_features.shape[1] == self.data_dim['clip_seq_len'], \
|
55 |
+
f'{self.clip_features.shape[1]} != {self.data_dim["clip_seq_len"]}'
|
56 |
+
assert self.sync_features.shape[1] == self.data_dim['sync_seq_len'], \
|
57 |
+
f'{self.sync_features.shape[1]} != {self.data_dim["sync_seq_len"]}'
|
58 |
+
assert self.text_features.shape[1] == self.data_dim['text_seq_len'], \
|
59 |
+
f'{self.text_features.shape[1]} != {self.data_dim["text_seq_len"]}'
|
60 |
+
|
61 |
+
assert self.clip_features.shape[-1] == self.data_dim['clip_dim'], \
|
62 |
+
f'{self.clip_features.shape[-1]} != {self.data_dim["clip_dim"]}'
|
63 |
+
assert self.sync_features.shape[-1] == self.data_dim['sync_dim'], \
|
64 |
+
f'{self.sync_features.shape[-1]} != {self.data_dim["sync_dim"]}'
|
65 |
+
assert self.text_features.shape[-1] == self.data_dim['text_dim'], \
|
66 |
+
f'{self.text_features.shape[-1]} != {self.data_dim["text_dim"]}'
|
67 |
+
|
68 |
+
self.video_exist = torch.tensor(1, dtype=torch.bool)
|
69 |
+
self.text_exist = torch.tensor(1, dtype=torch.bool)
|
70 |
+
|
71 |
+
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
72 |
+
latents = self.mean
|
73 |
+
return latents.mean(dim=(0, 1)), latents.std(dim=(0, 1))
|
74 |
+
|
75 |
+
def get_memory_mapped_tensor(self) -> TensorDict:
|
76 |
+
td = TensorDict({
|
77 |
+
'mean': self.mean,
|
78 |
+
'std': self.std,
|
79 |
+
'clip_features': self.clip_features,
|
80 |
+
'sync_features': self.sync_features,
|
81 |
+
'text_features': self.text_features,
|
82 |
+
})
|
83 |
+
return td
|
84 |
+
|
85 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
86 |
+
data = {
|
87 |
+
'id': self.df_list[idx]['id'],
|
88 |
+
'a_mean': self.mean[idx],
|
89 |
+
'a_std': self.std[idx],
|
90 |
+
'clip_features': self.clip_features[idx],
|
91 |
+
'sync_features': self.sync_features[idx],
|
92 |
+
'text_features': self.text_features[idx],
|
93 |
+
'caption': self.df_list[idx]['label'],
|
94 |
+
'video_exist': self.video_exist,
|
95 |
+
'text_exist': self.text_exist,
|
96 |
+
}
|
97 |
+
|
98 |
+
return data
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self.ids)
|
mmaudio/data/extraction/__init__.py
ADDED
File without changes
|
mmaudio/data/extraction/vgg_sound.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional, Union
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import torch
|
8 |
+
import torchaudio
|
9 |
+
from torch.utils.data.dataset import Dataset
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
from torio.io import StreamingMediaDecoder
|
12 |
+
|
13 |
+
from mmaudio.utils.dist_utils import local_rank
|
14 |
+
|
15 |
+
log = logging.getLogger()
|
16 |
+
|
17 |
+
_CLIP_SIZE = 384
|
18 |
+
_CLIP_FPS = 8.0
|
19 |
+
|
20 |
+
_SYNC_SIZE = 224
|
21 |
+
_SYNC_FPS = 25.0
|
22 |
+
|
23 |
+
|
24 |
+
class VGGSound(Dataset):
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
root: Union[str, Path],
|
29 |
+
*,
|
30 |
+
tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
|
31 |
+
sample_rate: int = 16_000,
|
32 |
+
duration_sec: float = 8.0,
|
33 |
+
audio_samples: Optional[int] = None,
|
34 |
+
normalize_audio: bool = False,
|
35 |
+
):
|
36 |
+
self.root = Path(root)
|
37 |
+
self.normalize_audio = normalize_audio
|
38 |
+
if audio_samples is None:
|
39 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
40 |
+
else:
|
41 |
+
self.audio_samples = audio_samples
|
42 |
+
effective_duration = audio_samples / sample_rate
|
43 |
+
# make sure the duration is close enough, within 15ms
|
44 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
45 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
46 |
+
|
47 |
+
videos = sorted(os.listdir(self.root))
|
48 |
+
videos = set([Path(v).stem for v in videos]) # remove extensions
|
49 |
+
self.labels = {}
|
50 |
+
self.videos = []
|
51 |
+
missing_videos = []
|
52 |
+
|
53 |
+
# read the tsv for subset information
|
54 |
+
df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
|
55 |
+
for record in df_list:
|
56 |
+
id = record['id']
|
57 |
+
label = record['label']
|
58 |
+
if id in videos:
|
59 |
+
self.labels[id] = label
|
60 |
+
self.videos.append(id)
|
61 |
+
else:
|
62 |
+
missing_videos.append(id)
|
63 |
+
|
64 |
+
if local_rank == 0:
|
65 |
+
log.info(f'{len(videos)} videos found in {root}')
|
66 |
+
log.info(f'{len(self.videos)} videos found in {tsv_path}')
|
67 |
+
log.info(f'{len(missing_videos)} videos missing in {root}')
|
68 |
+
|
69 |
+
self.sample_rate = sample_rate
|
70 |
+
self.duration_sec = duration_sec
|
71 |
+
|
72 |
+
self.expected_audio_length = audio_samples
|
73 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
74 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
75 |
+
|
76 |
+
self.clip_transform = v2.Compose([
|
77 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
78 |
+
v2.ToImage(),
|
79 |
+
v2.ToDtype(torch.float32, scale=True),
|
80 |
+
])
|
81 |
+
|
82 |
+
self.sync_transform = v2.Compose([
|
83 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
84 |
+
v2.CenterCrop(_SYNC_SIZE),
|
85 |
+
v2.ToImage(),
|
86 |
+
v2.ToDtype(torch.float32, scale=True),
|
87 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
88 |
+
])
|
89 |
+
|
90 |
+
self.resampler = {}
|
91 |
+
|
92 |
+
def sample(self, idx: int) -> dict[str, torch.Tensor]:
|
93 |
+
video_id = self.videos[idx]
|
94 |
+
label = self.labels[video_id]
|
95 |
+
|
96 |
+
reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
|
97 |
+
reader.add_basic_video_stream(
|
98 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
99 |
+
frame_rate=_CLIP_FPS,
|
100 |
+
format='rgb24',
|
101 |
+
)
|
102 |
+
reader.add_basic_video_stream(
|
103 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
104 |
+
frame_rate=_SYNC_FPS,
|
105 |
+
format='rgb24',
|
106 |
+
)
|
107 |
+
reader.add_basic_audio_stream(frames_per_chunk=2**30, )
|
108 |
+
|
109 |
+
reader.fill_buffer()
|
110 |
+
data_chunk = reader.pop_chunks()
|
111 |
+
|
112 |
+
clip_chunk = data_chunk[0]
|
113 |
+
sync_chunk = data_chunk[1]
|
114 |
+
audio_chunk = data_chunk[2]
|
115 |
+
|
116 |
+
if clip_chunk is None:
|
117 |
+
raise RuntimeError(f'CLIP video returned None {video_id}')
|
118 |
+
if clip_chunk.shape[0] < self.clip_expected_length:
|
119 |
+
raise RuntimeError(
|
120 |
+
f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
|
121 |
+
)
|
122 |
+
|
123 |
+
if sync_chunk is None:
|
124 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
125 |
+
if sync_chunk.shape[0] < self.sync_expected_length:
|
126 |
+
raise RuntimeError(
|
127 |
+
f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
|
128 |
+
)
|
129 |
+
|
130 |
+
# process audio
|
131 |
+
sample_rate = int(reader.get_out_stream_info(2).sample_rate)
|
132 |
+
audio_chunk = audio_chunk.transpose(0, 1)
|
133 |
+
audio_chunk = audio_chunk.mean(dim=0) # mono
|
134 |
+
if self.normalize_audio:
|
135 |
+
abs_max = audio_chunk.abs().max()
|
136 |
+
audio_chunk = audio_chunk / abs_max * 0.95
|
137 |
+
if abs_max <= 1e-6:
|
138 |
+
raise RuntimeError(f'Audio is silent {video_id}')
|
139 |
+
|
140 |
+
# resample
|
141 |
+
if sample_rate == self.sample_rate:
|
142 |
+
audio_chunk = audio_chunk
|
143 |
+
else:
|
144 |
+
if sample_rate not in self.resampler:
|
145 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
146 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
147 |
+
sample_rate,
|
148 |
+
self.sample_rate,
|
149 |
+
lowpass_filter_width=64,
|
150 |
+
rolloff=0.9475937167399596,
|
151 |
+
resampling_method='sinc_interp_kaiser',
|
152 |
+
beta=14.769656459379492,
|
153 |
+
)
|
154 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
155 |
+
|
156 |
+
if audio_chunk.shape[0] < self.expected_audio_length:
|
157 |
+
raise RuntimeError(f'Audio too short {video_id}')
|
158 |
+
audio_chunk = audio_chunk[:self.expected_audio_length]
|
159 |
+
|
160 |
+
# truncate the video
|
161 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
162 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
163 |
+
raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
164 |
+
f'expected {self.clip_expected_length}, '
|
165 |
+
f'got {clip_chunk.shape[0]}')
|
166 |
+
clip_chunk = self.clip_transform(clip_chunk)
|
167 |
+
|
168 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
169 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
170 |
+
raise RuntimeError(f'Sync video wrong length {video_id}, '
|
171 |
+
f'expected {self.sync_expected_length}, '
|
172 |
+
f'got {sync_chunk.shape[0]}')
|
173 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
174 |
+
|
175 |
+
data = {
|
176 |
+
'id': video_id,
|
177 |
+
'caption': label,
|
178 |
+
'audio': audio_chunk,
|
179 |
+
'clip_video': clip_chunk,
|
180 |
+
'sync_video': sync_chunk,
|
181 |
+
}
|
182 |
+
|
183 |
+
return data
|
184 |
+
|
185 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
186 |
+
try:
|
187 |
+
return self.sample(idx)
|
188 |
+
except Exception as e:
|
189 |
+
log.error(f'Error loading video {self.videos[idx]}: {e}')
|
190 |
+
return None
|
191 |
+
|
192 |
+
def __len__(self):
|
193 |
+
return len(self.labels)
|
mmaudio/data/extraction/wav_dataset.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import open_clip
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
+
from torch.utils.data.dataset import Dataset
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
|
15 |
+
class WavTextClipsDataset(Dataset):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
root: Union[str, Path],
|
20 |
+
*,
|
21 |
+
captions_tsv: Union[str, Path],
|
22 |
+
clips_tsv: Union[str, Path],
|
23 |
+
sample_rate: int,
|
24 |
+
num_samples: int,
|
25 |
+
normalize_audio: bool = False,
|
26 |
+
reject_silent: bool = False,
|
27 |
+
tokenizer_id: str = 'ViT-H-14-378-quickgelu',
|
28 |
+
):
|
29 |
+
self.root = Path(root)
|
30 |
+
self.sample_rate = sample_rate
|
31 |
+
self.num_samples = num_samples
|
32 |
+
self.normalize_audio = normalize_audio
|
33 |
+
self.reject_silent = reject_silent
|
34 |
+
self.tokenizer = open_clip.get_tokenizer(tokenizer_id)
|
35 |
+
|
36 |
+
audios = sorted(os.listdir(self.root))
|
37 |
+
audios = set([
|
38 |
+
Path(audio).stem for audio in audios
|
39 |
+
if audio.endswith('.wav') or audio.endswith('.flac')
|
40 |
+
])
|
41 |
+
self.captions = {}
|
42 |
+
|
43 |
+
# read the caption tsv
|
44 |
+
df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
|
45 |
+
for record in df_list:
|
46 |
+
id = record['id']
|
47 |
+
caption = record['caption']
|
48 |
+
self.captions[id] = caption
|
49 |
+
|
50 |
+
# read the clip tsv
|
51 |
+
df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
|
52 |
+
'id': str,
|
53 |
+
'name': str
|
54 |
+
}).to_dict('records')
|
55 |
+
self.clips = []
|
56 |
+
for record in df_list:
|
57 |
+
record['id'] = record['id']
|
58 |
+
record['name'] = record['name']
|
59 |
+
id = record['id']
|
60 |
+
name = record['name']
|
61 |
+
if name not in self.captions:
|
62 |
+
log.warning(f'Audio {name} not found in {captions_tsv}')
|
63 |
+
continue
|
64 |
+
record['caption'] = self.captions[name]
|
65 |
+
self.clips.append(record)
|
66 |
+
|
67 |
+
log.info(f'Found {len(self.clips)} audio files in {self.root}')
|
68 |
+
|
69 |
+
self.resampler = {}
|
70 |
+
|
71 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
72 |
+
try:
|
73 |
+
clip = self.clips[idx]
|
74 |
+
audio_name = clip['name']
|
75 |
+
audio_id = clip['id']
|
76 |
+
caption = clip['caption']
|
77 |
+
start_sample = clip['start_sample']
|
78 |
+
end_sample = clip['end_sample']
|
79 |
+
|
80 |
+
audio_path = self.root / f'{audio_name}.flac'
|
81 |
+
if not audio_path.exists():
|
82 |
+
audio_path = self.root / f'{audio_name}.wav'
|
83 |
+
assert audio_path.exists()
|
84 |
+
|
85 |
+
audio_chunk, sample_rate = torchaudio.load(audio_path)
|
86 |
+
audio_chunk = audio_chunk.mean(dim=0) # mono
|
87 |
+
abs_max = audio_chunk.abs().max()
|
88 |
+
if self.normalize_audio:
|
89 |
+
audio_chunk = audio_chunk / abs_max * 0.95
|
90 |
+
|
91 |
+
if self.reject_silent and abs_max < 1e-6:
|
92 |
+
log.warning(f'Rejecting silent audio')
|
93 |
+
return None
|
94 |
+
|
95 |
+
audio_chunk = audio_chunk[start_sample:end_sample]
|
96 |
+
|
97 |
+
# resample
|
98 |
+
if sample_rate == self.sample_rate:
|
99 |
+
audio_chunk = audio_chunk
|
100 |
+
else:
|
101 |
+
if sample_rate not in self.resampler:
|
102 |
+
# https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
|
103 |
+
self.resampler[sample_rate] = torchaudio.transforms.Resample(
|
104 |
+
sample_rate,
|
105 |
+
self.sample_rate,
|
106 |
+
lowpass_filter_width=64,
|
107 |
+
rolloff=0.9475937167399596,
|
108 |
+
resampling_method='sinc_interp_kaiser',
|
109 |
+
beta=14.769656459379492,
|
110 |
+
)
|
111 |
+
audio_chunk = self.resampler[sample_rate](audio_chunk)
|
112 |
+
|
113 |
+
if audio_chunk.shape[0] < self.num_samples:
|
114 |
+
raise ValueError('Audio is too short')
|
115 |
+
audio_chunk = audio_chunk[:self.num_samples]
|
116 |
+
|
117 |
+
tokens = self.tokenizer([caption])[0]
|
118 |
+
|
119 |
+
output = {
|
120 |
+
'waveform': audio_chunk,
|
121 |
+
'id': audio_id,
|
122 |
+
'caption': caption,
|
123 |
+
'tokens': tokens,
|
124 |
+
}
|
125 |
+
|
126 |
+
return output
|
127 |
+
except Exception as e:
|
128 |
+
log.error(f'Error reading {audio_path}: {e}')
|
129 |
+
return None
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return len(self.clips)
|
mmaudio/data/mm_dataset.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bisect
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.utils.data.dataset import Dataset
|
5 |
+
|
6 |
+
|
7 |
+
# modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
|
8 |
+
class MultiModalDataset(Dataset):
|
9 |
+
datasets: list[Dataset]
|
10 |
+
cumulative_sizes: list[int]
|
11 |
+
|
12 |
+
@staticmethod
|
13 |
+
def cumsum(sequence):
|
14 |
+
r, s = [], 0
|
15 |
+
for e in sequence:
|
16 |
+
l = len(e)
|
17 |
+
r.append(l + s)
|
18 |
+
s += l
|
19 |
+
return r
|
20 |
+
|
21 |
+
def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
|
22 |
+
super().__init__()
|
23 |
+
self.video_datasets = list(video_datasets)
|
24 |
+
self.audio_datasets = list(audio_datasets)
|
25 |
+
self.datasets = self.video_datasets + self.audio_datasets
|
26 |
+
|
27 |
+
self.cumulative_sizes = self.cumsum(self.datasets)
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return self.cumulative_sizes[-1]
|
31 |
+
|
32 |
+
def __getitem__(self, idx):
|
33 |
+
if idx < 0:
|
34 |
+
if -idx > len(self):
|
35 |
+
raise ValueError("absolute value of index should not exceed dataset length")
|
36 |
+
idx = len(self) + idx
|
37 |
+
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
38 |
+
if dataset_idx == 0:
|
39 |
+
sample_idx = idx
|
40 |
+
else:
|
41 |
+
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
42 |
+
return self.datasets[dataset_idx][sample_idx]
|
43 |
+
|
44 |
+
def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
|
45 |
+
return self.video_datasets[0].compute_latent_stats()
|
mmaudio/data/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import tempfile
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
from tensordict import MemoryMappedTensor
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
from torch.utils.data.dataset import Dataset
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
from mmaudio.utils.dist_utils import local_rank, world_size
|
16 |
+
|
17 |
+
scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
|
18 |
+
shm_path = Path('/dev/shm')
|
19 |
+
|
20 |
+
log = logging.getLogger()
|
21 |
+
|
22 |
+
|
23 |
+
def reseed(seed):
|
24 |
+
random.seed(seed)
|
25 |
+
torch.manual_seed(seed)
|
26 |
+
|
27 |
+
|
28 |
+
def local_scatter_torch(obj: Optional[Any]):
|
29 |
+
if world_size == 1:
|
30 |
+
# Just one worker. Do nothing.
|
31 |
+
return obj
|
32 |
+
|
33 |
+
array = [obj] * world_size
|
34 |
+
target_array = [None]
|
35 |
+
if local_rank == 0:
|
36 |
+
dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
|
37 |
+
else:
|
38 |
+
dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
|
39 |
+
return target_array[0]
|
40 |
+
|
41 |
+
|
42 |
+
class ShardDataset(Dataset):
|
43 |
+
|
44 |
+
def __init__(self, root):
|
45 |
+
self.root = root
|
46 |
+
self.shards = sorted(os.listdir(root))
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.shards)
|
50 |
+
|
51 |
+
def __getitem__(self, idx):
|
52 |
+
return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
|
53 |
+
|
54 |
+
|
55 |
+
def get_tmp_dir(in_memory: bool) -> Path:
|
56 |
+
return shm_path if in_memory else scratch_path
|
57 |
+
|
58 |
+
|
59 |
+
def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
|
60 |
+
in_memory: bool) -> MemoryMappedTensor:
|
61 |
+
if local_rank == 0:
|
62 |
+
with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
|
63 |
+
log.info(f'Loading shards from {data_path} into {f.name}...')
|
64 |
+
data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
|
65 |
+
data = share_tensor_to_all(data)
|
66 |
+
torch.distributed.barrier()
|
67 |
+
f.close() # why does the context manager not close the file for me?
|
68 |
+
else:
|
69 |
+
log.info('Waiting for the data to be shared with me...')
|
70 |
+
data = share_tensor_to_all(None)
|
71 |
+
torch.distributed.barrier()
|
72 |
+
|
73 |
+
return data
|
74 |
+
|
75 |
+
|
76 |
+
def load_shards(
|
77 |
+
data_path: Union[str, Path],
|
78 |
+
ids: list[int],
|
79 |
+
*,
|
80 |
+
tmp_file_path: str,
|
81 |
+
) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
|
82 |
+
|
83 |
+
id_set = set(ids)
|
84 |
+
shards = sorted(os.listdir(data_path))
|
85 |
+
log.info(f'Found {len(shards)} shards in {data_path}.')
|
86 |
+
first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
|
87 |
+
|
88 |
+
log.info(f'Rank {local_rank} created file {tmp_file_path}')
|
89 |
+
first_item = next(iter(first_shard.values()))
|
90 |
+
log.info(f'First item shape: {first_item.shape}')
|
91 |
+
mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
|
92 |
+
dtype=torch.float32,
|
93 |
+
filename=tmp_file_path,
|
94 |
+
existsok=True)
|
95 |
+
total_count = 0
|
96 |
+
used_index = set()
|
97 |
+
id_indexing = {i: idx for idx, i in enumerate(ids)}
|
98 |
+
# faster with no workers; otherwise we need to set_sharing_strategy('file_system')
|
99 |
+
loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
|
100 |
+
for data in tqdm(loader, desc='Loading shards'):
|
101 |
+
for i, v in data.items():
|
102 |
+
if i not in id_set:
|
103 |
+
continue
|
104 |
+
|
105 |
+
# tensor_index = ids.index(i)
|
106 |
+
tensor_index = id_indexing[i]
|
107 |
+
if tensor_index in used_index:
|
108 |
+
raise ValueError(f'Duplicate id {i} found in {data_path}.')
|
109 |
+
used_index.add(tensor_index)
|
110 |
+
mm_tensor[tensor_index] = v
|
111 |
+
total_count += 1
|
112 |
+
|
113 |
+
assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
|
114 |
+
log.info(f'Loaded {total_count} tensors from {data_path}.')
|
115 |
+
|
116 |
+
return mm_tensor
|
117 |
+
|
118 |
+
|
119 |
+
def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
|
120 |
+
"""
|
121 |
+
x: the tensor to be shared; None if local_rank != 0
|
122 |
+
return: the shared tensor
|
123 |
+
"""
|
124 |
+
|
125 |
+
# there is no need to share your stuff with anyone if you are alone; must be in memory
|
126 |
+
if world_size == 1:
|
127 |
+
return x
|
128 |
+
|
129 |
+
if local_rank == 0:
|
130 |
+
assert x is not None, 'x must not be None if local_rank == 0'
|
131 |
+
else:
|
132 |
+
assert x is None, 'x must be None if local_rank != 0'
|
133 |
+
|
134 |
+
if local_rank == 0:
|
135 |
+
filename = x.filename
|
136 |
+
meta_information = (filename, x.shape, x.dtype)
|
137 |
+
else:
|
138 |
+
meta_information = None
|
139 |
+
|
140 |
+
filename, data_shape, data_type = local_scatter_torch(meta_information)
|
141 |
+
if local_rank == 0:
|
142 |
+
data = x
|
143 |
+
else:
|
144 |
+
data = MemoryMappedTensor.from_filename(filename=filename,
|
145 |
+
dtype=data_type,
|
146 |
+
shape=data_shape)
|
147 |
+
|
148 |
+
return data
|
mmaudio/eval_utils.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from colorlog import ColoredFormatter
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import v2
|
11 |
+
|
12 |
+
from mmaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
|
13 |
+
from mmaudio.model.flow_matching import FlowMatching
|
14 |
+
from mmaudio.model.networks import MMAudio
|
15 |
+
from mmaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
|
16 |
+
from mmaudio.model.utils.features_utils import FeaturesUtils
|
17 |
+
from mmaudio.utils.download_utils import download_model_if_needed
|
18 |
+
|
19 |
+
log = logging.getLogger()
|
20 |
+
|
21 |
+
|
22 |
+
@dataclasses.dataclass
|
23 |
+
class ModelConfig:
|
24 |
+
model_name: str
|
25 |
+
model_path: Path
|
26 |
+
vae_path: Path
|
27 |
+
bigvgan_16k_path: Optional[Path]
|
28 |
+
mode: str
|
29 |
+
synchformer_ckpt: Path = Path('./ext_weights/synchformer_state_dict.pth')
|
30 |
+
|
31 |
+
@property
|
32 |
+
def seq_cfg(self) -> SequenceConfig:
|
33 |
+
if self.mode == '16k':
|
34 |
+
return CONFIG_16K
|
35 |
+
elif self.mode == '44k':
|
36 |
+
return CONFIG_44K
|
37 |
+
|
38 |
+
def download_if_needed(self):
|
39 |
+
download_model_if_needed(self.model_path)
|
40 |
+
download_model_if_needed(self.vae_path)
|
41 |
+
if self.bigvgan_16k_path is not None:
|
42 |
+
download_model_if_needed(self.bigvgan_16k_path)
|
43 |
+
download_model_if_needed(self.synchformer_ckpt)
|
44 |
+
|
45 |
+
|
46 |
+
small_16k = ModelConfig(model_name='small_16k',
|
47 |
+
model_path=Path('./weights/mmaudio_small_16k.pth'),
|
48 |
+
vae_path=Path('./ext_weights/v1-16.pth'),
|
49 |
+
bigvgan_16k_path=Path('./ext_weights/best_netG.pt'),
|
50 |
+
mode='16k')
|
51 |
+
small_44k = ModelConfig(model_name='small_44k',
|
52 |
+
model_path=Path('./weights/mmaudio_small_44k.pth'),
|
53 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
54 |
+
bigvgan_16k_path=None,
|
55 |
+
mode='44k')
|
56 |
+
medium_44k = ModelConfig(model_name='medium_44k',
|
57 |
+
model_path=Path('./weights/mmaudio_medium_44k.pth'),
|
58 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
59 |
+
bigvgan_16k_path=None,
|
60 |
+
mode='44k')
|
61 |
+
large_44k = ModelConfig(model_name='large_44k',
|
62 |
+
model_path=Path('./weights/mmaudio_large_44k.pth'),
|
63 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
64 |
+
bigvgan_16k_path=None,
|
65 |
+
mode='44k')
|
66 |
+
large_44k_v2 = ModelConfig(model_name='large_44k_v2',
|
67 |
+
model_path=Path('./weights/mmaudio_large_44k_v2.pth'),
|
68 |
+
vae_path=Path('./ext_weights/v1-44.pth'),
|
69 |
+
bigvgan_16k_path=None,
|
70 |
+
mode='44k')
|
71 |
+
all_model_cfg: dict[str, ModelConfig] = {
|
72 |
+
'small_16k': small_16k,
|
73 |
+
'small_44k': small_44k,
|
74 |
+
'medium_44k': medium_44k,
|
75 |
+
'large_44k': large_44k,
|
76 |
+
'large_44k_v2': large_44k_v2,
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def generate(
|
81 |
+
clip_video: Optional[torch.Tensor],
|
82 |
+
sync_video: Optional[torch.Tensor],
|
83 |
+
text: Optional[list[str]],
|
84 |
+
*,
|
85 |
+
negative_text: Optional[list[str]] = None,
|
86 |
+
feature_utils: FeaturesUtils,
|
87 |
+
net: MMAudio,
|
88 |
+
fm: FlowMatching,
|
89 |
+
rng: torch.Generator,
|
90 |
+
cfg_strength: float,
|
91 |
+
clip_batch_size_multiplier: int = 40,
|
92 |
+
sync_batch_size_multiplier: int = 40,
|
93 |
+
image_input: bool = False,
|
94 |
+
) -> torch.Tensor:
|
95 |
+
device = feature_utils.device
|
96 |
+
dtype = feature_utils.dtype
|
97 |
+
|
98 |
+
bs = len(text)
|
99 |
+
if clip_video is not None:
|
100 |
+
clip_video = clip_video.to(device, dtype, non_blocking=True)
|
101 |
+
clip_features = feature_utils.encode_video_with_clip(clip_video,
|
102 |
+
batch_size=bs *
|
103 |
+
clip_batch_size_multiplier)
|
104 |
+
if image_input:
|
105 |
+
clip_features = clip_features.expand(-1, net.clip_seq_len, -1)
|
106 |
+
else:
|
107 |
+
clip_features = net.get_empty_clip_sequence(bs)
|
108 |
+
|
109 |
+
if sync_video is not None and not image_input:
|
110 |
+
sync_video = sync_video.to(device, dtype, non_blocking=True)
|
111 |
+
sync_features = feature_utils.encode_video_with_sync(sync_video,
|
112 |
+
batch_size=bs *
|
113 |
+
sync_batch_size_multiplier)
|
114 |
+
else:
|
115 |
+
sync_features = net.get_empty_sync_sequence(bs)
|
116 |
+
|
117 |
+
if text is not None:
|
118 |
+
text_features = feature_utils.encode_text(text)
|
119 |
+
else:
|
120 |
+
text_features = net.get_empty_string_sequence(bs)
|
121 |
+
|
122 |
+
if negative_text is not None:
|
123 |
+
assert len(negative_text) == bs
|
124 |
+
negative_text_features = feature_utils.encode_text(negative_text)
|
125 |
+
else:
|
126 |
+
negative_text_features = net.get_empty_string_sequence(bs)
|
127 |
+
|
128 |
+
x0 = torch.randn(bs,
|
129 |
+
net.latent_seq_len,
|
130 |
+
net.latent_dim,
|
131 |
+
device=device,
|
132 |
+
dtype=dtype,
|
133 |
+
generator=rng)
|
134 |
+
preprocessed_conditions = net.preprocess_conditions(clip_features, sync_features, text_features)
|
135 |
+
empty_conditions = net.get_empty_conditions(
|
136 |
+
bs, negative_text_features=negative_text_features if negative_text is not None else None)
|
137 |
+
|
138 |
+
cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
|
139 |
+
cfg_strength)
|
140 |
+
x1 = fm.to_data(cfg_ode_wrapper, x0)
|
141 |
+
x1 = net.unnormalize(x1)
|
142 |
+
spec = feature_utils.decode(x1)
|
143 |
+
audio = feature_utils.vocode(spec)
|
144 |
+
return audio
|
145 |
+
|
146 |
+
|
147 |
+
LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
|
148 |
+
|
149 |
+
|
150 |
+
def setup_eval_logging(log_level: int = logging.INFO):
|
151 |
+
logging.root.setLevel(log_level)
|
152 |
+
formatter = ColoredFormatter(LOGFORMAT)
|
153 |
+
stream = logging.StreamHandler()
|
154 |
+
stream.setLevel(log_level)
|
155 |
+
stream.setFormatter(formatter)
|
156 |
+
log = logging.getLogger()
|
157 |
+
log.setLevel(log_level)
|
158 |
+
log.addHandler(stream)
|
159 |
+
|
160 |
+
|
161 |
+
_CLIP_SIZE = 384
|
162 |
+
_CLIP_FPS = 8.0
|
163 |
+
|
164 |
+
_SYNC_SIZE = 224
|
165 |
+
_SYNC_FPS = 25.0
|
166 |
+
|
167 |
+
|
168 |
+
def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
|
169 |
+
|
170 |
+
clip_transform = v2.Compose([
|
171 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
172 |
+
v2.ToImage(),
|
173 |
+
v2.ToDtype(torch.float32, scale=True),
|
174 |
+
])
|
175 |
+
|
176 |
+
sync_transform = v2.Compose([
|
177 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
178 |
+
v2.CenterCrop(_SYNC_SIZE),
|
179 |
+
v2.ToImage(),
|
180 |
+
v2.ToDtype(torch.float32, scale=True),
|
181 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
182 |
+
])
|
183 |
+
|
184 |
+
output_frames, all_frames, orig_fps = read_frames(video_path,
|
185 |
+
list_of_fps=[_CLIP_FPS, _SYNC_FPS],
|
186 |
+
start_sec=0,
|
187 |
+
end_sec=duration_sec,
|
188 |
+
need_all_frames=load_all_frames)
|
189 |
+
|
190 |
+
clip_chunk, sync_chunk = output_frames
|
191 |
+
clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
|
192 |
+
sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
|
193 |
+
|
194 |
+
clip_frames = clip_transform(clip_chunk)
|
195 |
+
sync_frames = sync_transform(sync_chunk)
|
196 |
+
|
197 |
+
clip_length_sec = clip_frames.shape[0] / _CLIP_FPS
|
198 |
+
sync_length_sec = sync_frames.shape[0] / _SYNC_FPS
|
199 |
+
|
200 |
+
if clip_length_sec < duration_sec:
|
201 |
+
log.warning(f'Clip video is too short: {clip_length_sec:.2f} < {duration_sec:.2f}')
|
202 |
+
log.warning(f'Truncating to {clip_length_sec:.2f} sec')
|
203 |
+
duration_sec = clip_length_sec
|
204 |
+
|
205 |
+
if sync_length_sec < duration_sec:
|
206 |
+
log.warning(f'Sync video is too short: {sync_length_sec:.2f} < {duration_sec:.2f}')
|
207 |
+
log.warning(f'Truncating to {sync_length_sec:.2f} sec')
|
208 |
+
duration_sec = sync_length_sec
|
209 |
+
|
210 |
+
clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
|
211 |
+
sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
|
212 |
+
|
213 |
+
video_info = VideoInfo(
|
214 |
+
duration_sec=duration_sec,
|
215 |
+
fps=orig_fps,
|
216 |
+
clip_frames=clip_frames,
|
217 |
+
sync_frames=sync_frames,
|
218 |
+
all_frames=all_frames if load_all_frames else None,
|
219 |
+
)
|
220 |
+
return video_info
|
221 |
+
|
222 |
+
|
223 |
+
def load_image(image_path: Path) -> VideoInfo:
|
224 |
+
clip_transform = v2.Compose([
|
225 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
226 |
+
v2.ToImage(),
|
227 |
+
v2.ToDtype(torch.float32, scale=True),
|
228 |
+
])
|
229 |
+
|
230 |
+
sync_transform = v2.Compose([
|
231 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
232 |
+
v2.CenterCrop(_SYNC_SIZE),
|
233 |
+
v2.ToImage(),
|
234 |
+
v2.ToDtype(torch.float32, scale=True),
|
235 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
236 |
+
])
|
237 |
+
|
238 |
+
frame = np.array(Image.open(image_path))
|
239 |
+
|
240 |
+
clip_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
241 |
+
sync_chunk = torch.from_numpy(frame).unsqueeze(0).permute(0, 3, 1, 2)
|
242 |
+
|
243 |
+
clip_frames = clip_transform(clip_chunk)
|
244 |
+
sync_frames = sync_transform(sync_chunk)
|
245 |
+
|
246 |
+
video_info = ImageInfo(
|
247 |
+
clip_frames=clip_frames,
|
248 |
+
sync_frames=sync_frames,
|
249 |
+
original_frame=frame,
|
250 |
+
)
|
251 |
+
return video_info
|
252 |
+
|
253 |
+
|
254 |
+
def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
|
255 |
+
reencode_with_audio(video_info, output_path, audio, sampling_rate)
|
mmaudio/ext/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
mmaudio/ext/autoencoder/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .autoencoder import AutoEncoderModule
|
mmaudio/ext/autoencoder/autoencoder.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from mmaudio.ext.autoencoder.vae import VAE, get_my_vae
|
7 |
+
from mmaudio.ext.bigvgan import BigVGAN
|
8 |
+
from mmaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
9 |
+
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
|
10 |
+
|
11 |
+
|
12 |
+
class AutoEncoderModule(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self,
|
15 |
+
*,
|
16 |
+
vae_ckpt_path,
|
17 |
+
vocoder_ckpt_path: Optional[str] = None,
|
18 |
+
mode: Literal['16k', '44k'],
|
19 |
+
need_vae_encoder: bool = True):
|
20 |
+
super().__init__()
|
21 |
+
self.vae: VAE = get_my_vae(mode).eval()
|
22 |
+
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
23 |
+
self.vae.load_state_dict(vae_state_dict)
|
24 |
+
self.vae.remove_weight_norm()
|
25 |
+
|
26 |
+
if mode == '16k':
|
27 |
+
assert vocoder_ckpt_path is not None
|
28 |
+
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
29 |
+
elif mode == '44k':
|
30 |
+
self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
|
31 |
+
use_cuda_kernel=False)
|
32 |
+
self.vocoder.remove_weight_norm()
|
33 |
+
else:
|
34 |
+
raise ValueError(f'Unknown mode: {mode}')
|
35 |
+
|
36 |
+
for param in self.parameters():
|
37 |
+
param.requires_grad = False
|
38 |
+
|
39 |
+
if not need_vae_encoder:
|
40 |
+
del self.vae.encoder
|
41 |
+
|
42 |
+
@torch.inference_mode()
|
43 |
+
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
44 |
+
return self.vae.encode(x)
|
45 |
+
|
46 |
+
@torch.inference_mode()
|
47 |
+
def decode(self, z: torch.Tensor) -> torch.Tensor:
|
48 |
+
return self.vae.decode(z)
|
49 |
+
|
50 |
+
@torch.inference_mode()
|
51 |
+
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
52 |
+
return self.vocoder(spec)
|
mmaudio/ext/autoencoder/edm2_utils.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under a Creative Commons
|
4 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International License.
|
5 |
+
# You should have received a copy of the license along with this
|
6 |
+
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
|
7 |
+
"""Improved diffusion model architecture proposed in the paper
|
8 |
+
"Analyzing and Improving the Training Dynamics of Diffusion Models"."""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
#----------------------------------------------------------------------------
|
14 |
+
# Variant of constant() that inherits dtype and device from the given
|
15 |
+
# reference tensor by default.
|
16 |
+
|
17 |
+
_constant_cache = dict()
|
18 |
+
|
19 |
+
|
20 |
+
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
21 |
+
value = np.asarray(value)
|
22 |
+
if shape is not None:
|
23 |
+
shape = tuple(shape)
|
24 |
+
if dtype is None:
|
25 |
+
dtype = torch.get_default_dtype()
|
26 |
+
if device is None:
|
27 |
+
device = torch.device('cpu')
|
28 |
+
if memory_format is None:
|
29 |
+
memory_format = torch.contiguous_format
|
30 |
+
|
31 |
+
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
32 |
+
tensor = _constant_cache.get(key, None)
|
33 |
+
if tensor is None:
|
34 |
+
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
35 |
+
if shape is not None:
|
36 |
+
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
37 |
+
tensor = tensor.contiguous(memory_format=memory_format)
|
38 |
+
_constant_cache[key] = tensor
|
39 |
+
return tensor
|
40 |
+
|
41 |
+
|
42 |
+
def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
|
43 |
+
if dtype is None:
|
44 |
+
dtype = ref.dtype
|
45 |
+
if device is None:
|
46 |
+
device = ref.device
|
47 |
+
return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
|
48 |
+
|
49 |
+
|
50 |
+
#----------------------------------------------------------------------------
|
51 |
+
# Normalize given tensor to unit magnitude with respect to the given
|
52 |
+
# dimensions. Default = all dimensions except the first.
|
53 |
+
|
54 |
+
|
55 |
+
def normalize(x, dim=None, eps=1e-4):
|
56 |
+
if dim is None:
|
57 |
+
dim = list(range(1, x.ndim))
|
58 |
+
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
59 |
+
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
|
60 |
+
return x / norm.to(x.dtype)
|
61 |
+
|
62 |
+
|
63 |
+
class Normalize(torch.nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, dim=None, eps=1e-4):
|
66 |
+
super().__init__()
|
67 |
+
self.dim = dim
|
68 |
+
self.eps = eps
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
return normalize(x, dim=self.dim, eps=self.eps)
|
72 |
+
|
73 |
+
|
74 |
+
#----------------------------------------------------------------------------
|
75 |
+
# Upsample or downsample the given tensor with the given filter,
|
76 |
+
# or keep it as is.
|
77 |
+
|
78 |
+
|
79 |
+
def resample(x, f=[1, 1], mode='keep'):
|
80 |
+
if mode == 'keep':
|
81 |
+
return x
|
82 |
+
f = np.float32(f)
|
83 |
+
assert f.ndim == 1 and len(f) % 2 == 0
|
84 |
+
pad = (len(f) - 1) // 2
|
85 |
+
f = f / f.sum()
|
86 |
+
f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
|
87 |
+
f = const_like(x, f)
|
88 |
+
c = x.shape[1]
|
89 |
+
if mode == 'down':
|
90 |
+
return torch.nn.functional.conv2d(x,
|
91 |
+
f.tile([c, 1, 1, 1]),
|
92 |
+
groups=c,
|
93 |
+
stride=2,
|
94 |
+
padding=(pad, ))
|
95 |
+
assert mode == 'up'
|
96 |
+
return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
|
97 |
+
groups=c,
|
98 |
+
stride=2,
|
99 |
+
padding=(pad, ))
|
100 |
+
|
101 |
+
|
102 |
+
#----------------------------------------------------------------------------
|
103 |
+
# Magnitude-preserving SiLU (Equation 81).
|
104 |
+
|
105 |
+
|
106 |
+
def mp_silu(x):
|
107 |
+
return torch.nn.functional.silu(x) / 0.596
|
108 |
+
|
109 |
+
|
110 |
+
class MPSiLU(torch.nn.Module):
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
return mp_silu(x)
|
114 |
+
|
115 |
+
|
116 |
+
#----------------------------------------------------------------------------
|
117 |
+
# Magnitude-preserving sum (Equation 88).
|
118 |
+
|
119 |
+
|
120 |
+
def mp_sum(a, b, t=0.5):
|
121 |
+
return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
|
122 |
+
|
123 |
+
|
124 |
+
#----------------------------------------------------------------------------
|
125 |
+
# Magnitude-preserving concatenation (Equation 103).
|
126 |
+
|
127 |
+
|
128 |
+
def mp_cat(a, b, dim=1, t=0.5):
|
129 |
+
Na = a.shape[dim]
|
130 |
+
Nb = b.shape[dim]
|
131 |
+
C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
|
132 |
+
wa = C / np.sqrt(Na) * (1 - t)
|
133 |
+
wb = C / np.sqrt(Nb) * t
|
134 |
+
return torch.cat([wa * a, wb * b], dim=dim)
|
135 |
+
|
136 |
+
|
137 |
+
#----------------------------------------------------------------------------
|
138 |
+
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
|
139 |
+
# with force weight normalization (Equation 66).
|
140 |
+
|
141 |
+
|
142 |
+
class MPConv1D(torch.nn.Module):
|
143 |
+
|
144 |
+
def __init__(self, in_channels, out_channels, kernel_size):
|
145 |
+
super().__init__()
|
146 |
+
self.out_channels = out_channels
|
147 |
+
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
|
148 |
+
|
149 |
+
self.weight_norm_removed = False
|
150 |
+
|
151 |
+
def forward(self, x, gain=1):
|
152 |
+
assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
|
153 |
+
|
154 |
+
w = self.weight * gain
|
155 |
+
if w.ndim == 2:
|
156 |
+
return x @ w.t()
|
157 |
+
assert w.ndim == 3
|
158 |
+
return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
|
159 |
+
|
160 |
+
def remove_weight_norm(self):
|
161 |
+
w = self.weight.to(torch.float32)
|
162 |
+
w = normalize(w) # traditional weight normalization
|
163 |
+
w = w / np.sqrt(w[0].numel())
|
164 |
+
w = w.to(self.weight.dtype)
|
165 |
+
self.weight.data.copy_(w)
|
166 |
+
|
167 |
+
self.weight_norm_removed = True
|
168 |
+
return self
|
mmaudio/ext/autoencoder/vae.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from mmaudio.ext.autoencoder.edm2_utils import MPConv1D
|
8 |
+
from mmaudio.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
9 |
+
Upsample1D, nonlinearity)
|
10 |
+
from mmaudio.model.utils.distributions import DiagonalGaussianDistribution
|
11 |
+
|
12 |
+
log = logging.getLogger()
|
13 |
+
|
14 |
+
DATA_MEAN_80D = [
|
15 |
+
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
16 |
+
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
17 |
+
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
18 |
+
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
19 |
+
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
20 |
+
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
21 |
+
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
22 |
+
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
23 |
+
]
|
24 |
+
|
25 |
+
DATA_STD_80D = [
|
26 |
+
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
27 |
+
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
28 |
+
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
29 |
+
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
30 |
+
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
31 |
+
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
32 |
+
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
33 |
+
]
|
34 |
+
|
35 |
+
DATA_MEAN_128D = [
|
36 |
+
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
37 |
+
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
38 |
+
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
39 |
+
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
40 |
+
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
41 |
+
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
42 |
+
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
43 |
+
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
44 |
+
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
45 |
+
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
46 |
+
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
47 |
+
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
48 |
+
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
49 |
+
]
|
50 |
+
|
51 |
+
DATA_STD_128D = [
|
52 |
+
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
53 |
+
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
54 |
+
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
55 |
+
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
56 |
+
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
57 |
+
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
58 |
+
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
59 |
+
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
60 |
+
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
61 |
+
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
62 |
+
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
63 |
+
]
|
64 |
+
|
65 |
+
|
66 |
+
class VAE(nn.Module):
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
*,
|
71 |
+
data_dim: int,
|
72 |
+
embed_dim: int,
|
73 |
+
hidden_dim: int,
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
if data_dim == 80:
|
78 |
+
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
79 |
+
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
80 |
+
elif data_dim == 128:
|
81 |
+
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
82 |
+
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
83 |
+
|
84 |
+
self.data_mean = self.data_mean.view(1, -1, 1)
|
85 |
+
self.data_std = self.data_std.view(1, -1, 1)
|
86 |
+
|
87 |
+
self.encoder = Encoder1D(
|
88 |
+
dim=hidden_dim,
|
89 |
+
ch_mult=(1, 2, 4),
|
90 |
+
num_res_blocks=2,
|
91 |
+
attn_layers=[3],
|
92 |
+
down_layers=[0],
|
93 |
+
in_dim=data_dim,
|
94 |
+
embed_dim=embed_dim,
|
95 |
+
)
|
96 |
+
self.decoder = Decoder1D(
|
97 |
+
dim=hidden_dim,
|
98 |
+
ch_mult=(1, 2, 4),
|
99 |
+
num_res_blocks=2,
|
100 |
+
attn_layers=[3],
|
101 |
+
down_layers=[0],
|
102 |
+
in_dim=data_dim,
|
103 |
+
out_dim=data_dim,
|
104 |
+
embed_dim=embed_dim,
|
105 |
+
)
|
106 |
+
|
107 |
+
self.embed_dim = embed_dim
|
108 |
+
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
109 |
+
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
110 |
+
|
111 |
+
self.initialize_weights()
|
112 |
+
|
113 |
+
def initialize_weights(self):
|
114 |
+
pass
|
115 |
+
|
116 |
+
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
117 |
+
if normalize:
|
118 |
+
x = self.normalize(x)
|
119 |
+
moments = self.encoder(x)
|
120 |
+
posterior = DiagonalGaussianDistribution(moments)
|
121 |
+
return posterior
|
122 |
+
|
123 |
+
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
124 |
+
dec = self.decoder(z)
|
125 |
+
if unnormalize:
|
126 |
+
dec = self.unnormalize(dec)
|
127 |
+
return dec
|
128 |
+
|
129 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
130 |
+
return (x - self.data_mean) / self.data_std
|
131 |
+
|
132 |
+
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
133 |
+
return x * self.data_std + self.data_mean
|
134 |
+
|
135 |
+
def forward(
|
136 |
+
self,
|
137 |
+
x: torch.Tensor,
|
138 |
+
sample_posterior: bool = True,
|
139 |
+
rng: Optional[torch.Generator] = None,
|
140 |
+
normalize: bool = True,
|
141 |
+
unnormalize: bool = True,
|
142 |
+
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
143 |
+
|
144 |
+
posterior = self.encode(x, normalize=normalize)
|
145 |
+
if sample_posterior:
|
146 |
+
z = posterior.sample(rng)
|
147 |
+
else:
|
148 |
+
z = posterior.mode()
|
149 |
+
dec = self.decode(z, unnormalize=unnormalize)
|
150 |
+
return dec, posterior
|
151 |
+
|
152 |
+
def load_weights(self, src_dict) -> None:
|
153 |
+
self.load_state_dict(src_dict, strict=True)
|
154 |
+
|
155 |
+
@property
|
156 |
+
def device(self) -> torch.device:
|
157 |
+
return next(self.parameters()).device
|
158 |
+
|
159 |
+
def get_last_layer(self):
|
160 |
+
return self.decoder.conv_out.weight
|
161 |
+
|
162 |
+
def remove_weight_norm(self):
|
163 |
+
for name, m in self.named_modules():
|
164 |
+
if isinstance(m, MPConv1D):
|
165 |
+
m.remove_weight_norm()
|
166 |
+
log.debug(f"Removed weight norm from {name}")
|
167 |
+
return self
|
168 |
+
|
169 |
+
|
170 |
+
class Encoder1D(nn.Module):
|
171 |
+
|
172 |
+
def __init__(self,
|
173 |
+
*,
|
174 |
+
dim: int,
|
175 |
+
ch_mult: tuple[int] = (1, 2, 4, 8),
|
176 |
+
num_res_blocks: int,
|
177 |
+
attn_layers: list[int] = [],
|
178 |
+
down_layers: list[int] = [],
|
179 |
+
resamp_with_conv: bool = True,
|
180 |
+
in_dim: int,
|
181 |
+
embed_dim: int,
|
182 |
+
double_z: bool = True,
|
183 |
+
kernel_size: int = 3,
|
184 |
+
clip_act: float = 256.0):
|
185 |
+
super().__init__()
|
186 |
+
self.dim = dim
|
187 |
+
self.num_layers = len(ch_mult)
|
188 |
+
self.num_res_blocks = num_res_blocks
|
189 |
+
self.in_channels = in_dim
|
190 |
+
self.clip_act = clip_act
|
191 |
+
self.down_layers = down_layers
|
192 |
+
self.attn_layers = attn_layers
|
193 |
+
self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
|
194 |
+
|
195 |
+
in_ch_mult = (1, ) + tuple(ch_mult)
|
196 |
+
self.in_ch_mult = in_ch_mult
|
197 |
+
# downsampling
|
198 |
+
self.down = nn.ModuleList()
|
199 |
+
for i_level in range(self.num_layers):
|
200 |
+
block = nn.ModuleList()
|
201 |
+
attn = nn.ModuleList()
|
202 |
+
block_in = dim * in_ch_mult[i_level]
|
203 |
+
block_out = dim * ch_mult[i_level]
|
204 |
+
for i_block in range(self.num_res_blocks):
|
205 |
+
block.append(
|
206 |
+
ResnetBlock1D(in_dim=block_in,
|
207 |
+
out_dim=block_out,
|
208 |
+
kernel_size=kernel_size,
|
209 |
+
use_norm=True))
|
210 |
+
block_in = block_out
|
211 |
+
if i_level in attn_layers:
|
212 |
+
attn.append(AttnBlock1D(block_in))
|
213 |
+
down = nn.Module()
|
214 |
+
down.block = block
|
215 |
+
down.attn = attn
|
216 |
+
if i_level in down_layers:
|
217 |
+
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
218 |
+
self.down.append(down)
|
219 |
+
|
220 |
+
# middle
|
221 |
+
self.mid = nn.Module()
|
222 |
+
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
223 |
+
out_dim=block_in,
|
224 |
+
kernel_size=kernel_size,
|
225 |
+
use_norm=True)
|
226 |
+
self.mid.attn_1 = AttnBlock1D(block_in)
|
227 |
+
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
228 |
+
out_dim=block_in,
|
229 |
+
kernel_size=kernel_size,
|
230 |
+
use_norm=True)
|
231 |
+
|
232 |
+
# end
|
233 |
+
self.conv_out = MPConv1D(block_in,
|
234 |
+
2 * embed_dim if double_z else embed_dim,
|
235 |
+
kernel_size=kernel_size)
|
236 |
+
|
237 |
+
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
|
241 |
+
# downsampling
|
242 |
+
hs = [self.conv_in(x)]
|
243 |
+
for i_level in range(self.num_layers):
|
244 |
+
for i_block in range(self.num_res_blocks):
|
245 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
246 |
+
if len(self.down[i_level].attn) > 0:
|
247 |
+
h = self.down[i_level].attn[i_block](h)
|
248 |
+
h = h.clamp(-self.clip_act, self.clip_act)
|
249 |
+
hs.append(h)
|
250 |
+
if i_level in self.down_layers:
|
251 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
252 |
+
|
253 |
+
# middle
|
254 |
+
h = hs[-1]
|
255 |
+
h = self.mid.block_1(h)
|
256 |
+
h = self.mid.attn_1(h)
|
257 |
+
h = self.mid.block_2(h)
|
258 |
+
h = h.clamp(-self.clip_act, self.clip_act)
|
259 |
+
|
260 |
+
# end
|
261 |
+
h = nonlinearity(h)
|
262 |
+
h = self.conv_out(h, gain=(self.learnable_gain + 1))
|
263 |
+
return h
|
264 |
+
|
265 |
+
|
266 |
+
class Decoder1D(nn.Module):
|
267 |
+
|
268 |
+
def __init__(self,
|
269 |
+
*,
|
270 |
+
dim: int,
|
271 |
+
out_dim: int,
|
272 |
+
ch_mult: tuple[int] = (1, 2, 4, 8),
|
273 |
+
num_res_blocks: int,
|
274 |
+
attn_layers: list[int] = [],
|
275 |
+
down_layers: list[int] = [],
|
276 |
+
kernel_size: int = 3,
|
277 |
+
resamp_with_conv: bool = True,
|
278 |
+
in_dim: int,
|
279 |
+
embed_dim: int,
|
280 |
+
clip_act: float = 256.0):
|
281 |
+
super().__init__()
|
282 |
+
self.ch = dim
|
283 |
+
self.num_layers = len(ch_mult)
|
284 |
+
self.num_res_blocks = num_res_blocks
|
285 |
+
self.in_channels = in_dim
|
286 |
+
self.clip_act = clip_act
|
287 |
+
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
288 |
+
|
289 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
290 |
+
block_in = dim * ch_mult[self.num_layers - 1]
|
291 |
+
|
292 |
+
# z to block_in
|
293 |
+
self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
|
294 |
+
|
295 |
+
# middle
|
296 |
+
self.mid = nn.Module()
|
297 |
+
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
298 |
+
self.mid.attn_1 = AttnBlock1D(block_in)
|
299 |
+
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
300 |
+
|
301 |
+
# upsampling
|
302 |
+
self.up = nn.ModuleList()
|
303 |
+
for i_level in reversed(range(self.num_layers)):
|
304 |
+
block = nn.ModuleList()
|
305 |
+
attn = nn.ModuleList()
|
306 |
+
block_out = dim * ch_mult[i_level]
|
307 |
+
for i_block in range(self.num_res_blocks + 1):
|
308 |
+
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
309 |
+
block_in = block_out
|
310 |
+
if i_level in attn_layers:
|
311 |
+
attn.append(AttnBlock1D(block_in))
|
312 |
+
up = nn.Module()
|
313 |
+
up.block = block
|
314 |
+
up.attn = attn
|
315 |
+
if i_level in self.down_layers:
|
316 |
+
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
317 |
+
self.up.insert(0, up) # prepend to get consistent order
|
318 |
+
|
319 |
+
# end
|
320 |
+
self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
|
321 |
+
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
322 |
+
|
323 |
+
def forward(self, z):
|
324 |
+
# z to block_in
|
325 |
+
h = self.conv_in(z)
|
326 |
+
|
327 |
+
# middle
|
328 |
+
h = self.mid.block_1(h)
|
329 |
+
h = self.mid.attn_1(h)
|
330 |
+
h = self.mid.block_2(h)
|
331 |
+
h = h.clamp(-self.clip_act, self.clip_act)
|
332 |
+
|
333 |
+
# upsampling
|
334 |
+
for i_level in reversed(range(self.num_layers)):
|
335 |
+
for i_block in range(self.num_res_blocks + 1):
|
336 |
+
h = self.up[i_level].block[i_block](h)
|
337 |
+
if len(self.up[i_level].attn) > 0:
|
338 |
+
h = self.up[i_level].attn[i_block](h)
|
339 |
+
h = h.clamp(-self.clip_act, self.clip_act)
|
340 |
+
if i_level in self.down_layers:
|
341 |
+
h = self.up[i_level].upsample(h)
|
342 |
+
|
343 |
+
h = nonlinearity(h)
|
344 |
+
h = self.conv_out(h, gain=(self.learnable_gain + 1))
|
345 |
+
return h
|
346 |
+
|
347 |
+
|
348 |
+
def VAE_16k(**kwargs) -> VAE:
|
349 |
+
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
350 |
+
|
351 |
+
|
352 |
+
def VAE_44k(**kwargs) -> VAE:
|
353 |
+
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
354 |
+
|
355 |
+
|
356 |
+
def get_my_vae(name: str, **kwargs) -> VAE:
|
357 |
+
if name == '16k':
|
358 |
+
return VAE_16k(**kwargs)
|
359 |
+
if name == '44k':
|
360 |
+
return VAE_44k(**kwargs)
|
361 |
+
raise ValueError(f'Unknown model: {name}')
|
362 |
+
|
363 |
+
|
364 |
+
if __name__ == '__main__':
|
365 |
+
network = get_my_vae('standard')
|
366 |
+
|
367 |
+
# print the number of parameters in terms of millions
|
368 |
+
num_params = sum(p.numel() for p in network.parameters()) / 1e6
|
369 |
+
print(f'Number of parameters: {num_params:.2f}M')
|