wzy013 commited on
Commit
f424901
·
1 Parent(s): 0939b31

Add HunyuanVideo-Foley source code and dependencies

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +43 -0
  2. README.md +88 -5
  3. app.py +405 -0
  4. configs/hunyuanvideo-foley-xxl.yaml +49 -0
  5. hunyuanvideo_foley/__init__.py +0 -0
  6. hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc +0 -0
  7. hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc +0 -0
  8. hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc +0 -0
  9. hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc +0 -0
  10. hunyuanvideo_foley/constants.py +57 -0
  11. hunyuanvideo_foley/models/__init__.py +0 -0
  12. hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc +0 -0
  13. hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc +0 -0
  14. hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc +0 -0
  15. hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc +0 -0
  16. hunyuanvideo_foley/models/dac_vae/__init__.py +16 -0
  17. hunyuanvideo_foley/models/dac_vae/__main__.py +36 -0
  18. hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc +0 -0
  19. hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc +0 -0
  20. hunyuanvideo_foley/models/dac_vae/model/__init__.py +4 -0
  21. hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc +0 -0
  22. hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc +0 -0
  23. hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc +0 -0
  24. hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc +0 -0
  25. hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc +0 -0
  26. hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc +0 -0
  27. hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc +0 -0
  28. hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc +0 -0
  29. hunyuanvideo_foley/models/dac_vae/model/base.py +301 -0
  30. hunyuanvideo_foley/models/dac_vae/model/dac.py +410 -0
  31. hunyuanvideo_foley/models/dac_vae/model/discriminator.py +228 -0
  32. hunyuanvideo_foley/models/dac_vae/nn/__init__.py +3 -0
  33. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc +0 -0
  34. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc +0 -0
  35. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc +0 -0
  36. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc +0 -0
  37. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc +0 -0
  38. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc +0 -0
  39. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc +0 -0
  40. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc +0 -0
  41. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc +0 -0
  42. hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc +0 -0
  43. hunyuanvideo_foley/models/dac_vae/nn/layers.py +33 -0
  44. hunyuanvideo_foley/models/dac_vae/nn/loss.py +368 -0
  45. hunyuanvideo_foley/models/dac_vae/nn/quantize.py +262 -0
  46. hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py +91 -0
  47. hunyuanvideo_foley/models/dac_vae/utils/__init__.py +121 -0
  48. hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc +0 -0
  49. hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc +0 -0
  50. hunyuanvideo_foley/models/dac_vae/utils/decode.py +95 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
@@ -33,3 +34,45 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Standard LFS file types for Hugging Face
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
 
34
  *.zip filter=lfs diff=lfs merge=lfs -text
35
  *.zst filter=lfs diff=lfs merge=lfs -text
36
  *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ # Media files
38
+ *.png filter=lfs diff=lfs merge=lfs -text
39
+ *.jpg filter=lfs diff=lfs merge=lfs -text
40
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
41
+ *.gif filter=lfs diff=lfs merge=lfs -text
42
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ *.avi filter=lfs diff=lfs merge=lfs -text
44
+ *.mov filter=lfs diff=lfs merge=lfs -text
45
+ *.mkv filter=lfs diff=lfs merge=lfs -text
46
+ *.webm filter=lfs diff=lfs merge=lfs -text
47
+ *.wav filter=lfs diff=lfs merge=lfs -text
48
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
49
+ *.flac filter=lfs diff=lfs merge=lfs -text
50
+ # Project specific files
51
+ examples/7_result.mp4 filter=lfs diff=lfs merge=lfs -text
52
+ examples/8_video.mp4 filter=lfs diff=lfs merge=lfs -text
53
+ examples/1_result.mp4 filter=lfs diff=lfs merge=lfs -text
54
+ examples/2_video.mp4 filter=lfs diff=lfs merge=lfs -text
55
+ examples/6_result.mp4 filter=lfs diff=lfs merge=lfs -text
56
+ examples/7_video.mp4 filter=lfs diff=lfs merge=lfs -text
57
+ examples/8_result.mp4 filter=lfs diff=lfs merge=lfs -text
58
+ examples/3_video.mp4 filter=lfs diff=lfs merge=lfs -text
59
+ examples/5_video.mp4 filter=lfs diff=lfs merge=lfs -text
60
+ examples/4_result.mp4 filter=lfs diff=lfs merge=lfs -text
61
+ examples/5_result.mp4 filter=lfs diff=lfs merge=lfs -text
62
+ examples/1_video.mp4 filter=lfs diff=lfs merge=lfs -text
63
+ examples/3_result.mp4 filter=lfs diff=lfs merge=lfs -text
64
+ examples/6_video.mp4 filter=lfs diff=lfs merge=lfs -text
65
+ examples/2_result.mp4 filter=lfs diff=lfs merge=lfs -text
66
+ examples/4_video.mp4 filter=lfs diff=lfs merge=lfs -text
67
+ assets/MovieGenAudioBenchSfx/video_with_audio/0.mp4 filter=lfs diff=lfs merge=lfs -text
68
+ assets/MovieGenAudioBenchSfx/video_with_audio/4.mp4 filter=lfs diff=lfs merge=lfs -text
69
+ assets/MovieGenAudioBenchSfx/video_with_audio/6.mp4 filter=lfs diff=lfs merge=lfs -text
70
+ assets/MovieGenAudioBenchSfx/video_with_audio/7.mp4 filter=lfs diff=lfs merge=lfs -text
71
+ assets/MovieGenAudioBenchSfx/video_with_audio/1.mp4 filter=lfs diff=lfs merge=lfs -text
72
+ assets/MovieGenAudioBenchSfx/video_with_audio/2.mp4 filter=lfs diff=lfs merge=lfs -text
73
+ assets/MovieGenAudioBenchSfx/video_with_audio/3.mp4 filter=lfs diff=lfs merge=lfs -text
74
+ assets/MovieGenAudioBenchSfx/video_with_audio/5.mp4 filter=lfs diff=lfs merge=lfs -text
75
+ assets/MovieGenAudioBenchSfx/video_with_audio/8.mp4 filter=lfs diff=lfs merge=lfs -text
76
+ assets/MovieGenAudioBenchSfx/video_with_audio/9.mp4 filter=lfs diff=lfs merge=lfs -text
77
+ assets/data_pipeline.png filter=lfs diff=lfs merge=lfs -text
78
+ assets/model_arch.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,96 @@
1
  ---
2
- title: Hunyuanvideo Foley
3
- emoji: 😻
4
  colorFrom: blue
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: HunyuanVideo-Foley
3
+ emoji: 🎵
4
  colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: Generate realistic audio from video and text descriptions
12
  ---
13
 
14
+ # HunyuanVideo-Foley
15
+
16
+ <div align="center">
17
+ <h2>🎵 Text-Video-to-Audio Synthesis</h2>
18
+ <p><strong>Generate realistic audio from video and text descriptions using AI</strong></p>
19
+ </div>
20
+
21
+ ## About
22
+
23
+ HunyuanVideo-Foley is a multimodal diffusion model that generates high-quality audio effects (Foley audio) synchronized with video content. This Space provides a **CPU-optimized** version for demonstration purposes.
24
+
25
+ ### ⚠️ CPU Performance Notice
26
+
27
+ This Space runs on **free CPU** which means:
28
+ - **Slower inference** (3-5 minutes per generation)
29
+ - **Limited concurrent users**
30
+ - **Reduced sample counts** (max 3 samples)
31
+
32
+ For **faster performance**, consider:
33
+ - Using the original repository with GPU
34
+ - Running locally with CUDA support
35
+ - Upgrading to a GPU Space (if available)
36
+
37
+ ## Features
38
+
39
+ - 🎬 **Video-to-Audio**: Generate audio effects from video content
40
+ - 📝 **Text Guidance**: Control generation with text descriptions
41
+ - 🎯 **Multiple Samples**: Generate up to 3 variations
42
+ - 🔧 **Adjustable Settings**: Control CFG scale and inference steps
43
+ - 📱 **User-Friendly**: Simple drag-and-drop interface
44
+
45
+ ## How to Use
46
+
47
+ 1. **Upload Video**: Drag and drop your video file (MP4, AVI, MOV)
48
+ 2. **Add Description** (Optional): Describe the audio you want to generate
49
+ 3. **Adjust Settings**: Modify CFG scale and inference steps if needed
50
+ 4. **Generate**: Click "Generate Audio" and wait (3-5 minutes on CPU)
51
+ 5. **Download**: Save your generated audio/video combinations
52
+
53
+ ## Tips for Best Results
54
+
55
+ - 📏 **Video Length**: Keep videos under 30 seconds for faster processing
56
+ - 🎯 **Text Prompts**: Use simple, clear descriptions
57
+ - ⚡ **Settings**: Lower values process faster on CPU
58
+ - 🔄 **Multiple Attempts**: Try different settings if not satisfied
59
+
60
+ ## Technical Details
61
+
62
+ - **Model**: HunyuanVideo-Foley-XXL
63
+ - **Architecture**: Multimodal diffusion transformer
64
+ - **Audio Quality**: 48kHz professional-grade output
65
+ - **Deployment**: CPU-optimized for Hugging Face Spaces
66
+
67
+ ## Original Project
68
+
69
+ This is a **CPU deployment** of the original HunyuanVideo-Foley project:
70
+
71
+ - 📄 **Paper**: [HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment](https://arxiv.org/abs/2508.16930)
72
+ - 💻 **GitHub**: [Tencent-Hunyuan/HunyuanVideo-Foley](https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley)
73
+ - 🤗 **Models**: [tencent/HunyuanVideo-Foley](https://huggingface.co/tencent/HunyuanVideo-Foley)
74
+
75
+ ## Citation
76
+
77
+ ```bibtex
78
+ @misc{shan2025hunyuanvideofoleymultimodaldiffusionrepresentation,
79
+ title={HunyuanVideo-Foley: Multimodal Diffusion with Representation Alignment for High-Fidelity Foley Audio Generation},
80
+ author={Sizhe Shan and Qiulin Li and Yutao Cui and Miles Yang and Yuehai Wang and Qun Yang and Jin Zhou and Zhao Zhong},
81
+ year={2025},
82
+ eprint={2508.16930},
83
+ archivePrefix={arXiv},
84
+ primaryClass={eess.AS}
85
+ }
86
+ ```
87
+
88
+ ## License
89
+
90
+ This project is licensed under the Apache 2.0 License.
91
+
92
+ ---
93
+
94
+ <div align="center">
95
+ <p><em>🚀 Powered by Tencent Hunyuan | Optimized for CPU deployment</em></p>
96
+ </div>
app.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import torch
5
+ import torchaudio
6
+ from loguru import logger
7
+ from typing import Optional, Tuple
8
+ import random
9
+ import numpy as np
10
+
11
+ # Force CPU usage for Hugging Face Spaces
12
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
13
+
14
+ from hunyuanvideo_foley.utils.model_utils import load_model
15
+ from hunyuanvideo_foley.utils.feature_utils import feature_process
16
+ from hunyuanvideo_foley.utils.model_utils import denoise_process
17
+ from hunyuanvideo_foley.utils.media_utils import merge_audio_video
18
+
19
+ # Global variables for model storage
20
+ model_dict = None
21
+ cfg = None
22
+ device = None
23
+
24
+ # Model path for Hugging Face Spaces - try to download automatically
25
+ MODEL_PATH = os.environ.get("HIFI_FOLEY_MODEL_PATH", "./pretrained_models/")
26
+ CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
27
+
28
+ def setup_device(force_cpu: bool = True) -> torch.device:
29
+ """Setup computing device - force CPU for Hugging Face Spaces"""
30
+ if force_cpu:
31
+ device = torch.device("cpu")
32
+ logger.info("Using CPU device (forced for Hugging Face Spaces)")
33
+ else:
34
+ if torch.cuda.is_available():
35
+ device = torch.device("cuda:0")
36
+ logger.info("Using CUDA device")
37
+ elif torch.backends.mps.is_available():
38
+ device = torch.device("mps")
39
+ logger.info("Using MPS device")
40
+ else:
41
+ device = torch.device("cpu")
42
+ logger.info("Using CPU device")
43
+
44
+ return device
45
+
46
+ def download_models():
47
+ """Download models from Hugging Face if not present"""
48
+ try:
49
+ from huggingface_hub import snapshot_download
50
+ logger.info("Downloading models from Hugging Face...")
51
+
52
+ # Download the model files
53
+ snapshot_download(
54
+ repo_id="tencent/HunyuanVideo-Foley",
55
+ local_dir="./pretrained_models",
56
+ local_dir_use_symlinks=False
57
+ )
58
+
59
+ logger.info("Model download completed!")
60
+ return True
61
+ except Exception as e:
62
+ logger.error(f"Failed to download models: {str(e)}")
63
+ return False
64
+
65
+ def auto_load_models() -> str:
66
+ """Automatically load preset models"""
67
+ global model_dict, cfg, device
68
+
69
+ try:
70
+ # First try to download models if they don't exist
71
+ if not os.path.exists(MODEL_PATH) or not os.listdir(MODEL_PATH):
72
+ logger.info("Models not found locally, attempting to download...")
73
+ if not download_models():
74
+ return "❌ Failed to download models from Hugging Face"
75
+
76
+ if not os.path.exists(CONFIG_PATH):
77
+ return f"❌ Config file not found: {CONFIG_PATH}"
78
+
79
+ # Force CPU usage for Hugging Face Spaces
80
+ device = setup_device(force_cpu=True)
81
+
82
+ # Load model with CPU optimization
83
+ logger.info("Loading model on CPU...")
84
+ logger.info(f"Model path: {MODEL_PATH}")
85
+ logger.info(f"Config path: {CONFIG_PATH}")
86
+
87
+ # Set torch to use fewer threads for CPU inference
88
+ torch.set_num_threads(2)
89
+
90
+ model_dict, cfg = load_model(MODEL_PATH, CONFIG_PATH, device)
91
+
92
+ logger.info("✅ Model loaded successfully on CPU!")
93
+ return "✅ Model loaded successfully on CPU!"
94
+
95
+ except Exception as e:
96
+ logger.error(f"Model loading failed: {str(e)}")
97
+ return f"❌ Model loading failed: {str(e)}"
98
+
99
+ def infer_single_video(
100
+ video_file,
101
+ text_prompt: str,
102
+ guidance_scale: float = 2.0, # Lower for CPU
103
+ num_inference_steps: int = 20, # Reduced for CPU
104
+ sample_nums: int = 1
105
+ ) -> Tuple[list, str]:
106
+ """Single video inference optimized for CPU"""
107
+ global model_dict, cfg, device
108
+
109
+ if model_dict is None or cfg is None:
110
+ return [], "❌ Please load the model first!"
111
+
112
+ if video_file is None:
113
+ return [], "❌ Please upload a video file!"
114
+
115
+ # Allow empty text prompt
116
+ if text_prompt is None:
117
+ text_prompt = ""
118
+ text_prompt = text_prompt.strip()
119
+
120
+ try:
121
+ logger.info(f"Processing video: {video_file}")
122
+ logger.info(f"Text prompt: {text_prompt}")
123
+ logger.info("Running inference on CPU (this may take a while)...")
124
+
125
+ # Feature processing
126
+ visual_feats, text_feats, audio_len_in_s = feature_process(
127
+ video_file,
128
+ text_prompt,
129
+ model_dict,
130
+ cfg
131
+ )
132
+
133
+ # Denoising process with CPU-optimized settings
134
+ logger.info(f"Generating {sample_nums} audio sample(s) on CPU...")
135
+ audio, sample_rate = denoise_process(
136
+ visual_feats,
137
+ text_feats,
138
+ audio_len_in_s,
139
+ model_dict,
140
+ cfg,
141
+ guidance_scale=guidance_scale,
142
+ num_inference_steps=num_inference_steps,
143
+ batch_size=sample_nums
144
+ )
145
+
146
+ # Create temporary files to save results
147
+ temp_dir = tempfile.mkdtemp()
148
+ video_outputs = []
149
+
150
+ # Process each generated audio sample
151
+ for i in range(sample_nums):
152
+ # Save audio file
153
+ audio_output = os.path.join(temp_dir, f"generated_audio_{i+1}.wav")
154
+ torchaudio.save(audio_output, audio[i], sample_rate)
155
+
156
+ # Merge video and audio
157
+ video_output = os.path.join(temp_dir, f"video_with_audio_{i+1}.mp4")
158
+ merge_audio_video(audio_output, video_file, video_output)
159
+ video_outputs.append(video_output)
160
+
161
+ logger.info(f"Inference completed! Generated {sample_nums} samples.")
162
+ return video_outputs, f"✅ Generated {sample_nums} audio sample(s) successfully on CPU!"
163
+
164
+ except Exception as e:
165
+ logger.error(f"Inference failed: {str(e)}")
166
+ return [], f"❌ Inference failed: {str(e)}"
167
+
168
+ def update_video_outputs(video_list, status_msg):
169
+ """Update video outputs based on the number of generated samples"""
170
+ # Initialize all outputs as None
171
+ outputs = [None] * 3 # Reduced to 3 for CPU
172
+
173
+ # Set values based on generated videos
174
+ for i, video_path in enumerate(video_list[:3]): # Max 3 samples for CPU
175
+ outputs[i] = video_path
176
+
177
+ # Return all outputs plus status message
178
+ return tuple(outputs + [status_msg])
179
+
180
+ def create_gradio_interface():
181
+ """Create Gradio interface optimized for CPU deployment"""
182
+
183
+ # Custom CSS with Hugging Face Spaces styling
184
+ css = """
185
+ .gradio-container {
186
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
187
+ background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
188
+ min-height: 100vh;
189
+ }
190
+
191
+ .main-header {
192
+ text-align: center;
193
+ padding: 2rem 0;
194
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
195
+ border-radius: 20px;
196
+ margin-bottom: 2rem;
197
+ box-shadow: 0 8px 32px rgba(0,0,0,0.15);
198
+ }
199
+
200
+ .main-header h1 {
201
+ color: white;
202
+ font-size: 3rem;
203
+ font-weight: 700;
204
+ margin-bottom: 0.5rem;
205
+ text-shadow: 0 2px 10px rgba(0,0,0,0.3);
206
+ }
207
+
208
+ .main-header p {
209
+ color: rgba(255, 255, 255, 0.95);
210
+ font-size: 1.2rem;
211
+ font-weight: 300;
212
+ }
213
+
214
+ .cpu-notice {
215
+ background: #fff3cd;
216
+ border: 1px solid #ffeaa7;
217
+ border-radius: 10px;
218
+ padding: 1rem;
219
+ margin: 1rem 0;
220
+ color: #856404;
221
+ }
222
+ """
223
+
224
+ with gr.Blocks(css=css, title="HunyuanVideo-Foley (CPU)") as app:
225
+
226
+ # Main header
227
+ with gr.Column(elem_classes=["main-header"]):
228
+ gr.HTML("""
229
+ <h1>🎵 HunyuanVideo-Foley</h1>
230
+ <p>Text-Video-to-Audio Synthesis (CPU Version)</p>
231
+ """)
232
+
233
+ # CPU Notice
234
+ gr.HTML("""
235
+ <div class="cpu-notice">
236
+ <strong>⚠️ CPU Deployment Notice:</strong> This Space runs on CPU which means inference will be slower than GPU version.
237
+ Each generation may take 3-5 minutes. For faster inference, consider running locally with GPU.
238
+ </div>
239
+ """)
240
+
241
+ # Usage Guide
242
+ gr.Markdown("""
243
+ ### 📋 Quick Start Guide
244
+ **1.** Upload your video file **2.** Add optional text description **3.** Click Generate Audio (be patient!)
245
+
246
+ 💡 **Tips for CPU usage:**
247
+ - Use shorter videos (< 30 seconds recommended)
248
+ - Simple text prompts work better
249
+ - Expect longer processing times
250
+ """)
251
+
252
+ # Main interface
253
+ with gr.Row():
254
+ # Input section
255
+ with gr.Column(scale=1):
256
+ gr.Markdown("### 📹 Video Input")
257
+
258
+ video_input = gr.Video(
259
+ label="Upload Video",
260
+ info="Supported formats: MP4, AVI, MOV, etc. Shorter videos recommended for CPU.",
261
+ height=300
262
+ )
263
+
264
+ text_input = gr.Textbox(
265
+ label="🎯 Audio Description (English)",
266
+ placeholder="A person walks on frozen ice",
267
+ lines=3,
268
+ info="Describe the audio you want to generate (optional)"
269
+ )
270
+
271
+ with gr.Row():
272
+ guidance_scale = gr.Slider(
273
+ minimum=1.0,
274
+ maximum=5.0,
275
+ value=2.0,
276
+ step=0.1,
277
+ label="🎚️ CFG Scale (lower for CPU)",
278
+ )
279
+
280
+ inference_steps = gr.Slider(
281
+ minimum=10,
282
+ maximum=50,
283
+ value=20,
284
+ step=5,
285
+ label="⚡ Steps (reduced for CPU)",
286
+ )
287
+
288
+ sample_nums = gr.Slider(
289
+ minimum=1,
290
+ maximum=3,
291
+ value=1,
292
+ step=1,
293
+ label="🎲 Sample Nums (max 3 for CPU)",
294
+ )
295
+
296
+ generate_btn = gr.Button(
297
+ "🎵 Generate Audio (CPU)",
298
+ variant="primary"
299
+ )
300
+
301
+ # Results section
302
+ with gr.Column(scale=1):
303
+ gr.Markdown("### 🎥 Generated Results")
304
+
305
+ # Reduced number of outputs for CPU
306
+ video_output_1 = gr.Video(
307
+ label="Sample 1",
308
+ height=250,
309
+ visible=True
310
+ )
311
+
312
+ with gr.Row():
313
+ video_output_2 = gr.Video(
314
+ label="Sample 2",
315
+ height=200,
316
+ visible=False
317
+ )
318
+ video_output_3 = gr.Video(
319
+ label="Sample 3",
320
+ height=200,
321
+ visible=False
322
+ )
323
+
324
+ result_text = gr.Textbox(
325
+ label="Status",
326
+ interactive=False,
327
+ lines=3
328
+ )
329
+
330
+ # Event handlers
331
+ def process_inference(video_file, text_prompt, guidance_scale, inference_steps, sample_nums):
332
+ # Generate videos
333
+ video_list, status_msg = infer_single_video(
334
+ video_file, text_prompt, guidance_scale, inference_steps, int(sample_nums)
335
+ )
336
+ # Update outputs with proper visibility
337
+ return update_video_outputs(video_list, status_msg)
338
+
339
+ # Add dynamic visibility control
340
+ def update_visibility(sample_nums):
341
+ sample_nums = int(sample_nums)
342
+ return [
343
+ gr.update(visible=True), # Sample 1 always visible
344
+ gr.update(visible=sample_nums >= 2), # Sample 2
345
+ gr.update(visible=sample_nums >= 3), # Sample 3
346
+ ]
347
+
348
+ # Update visibility when sample_nums changes
349
+ sample_nums.change(
350
+ fn=update_visibility,
351
+ inputs=[sample_nums],
352
+ outputs=[video_output_1, video_output_2, video_output_3]
353
+ )
354
+
355
+ generate_btn.click(
356
+ fn=process_inference,
357
+ inputs=[video_input, text_input, guidance_scale, inference_steps, sample_nums],
358
+ outputs=[
359
+ video_output_1,
360
+ video_output_2,
361
+ video_output_3,
362
+ result_text
363
+ ]
364
+ )
365
+
366
+ # Footer
367
+ gr.HTML("""
368
+ <div style="text-align: center; padding: 2rem; color: #666;">
369
+ <p>🚀 Powered by HunyuanVideo-Foley | Running on CPU for Hugging Face Spaces</p>
370
+ <p>For faster inference, visit the <a href="https://github.com/Tencent-Hunyuan/HunyuanVideo-Foley" target="_blank">original repository</a></p>
371
+ </div>
372
+ """)
373
+
374
+ return app
375
+
376
+ def set_manual_seed(global_seed):
377
+ random.seed(global_seed)
378
+ np.random.seed(global_seed)
379
+ torch.manual_seed(global_seed)
380
+
381
+ if __name__ == "__main__":
382
+ set_manual_seed(1)
383
+ # Setup logging
384
+ logger.remove()
385
+ logger.add(lambda msg: print(msg, end=''), level="INFO")
386
+
387
+ # Auto-load model
388
+ logger.info("Starting CPU application and loading model...")
389
+ model_load_result = auto_load_models()
390
+ logger.info(model_load_result)
391
+
392
+ # Create and launch Gradio app
393
+ app = create_gradio_interface()
394
+
395
+ # Log completion status
396
+ if "successfully" in model_load_result:
397
+ logger.info("Application ready, model loaded on CPU")
398
+
399
+ app.launch(
400
+ server_name="0.0.0.0",
401
+ server_port=7860, # Standard port for Hugging Face Spaces
402
+ share=False,
403
+ debug=False,
404
+ show_error=True
405
+ )
configs/hunyuanvideo-foley-xxl.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_config:
2
+ model_name: HunyuanVideo-Foley-XXL
3
+ model_type: 1d
4
+ model_precision: bf16
5
+ model_kwargs:
6
+ depth_triple_blocks: 18
7
+ depth_single_blocks: 36
8
+ hidden_size: 1536
9
+ num_heads: 12
10
+ mlp_ratio: 4
11
+ mlp_act_type: "gelu_tanh"
12
+ qkv_bias: True
13
+ qk_norm: True
14
+ qk_norm_type: "rms"
15
+ attn_mode: "torch"
16
+ embedder_type: "default"
17
+ interleaved_audio_visual_rope: True
18
+ enable_learnable_empty_visual_feat: True
19
+ sync_modulation: False
20
+ add_sync_feat_to_audio: True
21
+ cross_attention: True
22
+ use_attention_mask: False
23
+ condition_projection: "linear"
24
+ sync_feat_dim: 768 # syncformer 768 dim
25
+ condition_dim: 768 # clap 768 text condition dim (clip-text)
26
+ clip_dim: 768 # siglip2 visual dim
27
+ audio_vae_latent_dim: 128
28
+ audio_frame_rate: 50
29
+ patch_size: 1
30
+ rope_dim_list: null
31
+ rope_theta: 10000
32
+ text_length: 77
33
+ clip_length: 64
34
+ sync_length: 192
35
+ use_mmaudio_singleblock: True
36
+ depth_triple_ssl_encoder: null
37
+ depth_single_ssl_encoder: 8
38
+ use_repa_with_audiossl: True
39
+
40
+ diffusion_config:
41
+ denoise_type: "flow"
42
+ flow_path_type: "linear"
43
+ flow_predict_type: "velocity"
44
+ flow_reverse: True
45
+ flow_solver: "euler"
46
+ sample_flow_shift: 1.0
47
+ sample_use_flux_shift: False
48
+ flux_base_shift: 0.5
49
+ flux_max_shift: 1.15
hunyuanvideo_foley/__init__.py ADDED
File without changes
hunyuanvideo_foley/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (149 Bytes). View file
 
hunyuanvideo_foley/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (149 Bytes). View file
 
hunyuanvideo_foley/__pycache__/constants.cpython-312.pyc ADDED
Binary file (1.93 kB). View file
 
hunyuanvideo_foley/__pycache__/constants.cpython-313.pyc ADDED
Binary file (1.93 kB). View file
 
hunyuanvideo_foley/constants.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Constants used throughout the HunyuanVideo-Foley project."""
2
+
3
+ from typing import Dict, List
4
+
5
+ # Model configuration
6
+ DEFAULT_AUDIO_SAMPLE_RATE = 48000
7
+ DEFAULT_VIDEO_FPS = 25
8
+ DEFAULT_AUDIO_CHANNELS = 2
9
+
10
+ # Video processing
11
+ MAX_VIDEO_DURATION_SECONDS = 15.0
12
+ MIN_VIDEO_DURATION_SECONDS = 1.0
13
+
14
+ # Audio processing
15
+ AUDIO_VAE_LATENT_DIM = 128
16
+ AUDIO_FRAME_RATE = 75 # frames per second in latent space
17
+
18
+ # Visual features
19
+ FPS_VISUAL: Dict[str, int] = {
20
+ "siglip2": 8,
21
+ "synchformer": 25
22
+ }
23
+
24
+ # Model paths (can be overridden by environment variables)
25
+ DEFAULT_MODEL_PATH = "./pretrained_models/"
26
+ DEFAULT_CONFIG_PATH = "configs/hunyuanvideo-foley-xxl.yaml"
27
+
28
+ # Inference parameters
29
+ DEFAULT_GUIDANCE_SCALE = 4.5
30
+ DEFAULT_NUM_INFERENCE_STEPS = 50
31
+ MIN_GUIDANCE_SCALE = 1.0
32
+ MAX_GUIDANCE_SCALE = 10.0
33
+ MIN_INFERENCE_STEPS = 10
34
+ MAX_INFERENCE_STEPS = 100
35
+
36
+ # Text processing
37
+ MAX_TEXT_LENGTH = 100
38
+ DEFAULT_NEGATIVE_PROMPT = "noisy, harsh"
39
+
40
+ # File extensions
41
+ SUPPORTED_VIDEO_EXTENSIONS: List[str] = [".mp4", ".avi", ".mov", ".mkv", ".webm"]
42
+ SUPPORTED_AUDIO_EXTENSIONS: List[str] = [".wav", ".mp3", ".flac", ".aac"]
43
+
44
+ # Quality settings
45
+ AUDIO_QUALITY_SETTINGS: Dict[str, List[str]] = {
46
+ "high": ["-b:a", "192k"],
47
+ "medium": ["-b:a", "128k"],
48
+ "low": ["-b:a", "96k"]
49
+ }
50
+
51
+ # Error messages
52
+ ERROR_MESSAGES: Dict[str, str] = {
53
+ "model_not_loaded": "Model is not loaded. Please load the model first.",
54
+ "invalid_video_format": "Unsupported video format. Supported formats: {formats}",
55
+ "video_too_long": f"Video duration exceeds maximum of {MAX_VIDEO_DURATION_SECONDS} seconds",
56
+ "ffmpeg_not_found": "ffmpeg not found. Please install ffmpeg: https://ffmpeg.org/download.html"
57
+ }
hunyuanvideo_foley/models/__init__.py ADDED
File without changes
hunyuanvideo_foley/models/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (156 Bytes). View file
 
hunyuanvideo_foley/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (156 Bytes). View file
 
hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-312.pyc ADDED
Binary file (43.6 kB). View file
 
hunyuanvideo_foley/models/__pycache__/hifi_foley.cpython-313.pyc ADDED
Binary file (43.5 kB). View file
 
hunyuanvideo_foley/models/dac_vae/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.0"
2
+
3
+ # preserved here for legacy reasons
4
+ __model_version__ = "latest"
5
+
6
+ import audiotools
7
+
8
+ audiotools.ml.BaseModel.INTERN += ["dac.**"]
9
+ audiotools.ml.BaseModel.EXTERN += ["einops"]
10
+
11
+
12
+ from . import nn
13
+ from . import model
14
+ from . import utils
15
+ from .model import DAC
16
+ from .model import DACFile
hunyuanvideo_foley/models/dac_vae/__main__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import argbind
4
+
5
+ from .utils import download
6
+ from .utils.decode import decode
7
+ from .utils.encode import encode
8
+
9
+ STAGES = ["encode", "decode", "download"]
10
+
11
+
12
+ def run(stage: str):
13
+ """Run stages.
14
+
15
+ Parameters
16
+ ----------
17
+ stage : str
18
+ Stage to run
19
+ """
20
+ if stage not in STAGES:
21
+ raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}")
22
+ stage_fn = globals()[stage]
23
+
24
+ if stage == "download":
25
+ stage_fn()
26
+ return
27
+
28
+ stage_fn()
29
+
30
+
31
+ if __name__ == "__main__":
32
+ group = sys.argv.pop(1)
33
+ args = argbind.parse_args(group=group)
34
+
35
+ with argbind.scope(args):
36
+ run(group)
hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (681 Bytes). View file
 
hunyuanvideo_foley/models/dac_vae/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (681 Bytes). View file
 
hunyuanvideo_foley/models/dac_vae/model/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base import CodecMixin
2
+ from .base import DACFile
3
+ from .dac import DAC
4
+ from .discriminator import Discriminator
hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (327 Bytes). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (327 Bytes). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-312.pyc ADDED
Binary file (13 kB). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/base.cpython-313.pyc ADDED
Binary file (13.1 kB). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-312.pyc ADDED
Binary file (17.3 kB). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/dac.cpython-313.pyc ADDED
Binary file (16.9 kB). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-312.pyc ADDED
Binary file (11.5 kB). View file
 
hunyuanvideo_foley/models/dac_vae/model/__pycache__/discriminator.cpython-313.pyc ADDED
Binary file (11.4 kB). View file
 
hunyuanvideo_foley/models/dac_vae/model/base.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ import tqdm
9
+ from audiotools import AudioSignal
10
+ from torch import nn
11
+
12
+ SUPPORTED_VERSIONS = ["1.0.0"]
13
+
14
+
15
+ @dataclass
16
+ class DACFile:
17
+ codes: torch.Tensor
18
+
19
+ # Metadata
20
+ chunk_length: int
21
+ original_length: int
22
+ input_db: float
23
+ channels: int
24
+ sample_rate: int
25
+ padding: bool
26
+ dac_version: str
27
+
28
+ def save(self, path):
29
+ artifacts = {
30
+ "codes": self.codes.numpy().astype(np.uint16),
31
+ "metadata": {
32
+ "input_db": self.input_db.numpy().astype(np.float32),
33
+ "original_length": self.original_length,
34
+ "sample_rate": self.sample_rate,
35
+ "chunk_length": self.chunk_length,
36
+ "channels": self.channels,
37
+ "padding": self.padding,
38
+ "dac_version": SUPPORTED_VERSIONS[-1],
39
+ },
40
+ }
41
+ path = Path(path).with_suffix(".dac")
42
+ with open(path, "wb") as f:
43
+ np.save(f, artifacts)
44
+ return path
45
+
46
+ @classmethod
47
+ def load(cls, path):
48
+ artifacts = np.load(path, allow_pickle=True)[()]
49
+ codes = torch.from_numpy(artifacts["codes"].astype(int))
50
+ if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
51
+ raise RuntimeError(
52
+ f"Given file {path} can't be loaded with this version of descript-audio-codec."
53
+ )
54
+ return cls(codes=codes, **artifacts["metadata"])
55
+
56
+
57
+ class CodecMixin:
58
+ @property
59
+ def padding(self):
60
+ if not hasattr(self, "_padding"):
61
+ self._padding = True
62
+ return self._padding
63
+
64
+ @padding.setter
65
+ def padding(self, value):
66
+ assert isinstance(value, bool)
67
+
68
+ layers = [
69
+ l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))
70
+ ]
71
+
72
+ for layer in layers:
73
+ if value:
74
+ if hasattr(layer, "original_padding"):
75
+ layer.padding = layer.original_padding
76
+ else:
77
+ layer.original_padding = layer.padding
78
+ layer.padding = tuple(0 for _ in range(len(layer.padding)))
79
+
80
+ self._padding = value
81
+
82
+ def get_delay(self):
83
+ # Any number works here, delay is invariant to input length
84
+ l_out = self.get_output_length(0)
85
+ L = l_out
86
+
87
+ layers = []
88
+ for layer in self.modules():
89
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
90
+ layers.append(layer)
91
+
92
+ for layer in reversed(layers):
93
+ d = layer.dilation[0]
94
+ k = layer.kernel_size[0]
95
+ s = layer.stride[0]
96
+
97
+ if isinstance(layer, nn.ConvTranspose1d):
98
+ L = ((L - d * (k - 1) - 1) / s) + 1
99
+ elif isinstance(layer, nn.Conv1d):
100
+ L = (L - 1) * s + d * (k - 1) + 1
101
+
102
+ L = math.ceil(L)
103
+
104
+ l_in = L
105
+
106
+ return (l_in - l_out) // 2
107
+
108
+ def get_output_length(self, input_length):
109
+ L = input_length
110
+ # Calculate output length
111
+ for layer in self.modules():
112
+ if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
113
+ d = layer.dilation[0]
114
+ k = layer.kernel_size[0]
115
+ s = layer.stride[0]
116
+
117
+ if isinstance(layer, nn.Conv1d):
118
+ L = ((L - d * (k - 1) - 1) / s) + 1
119
+ elif isinstance(layer, nn.ConvTranspose1d):
120
+ L = (L - 1) * s + d * (k - 1) + 1
121
+
122
+ L = math.floor(L)
123
+ return L
124
+
125
+ @torch.no_grad()
126
+ def compress(
127
+ self,
128
+ audio_path_or_signal: Union[str, Path, AudioSignal],
129
+ win_duration: float = 1.0,
130
+ verbose: bool = False,
131
+ normalize_db: float = -16,
132
+ n_quantizers: int = None,
133
+ ) -> DACFile:
134
+ """Processes an audio signal from a file or AudioSignal object into
135
+ discrete codes. This function processes the signal in short windows,
136
+ using constant GPU memory.
137
+
138
+ Parameters
139
+ ----------
140
+ audio_path_or_signal : Union[str, Path, AudioSignal]
141
+ audio signal to reconstruct
142
+ win_duration : float, optional
143
+ window duration in seconds, by default 5.0
144
+ verbose : bool, optional
145
+ by default False
146
+ normalize_db : float, optional
147
+ normalize db, by default -16
148
+
149
+ Returns
150
+ -------
151
+ DACFile
152
+ Object containing compressed codes and metadata
153
+ required for decompression
154
+ """
155
+ audio_signal = audio_path_or_signal
156
+ if isinstance(audio_signal, (str, Path)):
157
+ audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
158
+
159
+ self.eval()
160
+ original_padding = self.padding
161
+ original_device = audio_signal.device
162
+
163
+ audio_signal = audio_signal.clone()
164
+ audio_signal = audio_signal.to_mono()
165
+ original_sr = audio_signal.sample_rate
166
+
167
+ resample_fn = audio_signal.resample
168
+ loudness_fn = audio_signal.loudness
169
+
170
+ # If audio is > 10 minutes long, use the ffmpeg versions
171
+ if audio_signal.signal_duration >= 10 * 60 * 60:
172
+ resample_fn = audio_signal.ffmpeg_resample
173
+ loudness_fn = audio_signal.ffmpeg_loudness
174
+
175
+ original_length = audio_signal.signal_length
176
+ resample_fn(self.sample_rate)
177
+ input_db = loudness_fn()
178
+
179
+ if normalize_db is not None:
180
+ audio_signal.normalize(normalize_db)
181
+ audio_signal.ensure_max_of_audio()
182
+
183
+ nb, nac, nt = audio_signal.audio_data.shape
184
+ audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
185
+ win_duration = (
186
+ audio_signal.signal_duration if win_duration is None else win_duration
187
+ )
188
+
189
+ if audio_signal.signal_duration <= win_duration:
190
+ # Unchunked compression (used if signal length < win duration)
191
+ self.padding = True
192
+ n_samples = nt
193
+ hop = nt
194
+ else:
195
+ # Chunked inference
196
+ self.padding = False
197
+ # Zero-pad signal on either side by the delay
198
+ audio_signal.zero_pad(self.delay, self.delay)
199
+ n_samples = int(win_duration * self.sample_rate)
200
+ # Round n_samples to nearest hop length multiple
201
+ n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
202
+ hop = self.get_output_length(n_samples)
203
+
204
+ codes = []
205
+ range_fn = range if not verbose else tqdm.trange
206
+
207
+ for i in range_fn(0, nt, hop):
208
+ x = audio_signal[..., i : i + n_samples]
209
+ x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
210
+
211
+ audio_data = x.audio_data.to(self.device)
212
+ audio_data = self.preprocess(audio_data, self.sample_rate)
213
+ _, c, _, _, _ = self.encode(audio_data, n_quantizers)
214
+ codes.append(c.to(original_device))
215
+ chunk_length = c.shape[-1]
216
+
217
+ codes = torch.cat(codes, dim=-1)
218
+
219
+ dac_file = DACFile(
220
+ codes=codes,
221
+ chunk_length=chunk_length,
222
+ original_length=original_length,
223
+ input_db=input_db,
224
+ channels=nac,
225
+ sample_rate=original_sr,
226
+ padding=self.padding,
227
+ dac_version=SUPPORTED_VERSIONS[-1],
228
+ )
229
+
230
+ if n_quantizers is not None:
231
+ codes = codes[:, :n_quantizers, :]
232
+
233
+ self.padding = original_padding
234
+ return dac_file
235
+
236
+ @torch.no_grad()
237
+ def decompress(
238
+ self,
239
+ obj: Union[str, Path, DACFile],
240
+ verbose: bool = False,
241
+ ) -> AudioSignal:
242
+ """Reconstruct audio from a given .dac file
243
+
244
+ Parameters
245
+ ----------
246
+ obj : Union[str, Path, DACFile]
247
+ .dac file location or corresponding DACFile object.
248
+ verbose : bool, optional
249
+ Prints progress if True, by default False
250
+
251
+ Returns
252
+ -------
253
+ AudioSignal
254
+ Object with the reconstructed audio
255
+ """
256
+ self.eval()
257
+ if isinstance(obj, (str, Path)):
258
+ obj = DACFile.load(obj)
259
+
260
+ original_padding = self.padding
261
+ self.padding = obj.padding
262
+
263
+ range_fn = range if not verbose else tqdm.trange
264
+ codes = obj.codes
265
+ original_device = codes.device
266
+ chunk_length = obj.chunk_length
267
+ recons = []
268
+
269
+ for i in range_fn(0, codes.shape[-1], chunk_length):
270
+ c = codes[..., i : i + chunk_length].to(self.device)
271
+ z = self.quantizer.from_codes(c)[0]
272
+ r = self.decode(z)
273
+ recons.append(r.to(original_device))
274
+
275
+ recons = torch.cat(recons, dim=-1)
276
+ recons = AudioSignal(recons, self.sample_rate)
277
+
278
+ resample_fn = recons.resample
279
+ loudness_fn = recons.loudness
280
+
281
+ # If audio is > 10 minutes long, use the ffmpeg versions
282
+ if recons.signal_duration >= 10 * 60 * 60:
283
+ resample_fn = recons.ffmpeg_resample
284
+ loudness_fn = recons.ffmpeg_loudness
285
+
286
+ if obj.input_db is not None:
287
+ recons.normalize(obj.input_db)
288
+
289
+ resample_fn(obj.sample_rate)
290
+
291
+ if obj.original_length is not None:
292
+ recons = recons[..., : obj.original_length]
293
+ loudness_fn()
294
+ recons.audio_data = recons.audio_data.reshape(
295
+ -1, obj.channels, obj.original_length
296
+ )
297
+ else:
298
+ loudness_fn()
299
+
300
+ self.padding = original_padding
301
+ return recons
hunyuanvideo_foley/models/dac_vae/model/dac.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List
3
+ from typing import Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from audiotools.ml import BaseModel
9
+ from torch import nn
10
+
11
+ from .base import CodecMixin
12
+ from ..nn.layers import Snake1d
13
+ from ..nn.layers import WNConv1d
14
+ from ..nn.layers import WNConvTranspose1d
15
+ from ..nn.quantize import ResidualVectorQuantize
16
+ from ..nn.vae_utils import DiagonalGaussianDistribution
17
+
18
+
19
+ def init_weights(m):
20
+ if isinstance(m, nn.Conv1d):
21
+ nn.init.trunc_normal_(m.weight, std=0.02)
22
+ nn.init.constant_(m.bias, 0)
23
+
24
+
25
+ class ResidualUnit(nn.Module):
26
+ def __init__(self, dim: int = 16, dilation: int = 1):
27
+ super().__init__()
28
+ pad = ((7 - 1) * dilation) // 2
29
+ self.block = nn.Sequential(
30
+ Snake1d(dim),
31
+ WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
32
+ Snake1d(dim),
33
+ WNConv1d(dim, dim, kernel_size=1),
34
+ )
35
+
36
+ def forward(self, x):
37
+ y = self.block(x)
38
+ pad = (x.shape[-1] - y.shape[-1]) // 2
39
+ if pad > 0:
40
+ x = x[..., pad:-pad]
41
+ return x + y
42
+
43
+
44
+ class EncoderBlock(nn.Module):
45
+ def __init__(self, dim: int = 16, stride: int = 1):
46
+ super().__init__()
47
+ self.block = nn.Sequential(
48
+ ResidualUnit(dim // 2, dilation=1),
49
+ ResidualUnit(dim // 2, dilation=3),
50
+ ResidualUnit(dim // 2, dilation=9),
51
+ Snake1d(dim // 2),
52
+ WNConv1d(
53
+ dim // 2,
54
+ dim,
55
+ kernel_size=2 * stride,
56
+ stride=stride,
57
+ padding=math.ceil(stride / 2),
58
+ ),
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.block(x)
63
+
64
+
65
+ class Encoder(nn.Module):
66
+ def __init__(
67
+ self,
68
+ d_model: int = 64,
69
+ strides: list = [2, 4, 8, 8],
70
+ d_latent: int = 64,
71
+ ):
72
+ super().__init__()
73
+ # Create first convolution
74
+ self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
75
+
76
+ # Create EncoderBlocks that double channels as they downsample by `stride`
77
+ for stride in strides:
78
+ d_model *= 2
79
+ self.block += [EncoderBlock(d_model, stride=stride)]
80
+
81
+ # Create last convolution
82
+ self.block += [
83
+ Snake1d(d_model),
84
+ WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
85
+ ]
86
+
87
+ # Wrap black into nn.Sequential
88
+ self.block = nn.Sequential(*self.block)
89
+ self.enc_dim = d_model
90
+
91
+ def forward(self, x):
92
+ return self.block(x)
93
+
94
+
95
+ class DecoderBlock(nn.Module):
96
+ def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1):
97
+ super().__init__()
98
+ self.block = nn.Sequential(
99
+ Snake1d(input_dim),
100
+ WNConvTranspose1d(
101
+ input_dim,
102
+ output_dim,
103
+ kernel_size=2 * stride,
104
+ stride=stride,
105
+ padding=math.ceil(stride / 2),
106
+ output_padding=stride % 2,
107
+ ),
108
+ ResidualUnit(output_dim, dilation=1),
109
+ ResidualUnit(output_dim, dilation=3),
110
+ ResidualUnit(output_dim, dilation=9),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.block(x)
115
+
116
+
117
+ class Decoder(nn.Module):
118
+ def __init__(
119
+ self,
120
+ input_channel,
121
+ channels,
122
+ rates,
123
+ d_out: int = 1,
124
+ ):
125
+ super().__init__()
126
+
127
+ # Add first conv layer
128
+ layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
129
+
130
+ # Add upsampling + MRF blocks
131
+ for i, stride in enumerate(rates):
132
+ input_dim = channels // 2**i
133
+ output_dim = channels // 2 ** (i + 1)
134
+ layers += [DecoderBlock(input_dim, output_dim, stride)]
135
+
136
+ # Add final conv layer
137
+ layers += [
138
+ Snake1d(output_dim),
139
+ WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
140
+ nn.Tanh(),
141
+ ]
142
+
143
+ self.model = nn.Sequential(*layers)
144
+
145
+ def forward(self, x):
146
+ return self.model(x)
147
+
148
+
149
+ class DAC(BaseModel, CodecMixin):
150
+ def __init__(
151
+ self,
152
+ encoder_dim: int = 64,
153
+ encoder_rates: List[int] = [2, 4, 8, 8],
154
+ latent_dim: int = None,
155
+ decoder_dim: int = 1536,
156
+ decoder_rates: List[int] = [8, 8, 4, 2],
157
+ n_codebooks: int = 9,
158
+ codebook_size: int = 1024,
159
+ codebook_dim: Union[int, list] = 8,
160
+ quantizer_dropout: bool = False,
161
+ sample_rate: int = 44100,
162
+ continuous: bool = False,
163
+ ):
164
+ super().__init__()
165
+
166
+ self.encoder_dim = encoder_dim
167
+ self.encoder_rates = encoder_rates
168
+ self.decoder_dim = decoder_dim
169
+ self.decoder_rates = decoder_rates
170
+ self.sample_rate = sample_rate
171
+ self.continuous = continuous
172
+
173
+ if latent_dim is None:
174
+ latent_dim = encoder_dim * (2 ** len(encoder_rates))
175
+
176
+ self.latent_dim = latent_dim
177
+
178
+ self.hop_length = np.prod(encoder_rates)
179
+ self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
180
+
181
+ if not continuous:
182
+ self.n_codebooks = n_codebooks
183
+ self.codebook_size = codebook_size
184
+ self.codebook_dim = codebook_dim
185
+ self.quantizer = ResidualVectorQuantize(
186
+ input_dim=latent_dim,
187
+ n_codebooks=n_codebooks,
188
+ codebook_size=codebook_size,
189
+ codebook_dim=codebook_dim,
190
+ quantizer_dropout=quantizer_dropout,
191
+ )
192
+ else:
193
+ self.quant_conv = torch.nn.Conv1d(latent_dim, 2 * latent_dim, 1)
194
+ self.post_quant_conv = torch.nn.Conv1d(latent_dim, latent_dim, 1)
195
+
196
+ self.decoder = Decoder(
197
+ latent_dim,
198
+ decoder_dim,
199
+ decoder_rates,
200
+ )
201
+ self.sample_rate = sample_rate
202
+ self.apply(init_weights)
203
+
204
+ self.delay = self.get_delay()
205
+
206
+ @property
207
+ def dtype(self):
208
+ """Get the dtype of the model parameters."""
209
+ # Return the dtype of the first parameter found
210
+ for param in self.parameters():
211
+ return param.dtype
212
+ return torch.float32 # fallback
213
+
214
+ @property
215
+ def device(self):
216
+ """Get the device of the model parameters."""
217
+ # Return the device of the first parameter found
218
+ for param in self.parameters():
219
+ return param.device
220
+ return torch.device('cpu') # fallback
221
+
222
+ def preprocess(self, audio_data, sample_rate):
223
+ if sample_rate is None:
224
+ sample_rate = self.sample_rate
225
+ assert sample_rate == self.sample_rate
226
+
227
+ length = audio_data.shape[-1]
228
+ right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
229
+ audio_data = nn.functional.pad(audio_data, (0, right_pad))
230
+
231
+ return audio_data
232
+
233
+ def encode(
234
+ self,
235
+ audio_data: torch.Tensor,
236
+ n_quantizers: int = None,
237
+ ):
238
+ """Encode given audio data and return quantized latent codes
239
+
240
+ Parameters
241
+ ----------
242
+ audio_data : Tensor[B x 1 x T]
243
+ Audio data to encode
244
+ n_quantizers : int, optional
245
+ Number of quantizers to use, by default None
246
+ If None, all quantizers are used.
247
+
248
+ Returns
249
+ -------
250
+ dict
251
+ A dictionary with the following keys:
252
+ "z" : Tensor[B x D x T]
253
+ Quantized continuous representation of input
254
+ "codes" : Tensor[B x N x T]
255
+ Codebook indices for each codebook
256
+ (quantized discrete representation of input)
257
+ "latents" : Tensor[B x N*D x T]
258
+ Projected latents (continuous representation of input before quantization)
259
+ "vq/commitment_loss" : Tensor[1]
260
+ Commitment loss to train encoder to predict vectors closer to codebook
261
+ entries
262
+ "vq/codebook_loss" : Tensor[1]
263
+ Codebook loss to update the codebook
264
+ "length" : int
265
+ Number of samples in input audio
266
+ """
267
+ z = self.encoder(audio_data) # [B x D x T]
268
+ if not self.continuous:
269
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
270
+ else:
271
+ z = self.quant_conv(z) # [B x 2D x T]
272
+ z = DiagonalGaussianDistribution(z)
273
+ codes, latents, commitment_loss, codebook_loss = None, None, 0, 0
274
+
275
+ return z, codes, latents, commitment_loss, codebook_loss
276
+
277
+ def decode(self, z: torch.Tensor):
278
+ """Decode given latent codes and return audio data
279
+
280
+ Parameters
281
+ ----------
282
+ z : Tensor[B x D x T]
283
+ Quantized continuous representation of input
284
+ length : int, optional
285
+ Number of samples in output audio, by default None
286
+
287
+ Returns
288
+ -------
289
+ dict
290
+ A dictionary with the following keys:
291
+ "audio" : Tensor[B x 1 x length]
292
+ Decoded audio data.
293
+ """
294
+ if not self.continuous:
295
+ audio = self.decoder(z)
296
+ else:
297
+ z = self.post_quant_conv(z)
298
+ audio = self.decoder(z)
299
+
300
+ return audio
301
+
302
+ def forward(
303
+ self,
304
+ audio_data: torch.Tensor,
305
+ sample_rate: int = None,
306
+ n_quantizers: int = None,
307
+ ):
308
+ """Model forward pass
309
+
310
+ Parameters
311
+ ----------
312
+ audio_data : Tensor[B x 1 x T]
313
+ Audio data to encode
314
+ sample_rate : int, optional
315
+ Sample rate of audio data in Hz, by default None
316
+ If None, defaults to `self.sample_rate`
317
+ n_quantizers : int, optional
318
+ Number of quantizers to use, by default None.
319
+ If None, all quantizers are used.
320
+
321
+ Returns
322
+ -------
323
+ dict
324
+ A dictionary with the following keys:
325
+ "z" : Tensor[B x D x T]
326
+ Quantized continuous representation of input
327
+ "codes" : Tensor[B x N x T]
328
+ Codebook indices for each codebook
329
+ (quantized discrete representation of input)
330
+ "latents" : Tensor[B x N*D x T]
331
+ Projected latents (continuous representation of input before quantization)
332
+ "vq/commitment_loss" : Tensor[1]
333
+ Commitment loss to train encoder to predict vectors closer to codebook
334
+ entries
335
+ "vq/codebook_loss" : Tensor[1]
336
+ Codebook loss to update the codebook
337
+ "length" : int
338
+ Number of samples in input audio
339
+ "audio" : Tensor[B x 1 x length]
340
+ Decoded audio data.
341
+ """
342
+ length = audio_data.shape[-1]
343
+ audio_data = self.preprocess(audio_data, sample_rate)
344
+ if not self.continuous:
345
+ z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
346
+
347
+ x = self.decode(z)
348
+ return {
349
+ "audio": x[..., :length],
350
+ "z": z,
351
+ "codes": codes,
352
+ "latents": latents,
353
+ "vq/commitment_loss": commitment_loss,
354
+ "vq/codebook_loss": codebook_loss,
355
+ }
356
+ else:
357
+ posterior, _, _, _, _ = self.encode(audio_data, n_quantizers)
358
+ z = posterior.sample()
359
+ x = self.decode(z)
360
+
361
+ kl_loss = posterior.kl()
362
+ kl_loss = kl_loss.mean()
363
+
364
+ return {
365
+ "audio": x[..., :length],
366
+ "z": z,
367
+ "kl_loss": kl_loss,
368
+ }
369
+
370
+
371
+ if __name__ == "__main__":
372
+ import numpy as np
373
+ from functools import partial
374
+
375
+ model = DAC().to("cpu")
376
+
377
+ for n, m in model.named_modules():
378
+ o = m.extra_repr()
379
+ p = sum([np.prod(p.size()) for p in m.parameters()])
380
+ fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
381
+ setattr(m, "extra_repr", partial(fn, o=o, p=p))
382
+ print(model)
383
+ print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
384
+
385
+ length = 88200 * 2
386
+ x = torch.randn(1, 1, length).to(model.device)
387
+ x.requires_grad_(True)
388
+ x.retain_grad()
389
+
390
+ # Make a forward pass
391
+ out = model(x)["audio"]
392
+ print("Input shape:", x.shape)
393
+ print("Output shape:", out.shape)
394
+
395
+ # Create gradient variable
396
+ grad = torch.zeros_like(out)
397
+ grad[:, :, grad.shape[-1] // 2] = 1
398
+
399
+ # Make a backward pass
400
+ out.backward(grad)
401
+
402
+ # Check non-zero values
403
+ gradmap = x.grad.squeeze(0)
404
+ gradmap = (gradmap != 0).sum(0) # sum across features
405
+ rf = (gradmap != 0).sum()
406
+
407
+ print(f"Receptive field: {rf.item()}")
408
+
409
+ x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
410
+ model.decompress(model.compress(x, verbose=True), verbose=True)
hunyuanvideo_foley/models/dac_vae/model/discriminator.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from audiotools import AudioSignal
5
+ from audiotools import ml
6
+ from audiotools import STFTParams
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def WNConv1d(*args, **kwargs):
12
+ act = kwargs.pop("act", True)
13
+ conv = weight_norm(nn.Conv1d(*args, **kwargs))
14
+ if not act:
15
+ return conv
16
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
17
+
18
+
19
+ def WNConv2d(*args, **kwargs):
20
+ act = kwargs.pop("act", True)
21
+ conv = weight_norm(nn.Conv2d(*args, **kwargs))
22
+ if not act:
23
+ return conv
24
+ return nn.Sequential(conv, nn.LeakyReLU(0.1))
25
+
26
+
27
+ class MPD(nn.Module):
28
+ def __init__(self, period):
29
+ super().__init__()
30
+ self.period = period
31
+ self.convs = nn.ModuleList(
32
+ [
33
+ WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
34
+ WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
35
+ WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)),
36
+ WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)),
37
+ WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)),
38
+ ]
39
+ )
40
+ self.conv_post = WNConv2d(
41
+ 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False
42
+ )
43
+
44
+ def pad_to_period(self, x):
45
+ t = x.shape[-1]
46
+ x = F.pad(x, (0, self.period - t % self.period), mode="reflect")
47
+ return x
48
+
49
+ def forward(self, x):
50
+ fmap = []
51
+
52
+ x = self.pad_to_period(x)
53
+ x = rearrange(x, "b c (l p) -> b c l p", p=self.period)
54
+
55
+ for layer in self.convs:
56
+ x = layer(x)
57
+ fmap.append(x)
58
+
59
+ x = self.conv_post(x)
60
+ fmap.append(x)
61
+
62
+ return fmap
63
+
64
+
65
+ class MSD(nn.Module):
66
+ def __init__(self, rate: int = 1, sample_rate: int = 44100):
67
+ super().__init__()
68
+ self.convs = nn.ModuleList(
69
+ [
70
+ WNConv1d(1, 16, 15, 1, padding=7),
71
+ WNConv1d(16, 64, 41, 4, groups=4, padding=20),
72
+ WNConv1d(64, 256, 41, 4, groups=16, padding=20),
73
+ WNConv1d(256, 1024, 41, 4, groups=64, padding=20),
74
+ WNConv1d(1024, 1024, 41, 4, groups=256, padding=20),
75
+ WNConv1d(1024, 1024, 5, 1, padding=2),
76
+ ]
77
+ )
78
+ self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False)
79
+ self.sample_rate = sample_rate
80
+ self.rate = rate
81
+
82
+ def forward(self, x):
83
+ x = AudioSignal(x, self.sample_rate)
84
+ x.resample(self.sample_rate // self.rate)
85
+ x = x.audio_data
86
+
87
+ fmap = []
88
+
89
+ for l in self.convs:
90
+ x = l(x)
91
+ fmap.append(x)
92
+ x = self.conv_post(x)
93
+ fmap.append(x)
94
+
95
+ return fmap
96
+
97
+
98
+ BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)]
99
+
100
+
101
+ class MRD(nn.Module):
102
+ def __init__(
103
+ self,
104
+ window_length: int,
105
+ hop_factor: float = 0.25,
106
+ sample_rate: int = 44100,
107
+ bands: list = BANDS,
108
+ ):
109
+ """Complex multi-band spectrogram discriminator.
110
+ Parameters
111
+ ----------
112
+ window_length : int
113
+ Window length of STFT.
114
+ hop_factor : float, optional
115
+ Hop factor of the STFT, defaults to ``0.25 * window_length``.
116
+ sample_rate : int, optional
117
+ Sampling rate of audio in Hz, by default 44100
118
+ bands : list, optional
119
+ Bands to run discriminator over.
120
+ """
121
+ super().__init__()
122
+
123
+ self.window_length = window_length
124
+ self.hop_factor = hop_factor
125
+ self.sample_rate = sample_rate
126
+ self.stft_params = STFTParams(
127
+ window_length=window_length,
128
+ hop_length=int(window_length * hop_factor),
129
+ match_stride=True,
130
+ )
131
+
132
+ n_fft = window_length // 2 + 1
133
+ bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
134
+ self.bands = bands
135
+
136
+ ch = 32
137
+ convs = lambda: nn.ModuleList(
138
+ [
139
+ WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)),
140
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
141
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
142
+ WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)),
143
+ WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)),
144
+ ]
145
+ )
146
+ self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
147
+ self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False)
148
+
149
+ def spectrogram(self, x):
150
+ x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params)
151
+ x = torch.view_as_real(x.stft())
152
+ x = rearrange(x, "b 1 f t c -> (b 1) c t f")
153
+ # Split into bands
154
+ x_bands = [x[..., b[0] : b[1]] for b in self.bands]
155
+ return x_bands
156
+
157
+ def forward(self, x):
158
+ x_bands = self.spectrogram(x)
159
+ fmap = []
160
+
161
+ x = []
162
+ for band, stack in zip(x_bands, self.band_convs):
163
+ for layer in stack:
164
+ band = layer(band)
165
+ fmap.append(band)
166
+ x.append(band)
167
+
168
+ x = torch.cat(x, dim=-1)
169
+ x = self.conv_post(x)
170
+ fmap.append(x)
171
+
172
+ return fmap
173
+
174
+
175
+ class Discriminator(ml.BaseModel):
176
+ def __init__(
177
+ self,
178
+ rates: list = [],
179
+ periods: list = [2, 3, 5, 7, 11],
180
+ fft_sizes: list = [2048, 1024, 512],
181
+ sample_rate: int = 44100,
182
+ bands: list = BANDS,
183
+ ):
184
+ """Discriminator that combines multiple discriminators.
185
+
186
+ Parameters
187
+ ----------
188
+ rates : list, optional
189
+ sampling rates (in Hz) to run MSD at, by default []
190
+ If empty, MSD is not used.
191
+ periods : list, optional
192
+ periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11]
193
+ fft_sizes : list, optional
194
+ Window sizes of the FFT to run MRD at, by default [2048, 1024, 512]
195
+ sample_rate : int, optional
196
+ Sampling rate of audio in Hz, by default 44100
197
+ bands : list, optional
198
+ Bands to run MRD at, by default `BANDS`
199
+ """
200
+ super().__init__()
201
+ discs = []
202
+ discs += [MPD(p) for p in periods]
203
+ discs += [MSD(r, sample_rate=sample_rate) for r in rates]
204
+ discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes]
205
+ self.discriminators = nn.ModuleList(discs)
206
+
207
+ def preprocess(self, y):
208
+ # Remove DC offset
209
+ y = y - y.mean(dim=-1, keepdims=True)
210
+ # Peak normalize the volume of input audio
211
+ y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
212
+ return y
213
+
214
+ def forward(self, x):
215
+ x = self.preprocess(x)
216
+ fmaps = [d(x) for d in self.discriminators]
217
+ return fmaps
218
+
219
+
220
+ if __name__ == "__main__":
221
+ disc = Discriminator()
222
+ x = torch.zeros(1, 1, 44100)
223
+ results = disc(x)
224
+ for i, result in enumerate(results):
225
+ print(f"disc{i}")
226
+ for i, r in enumerate(result):
227
+ print(r.shape, r.mean(), r.min(), r.max())
228
+ print()
hunyuanvideo_foley/models/dac_vae/nn/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import layers
2
+ from . import loss
3
+ from . import quantize
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (261 Bytes). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (261 Bytes). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-312.pyc ADDED
Binary file (2.2 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/layers.cpython-313.pyc ADDED
Binary file (2.27 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-312.pyc ADDED
Binary file (16.4 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/loss.cpython-313.pyc ADDED
Binary file (15.9 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-312.pyc ADDED
Binary file (12.4 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/quantize.cpython-313.pyc ADDED
Binary file (11.9 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-312.pyc ADDED
Binary file (5.96 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/__pycache__/vae_utils.cpython-313.pyc ADDED
Binary file (6.12 kB). View file
 
hunyuanvideo_foley/models/dac_vae/nn/layers.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+ from torch.nn.utils import weight_norm
7
+
8
+
9
+ def WNConv1d(*args, **kwargs):
10
+ return weight_norm(nn.Conv1d(*args, **kwargs))
11
+
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+
17
+ # Scripting this brings model speed up 1.4x
18
+ @torch.jit.script
19
+ def snake(x, alpha):
20
+ shape = x.shape
21
+ x = x.reshape(shape[0], shape[1], -1)
22
+ x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
23
+ x = x.reshape(shape)
24
+ return x
25
+
26
+
27
+ class Snake1d(nn.Module):
28
+ def __init__(self, channels):
29
+ super().__init__()
30
+ self.alpha = nn.Parameter(torch.ones(1, channels, 1))
31
+
32
+ def forward(self, x):
33
+ return snake(x, self.alpha)
hunyuanvideo_foley/models/dac_vae/nn/loss.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from audiotools import AudioSignal
7
+ from audiotools import STFTParams
8
+ from torch import nn
9
+
10
+
11
+ class L1Loss(nn.L1Loss):
12
+ """L1 Loss between AudioSignals. Defaults
13
+ to comparing ``audio_data``, but any
14
+ attribute of an AudioSignal can be used.
15
+
16
+ Parameters
17
+ ----------
18
+ attribute : str, optional
19
+ Attribute of signal to compare, defaults to ``audio_data``.
20
+ weight : float, optional
21
+ Weight of this loss, defaults to 1.0.
22
+
23
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
24
+ """
25
+
26
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
27
+ self.attribute = attribute
28
+ self.weight = weight
29
+ super().__init__(**kwargs)
30
+
31
+ def forward(self, x: AudioSignal, y: AudioSignal):
32
+ """
33
+ Parameters
34
+ ----------
35
+ x : AudioSignal
36
+ Estimate AudioSignal
37
+ y : AudioSignal
38
+ Reference AudioSignal
39
+
40
+ Returns
41
+ -------
42
+ torch.Tensor
43
+ L1 loss between AudioSignal attributes.
44
+ """
45
+ if isinstance(x, AudioSignal):
46
+ x = getattr(x, self.attribute)
47
+ y = getattr(y, self.attribute)
48
+ return super().forward(x, y)
49
+
50
+
51
+ class SISDRLoss(nn.Module):
52
+ """
53
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
54
+ of estimated and reference audio signals or aligned features.
55
+
56
+ Parameters
57
+ ----------
58
+ scaling : int, optional
59
+ Whether to use scale-invariant (True) or
60
+ signal-to-noise ratio (False), by default True
61
+ reduction : str, optional
62
+ How to reduce across the batch (either 'mean',
63
+ 'sum', or none).], by default ' mean'
64
+ zero_mean : int, optional
65
+ Zero mean the references and estimates before
66
+ computing the loss, by default True
67
+ clip_min : int, optional
68
+ The minimum possible loss value. Helps network
69
+ to not focus on making already good examples better, by default None
70
+ weight : float, optional
71
+ Weight of this loss, defaults to 1.0.
72
+
73
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ scaling: int = True,
79
+ reduction: str = "mean",
80
+ zero_mean: int = True,
81
+ clip_min: int = None,
82
+ weight: float = 1.0,
83
+ ):
84
+ self.scaling = scaling
85
+ self.reduction = reduction
86
+ self.zero_mean = zero_mean
87
+ self.clip_min = clip_min
88
+ self.weight = weight
89
+ super().__init__()
90
+
91
+ def forward(self, x: AudioSignal, y: AudioSignal):
92
+ eps = 1e-8
93
+ # nb, nc, nt
94
+ if isinstance(x, AudioSignal):
95
+ references = x.audio_data
96
+ estimates = y.audio_data
97
+ else:
98
+ references = x
99
+ estimates = y
100
+
101
+ nb = references.shape[0]
102
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
103
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
104
+
105
+ # samples now on axis 1
106
+ if self.zero_mean:
107
+ mean_reference = references.mean(dim=1, keepdim=True)
108
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
109
+ else:
110
+ mean_reference = 0
111
+ mean_estimate = 0
112
+
113
+ _references = references - mean_reference
114
+ _estimates = estimates - mean_estimate
115
+
116
+ references_projection = (_references**2).sum(dim=-2) + eps
117
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
118
+
119
+ scale = (
120
+ (references_on_estimates / references_projection).unsqueeze(1)
121
+ if self.scaling
122
+ else 1
123
+ )
124
+
125
+ e_true = scale * _references
126
+ e_res = _estimates - e_true
127
+
128
+ signal = (e_true**2).sum(dim=1)
129
+ noise = (e_res**2).sum(dim=1)
130
+ sdr = -10 * torch.log10(signal / noise + eps)
131
+
132
+ if self.clip_min is not None:
133
+ sdr = torch.clamp(sdr, min=self.clip_min)
134
+
135
+ if self.reduction == "mean":
136
+ sdr = sdr.mean()
137
+ elif self.reduction == "sum":
138
+ sdr = sdr.sum()
139
+ return sdr
140
+
141
+
142
+ class MultiScaleSTFTLoss(nn.Module):
143
+ """Computes the multi-scale STFT loss from [1].
144
+
145
+ Parameters
146
+ ----------
147
+ window_lengths : List[int], optional
148
+ Length of each window of each STFT, by default [2048, 512]
149
+ loss_fn : typing.Callable, optional
150
+ How to compare each loss, by default nn.L1Loss()
151
+ clamp_eps : float, optional
152
+ Clamp on the log magnitude, below, by default 1e-5
153
+ mag_weight : float, optional
154
+ Weight of raw magnitude portion of loss, by default 1.0
155
+ log_weight : float, optional
156
+ Weight of log magnitude portion of loss, by default 1.0
157
+ pow : float, optional
158
+ Power to raise magnitude to before taking log, by default 2.0
159
+ weight : float, optional
160
+ Weight of this loss, by default 1.0
161
+ match_stride : bool, optional
162
+ Whether to match the stride of convolutional layers, by default False
163
+
164
+ References
165
+ ----------
166
+
167
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
168
+ "DDSP: Differentiable Digital Signal Processing."
169
+ International Conference on Learning Representations. 2019.
170
+
171
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ window_lengths: List[int] = [2048, 512],
177
+ loss_fn: typing.Callable = nn.L1Loss(),
178
+ clamp_eps: float = 1e-5,
179
+ mag_weight: float = 1.0,
180
+ log_weight: float = 1.0,
181
+ pow: float = 2.0,
182
+ weight: float = 1.0,
183
+ match_stride: bool = False,
184
+ window_type: str = None,
185
+ ):
186
+ super().__init__()
187
+ self.stft_params = [
188
+ STFTParams(
189
+ window_length=w,
190
+ hop_length=w // 4,
191
+ match_stride=match_stride,
192
+ window_type=window_type,
193
+ )
194
+ for w in window_lengths
195
+ ]
196
+ self.loss_fn = loss_fn
197
+ self.log_weight = log_weight
198
+ self.mag_weight = mag_weight
199
+ self.clamp_eps = clamp_eps
200
+ self.weight = weight
201
+ self.pow = pow
202
+
203
+ def forward(self, x: AudioSignal, y: AudioSignal):
204
+ """Computes multi-scale STFT between an estimate and a reference
205
+ signal.
206
+
207
+ Parameters
208
+ ----------
209
+ x : AudioSignal
210
+ Estimate signal
211
+ y : AudioSignal
212
+ Reference signal
213
+
214
+ Returns
215
+ -------
216
+ torch.Tensor
217
+ Multi-scale STFT loss.
218
+ """
219
+ loss = 0.0
220
+ for s in self.stft_params:
221
+ x.stft(s.window_length, s.hop_length, s.window_type)
222
+ y.stft(s.window_length, s.hop_length, s.window_type)
223
+ loss += self.log_weight * self.loss_fn(
224
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
225
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
226
+ )
227
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
228
+ return loss
229
+
230
+
231
+ class MelSpectrogramLoss(nn.Module):
232
+ """Compute distance between mel spectrograms. Can be used
233
+ in a multi-scale way.
234
+
235
+ Parameters
236
+ ----------
237
+ n_mels : List[int]
238
+ Number of mels per STFT, by default [150, 80],
239
+ window_lengths : List[int], optional
240
+ Length of each window of each STFT, by default [2048, 512]
241
+ loss_fn : typing.Callable, optional
242
+ How to compare each loss, by default nn.L1Loss()
243
+ clamp_eps : float, optional
244
+ Clamp on the log magnitude, below, by default 1e-5
245
+ mag_weight : float, optional
246
+ Weight of raw magnitude portion of loss, by default 1.0
247
+ log_weight : float, optional
248
+ Weight of log magnitude portion of loss, by default 1.0
249
+ pow : float, optional
250
+ Power to raise magnitude to before taking log, by default 2.0
251
+ weight : float, optional
252
+ Weight of this loss, by default 1.0
253
+ match_stride : bool, optional
254
+ Whether to match the stride of convolutional layers, by default False
255
+
256
+ Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
257
+ """
258
+
259
+ def __init__(
260
+ self,
261
+ n_mels: List[int] = [150, 80],
262
+ window_lengths: List[int] = [2048, 512],
263
+ loss_fn: typing.Callable = nn.L1Loss(),
264
+ clamp_eps: float = 1e-5,
265
+ mag_weight: float = 1.0,
266
+ log_weight: float = 1.0,
267
+ pow: float = 2.0,
268
+ weight: float = 1.0,
269
+ match_stride: bool = False,
270
+ mel_fmin: List[float] = [0.0, 0.0],
271
+ mel_fmax: List[float] = [None, None],
272
+ window_type: str = None,
273
+ ):
274
+ super().__init__()
275
+ self.stft_params = [
276
+ STFTParams(
277
+ window_length=w,
278
+ hop_length=w // 4,
279
+ match_stride=match_stride,
280
+ window_type=window_type,
281
+ )
282
+ for w in window_lengths
283
+ ]
284
+ self.n_mels = n_mels
285
+ self.loss_fn = loss_fn
286
+ self.clamp_eps = clamp_eps
287
+ self.log_weight = log_weight
288
+ self.mag_weight = mag_weight
289
+ self.weight = weight
290
+ self.mel_fmin = mel_fmin
291
+ self.mel_fmax = mel_fmax
292
+ self.pow = pow
293
+
294
+ def forward(self, x: AudioSignal, y: AudioSignal):
295
+ """Computes mel loss between an estimate and a reference
296
+ signal.
297
+
298
+ Parameters
299
+ ----------
300
+ x : AudioSignal
301
+ Estimate signal
302
+ y : AudioSignal
303
+ Reference signal
304
+
305
+ Returns
306
+ -------
307
+ torch.Tensor
308
+ Mel loss.
309
+ """
310
+ loss = 0.0
311
+ for n_mels, fmin, fmax, s in zip(
312
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
313
+ ):
314
+ kwargs = {
315
+ "window_length": s.window_length,
316
+ "hop_length": s.hop_length,
317
+ "window_type": s.window_type,
318
+ }
319
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
320
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
321
+
322
+ loss += self.log_weight * self.loss_fn(
323
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
324
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
325
+ )
326
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
327
+ return loss
328
+
329
+
330
+ class GANLoss(nn.Module):
331
+ """
332
+ Computes a discriminator loss, given a discriminator on
333
+ generated waveforms/spectrograms compared to ground truth
334
+ waveforms/spectrograms. Computes the loss for both the
335
+ discriminator and the generator in separate functions.
336
+ """
337
+
338
+ def __init__(self, discriminator):
339
+ super().__init__()
340
+ self.discriminator = discriminator
341
+
342
+ def forward(self, fake, real):
343
+ d_fake = self.discriminator(fake.audio_data)
344
+ d_real = self.discriminator(real.audio_data)
345
+ return d_fake, d_real
346
+
347
+ def discriminator_loss(self, fake, real):
348
+ d_fake, d_real = self.forward(fake.clone().detach(), real)
349
+
350
+ loss_d = 0
351
+ for x_fake, x_real in zip(d_fake, d_real):
352
+ loss_d += torch.mean(x_fake[-1] ** 2)
353
+ loss_d += torch.mean((1 - x_real[-1]) ** 2)
354
+ return loss_d
355
+
356
+ def generator_loss(self, fake, real):
357
+ d_fake, d_real = self.forward(fake, real)
358
+
359
+ loss_g = 0
360
+ for x_fake in d_fake:
361
+ loss_g += torch.mean((1 - x_fake[-1]) ** 2)
362
+
363
+ loss_feature = 0
364
+
365
+ for i in range(len(d_fake)):
366
+ for j in range(len(d_fake[i]) - 1):
367
+ loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach())
368
+ return loss_g, loss_feature
hunyuanvideo_foley/models/dac_vae/nn/quantize.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ from .layers import WNConv1d
11
+
12
+
13
+ class VectorQuantize(nn.Module):
14
+ """
15
+ Implementation of VQ similar to Karpathy's repo:
16
+ https://github.com/karpathy/deep-vector-quantization
17
+ Additionally uses following tricks from Improved VQGAN
18
+ (https://arxiv.org/pdf/2110.04627.pdf):
19
+ 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
20
+ for improved codebook usage
21
+ 2. l2-normalized codes: Converts euclidean distance to cosine similarity which
22
+ improves training stability
23
+ """
24
+
25
+ def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
26
+ super().__init__()
27
+ self.codebook_size = codebook_size
28
+ self.codebook_dim = codebook_dim
29
+
30
+ self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
31
+ self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
32
+ self.codebook = nn.Embedding(codebook_size, codebook_dim)
33
+
34
+ def forward(self, z):
35
+ """Quantized the input tensor using a fixed codebook and returns
36
+ the corresponding codebook vectors
37
+
38
+ Parameters
39
+ ----------
40
+ z : Tensor[B x D x T]
41
+
42
+ Returns
43
+ -------
44
+ Tensor[B x D x T]
45
+ Quantized continuous representation of input
46
+ Tensor[1]
47
+ Commitment loss to train encoder to predict vectors closer to codebook
48
+ entries
49
+ Tensor[1]
50
+ Codebook loss to update the codebook
51
+ Tensor[B x T]
52
+ Codebook indices (quantized discrete representation of input)
53
+ Tensor[B x D x T]
54
+ Projected latents (continuous representation of input before quantization)
55
+ """
56
+
57
+ # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
58
+ z_e = self.in_proj(z) # z_e : (B x D x T)
59
+ z_q, indices = self.decode_latents(z_e)
60
+
61
+ commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
62
+ codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
63
+
64
+ z_q = (
65
+ z_e + (z_q - z_e).detach()
66
+ ) # noop in forward pass, straight-through gradient estimator in backward pass
67
+
68
+ z_q = self.out_proj(z_q)
69
+
70
+ return z_q, commitment_loss, codebook_loss, indices, z_e
71
+
72
+ def embed_code(self, embed_id):
73
+ return F.embedding(embed_id, self.codebook.weight)
74
+
75
+ def decode_code(self, embed_id):
76
+ return self.embed_code(embed_id).transpose(1, 2)
77
+
78
+ def decode_latents(self, latents):
79
+ encodings = rearrange(latents, "b d t -> (b t) d")
80
+ codebook = self.codebook.weight # codebook: (N x D)
81
+
82
+ # L2 normalize encodings and codebook (ViT-VQGAN)
83
+ encodings = F.normalize(encodings)
84
+ codebook = F.normalize(codebook)
85
+
86
+ # Compute euclidean distance with codebook
87
+ dist = (
88
+ encodings.pow(2).sum(1, keepdim=True)
89
+ - 2 * encodings @ codebook.t()
90
+ + codebook.pow(2).sum(1, keepdim=True).t()
91
+ )
92
+ indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
93
+ z_q = self.decode_code(indices)
94
+ return z_q, indices
95
+
96
+
97
+ class ResidualVectorQuantize(nn.Module):
98
+ """
99
+ Introduced in SoundStream: An end2end neural audio codec
100
+ https://arxiv.org/abs/2107.03312
101
+ """
102
+
103
+ def __init__(
104
+ self,
105
+ input_dim: int = 512,
106
+ n_codebooks: int = 9,
107
+ codebook_size: int = 1024,
108
+ codebook_dim: Union[int, list] = 8,
109
+ quantizer_dropout: float = 0.0,
110
+ ):
111
+ super().__init__()
112
+ if isinstance(codebook_dim, int):
113
+ codebook_dim = [codebook_dim for _ in range(n_codebooks)]
114
+
115
+ self.n_codebooks = n_codebooks
116
+ self.codebook_dim = codebook_dim
117
+ self.codebook_size = codebook_size
118
+
119
+ self.quantizers = nn.ModuleList(
120
+ [
121
+ VectorQuantize(input_dim, codebook_size, codebook_dim[i])
122
+ for i in range(n_codebooks)
123
+ ]
124
+ )
125
+ self.quantizer_dropout = quantizer_dropout
126
+
127
+ def forward(self, z, n_quantizers: int = None):
128
+ """Quantized the input tensor using a fixed set of `n` codebooks and returns
129
+ the corresponding codebook vectors
130
+ Parameters
131
+ ----------
132
+ z : Tensor[B x D x T]
133
+ n_quantizers : int, optional
134
+ No. of quantizers to use
135
+ (n_quantizers < self.n_codebooks ex: for quantizer dropout)
136
+ Note: if `self.quantizer_dropout` is True, this argument is ignored
137
+ when in training mode, and a random number of quantizers is used.
138
+ Returns
139
+ -------
140
+ dict
141
+ A dictionary with the following keys:
142
+
143
+ "z" : Tensor[B x D x T]
144
+ Quantized continuous representation of input
145
+ "codes" : Tensor[B x N x T]
146
+ Codebook indices for each codebook
147
+ (quantized discrete representation of input)
148
+ "latents" : Tensor[B x N*D x T]
149
+ Projected latents (continuous representation of input before quantization)
150
+ "vq/commitment_loss" : Tensor[1]
151
+ Commitment loss to train encoder to predict vectors closer to codebook
152
+ entries
153
+ "vq/codebook_loss" : Tensor[1]
154
+ Codebook loss to update the codebook
155
+ """
156
+ z_q = 0
157
+ residual = z
158
+ commitment_loss = 0
159
+ codebook_loss = 0
160
+
161
+ codebook_indices = []
162
+ latents = []
163
+
164
+ if n_quantizers is None:
165
+ n_quantizers = self.n_codebooks
166
+ if self.training:
167
+ n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
168
+ dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
169
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
170
+ n_quantizers[:n_dropout] = dropout[:n_dropout]
171
+ n_quantizers = n_quantizers.to(z.device)
172
+
173
+ for i, quantizer in enumerate(self.quantizers):
174
+ if self.training is False and i >= n_quantizers:
175
+ break
176
+
177
+ z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(
178
+ residual
179
+ )
180
+
181
+ # Create mask to apply quantizer dropout
182
+ mask = (
183
+ torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
184
+ )
185
+ z_q = z_q + z_q_i * mask[:, None, None]
186
+ residual = residual - z_q_i
187
+
188
+ # Sum losses
189
+ commitment_loss += (commitment_loss_i * mask).mean()
190
+ codebook_loss += (codebook_loss_i * mask).mean()
191
+
192
+ codebook_indices.append(indices_i)
193
+ latents.append(z_e_i)
194
+
195
+ codes = torch.stack(codebook_indices, dim=1)
196
+ latents = torch.cat(latents, dim=1)
197
+
198
+ return z_q, codes, latents, commitment_loss, codebook_loss
199
+
200
+ def from_codes(self, codes: torch.Tensor):
201
+ """Given the quantized codes, reconstruct the continuous representation
202
+ Parameters
203
+ ----------
204
+ codes : Tensor[B x N x T]
205
+ Quantized discrete representation of input
206
+ Returns
207
+ -------
208
+ Tensor[B x D x T]
209
+ Quantized continuous representation of input
210
+ """
211
+ z_q = 0.0
212
+ z_p = []
213
+ n_codebooks = codes.shape[1]
214
+ for i in range(n_codebooks):
215
+ z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
216
+ z_p.append(z_p_i)
217
+
218
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
219
+ z_q = z_q + z_q_i
220
+ return z_q, torch.cat(z_p, dim=1), codes
221
+
222
+ def from_latents(self, latents: torch.Tensor):
223
+ """Given the unquantized latents, reconstruct the
224
+ continuous representation after quantization.
225
+
226
+ Parameters
227
+ ----------
228
+ latents : Tensor[B x N x T]
229
+ Continuous representation of input after projection
230
+
231
+ Returns
232
+ -------
233
+ Tensor[B x D x T]
234
+ Quantized representation of full-projected space
235
+ Tensor[B x D x T]
236
+ Quantized representation of latent space
237
+ """
238
+ z_q = 0
239
+ z_p = []
240
+ codes = []
241
+ dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
242
+
243
+ n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[
244
+ 0
245
+ ]
246
+ for i in range(n_codebooks):
247
+ j, k = dims[i], dims[i + 1]
248
+ z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
249
+ z_p.append(z_p_i)
250
+ codes.append(codes_i)
251
+
252
+ z_q_i = self.quantizers[i].out_proj(z_p_i)
253
+ z_q = z_q + z_q_i
254
+
255
+ return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
256
+
257
+
258
+ if __name__ == "__main__":
259
+ rvq = ResidualVectorQuantize(quantizer_dropout=True)
260
+ x = torch.randn(16, 512, 80)
261
+ y = rvq(x)
262
+ print(y["latents"].shape)
hunyuanvideo_foley/models/dac_vae/nn/vae_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class AbstractDistribution:
6
+ def sample(self):
7
+ raise NotImplementedError()
8
+
9
+ def mode(self):
10
+ raise NotImplementedError()
11
+
12
+
13
+ class DiracDistribution(AbstractDistribution):
14
+ def __init__(self, value):
15
+ self.value = value
16
+
17
+ def sample(self):
18
+ return self.value
19
+
20
+ def mode(self):
21
+ return self.value
22
+
23
+
24
+ class DiagonalGaussianDistribution(object):
25
+ def __init__(self, parameters, deterministic=False):
26
+ self.parameters = parameters
27
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29
+ self.deterministic = deterministic
30
+ self.std = torch.exp(0.5 * self.logvar)
31
+ self.var = torch.exp(self.logvar)
32
+ if self.deterministic:
33
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34
+
35
+ def sample(self):
36
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37
+ return x
38
+
39
+ def kl(self, other=None):
40
+ if self.deterministic:
41
+ return torch.Tensor([0.0])
42
+ else:
43
+ if other is None:
44
+ return 0.5 * torch.mean(
45
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
46
+ dim=[1, 2],
47
+ )
48
+ else:
49
+ return 0.5 * torch.mean(
50
+ torch.pow(self.mean - other.mean, 2) / other.var
51
+ + self.var / other.var
52
+ - 1.0
53
+ - self.logvar
54
+ + other.logvar,
55
+ dim=[1, 2],
56
+ )
57
+
58
+ def nll(self, sample, dims=[1, 2]):
59
+ if self.deterministic:
60
+ return torch.Tensor([0.0])
61
+ logtwopi = np.log(2.0 * np.pi)
62
+ return 0.5 * torch.sum(
63
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
64
+ dim=dims,
65
+ )
66
+
67
+ def mode(self):
68
+ return self.mean
69
+
70
+
71
+ def normal_kl(mean1, logvar1, mean2, logvar2):
72
+ """
73
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
74
+ Compute the KL divergence between two gaussians.
75
+ Shapes are automatically broadcasted, so batches can be compared to
76
+ scalars, among other use cases.
77
+ """
78
+ tensor = None
79
+ for obj in (mean1, logvar1, mean2, logvar2):
80
+ if isinstance(obj, torch.Tensor):
81
+ tensor = obj
82
+ break
83
+ assert tensor is not None, "at least one argument must be a Tensor"
84
+
85
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
86
+ # Tensors, but it does not work for torch.exp().
87
+ logvar1, logvar2 = [x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) for x in (logvar1, logvar2)]
88
+
89
+ return 0.5 * (
90
+ -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
91
+ )
hunyuanvideo_foley/models/dac_vae/utils/__init__.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from audiotools import ml
5
+
6
+ from ..model import DAC
7
+ Accelerator = ml.Accelerator
8
+
9
+ __MODEL_LATEST_TAGS__ = {
10
+ ("44khz", "8kbps"): "0.0.1",
11
+ ("24khz", "8kbps"): "0.0.4",
12
+ ("16khz", "8kbps"): "0.0.5",
13
+ ("44khz", "16kbps"): "1.0.0",
14
+ }
15
+
16
+ __MODEL_URLS__ = {
17
+ (
18
+ "44khz",
19
+ "0.0.1",
20
+ "8kbps",
21
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth",
22
+ (
23
+ "24khz",
24
+ "0.0.4",
25
+ "8kbps",
26
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth",
27
+ (
28
+ "16khz",
29
+ "0.0.5",
30
+ "8kbps",
31
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth",
32
+ (
33
+ "44khz",
34
+ "1.0.0",
35
+ "16kbps",
36
+ ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth",
37
+ }
38
+
39
+
40
+ @argbind.bind(group="download", positional=True, without_prefix=True)
41
+ def download(
42
+ model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest"
43
+ ):
44
+ """
45
+ Function that downloads the weights file from URL if a local cache is not found.
46
+
47
+ Parameters
48
+ ----------
49
+ model_type : str
50
+ The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz".
51
+ model_bitrate: str
52
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
53
+ Only 44khz model supports 16kbps.
54
+ tag : str
55
+ The tag of the model to download. Defaults to "latest".
56
+
57
+ Returns
58
+ -------
59
+ Path
60
+ Directory path required to load model via audiotools.
61
+ """
62
+ model_type = model_type.lower()
63
+ tag = tag.lower()
64
+
65
+ assert model_type in [
66
+ "44khz",
67
+ "24khz",
68
+ "16khz",
69
+ ], "model_type must be one of '44khz', '24khz', or '16khz'"
70
+
71
+ assert model_bitrate in [
72
+ "8kbps",
73
+ "16kbps",
74
+ ], "model_bitrate must be one of '8kbps', or '16kbps'"
75
+
76
+ if tag == "latest":
77
+ tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)]
78
+
79
+ download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None)
80
+
81
+ if download_link is None:
82
+ raise ValueError(
83
+ f"Could not find model with tag {tag} and model type {model_type}"
84
+ )
85
+
86
+ local_path = (
87
+ Path.home()
88
+ / ".cache"
89
+ / "descript"
90
+ / "dac"
91
+ / f"weights_{model_type}_{model_bitrate}_{tag}.pth"
92
+ )
93
+ if not local_path.exists():
94
+ local_path.parent.mkdir(parents=True, exist_ok=True)
95
+
96
+ # Download the model
97
+ import requests
98
+
99
+ response = requests.get(download_link)
100
+
101
+ if response.status_code != 200:
102
+ raise ValueError(
103
+ f"Could not download model. Received response code {response.status_code}"
104
+ )
105
+ local_path.write_bytes(response.content)
106
+
107
+ return local_path
108
+
109
+
110
+ def load_model(
111
+ model_type: str = "44khz",
112
+ model_bitrate: str = "8kbps",
113
+ tag: str = "latest",
114
+ load_path: str = None,
115
+ ):
116
+ if not load_path:
117
+ load_path = download(
118
+ model_type=model_type, model_bitrate=model_bitrate, tag=tag
119
+ )
120
+ generator = DAC.load(load_path)
121
+ return generator
hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (3.77 kB). View file
 
hunyuanvideo_foley/models/dac_vae/utils/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (3.73 kB). View file
 
hunyuanvideo_foley/models/dac_vae/utils/decode.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from pathlib import Path
3
+
4
+ import argbind
5
+ import numpy as np
6
+ import torch
7
+ from audiotools import AudioSignal
8
+ from tqdm import tqdm
9
+
10
+ from ..model import DACFile
11
+ from . import load_model
12
+
13
+ warnings.filterwarnings("ignore", category=UserWarning)
14
+
15
+
16
+ @argbind.bind(group="decode", positional=True, without_prefix=True)
17
+ @torch.inference_mode()
18
+ @torch.no_grad()
19
+ def decode(
20
+ input: str,
21
+ output: str = "",
22
+ weights_path: str = "",
23
+ model_tag: str = "latest",
24
+ model_bitrate: str = "8kbps",
25
+ device: str = "cuda",
26
+ model_type: str = "44khz",
27
+ verbose: bool = False,
28
+ ):
29
+ """Decode audio from codes.
30
+
31
+ Parameters
32
+ ----------
33
+ input : str
34
+ Path to input directory or file
35
+ output : str, optional
36
+ Path to output directory, by default "".
37
+ If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`.
38
+ weights_path : str, optional
39
+ Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the
40
+ model_tag and model_type.
41
+ model_tag : str, optional
42
+ Tag of the model to use, by default "latest". Ignored if `weights_path` is specified.
43
+ model_bitrate: str
44
+ Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps".
45
+ device : str, optional
46
+ Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU.
47
+ model_type : str, optional
48
+ The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified.
49
+ """
50
+ generator = load_model(
51
+ model_type=model_type,
52
+ model_bitrate=model_bitrate,
53
+ tag=model_tag,
54
+ load_path=weights_path,
55
+ )
56
+ generator.to(device)
57
+ generator.eval()
58
+
59
+ # Find all .dac files in input directory
60
+ _input = Path(input)
61
+ input_files = list(_input.glob("**/*.dac"))
62
+
63
+ # If input is a .dac file, add it to the list
64
+ if _input.suffix == ".dac":
65
+ input_files.append(_input)
66
+
67
+ # Create output directory
68
+ output = Path(output)
69
+ output.mkdir(parents=True, exist_ok=True)
70
+
71
+ for i in tqdm(range(len(input_files)), desc=f"Decoding files"):
72
+ # Load file
73
+ artifact = DACFile.load(input_files[i])
74
+
75
+ # Reconstruct audio from codes
76
+ recons = generator.decompress(artifact, verbose=verbose)
77
+
78
+ # Compute output path
79
+ relative_path = input_files[i].relative_to(input)
80
+ output_dir = output / relative_path.parent
81
+ if not relative_path.name:
82
+ output_dir = output
83
+ relative_path = input_files[i]
84
+ output_name = relative_path.with_suffix(".wav").name
85
+ output_path = output_dir / output_name
86
+ output_path.parent.mkdir(parents=True, exist_ok=True)
87
+
88
+ # Write to file
89
+ recons.write(output_path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ args = argbind.parse_args()
94
+ with argbind.scope(args):
95
+ decode()