Spaces:
Running
on
A100
Running
on
A100
Commit
·
01c0e76
0
Parent(s):
Initial commit with LFS-tracked binary files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .claude/settings.local.json +19 -0
- .gitattributes +26 -0
- .gitignore +1 -0
- CLAUDE.md +137 -0
- LICENSE +77 -0
- Notice.txt +100 -0
- README.md +296 -0
- app.py +360 -0
- asset/method.png +3 -0
- asset/teaser.png +3 -0
- asset/village.png +3 -0
- docs_for_ai_coding_bots/.DS_Store +0 -0
- docs_for_ai_coding_bots/huggingface_hub/Downloading-model-from-hub.md +174 -0
- docs_for_ai_coding_bots/huggingface_hub/Using-the-cache-in-hf-hub-library.md +531 -0
- hymm_sp/__init__.py +0 -0
- hymm_sp/config.py +160 -0
- hymm_sp/constants.py +58 -0
- hymm_sp/data_kits/data_tools.py +115 -0
- hymm_sp/data_kits/video_dataset.py +259 -0
- hymm_sp/diffusion/__init__.py +30 -0
- hymm_sp/diffusion/pipelines/__init__.py +5 -0
- hymm_sp/diffusion/pipelines/pipeline_hunyuan_video_game.py +1152 -0
- hymm_sp/diffusion/schedulers/__init__.py +2 -0
- hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py +240 -0
- hymm_sp/helpers.py +194 -0
- hymm_sp/inference.py +201 -0
- hymm_sp/modules/__init__.py +38 -0
- hymm_sp/modules/activation_layers.py +23 -0
- hymm_sp/modules/attn_layers.py +437 -0
- hymm_sp/modules/cameranet.py +248 -0
- hymm_sp/modules/embed_layers.py +146 -0
- hymm_sp/modules/fp8_optimization.py +246 -0
- hymm_sp/modules/mlp_layers.py +97 -0
- hymm_sp/modules/models.py +697 -0
- hymm_sp/modules/modulate_layers.py +76 -0
- hymm_sp/modules/norm_layers.py +77 -0
- hymm_sp/modules/parallel_states.py +381 -0
- hymm_sp/modules/posemb_layers.py +112 -0
- hymm_sp/modules/token_refiner.py +265 -0
- hymm_sp/sample_batch.py +298 -0
- hymm_sp/sample_inference.py +716 -0
- hymm_sp/text_encoder/__init__.py +310 -0
- hymm_sp/vae/__init__.py +79 -0
- hymm_sp/vae/autoencoder_kl_causal_3d.py +781 -0
- hymm_sp/vae/unet_causal_3d_blocks.py +900 -0
- hymm_sp/vae/vae.py +433 -0
- requirements.txt +60 -0
- scripts/run_sample_batch_4090.sh +35 -0
- scripts/run_sample_batch_distill.sh +24 -0
- scripts/run_sample_batch_sp.sh +24 -0
.claude/settings.local.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"permissions": {
|
3 |
+
"allow": [
|
4 |
+
"Bash(git remote set-url:*)",
|
5 |
+
"Bash(git lfs track:*)",
|
6 |
+
"Bash(git add:*)",
|
7 |
+
"Bash(git commit:*)",
|
8 |
+
"Bash(git push:*)",
|
9 |
+
"Bash(git rm:*)",
|
10 |
+
"Bash(git lfs:*)",
|
11 |
+
"Bash(git gc:*)",
|
12 |
+
"Bash(GIT_TRACE=1 git push origin main -f)",
|
13 |
+
"Bash(git rev-list:*)",
|
14 |
+
"Bash(git checkout:*)"
|
15 |
+
],
|
16 |
+
"deny": [],
|
17 |
+
"ask": []
|
18 |
+
}
|
19 |
+
}
|
.gitattributes
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.avi filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.hdf5 filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.mov filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.bmp filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tar.gz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.webm filter=lfs diff=lfs merge=lfs -text
|
25 |
+
asset/teaser.png filter=lfs diff=lfs merge=lfs -text
|
26 |
+
asset/*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.DS_Store
|
CLAUDE.md
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CLAUDE.md
|
2 |
+
|
3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
4 |
+
|
5 |
+
## Project Overview
|
6 |
+
|
7 |
+
Hunyuan-GameCraft is a high-dynamic interactive game video generation system that creates gameplay videos with controllable camera movements and actions. The system uses diffusion models and action-controlled generation to synthesize realistic game footage from reference images and keyboard/mouse input controls.
|
8 |
+
|
9 |
+
## Key Commands
|
10 |
+
|
11 |
+
### Installation
|
12 |
+
```bash
|
13 |
+
# Create and activate conda environment
|
14 |
+
conda create -n HYGameCraft python==3.10
|
15 |
+
conda activate HYGameCraft
|
16 |
+
|
17 |
+
# Install PyTorch and dependencies
|
18 |
+
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
|
19 |
+
|
20 |
+
# Install requirements
|
21 |
+
python -m pip install -r requirements.txt
|
22 |
+
|
23 |
+
# Install flash attention (optional, for acceleration)
|
24 |
+
python -m pip install ninja
|
25 |
+
python -m pip install git+https://github.com/Dao-AILab/[email protected]
|
26 |
+
```
|
27 |
+
|
28 |
+
### Download Models
|
29 |
+
```bash
|
30 |
+
cd weights
|
31 |
+
huggingface-cli download tencent/Hunyuan-GameCraft-1.0 --local-dir ./
|
32 |
+
```
|
33 |
+
|
34 |
+
### Run Inference
|
35 |
+
|
36 |
+
**Multi-GPU (8 GPUs) - Standard Model:**
|
37 |
+
```bash
|
38 |
+
torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
|
39 |
+
--image-path "asset/village.png" \
|
40 |
+
--prompt "YOUR_PROMPT" \
|
41 |
+
--ckpt weights/gamecraft_models/mp_rank_00_model_states.pt \
|
42 |
+
--video-size 704 1216 \
|
43 |
+
--cfg-scale 2.0 \
|
44 |
+
--image-start \
|
45 |
+
--action-list w s d a \
|
46 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
47 |
+
--seed 250160 \
|
48 |
+
--infer-steps 50 \
|
49 |
+
--save-path './results/'
|
50 |
+
```
|
51 |
+
|
52 |
+
**Single GPU with Low VRAM (24GB minimum):**
|
53 |
+
```bash
|
54 |
+
export DISABLE_SP=1
|
55 |
+
export CPU_OFFLOAD=1
|
56 |
+
torchrun --nnodes=1 --nproc_per_node=1 --master_port 29605 hymm_sp/sample_batch.py \
|
57 |
+
--ckpt weights/gamecraft_models/mp_rank_00_model_states.pt \
|
58 |
+
--cpu-offload \
|
59 |
+
--use-fp8 \
|
60 |
+
[other parameters...]
|
61 |
+
```
|
62 |
+
|
63 |
+
**Distilled Model (faster, 8 inference steps):**
|
64 |
+
```bash
|
65 |
+
torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
|
66 |
+
--ckpt weights/gamecraft_models/mp_rank_00_model_states_distill.pt \
|
67 |
+
--cfg-scale 1.0 \
|
68 |
+
--infer-steps 8 \
|
69 |
+
--use-fp8 \
|
70 |
+
[other parameters...]
|
71 |
+
```
|
72 |
+
|
73 |
+
## Architecture Overview
|
74 |
+
|
75 |
+
### Core Components
|
76 |
+
|
77 |
+
1. **Main Entry Points**
|
78 |
+
- `hymm_sp/sample_batch.py`: Main script for batch video generation with distributed processing
|
79 |
+
- `hymm_sp/sample_inference.py`: Core inference logic and model sampling
|
80 |
+
- `hymm_sp/config.py`: Configuration parsing and argument handling
|
81 |
+
|
82 |
+
2. **Model Architecture (`hymm_sp/modules/`)**
|
83 |
+
- `models.py`: Core diffusion model implementation
|
84 |
+
- `cameranet.py`: Camera control and action encoding for game interactions
|
85 |
+
- `token_refiner.py`: Text token refinement for prompt conditioning
|
86 |
+
- `parallel_states.py`: Distributed training/inference state management
|
87 |
+
- `fp8_optimization.py`: FP8 quantization for memory/speed optimization
|
88 |
+
|
89 |
+
3. **VAE Module (`hymm_sp/vae/`)**
|
90 |
+
- `autoencoder_kl_causal_3d.py`: 3D causal VAE for video encoding/decoding
|
91 |
+
- Handles latent space conversion for video frames
|
92 |
+
|
93 |
+
4. **Diffusion Pipeline (`hymm_sp/diffusion/`)**
|
94 |
+
- `pipeline_hunyuan_video_game.py`: Custom pipeline for game video generation
|
95 |
+
- `scheduling_flow_match_discrete.py`: Flow matching scheduler for denoising
|
96 |
+
|
97 |
+
5. **Data Processing (`hymm_sp/data_kits/`)**
|
98 |
+
- `video_dataset.py`: Dataset handling for video inputs
|
99 |
+
- `data_tools.py`: Video saving and processing utilities
|
100 |
+
|
101 |
+
### Key Features
|
102 |
+
|
103 |
+
- **Action Control**: Maps keyboard inputs (w/a/s/d) to continuous camera space for smooth transitions
|
104 |
+
- **Hybrid History Conditioning**: Extends video sequences autoregressively while preserving scene context
|
105 |
+
- **Model Distillation**: Accelerated inference model (8 steps vs 50 steps)
|
106 |
+
- **Memory Optimization**: FP8 quantization, CPU offloading, and SageAttention support
|
107 |
+
- **Distributed Processing**: Multi-GPU support with sequence parallelism
|
108 |
+
|
109 |
+
### Important Parameters
|
110 |
+
|
111 |
+
- `--action-list`: Sequence of keyboard actions (w/a/s/d)
|
112 |
+
- `--action-speed-list`: Movement speed for each action (0.0-3.0)
|
113 |
+
- `--video-size`: Output resolution (height width)
|
114 |
+
- `--cfg-scale`: Classifier-free guidance scale (1.0 for distilled, 2.0 for standard)
|
115 |
+
- `--infer-steps`: Denoising steps (8 for distilled, 50 for standard)
|
116 |
+
- `--use-fp8`: Enable FP8 optimization for memory reduction
|
117 |
+
- `--cpu-offload`: Offload model to CPU for low VRAM scenarios
|
118 |
+
|
119 |
+
### Model Weights Structure
|
120 |
+
```
|
121 |
+
weights/
|
122 |
+
├── gamecraft_models/
|
123 |
+
│ ├── mp_rank_00_model_states.pt # Standard model
|
124 |
+
│ └── mp_rank_00_model_states_distill.pt # Distilled model
|
125 |
+
└── stdmodels/
|
126 |
+
├── vae_3d/ # 3D VAE model
|
127 |
+
├── llava-llama-3-8b-v1_1-transformers/ # Text encoder
|
128 |
+
└── openai_clip-vit-large-patch14/ # CLIP encoder
|
129 |
+
```
|
130 |
+
|
131 |
+
## Development Notes
|
132 |
+
|
133 |
+
- Environment variable `MODEL_BASE` should point to `weights/stdmodels`
|
134 |
+
- Use `export DISABLE_SP=1` and `export CPU_OFFLOAD=1` for single GPU inference
|
135 |
+
- Minimum GPU memory: 24GB (very slow), Recommended: 80GB per GPU
|
136 |
+
- Action length determines video duration (1 action = 33 frames at 25 FPS)
|
137 |
+
- SageAttention can be installed for additional acceleration
|
LICENSE
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
2 |
+
Tencent Hunyuan-GameCraft Release Date: August 14, 2025
|
3 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
4 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
5 |
+
1. DEFINITIONS.
|
6 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
7 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
8 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
9 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
10 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
11 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
12 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
13 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
14 |
+
i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
|
15 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan-GameCraft released at [https://github.com/Tencent-Hunyuan/Hunyuan-GameCraft-1.0].
|
16 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
17 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
18 |
+
m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
19 |
+
n. “including” shall mean including but not limited to.
|
20 |
+
2. GRANT OF RIGHTS.
|
21 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
22 |
+
3. DISTRIBUTION.
|
23 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
24 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
25 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
26 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
27 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
28 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
29 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
30 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
31 |
+
5. RULES OF USE.
|
32 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
33 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
34 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
35 |
+
6. INTELLECTUAL PROPERTY.
|
36 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
37 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
38 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
39 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
40 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
41 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
42 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
43 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
44 |
+
8. SURVIVAL AND TERMINATION.
|
45 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
46 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
47 |
+
9. GOVERNING LAW AND JURISDICTION.
|
48 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
49 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
50 |
+
|
51 |
+
EXHIBIT A
|
52 |
+
ACCEPTABLE USE POLICY
|
53 |
+
|
54 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
55 |
+
Last modified: November 5, 2024
|
56 |
+
|
57 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
58 |
+
1. Outside the Territory;
|
59 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
60 |
+
3. To harm Yourself or others;
|
61 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
62 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
63 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
64 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
65 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
66 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
67 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
68 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
69 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
70 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
71 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
72 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
73 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
74 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
75 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
76 |
+
19. For military purposes;
|
77 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
Notice.txt
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Usage and Legal Notices:
|
2 |
+
|
3 |
+
Tencent is pleased to support the open source community by making Tencent Hunyuan-GameCraft available.
|
4 |
+
|
5 |
+
Copyright (C) 2025 Tencent. All rights reserved. The below softwares in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
|
6 |
+
Tencent Hunyuan-GameCraft is licensed under Tencent Hunyuan Community License Agreement, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent Hunyuan-GameCraft does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
7 |
+
|
8 |
+
|
9 |
+
Other dependencies and licenses:
|
10 |
+
|
11 |
+
|
12 |
+
Open Source Software Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT and Other Licenses of the Third-Party Components therein:
|
13 |
+
The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
14 |
+
--------------------------------------------------------------------
|
15 |
+
1. HunyuanVideo
|
16 |
+
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
17 |
+
|
18 |
+
|
19 |
+
Terms of the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT:
|
20 |
+
--------------------------------------------------------------------
|
21 |
+
TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
|
22 |
+
Tencent HunyuanVideo Release Date: December 3, 2024
|
23 |
+
THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
|
24 |
+
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
25 |
+
1. DEFINITIONS.
|
26 |
+
a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
27 |
+
b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
28 |
+
c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
29 |
+
d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
30 |
+
e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
31 |
+
f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
32 |
+
g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
33 |
+
h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
34 |
+
i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
|
35 |
+
j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo released at [https://github.com/Tencent/HunyuanVideo].
|
36 |
+
k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
37 |
+
l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
|
38 |
+
m. “Third Party�� or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
39 |
+
n. “including” shall mean including but not limited to.
|
40 |
+
2. GRANT OF RIGHTS.
|
41 |
+
We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
42 |
+
3. DISTRIBUTION.
|
43 |
+
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
|
44 |
+
a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
45 |
+
b. You must cause any modified files to carry prominent notices stating that You changed the files;
|
46 |
+
c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
47 |
+
d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
48 |
+
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
49 |
+
4. ADDITIONAL COMMERCIAL TERMS.
|
50 |
+
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
51 |
+
5. RULES OF USE.
|
52 |
+
a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
53 |
+
b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
|
54 |
+
c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
|
55 |
+
6. INTELLECTUAL PROPERTY.
|
56 |
+
a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
57 |
+
b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
58 |
+
c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
59 |
+
d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
60 |
+
7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
61 |
+
a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
62 |
+
b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
63 |
+
c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
64 |
+
8. SURVIVAL AND TERMINATION.
|
65 |
+
a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
66 |
+
b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
67 |
+
9. GOVERNING LAW AND JURISDICTION.
|
68 |
+
a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
69 |
+
b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
70 |
+
|
71 |
+
EXHIBIT A
|
72 |
+
ACCEPTABLE USE POLICY
|
73 |
+
|
74 |
+
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
75 |
+
Last modified: November 5, 2024
|
76 |
+
|
77 |
+
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
78 |
+
1. Outside the Territory;
|
79 |
+
2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
80 |
+
3. To harm Yourself or others;
|
81 |
+
4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
82 |
+
5. To override or circumvent the safety guardrails and safeguards We have put in place;
|
83 |
+
6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
84 |
+
7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
85 |
+
8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
86 |
+
9. To intentionally defame, disparage or otherwise harass others;
|
87 |
+
10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
88 |
+
11. To generate or disseminate personal identifiable information with the purpose of harming others;
|
89 |
+
12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
90 |
+
13. To impersonate another individual without consent, authorization, or legal right;
|
91 |
+
14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
92 |
+
15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
93 |
+
16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
94 |
+
17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
95 |
+
18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
96 |
+
19. For military purposes;
|
97 |
+
20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
98 |
+
|
99 |
+
For the license of other third party components, please refer to the following URL:
|
100 |
+
https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/ff2dd59277b3177785d8279d4170968afa3b1d55/Notice
|
README.md
ADDED
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Hunyuan-GameCraft
|
3 |
+
emoji: 🎮
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.42.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
license: mit
|
11 |
+
short_description: Interactive Game Video Generation
|
12 |
+
---
|
13 |
+
|
14 |
+
<!-- ## **Hunyuan-GameCraft** -->
|
15 |
+
|
16 |
+
<!-- <p align="center">
|
17 |
+
<img src="assets/material/logo.png" height=100>
|
18 |
+
</p> -->
|
19 |
+
|
20 |
+
# **Hunyuan-GameCraft** 🎮
|
21 |
+
|
22 |
+
<div align="center">
|
23 |
+
<a href="https://github.com/Tencent-Hunyuan/Hunyuan-GameCraft-1.0"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue"></a>  
|
24 |
+
<a href="https://hunyuan-gamecraft.github.io/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Web&color=green"></a>  
|
25 |
+
<a href="https://arxiv.org/abs/2506.17201"><img src="https://img.shields.io/badge/ArXiv-2506.17201-red"></a>  
|
26 |
+
<a href="https://huggingface.co/tencent/Hunyuan-GameCraft-1.0"><img src="https://img.shields.io/static/v1?label=Huggingface&message=Hunyuan-GameCraft-1.0&color=yellow"></a>  
|
27 |
+
</div>
|
28 |
+
|
29 |
+

|
30 |
+
|
31 |
+
> [**Hunyuan-GameCraft: High-dynamic Interactive Game Video Generation with Hybrid History Condition**](https://arxiv.org/abs/2506.17201) <be>
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
## 🔥🔥🔥 News!!
|
36 |
+
* Aug 14, 2025: 👋 We release the inference code and model weights of Hunyuan-GameCraft. [Download](weights/README.md).
|
37 |
+
|
38 |
+
|
39 |
+
## 📑 Open-source Plan
|
40 |
+
|
41 |
+
- Hunyuan-GameCraft
|
42 |
+
- [x] Inference
|
43 |
+
- [x] Checkpoints
|
44 |
+
- [ ] Gradio & Huggingface Demo
|
45 |
+
|
46 |
+
## Contents
|
47 |
+
- [**Hunyuan-GameCraft** 🌅](#Hunyuan-GameCraft-)
|
48 |
+
- [🔥🔥🔥 News!!](#-news)
|
49 |
+
- [📑 Open-source Plan](#-open-source-plan)
|
50 |
+
- [Contents](#contents)
|
51 |
+
- [**Abstract**](#abstract)
|
52 |
+
- [**Overall Architecture**](#-overall-architecture)
|
53 |
+
- [📜 Requirements](#-requirements)
|
54 |
+
- [🛠️ Dependencies and Installation](#️-dependencies-and-installation)
|
55 |
+
- [Installation Guide for Linux](#installation-guide-for-linux)
|
56 |
+
- [🧱 Download Pretrained Models](#-download-pretrained-models)
|
57 |
+
- [🚀 Parallel Inference on Multiple GPUs](#-parallel-inference-on-multiple-gpus)
|
58 |
+
- [🔑 Single-gpu Inference](#-single-gpu-inference)
|
59 |
+
- [Run with very low VRAM](#run-with-very-low-vram)
|
60 |
+
- [Run a Gradio Server](#run-a-gradio-server)
|
61 |
+
- [🔗 BibTeX](#-bibtex)
|
62 |
+
- [Acknowledgements](#acknowledgements)
|
63 |
+
---
|
64 |
+
|
65 |
+
## **Abstract**
|
66 |
+
|
67 |
+
Recent advances in diffusion-based and controllable video generation have enabled high-quality and temporally coherent video synthesis, laying the groundwork for immersive interactive gaming experiences. However, current methods face limitations in **dynamics**, **physically realistic**, **long-term consistency**, and **efficiency**, which limit the ability to create various gameplay videos. To address these gaps, we introduce Hunyuan-GameCraft, a novel framework for high-dynamic interactive video generation in game environments. To achieve fine-grained action control, we unify standard keyboard and mouse inputs into a **shared camera representation space**, facilitating smooth interpolation between various camera and movement operations. Then we propose a **hybrid history-conditioned training strategy** that extends video sequences autoregressively while preserving game scene information. Additionally, to enhance inference efficiency and playability, we achieve **model distillation** to reduce computational overhead while maintaining consistency across long temporal sequences, making it suitable for real-time deployment in complex interactive environments. The model is trained on a large-scale dataset comprising over one million gameplay recordings across over 100 AAA games, ensuring broad coverage and diversity, then fine-tuned on a carefully annotated synthetic dataset to enhance precision and control. The curated game scene data significantly improves the visual fidelity, realism and action controllability. Extensive experiments demonstrate that Hunyuan-GameCraft significantly outperforms existing models, advancing the realism and playability of interactive game video generation.
|
68 |
+
|
69 |
+
## **Overall Architecture**
|
70 |
+
|
71 |
+

|
72 |
+
|
73 |
+
Given a reference image and the corresponding prompt, the keyboard or mouse signal, we transform these options to the continuous camera space. Then we design a light-weight action encoder to encode the input camera trajectory. The action and image features are added after patchify. For long video extension, we design a variable mask indicator, where 1 and 0 indicate history frames and predicted frames, respectively.
|
74 |
+
|
75 |
+
|
76 |
+
## 📜 Requirements
|
77 |
+
|
78 |
+
* An NVIDIA GPU with CUDA support is required.
|
79 |
+
* The model is tested on a machine with 8*H20/H800GPUs.
|
80 |
+
* **Minimum**: The minimum GPU memory required is 24GB but very slow.
|
81 |
+
* **Recommended**: We recommend using a GPU with 80GB of memory for better generation quality.
|
82 |
+
* Tested operating system: Linux
|
83 |
+
|
84 |
+
|
85 |
+
## 🛠️ Dependencies and Installation
|
86 |
+
|
87 |
+
Begin by cloning the repository:
|
88 |
+
```shell
|
89 |
+
git clone https://github.com/Tencent-Hunyuan/Hunyuan-GameCraft-1.0.git
|
90 |
+
cd Hunyuan-GameCraft-1.0
|
91 |
+
```
|
92 |
+
|
93 |
+
### Installation Guide for Linux
|
94 |
+
|
95 |
+
We recommend CUDA versions 12.4 for the manual installation.
|
96 |
+
|
97 |
+
Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
|
98 |
+
|
99 |
+
```shell
|
100 |
+
# 1. Create conda environment
|
101 |
+
conda create -n HYGameCraft python==3.10
|
102 |
+
|
103 |
+
# 2. Activate the environment
|
104 |
+
conda activate HYGameCraft
|
105 |
+
|
106 |
+
# 3. Install PyTorch and other dependencies using conda
|
107 |
+
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
|
108 |
+
|
109 |
+
# 4. Install pip dependencies
|
110 |
+
python -m pip install -r requirements.txt
|
111 |
+
# 5. Install flash attention v2 for acceleration (requires CUDA 11.8 or above)
|
112 |
+
python -m pip install ninja
|
113 |
+
python -m pip install git+https://github.com/Dao-AILab/[email protected]
|
114 |
+
```
|
115 |
+
|
116 |
+
Additionally, you can also use HunyuanVideo Docker image. Use the following command to pull and run the docker image.
|
117 |
+
|
118 |
+
```shell
|
119 |
+
# For CUDA 12.4 (updated to avoid float point exception)
|
120 |
+
docker pull hunyuanvideo/hunyuanvideo:cuda_12
|
121 |
+
docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
|
122 |
+
pip install diffusers==0.34.0 transformers==4.54.1
|
123 |
+
|
124 |
+
```
|
125 |
+
|
126 |
+
|
127 |
+
## 🧱 Download Pretrained Models
|
128 |
+
|
129 |
+
The details of download pretrained models are shown [here](weights/README.md).
|
130 |
+
|
131 |
+
## 🚀 Parallel Inference on Multiple GPUs
|
132 |
+
|
133 |
+
For example, to generate a video using 8 GPUs, you can use the following command, where `--action-list w s d a` simulate keyboard manipulation signals to help you generate a video of the corresponding content. `--action-speed-list 0.2 0.2 0.2 0.2` represents the displacement distance and can be replaced with any value between 0 and 3.
|
134 |
+
|
135 |
+
You can try any combination and any length of the action list (one action per 33 frames, 25FPS) to generate a long video, and make sure the length of `--action-speed-list` must be the same as `--action-list`. It should be noticed that the inference time is linearly related to the action length:
|
136 |
+
|
137 |
+
```bash
|
138 |
+
#!/bin/bash
|
139 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
140 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
141 |
+
export MODEL_BASE="weights/stdmodels"
|
142 |
+
checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states.pt"
|
143 |
+
|
144 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
145 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
146 |
+
|
147 |
+
torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
|
148 |
+
--image-path "asset/village.png" \
|
149 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
150 |
+
--add-pos-prompt "Realistic, High-quality." \
|
151 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
152 |
+
--ckpt ${checkpoint_path} \
|
153 |
+
--video-size 704 1216 \
|
154 |
+
--cfg-scale 2.0 \
|
155 |
+
--image-start \
|
156 |
+
--action-list w s d a \
|
157 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
158 |
+
--seed 250160 \
|
159 |
+
--infer-steps 50 \
|
160 |
+
--flow-shift-eval-video 5.0 \
|
161 |
+
--save-path './results/'
|
162 |
+
|
163 |
+
```
|
164 |
+
|
165 |
+
|
166 |
+
Additionally, we support FP8 optimization and [SageAttn](https://github.com/thu-ml/SageAttention). To enable FP8, simply add the `--use-fp8` to your command.
|
167 |
+
And install SageAttention with:
|
168 |
+
```bash
|
169 |
+
git clone https://github.com/thu-ml/SageAttention.git
|
170 |
+
cd SageAttention
|
171 |
+
python setup.py install # or pip install -e .
|
172 |
+
```
|
173 |
+
|
174 |
+
We also provide an accelerated model, you can use the following command:
|
175 |
+
```bash
|
176 |
+
#!/bin/bash
|
177 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
178 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
179 |
+
export MODEL_BASE="weights/stdmodels"
|
180 |
+
checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
|
181 |
+
|
182 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
183 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
184 |
+
|
185 |
+
torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
|
186 |
+
--image-path "asset/village.png" \
|
187 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
188 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
189 |
+
--ckpt ${checkpoint_path} \
|
190 |
+
--video-size 704 1216 \
|
191 |
+
--cfg-scale 1.0 \
|
192 |
+
--image-start \
|
193 |
+
--action-list w s d a \
|
194 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
195 |
+
--seed 250160 \
|
196 |
+
--infer-steps 8 \
|
197 |
+
--use-fp8 \
|
198 |
+
--flow-shift-eval-video 5.0 \
|
199 |
+
--save-path './results_distill/'
|
200 |
+
```
|
201 |
+
|
202 |
+
|
203 |
+
## 🔑 Single-gpu with Low-VRAM Inference
|
204 |
+
|
205 |
+
For example, to generate a video with 1 GPU with Low-VRAM (minimum GPU memory required is 24GB for 704px1216p but very slow), you can use the following command:
|
206 |
+
|
207 |
+
```bash
|
208 |
+
#!/bin/bash
|
209 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
210 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
211 |
+
export MODEL_BASE="weights/stdmodels"
|
212 |
+
checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states.pt"
|
213 |
+
|
214 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
215 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
216 |
+
|
217 |
+
# disable sp and cpu offload
|
218 |
+
export DISABLE_SP=1
|
219 |
+
export CPU_OFFLOAD=1
|
220 |
+
|
221 |
+
torchrun --nnodes=1 --nproc_per_node=1 --master_port 29605 hymm_sp/sample_batch.py \
|
222 |
+
--image-path "asset/village.png" \
|
223 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
224 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
225 |
+
--ckpt ${checkpoint_path} \
|
226 |
+
--video-size 704 1216 \
|
227 |
+
--cfg-scale 2.0 \
|
228 |
+
--image-start \
|
229 |
+
--action-list w a d s \
|
230 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
231 |
+
--seed 250160 \
|
232 |
+
--sample-n-frames 33 \
|
233 |
+
--infer-steps 50 \
|
234 |
+
--flow-shift-eval-video 5.0 \
|
235 |
+
--cpu-offload \
|
236 |
+
--use-fp8 \
|
237 |
+
--save-path './results_poor/'
|
238 |
+
|
239 |
+
```
|
240 |
+
|
241 |
+
As for using the accelerated model, you can use the following command:
|
242 |
+
|
243 |
+
```bash
|
244 |
+
#!/bin/bash
|
245 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
246 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
247 |
+
export MODEL_BASE="weights/stdmodels"
|
248 |
+
checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
|
249 |
+
|
250 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
251 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
252 |
+
|
253 |
+
# disable sp and cpu offload
|
254 |
+
export DISABLE_SP=1
|
255 |
+
export CPU_OFFLOAD=1
|
256 |
+
|
257 |
+
torchrun --nnodes=1 --nproc_per_node=1 --master_port 29605 hymm_sp/sample_batch.py \
|
258 |
+
--image-path "asset/village.png" \
|
259 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
260 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
261 |
+
--ckpt ${checkpoint_path} \
|
262 |
+
--video-size 704 1216 \
|
263 |
+
--cfg-scale 1.0 \
|
264 |
+
--image-start \
|
265 |
+
--action-list w a d s \
|
266 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
267 |
+
--seed 250160 \
|
268 |
+
--sample-n-frames 33 \
|
269 |
+
--infer-steps 8 \
|
270 |
+
--flow-shift-eval-video 5.0 \
|
271 |
+
--cpu-offload \
|
272 |
+
--use-fp8 \
|
273 |
+
--save-path './results_distill_poor/'
|
274 |
+
```
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
## 🔗 BibTeX
|
279 |
+
|
280 |
+
If you find [Hunyuan-GameCraft](https://arxiv.org/abs/2506.17201) useful for your research and applications, please cite using this BibTeX:
|
281 |
+
|
282 |
+
```BibTeX
|
283 |
+
@misc{li2025hunyuangamecrafthighdynamicinteractivegame,
|
284 |
+
title={Hunyuan-GameCraft: High-dynamic Interactive Game Video Generation with Hybrid History Condition},
|
285 |
+
author={Jiaqi Li and Junshu Tang and Zhiyong Xu and Longhuang Wu and Yuan Zhou and Shuai Shao and Tianbao Yu and Zhiguo Cao and Qinglin Lu},
|
286 |
+
year={2025},
|
287 |
+
eprint={2506.17201},
|
288 |
+
archivePrefix={arXiv},
|
289 |
+
primaryClass={cs.CV},
|
290 |
+
url={https://arxiv.org/abs/2506.17201},
|
291 |
+
}
|
292 |
+
```
|
293 |
+
|
294 |
+
## Acknowledgements
|
295 |
+
|
296 |
+
We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-Avatar](https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar),[SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
|
app.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import gradio as gr
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
from pathlib import Path
|
7 |
+
from PIL import Image
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from loguru import logger
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
import tempfile
|
12 |
+
|
13 |
+
from hymm_sp.sample_inference import HunyuanVideoSampler
|
14 |
+
from hymm_sp.data_kits.data_tools import save_videos_grid
|
15 |
+
from hymm_sp.config import parse_args
|
16 |
+
import argparse
|
17 |
+
|
18 |
+
os.environ["MODEL_BASE"] = "weights/stdmodels"
|
19 |
+
os.environ["DISABLE_SP"] = "1"
|
20 |
+
os.environ["CPU_OFFLOAD"] = "1"
|
21 |
+
|
22 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
23 |
+
|
24 |
+
class CropResize:
|
25 |
+
def __init__(self, size=(704, 1216)):
|
26 |
+
self.target_h, self.target_w = size
|
27 |
+
|
28 |
+
def __call__(self, img):
|
29 |
+
w, h = img.size
|
30 |
+
scale = max(
|
31 |
+
self.target_w / w,
|
32 |
+
self.target_h / h
|
33 |
+
)
|
34 |
+
new_size = (int(h * scale), int(w * scale))
|
35 |
+
resize_transform = transforms.Resize(
|
36 |
+
new_size,
|
37 |
+
interpolation=transforms.InterpolationMode.BILINEAR
|
38 |
+
)
|
39 |
+
resized_img = resize_transform(img)
|
40 |
+
crop_transform = transforms.CenterCrop((self.target_h, self.target_w))
|
41 |
+
return crop_transform(resized_img)
|
42 |
+
|
43 |
+
def create_args():
|
44 |
+
args = argparse.Namespace()
|
45 |
+
args.ckpt = "weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
|
46 |
+
args.video_size = [704, 1216]
|
47 |
+
args.cfg_scale = 1.0
|
48 |
+
args.image_start = True
|
49 |
+
args.seed = None
|
50 |
+
args.infer_steps = 8
|
51 |
+
args.use_fp8 = True
|
52 |
+
args.flow_shift_eval_video = 5.0
|
53 |
+
args.sample_n_frames = 33
|
54 |
+
args.num_images = 1
|
55 |
+
args.use_linear_quadratic_schedule = False
|
56 |
+
args.linear_schedule_end = 0.25
|
57 |
+
args.use_deepcache = False
|
58 |
+
args.cpu_offload = True
|
59 |
+
args.use_sage = True
|
60 |
+
args.save_path = './results/'
|
61 |
+
args.save_path_suffix = ''
|
62 |
+
args.add_pos_prompt = "Realistic, High-quality."
|
63 |
+
args.add_neg_prompt = "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border."
|
64 |
+
return args
|
65 |
+
|
66 |
+
logger.info("Initializing Hunyuan-GameCraft model...")
|
67 |
+
|
68 |
+
if not os.path.exists("weights/gamecraft_models/mp_rank_00_model_states_distill.pt"):
|
69 |
+
logger.info("Downloading model weights from Hugging Face...")
|
70 |
+
os.makedirs("weights/gamecraft_models", exist_ok=True)
|
71 |
+
hf_hub_download(
|
72 |
+
repo_id="tencent/Hunyuan-GameCraft-1.0",
|
73 |
+
filename="gamecraft_models/mp_rank_00_model_states_distill.pt",
|
74 |
+
local_dir="weights/",
|
75 |
+
local_dir_use_symlinks=False
|
76 |
+
)
|
77 |
+
|
78 |
+
args = create_args()
|
79 |
+
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(
|
80 |
+
args.ckpt,
|
81 |
+
args=args,
|
82 |
+
device=torch.device("cpu")
|
83 |
+
)
|
84 |
+
args = hunyuan_video_sampler.args
|
85 |
+
|
86 |
+
if args.cpu_offload:
|
87 |
+
from diffusers.hooks import apply_group_offloading
|
88 |
+
onload_device = torch.device("cuda")
|
89 |
+
apply_group_offloading(
|
90 |
+
hunyuan_video_sampler.pipeline.transformer,
|
91 |
+
onload_device=onload_device,
|
92 |
+
offload_type="block_level",
|
93 |
+
num_blocks_per_group=1
|
94 |
+
)
|
95 |
+
logger.info("Enabled CPU offloading for transformer blocks")
|
96 |
+
|
97 |
+
logger.info("Model loaded successfully!")
|
98 |
+
|
99 |
+
def generate_video(
|
100 |
+
input_image,
|
101 |
+
prompt,
|
102 |
+
action_sequence,
|
103 |
+
action_speeds,
|
104 |
+
negative_prompt,
|
105 |
+
seed,
|
106 |
+
cfg_scale,
|
107 |
+
num_inference_steps,
|
108 |
+
progress=gr.Progress(track_tqdm=True)
|
109 |
+
):
|
110 |
+
try:
|
111 |
+
progress(0, desc="Initializing...")
|
112 |
+
|
113 |
+
if input_image is None:
|
114 |
+
return None, "Please upload an image first!"
|
115 |
+
|
116 |
+
action_list = action_sequence.lower().replace(" ", "").split(",") if action_sequence else ["w"]
|
117 |
+
speed_list = [float(s.strip()) for s in action_speeds.split(",")] if action_speeds else [0.2]
|
118 |
+
|
119 |
+
if len(speed_list) != len(action_list):
|
120 |
+
if len(speed_list) == 1:
|
121 |
+
speed_list = speed_list * len(action_list)
|
122 |
+
else:
|
123 |
+
return None, f"Error: Number of speeds ({len(speed_list)}) must match number of actions ({len(action_list)})"
|
124 |
+
|
125 |
+
for action in action_list:
|
126 |
+
if action not in ['w', 'a', 's', 'd']:
|
127 |
+
return None, f"Error: Invalid action '{action}'. Use only w, a, s, d"
|
128 |
+
|
129 |
+
for speed in speed_list:
|
130 |
+
if not 0.0 <= speed <= 3.0:
|
131 |
+
return None, f"Error: Speed {speed} out of range. Use values between 0.0 and 3.0"
|
132 |
+
|
133 |
+
progress(0.1, desc="Processing image...")
|
134 |
+
|
135 |
+
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
|
136 |
+
input_image.save(tmp_file.name)
|
137 |
+
image_path = tmp_file.name
|
138 |
+
|
139 |
+
closest_size = (704, 1216)
|
140 |
+
ref_image_transform = transforms.Compose([
|
141 |
+
CropResize(closest_size),
|
142 |
+
transforms.CenterCrop(closest_size),
|
143 |
+
transforms.ToTensor(),
|
144 |
+
transforms.Normalize([0.5], [0.5])
|
145 |
+
])
|
146 |
+
|
147 |
+
raw_ref_image = Image.open(image_path).convert('RGB')
|
148 |
+
ref_image_pixel_values = ref_image_transform(raw_ref_image)
|
149 |
+
ref_image_pixel_values = ref_image_pixel_values.unsqueeze(0).unsqueeze(2).to(device)
|
150 |
+
|
151 |
+
progress(0.2, desc="Encoding image...")
|
152 |
+
|
153 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
154 |
+
if args.cpu_offload:
|
155 |
+
hunyuan_video_sampler.vae.quant_conv.to('cuda')
|
156 |
+
hunyuan_video_sampler.vae.encoder.to('cuda')
|
157 |
+
|
158 |
+
hunyuan_video_sampler.pipeline.vae.enable_tiling()
|
159 |
+
|
160 |
+
raw_last_latents = hunyuan_video_sampler.vae.encode(
|
161 |
+
ref_image_pixel_values
|
162 |
+
).latent_dist.sample().to(dtype=torch.float16)
|
163 |
+
raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
|
164 |
+
raw_ref_latents = raw_last_latents.clone()
|
165 |
+
|
166 |
+
hunyuan_video_sampler.pipeline.vae.disable_tiling()
|
167 |
+
if args.cpu_offload:
|
168 |
+
hunyuan_video_sampler.vae.quant_conv.to('cpu')
|
169 |
+
hunyuan_video_sampler.vae.encoder.to('cpu')
|
170 |
+
|
171 |
+
ref_images = [raw_ref_image]
|
172 |
+
last_latents = raw_last_latents
|
173 |
+
ref_latents = raw_ref_latents
|
174 |
+
|
175 |
+
progress(0.3, desc="Starting video generation...")
|
176 |
+
|
177 |
+
if seed is None or seed == -1:
|
178 |
+
seed = random.randint(0, 1_000_000)
|
179 |
+
|
180 |
+
all_samples = []
|
181 |
+
|
182 |
+
for idx, (action_id, action_speed) in enumerate(zip(action_list, speed_list)):
|
183 |
+
is_image = (idx == 0)
|
184 |
+
|
185 |
+
progress(0.3 + (0.6 * idx / len(action_list)),
|
186 |
+
desc=f"Generating segment {idx+1}/{len(action_list)} (action: {action_id})")
|
187 |
+
|
188 |
+
outputs = hunyuan_video_sampler.predict(
|
189 |
+
prompt=prompt,
|
190 |
+
action_id=action_id,
|
191 |
+
action_speed=action_speed,
|
192 |
+
is_image=is_image,
|
193 |
+
size=(704, 1216),
|
194 |
+
seed=seed,
|
195 |
+
last_latents=last_latents,
|
196 |
+
ref_latents=ref_latents,
|
197 |
+
video_length=args.sample_n_frames,
|
198 |
+
guidance_scale=cfg_scale,
|
199 |
+
num_images_per_prompt=1,
|
200 |
+
negative_prompt=negative_prompt,
|
201 |
+
infer_steps=num_inference_steps,
|
202 |
+
flow_shift=args.flow_shift_eval_video,
|
203 |
+
use_linear_quadratic_schedule=args.use_linear_quadratic_schedule,
|
204 |
+
linear_schedule_end=args.linear_schedule_end,
|
205 |
+
use_deepcache=args.use_deepcache,
|
206 |
+
cpu_offload=args.cpu_offload,
|
207 |
+
ref_images=ref_images,
|
208 |
+
output_dir=None,
|
209 |
+
return_latents=True,
|
210 |
+
use_sage=args.use_sage,
|
211 |
+
)
|
212 |
+
|
213 |
+
ref_latents = outputs["ref_latents"]
|
214 |
+
last_latents = outputs["last_latents"]
|
215 |
+
|
216 |
+
sub_samples = outputs['samples'][0]
|
217 |
+
all_samples.append(sub_samples)
|
218 |
+
|
219 |
+
progress(0.9, desc="Finalizing video...")
|
220 |
+
|
221 |
+
if len(all_samples) > 0:
|
222 |
+
out_cat = torch.cat(all_samples, dim=2)
|
223 |
+
else:
|
224 |
+
out_cat = all_samples[0]
|
225 |
+
|
226 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
|
227 |
+
output_path = tmp_video.name
|
228 |
+
|
229 |
+
save_videos_grid(out_cat, output_path, n_rows=1, fps=25)
|
230 |
+
|
231 |
+
if os.path.exists(image_path):
|
232 |
+
os.remove(image_path)
|
233 |
+
|
234 |
+
progress(1.0, desc="Complete!")
|
235 |
+
return output_path, "Video generated successfully!"
|
236 |
+
|
237 |
+
except Exception as e:
|
238 |
+
logger.error(f"Error generating video: {e}")
|
239 |
+
return None, f"Error: {str(e)}"
|
240 |
+
|
241 |
+
with gr.Blocks(title="Hunyuan-GameCraft") as demo:
|
242 |
+
gr.Markdown("""
|
243 |
+
# 🎮 Hunyuan-GameCraft Video Generation
|
244 |
+
|
245 |
+
Generate interactive game-style videos from a single image using keyboard actions (W/A/S/D).
|
246 |
+
Using the **distilled model** for faster generation (8 inference steps).
|
247 |
+
""")
|
248 |
+
|
249 |
+
with gr.Row():
|
250 |
+
with gr.Column(scale=1):
|
251 |
+
input_image = gr.Image(
|
252 |
+
label="Input Image",
|
253 |
+
type="pil",
|
254 |
+
height=400
|
255 |
+
)
|
256 |
+
|
257 |
+
prompt = gr.Textbox(
|
258 |
+
label="Prompt",
|
259 |
+
placeholder="Describe the scene...",
|
260 |
+
value="A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky.",
|
261 |
+
lines=3
|
262 |
+
)
|
263 |
+
|
264 |
+
with gr.Accordion("Action Controls", open=True):
|
265 |
+
action_sequence = gr.Textbox(
|
266 |
+
label="Action Sequence (comma-separated)",
|
267 |
+
placeholder="w, a, s, d",
|
268 |
+
value="w, s, d, a",
|
269 |
+
info="Use w (forward), a (left), s (backward), d (right)"
|
270 |
+
)
|
271 |
+
|
272 |
+
action_speeds = gr.Textbox(
|
273 |
+
label="Action Speeds (comma-separated)",
|
274 |
+
placeholder="0.2, 0.2, 0.2, 0.2",
|
275 |
+
value="0.2, 0.2, 0.2, 0.2",
|
276 |
+
info="Speed for each action (0.0 to 3.0). Single value applies to all."
|
277 |
+
)
|
278 |
+
|
279 |
+
with gr.Accordion("Advanced Settings", open=False):
|
280 |
+
negative_prompt = gr.Textbox(
|
281 |
+
label="Negative Prompt",
|
282 |
+
value="overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border.",
|
283 |
+
lines=2
|
284 |
+
)
|
285 |
+
|
286 |
+
seed = gr.Number(
|
287 |
+
label="Seed",
|
288 |
+
value=-1,
|
289 |
+
precision=0,
|
290 |
+
info="Set to -1 for random seed"
|
291 |
+
)
|
292 |
+
|
293 |
+
cfg_scale = gr.Slider(
|
294 |
+
label="CFG Scale",
|
295 |
+
minimum=0.5,
|
296 |
+
maximum=3.0,
|
297 |
+
value=1.0,
|
298 |
+
step=0.1,
|
299 |
+
info="Classifier-free guidance scale (1.0 for distilled model)"
|
300 |
+
)
|
301 |
+
|
302 |
+
num_inference_steps = gr.Slider(
|
303 |
+
label="Inference Steps",
|
304 |
+
minimum=4,
|
305 |
+
maximum=20,
|
306 |
+
value=8,
|
307 |
+
step=1,
|
308 |
+
info="Number of denoising steps (8 for distilled model)"
|
309 |
+
)
|
310 |
+
|
311 |
+
generate_btn = gr.Button("Generate Video", variant="primary")
|
312 |
+
|
313 |
+
with gr.Column(scale=1):
|
314 |
+
output_video = gr.Video(
|
315 |
+
label="Generated Video",
|
316 |
+
height=400
|
317 |
+
)
|
318 |
+
status_text = gr.Textbox(
|
319 |
+
label="Status",
|
320 |
+
interactive=False
|
321 |
+
)
|
322 |
+
|
323 |
+
gr.Markdown("""
|
324 |
+
### Tips:
|
325 |
+
- Each action generates 33 frames (1.3 seconds at 25 FPS)
|
326 |
+
- The distilled model is optimized for speed with 8 inference steps
|
327 |
+
- Use FP8 optimization for better memory efficiency
|
328 |
+
- Minimum GPU memory: 24GB VRAM
|
329 |
+
""")
|
330 |
+
|
331 |
+
generate_btn.click(
|
332 |
+
fn=generate_video,
|
333 |
+
inputs=[
|
334 |
+
input_image,
|
335 |
+
prompt,
|
336 |
+
action_sequence,
|
337 |
+
action_speeds,
|
338 |
+
negative_prompt,
|
339 |
+
seed,
|
340 |
+
cfg_scale,
|
341 |
+
num_inference_steps
|
342 |
+
],
|
343 |
+
outputs=[output_video, status_text]
|
344 |
+
)
|
345 |
+
|
346 |
+
gr.Examples(
|
347 |
+
examples=[
|
348 |
+
[
|
349 |
+
"asset/village.png",
|
350 |
+
"A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky.",
|
351 |
+
"w, a, d, s",
|
352 |
+
"0.2, 0.2, 0.2, 0.2"
|
353 |
+
]
|
354 |
+
],
|
355 |
+
inputs=[input_image, prompt, action_sequence, action_speeds],
|
356 |
+
label="Example"
|
357 |
+
)
|
358 |
+
|
359 |
+
if __name__ == "__main__":
|
360 |
+
demo.launch(share=True)
|
asset/method.png
ADDED
![]() |
Git LFS Details
|
asset/teaser.png
ADDED
![]() |
Git LFS Details
|
asset/village.png
ADDED
![]() |
Git LFS Details
|
docs_for_ai_coding_bots/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
docs_for_ai_coding_bots/huggingface_hub/Downloading-model-from-hub.md
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](#download-files-from-the-hub)Download files from the Hub
|
2 |
+
===========================================================
|
3 |
+
|
4 |
+
The `huggingface_hub` library provides functions to download files from the repositories stored on the Hub. You can use these functions independently or integrate them into your own library, making it more convenient for your users to interact with the Hub. This guide will show you how to:
|
5 |
+
|
6 |
+
* Download and cache a single file.
|
7 |
+
* Download and cache an entire repository.
|
8 |
+
* Download files to a local folder.
|
9 |
+
|
10 |
+
[](#download-a-single-file)Download a single file
|
11 |
+
-------------------------------------------------
|
12 |
+
|
13 |
+
The [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) function is the main function for downloading files from the Hub. It downloads the remote file, caches it on disk (in a version-aware way), and returns its local file path.
|
14 |
+
|
15 |
+
The returned filepath is a pointer to the HF local cache. Therefore, it is important to not modify the file to avoid having a corrupted cache. If you are interested in getting to know more about how files are cached, please refer to our [caching guide](./manage-cache).
|
16 |
+
|
17 |
+
### [](#from-latest-version)From latest version
|
18 |
+
|
19 |
+
Select the file to download using the `repo_id`, `repo_type` and `filename` parameters. By default, the file will be considered as being part of a `model` repo.
|
20 |
+
|
21 |
+
Copied
|
22 |
+
|
23 |
+
\>>> from huggingface\_hub import hf\_hub\_download
|
24 |
+
\>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json")
|
25 |
+
'/root/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade/config.json'
|
26 |
+
|
27 |
+
\# Download from a dataset
|
28 |
+
\>>> hf\_hub\_download(repo\_id="google/fleurs", filename="fleurs.py", repo\_type="dataset")
|
29 |
+
'/root/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34/fleurs.py'
|
30 |
+
|
31 |
+
### [](#from-specific-version)From specific version
|
32 |
+
|
33 |
+
By default, the latest version from the `main` branch is downloaded. However, in some cases you want to download a file at a particular version (e.g. from a specific branch, a PR, a tag or a commit hash). To do so, use the `revision` parameter:
|
34 |
+
|
35 |
+
Copied
|
36 |
+
|
37 |
+
\# Download from the \`v1.0\` tag
|
38 |
+
\>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="v1.0")
|
39 |
+
|
40 |
+
\# Download from the \`test-branch\` branch
|
41 |
+
\>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="test-branch")
|
42 |
+
|
43 |
+
\# Download from Pull Request #3
|
44 |
+
\>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="refs/pr/3")
|
45 |
+
|
46 |
+
\# Download from a specific commit hash
|
47 |
+
\>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="877b84a8f93f2d619faa2a6e514a32beef88ab0a")
|
48 |
+
|
49 |
+
**Note:** When using the commit hash, it must be the full-length hash instead of a 7-character commit hash.
|
50 |
+
|
51 |
+
### [](#construct-a-download-url)Construct a download URL
|
52 |
+
|
53 |
+
In case you want to construct the URL used to download a file from a repo, you can use [hf\_hub\_url()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_url) which returns a URL. Note that it is used internally by [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download).
|
54 |
+
|
55 |
+
[](#download-an-entire-repository)Download an entire repository
|
56 |
+
---------------------------------------------------------------
|
57 |
+
|
58 |
+
[snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) downloads an entire repository at a given revision. It uses internally [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) which means all downloaded files are also cached on your local disk. Downloads are made concurrently to speed-up the process.
|
59 |
+
|
60 |
+
To download a whole repository, just pass the `repo_id` and `repo_type`:
|
61 |
+
|
62 |
+
Copied
|
63 |
+
|
64 |
+
\>>> from huggingface\_hub import snapshot\_download
|
65 |
+
\>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp")
|
66 |
+
'/home/lysandre/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade'
|
67 |
+
|
68 |
+
\# Or from a dataset
|
69 |
+
\>>> snapshot\_download(repo\_id="google/fleurs", repo\_type="dataset")
|
70 |
+
'/home/lysandre/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34'
|
71 |
+
|
72 |
+
[snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) downloads the latest revision by default. If you want a specific repository revision, use the `revision` parameter:
|
73 |
+
|
74 |
+
Copied
|
75 |
+
|
76 |
+
\>>> from huggingface\_hub import snapshot\_download
|
77 |
+
\>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp", revision="refs/pr/1")
|
78 |
+
|
79 |
+
### [](#filter-files-to-download)Filter files to download
|
80 |
+
|
81 |
+
[snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) provides an easy way to download a repository. However, you don’t always want to download the entire content of a repository. For example, you might want to prevent downloading all `.bin` files if you know you’ll only use the `.safetensors` weights. You can do that using `allow_patterns` and `ignore_patterns` parameters.
|
82 |
+
|
83 |
+
These parameters accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). The pattern matching is based on [`fnmatch`](https://docs.python.org/3/library/fnmatch.html).
|
84 |
+
|
85 |
+
For example, you can use `allow_patterns` to only download JSON configuration files:
|
86 |
+
|
87 |
+
Copied
|
88 |
+
|
89 |
+
\>>> from huggingface\_hub import snapshot\_download
|
90 |
+
\>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp", allow\_patterns="\*.json")
|
91 |
+
|
92 |
+
On the other hand, `ignore_patterns` can exclude certain files from being downloaded. The following example ignores the `.msgpack` and `.h5` file extensions:
|
93 |
+
|
94 |
+
Copied
|
95 |
+
|
96 |
+
\>>> from huggingface\_hub import snapshot\_download
|
97 |
+
\>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp", ignore\_patterns=\["\*.msgpack", "\*.h5"\])
|
98 |
+
|
99 |
+
Finally, you can combine both to precisely filter your download. Here is an example to download all json and markdown files except `vocab.json`.
|
100 |
+
|
101 |
+
Copied
|
102 |
+
|
103 |
+
\>>> from huggingface\_hub import snapshot\_download
|
104 |
+
\>>> snapshot\_download(repo\_id="gpt2", allow\_patterns=\["\*.md", "\*.json"\], ignore\_patterns="vocab.json")
|
105 |
+
|
106 |
+
[](#download-files-to-a-local-folder)Download file(s) to a local folder
|
107 |
+
-----------------------------------------------------------------------
|
108 |
+
|
109 |
+
By default, we recommend using the [cache system](./manage-cache) to download files from the Hub. You can specify a custom cache location using the `cache_dir` parameter in [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) and [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download), or by setting the [`HF_HOME`](../package_reference/environment_variables#hf_home) environment variable.
|
110 |
+
|
111 |
+
However, if you need to download files to a specific folder, you can pass a `local_dir` parameter to the download function. This is useful to get a workflow closer to what the `git` command offers. The downloaded files will maintain their original file structure within the specified folder. For example, if `filename="data/train.csv"` and `local_dir="path/to/folder"`, the resulting filepath will be `"path/to/folder/data/train.csv"`.
|
112 |
+
|
113 |
+
A `.cache/huggingface/` folder is created at the root of your local directory containing metadata about the downloaded files. This prevents re-downloading files if they’re already up-to-date. If the metadata has changed, then the new file version is downloaded. This makes the `local_dir` optimized for pulling only the latest changes.
|
114 |
+
|
115 |
+
After completing the download, you can safely remove the `.cache/huggingface/` folder if you no longer need it. However, be aware that re-running your script without this folder may result in longer recovery times, as metadata will be lost. Rest assured that your local data will remain intact and unaffected.
|
116 |
+
|
117 |
+
Don’t worry about the `.cache/huggingface/` folder when committing changes to the Hub! This folder is automatically ignored by both `git` and [upload\_folder()](/docs/huggingface_hub/v0.32.2/en/package_reference/hf_api#huggingface_hub.HfApi.upload_folder).
|
118 |
+
|
119 |
+
[](#download-from-the-cli)Download from the CLI
|
120 |
+
-----------------------------------------------
|
121 |
+
|
122 |
+
You can use the `huggingface-cli download` command from the terminal to directly download files from the Hub. Internally, it uses the same [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) and [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) helpers described above and prints the returned path to the terminal.
|
123 |
+
|
124 |
+
Copied
|
125 |
+
|
126 |
+
\>>> huggingface-cli download gpt2 config.json
|
127 |
+
/home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json
|
128 |
+
|
129 |
+
You can download multiple files at once which displays a progress bar and returns the snapshot path in which the files are located:
|
130 |
+
|
131 |
+
Copied
|
132 |
+
|
133 |
+
\>>> huggingface-cli download gpt2 config.json model.safetensors
|
134 |
+
Fetching 2 files: 100%|████████████████████████████████████████████| 2/2 \[00:00<00:00, 23831.27it/s\]
|
135 |
+
/home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10
|
136 |
+
|
137 |
+
For more details about the CLI download command, please refer to the [CLI guide](./cli#huggingface-cli-download).
|
138 |
+
|
139 |
+
[](#faster-downloads)Faster downloads
|
140 |
+
-------------------------------------
|
141 |
+
|
142 |
+
There are two options to speed up downloads. Both involve installing a Python package written in Rust.
|
143 |
+
|
144 |
+
* `hf_xet` is newer and uses the Xet storage backend for upload/download. It is available in production, but is in the process of being rolled out to all users, so join the [waitlist](https://huggingface.co/join/xet) to get onboarded soon!
|
145 |
+
* `hf_transfer` is a power-tool to download and upload to our LFS storage backend (note: this is less future-proof than Xet). It is thoroughly tested and has been in production for a long time, but it has some limitations.
|
146 |
+
|
147 |
+
### [](#hfxet)hf\_xet
|
148 |
+
|
149 |
+
Take advantage of faster downloads through `hf_xet`, the Python binding to the [`xet-core`](https://github.com/huggingface/xet-core) library that enables chunk-based deduplication for faster downloads and uploads. `hf_xet` integrates seamlessly with `huggingface_hub`, but uses the Rust `xet-core` library and Xet storage instead of LFS.
|
150 |
+
|
151 |
+
`hf_xet` uses the Xet storage system, which breaks files down into immutable chunks, storing collections of these chunks (called blocks or xorbs) remotely and retrieving them to reassemble the file when requested. When downloading, after confirming the user is authorized to access the files, `hf_xet` will query the Xet content-addressable service (CAS) with the LFS SHA256 hash for this file to receive the reconstruction metadata (ranges within xorbs) to assemble these files, along with presigned URLs to download the xorbs directly. Then `hf_xet` will efficiently download the xorb ranges necessary and will write out the files on disk. `hf_xet` uses a local disk cache to only download chunks once, learn more in the [Chunk-based caching(Xet)](./manage-cache#chunk-based-caching-xet) section.
|
152 |
+
|
153 |
+
To enable it, specify the `hf_xet` package when installing `huggingface_hub`:
|
154 |
+
|
155 |
+
Copied
|
156 |
+
|
157 |
+
pip install -U "huggingface\_hub\[hf\_xet\]"
|
158 |
+
|
159 |
+
Note: `hf_xet` will only be utilized when the files being downloaded are being stored with Xet Storage.
|
160 |
+
|
161 |
+
All other `huggingface_hub` APIs will continue to work without any modification. To learn more about the benefits of Xet storage and `hf_xet`, refer to this [section](https://huggingface.co/docs/hub/storage-backends).
|
162 |
+
|
163 |
+
### [](#hftransfer)hf\_transfer
|
164 |
+
|
165 |
+
If you are running on a machine with high bandwidth, you can increase your download speed with [`hf_transfer`](https://github.com/huggingface/hf_transfer), a Rust-based library developed to speed up file transfers with the Hub. To enable it:
|
166 |
+
|
167 |
+
1. Specify the `hf_transfer` extra when installing `huggingface_hub` (e.g. `pip install huggingface_hub[hf_transfer]`).
|
168 |
+
2. Set `HF_HUB_ENABLE_HF_TRANSFER=1` as an environment variable.
|
169 |
+
|
170 |
+
`hf_transfer` is a power user tool! It is tested and production-ready, but it lacks user-friendly features like advanced error handling or proxies. For more details, please take a look at this [section](https://huggingface.co/docs/huggingface_hub/hf_transfer).
|
171 |
+
|
172 |
+
[< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/guides/download.md)
|
173 |
+
|
174 |
+
Command Line Interface (CLI)
|
docs_for_ai_coding_bots/huggingface_hub/Using-the-cache-in-hf-hub-library.md
ADDED
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](#understand-caching)Understand caching
|
2 |
+
=========================================
|
3 |
+
|
4 |
+
`huggingface_hub` utilizes the local disk as two caches, which avoid re-downloading items again. The first cache is a file-based cache, which caches individual files downloaded from the Hub and ensures that the same file is not downloaded again when a repo gets updated. The second cache is a chunk cache, where each chunk represents a byte range from a file and ensures that chunks that are shared across files are only downloaded once.
|
5 |
+
|
6 |
+
[](#file-based-caching)File-based caching
|
7 |
+
-----------------------------------------
|
8 |
+
|
9 |
+
The Hugging Face Hub cache-system is designed to be the central cache shared across libraries that depend on the Hub. It has been updated in v0.8.0 to prevent re-downloading same files between revisions.
|
10 |
+
|
11 |
+
The caching system is designed as follows:
|
12 |
+
|
13 |
+
Copied
|
14 |
+
|
15 |
+
<CACHE\_DIR\>
|
16 |
+
├─ <MODELS\>
|
17 |
+
├─ <DATASETS\>
|
18 |
+
├─ <SPACES\>
|
19 |
+
|
20 |
+
The default `<CACHE_DIR>` is `~/.cache/huggingface/hub`. However, it is customizable with the `cache_dir` argument on all methods, or by specifying either `HF_HOME` or `HF_HUB_CACHE` environment variable.
|
21 |
+
|
22 |
+
Models, datasets and spaces share a common root. Each of these repositories contains the repository type, the namespace (organization or username) if it exists and the repository name:
|
23 |
+
|
24 |
+
Copied
|
25 |
+
|
26 |
+
<CACHE\_DIR\>
|
27 |
+
├─ models\--julien\-c\--EsperBERTo\-small
|
28 |
+
├─ models\--lysandrejik\--arxiv\-nlp
|
29 |
+
├─ models\--bert\-base\-cased
|
30 |
+
├─ datasets\--glue
|
31 |
+
├─ datasets\--huggingface\--DataMeasurementsFiles
|
32 |
+
├─ spaces\--dalle\-mini\--dalle\-mini
|
33 |
+
|
34 |
+
It is within these folders that all files will now be downloaded from the Hub. Caching ensures that a file isn’t downloaded twice if it already exists and wasn’t updated; but if it was updated, and you’re asking for the latest file, then it will download the latest file (while keeping the previous file intact in case you need it again).
|
35 |
+
|
36 |
+
In order to achieve this, all folders contain the same skeleton:
|
37 |
+
|
38 |
+
Copied
|
39 |
+
|
40 |
+
<CACHE\_DIR>
|
41 |
+
├─ datasets\--glue
|
42 |
+
│ ├─ refs
|
43 |
+
│ ├─ blobs
|
44 |
+
│ ├─ snapshots
|
45 |
+
...
|
46 |
+
|
47 |
+
Each folder is designed to contain the following:
|
48 |
+
|
49 |
+
### [](#refs)Refs
|
50 |
+
|
51 |
+
The `refs` folder contains files which indicates the latest revision of the given reference. For example, if we have previously fetched a file from the `main` branch of a repository, the `refs` folder will contain a file named `main`, which will itself contain the commit identifier of the current head.
|
52 |
+
|
53 |
+
If the latest commit of `main` has `aaaaaa` as identifier, then it will contain `aaaaaa`.
|
54 |
+
|
55 |
+
If that same branch gets updated with a new commit, that has `bbbbbb` as an identifier, then re-downloading a file from that reference will update the `refs/main` file to contain `bbbbbb`.
|
56 |
+
|
57 |
+
### [](#blobs)Blobs
|
58 |
+
|
59 |
+
The `blobs` folder contains the actual files that we have downloaded. The name of each file is their hash.
|
60 |
+
|
61 |
+
### [](#snapshots)Snapshots
|
62 |
+
|
63 |
+
The `snapshots` folder contains symlinks to the blobs mentioned above. It is itself made up of several folders: one per known revision!
|
64 |
+
|
65 |
+
In the explanation above, we had initially fetched a file from the `aaaaaa` revision, before fetching a file from the `bbbbbb` revision. In this situation, we would now have two folders in the `snapshots` folder: `aaaaaa` and `bbbbbb`.
|
66 |
+
|
67 |
+
In each of these folders, live symlinks that have the names of the files that we have downloaded. For example, if we had downloaded the `README.md` file at revision `aaaaaa`, we would have the following path:
|
68 |
+
|
69 |
+
Copied
|
70 |
+
|
71 |
+
<CACHE\_DIR>/<REPO\_NAME>/snapshots/aaaaaa/README.md
|
72 |
+
|
73 |
+
That `README.md` file is actually a symlink linking to the blob that has the hash of the file.
|
74 |
+
|
75 |
+
By creating the skeleton this way we open the mechanism to file sharing: if the same file was fetched in revision `bbbbbb`, it would have the same hash and the file would not need to be re-downloaded.
|
76 |
+
|
77 |
+
### [](#noexist-advanced).no\_exist (advanced)
|
78 |
+
|
79 |
+
In addition to the `blobs`, `refs` and `snapshots` folders, you might also find a `.no_exist` folder in your cache. This folder keeps track of files that you’ve tried to download once but don’t exist on the Hub. Its structure is the same as the `snapshots` folder with 1 subfolder per known revision:
|
80 |
+
|
81 |
+
Copied
|
82 |
+
|
83 |
+
<CACHE\_DIR>/<REPO\_NAME>/.no\_exist/aaaaaa/config\_that\_does\_not\_exist.json
|
84 |
+
|
85 |
+
Unlike the `snapshots` folder, files are simple empty files (no symlinks). In this example, the file `"config_that_does_not_exist.json"` does not exist on the Hub for the revision `"aaaaaa"`. As it only stores empty files, this folder is neglectable in term of disk usage.
|
86 |
+
|
87 |
+
So now you might wonder, why is this information even relevant? In some cases, a framework tries to load optional files for a model. Saving the non-existence of optional files makes it faster to load a model as it saves 1 HTTP call per possible optional file. This is for example the case in `transformers` where each tokenizer can support additional files. The first time you load the tokenizer on your machine, it will cache which optional files exist (and which doesn’t) to make the loading time faster for the next initializations.
|
88 |
+
|
89 |
+
To test if a file is cached locally (without making any HTTP request), you can use the [try\_to\_load\_from\_cache()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.try_to_load_from_cache) helper. It will either return the filepath (if exists and cached), the object `_CACHED_NO_EXIST` (if non-existence is cached) or `None` (if we don’t know).
|
90 |
+
|
91 |
+
Copied
|
92 |
+
|
93 |
+
from huggingface\_hub import try\_to\_load\_from\_cache, \_CACHED\_NO\_EXIST
|
94 |
+
|
95 |
+
filepath = try\_to\_load\_from\_cache()
|
96 |
+
if isinstance(filepath, str):
|
97 |
+
\# file exists and is cached
|
98 |
+
...
|
99 |
+
elif filepath is \_CACHED\_NO\_EXIST:
|
100 |
+
\# non-existence of file is cached
|
101 |
+
...
|
102 |
+
else:
|
103 |
+
\# file is not cached
|
104 |
+
...
|
105 |
+
|
106 |
+
### [](#in-practice)In practice
|
107 |
+
|
108 |
+
In practice, your cache should look like the following tree:
|
109 |
+
|
110 |
+
Copied
|
111 |
+
|
112 |
+
\[ 96\] .
|
113 |
+
└── \[ 160\] models--julien-c--EsperBERTo-small
|
114 |
+
├── \[ 160\] blobs
|
115 |
+
│ ├── \[321M\] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
116 |
+
│ ├── \[ 398\] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
117 |
+
│ └── \[1.4K\] d7edf6bd2a681fb0175f7735299831ee1b22b812
|
118 |
+
├── \[ 96\] refs
|
119 |
+
│ └── \[ 40\] main
|
120 |
+
└── \[ 128\] snapshots
|
121 |
+
├── \[ 128\] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
|
122 |
+
│ ├── \[ 52\] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
|
123 |
+
│ └── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
124 |
+
└── \[ 128\] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
|
125 |
+
├── \[ 52\] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
126 |
+
└── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
127 |
+
|
128 |
+
### [](#limitations)Limitations
|
129 |
+
|
130 |
+
In order to have an efficient cache-system, `huggingface-hub` uses symlinks. However, symlinks are not supported on all machines. This is a known limitation especially on Windows. When this is the case, `huggingface_hub` do not use the `blobs/` directory but directly stores the files in the `snapshots/` directory instead. This workaround allows users to download and cache files from the Hub exactly the same way. Tools to inspect and delete the cache (see below) are also supported. However, the cache-system is less efficient as a single file might be downloaded several times if multiple revisions of the same repo is downloaded.
|
131 |
+
|
132 |
+
If you want to benefit from the symlink-based cache-system on a Windows machine, you either need to [activate Developer Mode](https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development) or to run Python as an administrator.
|
133 |
+
|
134 |
+
When symlinks are not supported, a warning message is displayed to the user to alert them they are using a degraded version of the cache-system. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable to true.
|
135 |
+
|
136 |
+
[](#chunk-based-caching-xet)Chunk-based caching (Xet)
|
137 |
+
-----------------------------------------------------
|
138 |
+
|
139 |
+
To provide more efficient file transfers, `hf_xet` adds a `xet` directory to the existing `huggingface_hub` cache, creating additional caching layer to enable chunk-based deduplication. This cache holds chunks, which are immutable byte ranges from files (up to 64KB) that are created using content-defined chunking. For more information on the Xet Storage system, see this [section](https://huggingface.co/docs/hub/storage-backends).
|
140 |
+
|
141 |
+
The `xet` directory, located at `~/.cache/huggingface/xet` by default, contains two caches, utilized for uploads and downloads with the following structure
|
142 |
+
|
143 |
+
Copied
|
144 |
+
|
145 |
+
<CACHE\_DIR>
|
146 |
+
├─ chunk\_cache
|
147 |
+
├─ shard\_cache
|
148 |
+
|
149 |
+
The `xet` cache, like the rest of `hf_xet` is fully integrated with `huggingface_hub`. If you use the existing APIs for interacting with cached assets, there is no need to update your workflow. The `xet` cache is built as an optimization layer on top of the existing `hf_xet` chunk-based deduplication and `huggingface_hub` cache system.
|
150 |
+
|
151 |
+
The `chunk-cache` directory contains cached data chunks that are used to speed up downloads while the `shard-cache` directory contains cached shards that are utilized on the upload path.
|
152 |
+
|
153 |
+
### [](#chunkcache)chunk\_cache
|
154 |
+
|
155 |
+
This cache is used on the download path. The cache directory structure is based on a base-64 encoded hash from the content-addressed store (CAS) that backs each Xet-enabled repository. A CAS hash serves as the key to lookup the offsets of where the data is stored.
|
156 |
+
|
157 |
+
At the topmost level, the first two letters of the base 64 encoded CAS hash are used to create a subdirectory in the `chunk_cache` (keys that share these first two letters are grouped here). The inner levels are comprised of subdirectories with the full key as the directory name. At the base are the cache items which are ranges of blocks that contain the cached chunks.
|
158 |
+
|
159 |
+
Copied
|
160 |
+
|
161 |
+
<CACHE\_DIR>
|
162 |
+
├─ xet
|
163 |
+
│ ├─ chunk\_cache
|
164 |
+
│ │ ├─ A1
|
165 |
+
│ │ │ ├─ A1GerURLUcISVivdseeoY1PnYifYkOaCCJ7V5Q9fjgxkZWZhdWx0
|
166 |
+
│ │ │ │ ├─ AAAAAAEAAAA5DQAAAAAAAIhRLjDI3SS5jYs4ysNKZiJy9XFI8CN7Ww0UyEA9KPD9
|
167 |
+
│ │ │ │ ├─ AQAAAAIAAABzngAAAAAAAPNqPjd5Zby5aBvabF7Z1itCx0ryMwoCnuQcDwq79jlB
|
168 |
+
|
169 |
+
When requesting a file, the first thing `hf_xet` does is communicate with Xet storage’s content addressed store (CAS) for reconstruction information. The reconstruction information contains information about the CAS keys required to download the file in its entirety.
|
170 |
+
|
171 |
+
Before executing the requests for the CAS keys, the `chunk_cache` is consulted. If a key in the cache matches a CAS key, then there is no reason to issue a request for that content. `hf_xet` uses the chunks stored in the directory instead.
|
172 |
+
|
173 |
+
As the `chunk_cache` is purely an optimization, not a guarantee, `hf_xet` utilizes a computationally efficient eviction policy. When the `chunk_cache` is full (see `Limits and Limitations` below), `hf_xet` implements a random eviction policy when selecting an eviction candidate. This significantly reduces the overhead of managing a robust caching system (e.g., LRU) while still providing most of the benefits of caching chunks.
|
174 |
+
|
175 |
+
### [](#shardcache)shard\_cache
|
176 |
+
|
177 |
+
This cache is used when uploading content to the Hub. The directory is flat, comprising only of shard files, each using an ID for the shard name.
|
178 |
+
|
179 |
+
Copied
|
180 |
+
|
181 |
+
<CACHE\_DIR>
|
182 |
+
├─ xet
|
183 |
+
│ ├─ shard\_cache
|
184 |
+
│ │ ├─ 1fe4ffd5cf0c3375f1ef9aec5016cf773ccc5ca294293d3f92d92771dacfc15d.mdb
|
185 |
+
│ │ ├─ 906ee184dc1cd0615164a89ed64e8147b3fdccd1163d80d794c66814b3b09992.mdb
|
186 |
+
│ │ ├─ ceeeb7ea4cf6c0a8d395a2cf9c08871211fbbd17b9b5dc1005811845307e6b8f.mdb
|
187 |
+
│ │ ├─ e8535155b1b11ebd894c908e91a1e14e3461dddd1392695ddc90ae54a548d8b2.mdb
|
188 |
+
|
189 |
+
The `shard_cache` contains shards that are:
|
190 |
+
|
191 |
+
* Locally generated and successfully uploaded to the CAS
|
192 |
+
* Downloaded from CAS as part of the global deduplication algorithm
|
193 |
+
|
194 |
+
Shards provide a mapping between files and chunks. During uploads, each file is chunked and the hash of the chunk is saved. Every shard in the cache is then consulted. If a shard contains a chunk hash that is present in the local file being uploaded, then that chunk can be discarded as it is already stored in CAS.
|
195 |
+
|
196 |
+
All shards have an expiration date of 3-4 weeks from when they are downloaded. Shards that are expired are not loaded during upload and are deleted one week after expiration.
|
197 |
+
|
198 |
+
### [](#limits-and-limitations)Limits and Limitations
|
199 |
+
|
200 |
+
The `chunk_cache` is limited to 10GB in size while the `shard_cache` is technically without limits (in practice, the size and use of shards are such that limiting the cache is unnecessary).
|
201 |
+
|
202 |
+
By design, both caches are without high-level APIs. These caches are used primarily to facilitate the reconstruction (download) or upload of a file. To interact with the assets themselves, it’s recommended that you use the [`huggingface_hub` cache system APIs](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).
|
203 |
+
|
204 |
+
If you need to reclaim the space utilized by either cache or need to debug any potential cache-related issues, simply remove the `xet` cache entirely by running `rm -rf ~/<cache_dir>/xet` where `<cache_dir>` is the location of your Hugging Face cache, typically `~/.cache/huggingface`
|
205 |
+
|
206 |
+
Example full `xet`cache directory tree:
|
207 |
+
|
208 |
+
Copied
|
209 |
+
|
210 |
+
<CACHE\_DIR>
|
211 |
+
├─ xet
|
212 |
+
│ ├─ chunk\_cache
|
213 |
+
│ │ ├─ L1
|
214 |
+
│ │ │ ├─ L1GerURLUcISVivdseeoY1PnYifYkOaCCJ7V5Q9fjgxkZWZhdWx0
|
215 |
+
│ │ │ │ ├─ AAAAAAEAAAA5DQAAAAAAAIhRLjDI3SS5jYs4ysNKZiJy9XFI8CN7Ww0UyEA9KPD9
|
216 |
+
│ │ │ │ ├─ AQAAAAIAAABzngAAAAAAAPNqPjd5Zby5aBvabF7Z1itCx0ryMwoCnuQcDwq79jlB
|
217 |
+
│ ├─ shard\_cache
|
218 |
+
│ │ ├─ 1fe4ffd5cf0c3375f1ef9aec5016cf773ccc5ca294293d3f92d92771dacfc15d.mdb
|
219 |
+
│ │ ├─ 906ee184dc1cd0615164a89ed64e8147b3fdccd1163d80d794c66814b3b09992.mdb
|
220 |
+
│ │ ├─ ceeeb7ea4cf6c0a8d395a2cf9c08871211fbbd17b9b5dc1005811845307e6b8f.mdb
|
221 |
+
│ │ ├─ e8535155b1b11ebd894c908e91a1e14e3461dddd1392695ddc90ae54a548d8b2.mdb
|
222 |
+
|
223 |
+
To learn more about Xet Storage, see this [section](https://huggingface.co/docs/hub/storage-backends).
|
224 |
+
|
225 |
+
[](#caching-assets)Caching assets
|
226 |
+
---------------------------------
|
227 |
+
|
228 |
+
In addition to caching files from the Hub, downstream libraries often requires to cache other files related to HF but not handled directly by `huggingface_hub` (example: file downloaded from GitHub, preprocessed data, logs,…). In order to cache those files, called `assets`, one can use [cached\_assets\_path()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.cached_assets_path). This small helper generates paths in the HF cache in a unified way based on the name of the library requesting it and optionally on a namespace and a subfolder name. The goal is to let every downstream libraries manage its assets its own way (e.g. no rule on the structure) as long as it stays in the right assets folder. Those libraries can then leverage tools from `huggingface_hub` to manage the cache, in particular scanning and deleting parts of the assets from a CLI command.
|
229 |
+
|
230 |
+
Copied
|
231 |
+
|
232 |
+
from huggingface\_hub import cached\_assets\_path
|
233 |
+
|
234 |
+
assets\_path = cached\_assets\_path(library\_name="datasets", namespace="SQuAD", subfolder="download")
|
235 |
+
something\_path = assets\_path / "something.json" \# Do anything you like in your assets folder !
|
236 |
+
|
237 |
+
[cached\_assets\_path()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.cached_assets_path) is the recommended way to store assets but is not mandatory. If your library already uses its own cache, feel free to use it!
|
238 |
+
|
239 |
+
### [](#assets-in-practice)Assets in practice
|
240 |
+
|
241 |
+
In practice, your assets cache should look like the following tree:
|
242 |
+
|
243 |
+
Copied
|
244 |
+
|
245 |
+
assets/
|
246 |
+
└── datasets/
|
247 |
+
│ ├── SQuAD/
|
248 |
+
│ │ ├── downloaded/
|
249 |
+
│ │ ├── extracted/
|
250 |
+
│ │ └── processed/
|
251 |
+
│ ├── Helsinki-NLP--tatoeba\_mt/
|
252 |
+
│ ├── downloaded/
|
253 |
+
│ ├── extracted/
|
254 |
+
│ └── processed/
|
255 |
+
└── transformers/
|
256 |
+
├── default/
|
257 |
+
│ ├── something/
|
258 |
+
├── bert-base-cased/
|
259 |
+
│ ├── default/
|
260 |
+
│ └── training/
|
261 |
+
hub/
|
262 |
+
└── models--julien-c--EsperBERTo-small/
|
263 |
+
├── blobs/
|
264 |
+
│ ├── (...)
|
265 |
+
│ ├── (...)
|
266 |
+
├── refs/
|
267 |
+
│ └── (...)
|
268 |
+
└── \[ 128\] snapshots/
|
269 |
+
├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/
|
270 |
+
│ ├── (...)
|
271 |
+
└── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/
|
272 |
+
└── (...)
|
273 |
+
|
274 |
+
[](#manage-your-file-based-cache)Manage your file-based cache
|
275 |
+
-------------------------------------------------------------
|
276 |
+
|
277 |
+
### [](#scan-your-cache)Scan your cache
|
278 |
+
|
279 |
+
At the moment, cached files are never deleted from your local directory: when you download a new revision of a branch, previous files are kept in case you need them again. Therefore it can be useful to scan your cache directory in order to know which repos and revisions are taking the most disk space. `huggingface_hub` provides an helper to do so that can be used via `huggingface-cli` or in a python script.
|
280 |
+
|
281 |
+
**Scan cache from the terminal**
|
282 |
+
|
283 |
+
The easiest way to scan your HF cache-system is to use the `scan-cache` command from `huggingface-cli` tool. This command scans the cache and prints a report with information like repo id, repo type, disk usage, refs and full local path.
|
284 |
+
|
285 |
+
The snippet below shows a scan report in a folder in which 4 models and 2 datasets are cached.
|
286 |
+
|
287 |
+
Copied
|
288 |
+
|
289 |
+
➜ huggingface-cli scan-cache
|
290 |
+
REPO ID REPO TYPE SIZE ON DISK NB FILES LAST\_ACCESSED LAST\_MODIFIED REFS LOCAL PATH
|
291 |
+
--------------------------- --------- ------------ -------- ------------- ------------- ------------------- -------------------------------------------------------------------------
|
292 |
+
glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue
|
293 |
+
google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs
|
294 |
+
Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner
|
295 |
+
bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased
|
296 |
+
t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base
|
297 |
+
t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small
|
298 |
+
|
299 |
+
Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
|
300 |
+
Got 1 warning(s) while scanning. Use -vvv to print details.
|
301 |
+
|
302 |
+
To get a more detailed report, use the `--verbose` option. For each repo, you get a list of all revisions that have been downloaded. As explained above, the files that don’t change between 2 revisions are shared thanks to the symlinks. This means that the size of the repo on disk is expected to be less than the sum of the size of each of its revisions. For example, here `bert-base-cased` has 2 revisions of 1.4G and 1.5G but the total disk usage is only 1.9G.
|
303 |
+
|
304 |
+
Copied
|
305 |
+
|
306 |
+
➜ huggingface-cli scan-cache -v
|
307 |
+
REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST\_MODIFIED REFS LOCAL PATH
|
308 |
+
--------------------------- --------- ---------------------------------------- ------------ -------- ------------- ----------- ----------------------------------------------------------------------------------------------------------------------------
|
309 |
+
glue dataset 9338f7b671827df886678df2bdd7cc7b4f36dffd 97.7K 14 4 days ago main, 2.4.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/9338f7b671827df886678df2bdd7cc7b4f36dffd
|
310 |
+
glue dataset f021ae41c879fcabcf823648ec685e3fead91fe7 97.8K 14 1 week ago 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/f021ae41c879fcabcf823648ec685e3fead91fe7
|
311 |
+
google/fleurs dataset 129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 25.4K 3 2 weeks ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8
|
312 |
+
google/fleurs dataset 24f85a01eb955224ca3946e70050869c56446805 64.9M 4 1 week ago main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/24f85a01eb955224ca3946e70050869c56446805
|
313 |
+
Jean-Baptiste/camembert-ner model dbec8489a1c44ecad9da8a9185115bccabd799fe 441.0M 7 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner/snapshots/dbec8489a1c44ecad9da8a9185115bccabd799fe
|
314 |
+
bert-base-cased model 378aa1bda6387fd00e824948ebe3488630ad8565 1.5G 9 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/378aa1bda6387fd00e824948ebe3488630ad8565
|
315 |
+
bert-base-cased model a8d257ba9925ef39f3036bfc338acf5283c512d9 1.4G 9 3 days ago main /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/a8d257ba9925ef39f3036bfc338acf5283c512d9
|
316 |
+
t5-base model 23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 10.1K 3 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-base/snapshots/23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9
|
317 |
+
t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a
|
318 |
+
t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617
|
319 |
+
t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5
|
320 |
+
|
321 |
+
Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
|
322 |
+
Got 1 warning(s) while scanning. Use -vvv to print details.
|
323 |
+
|
324 |
+
**Grep example**
|
325 |
+
|
326 |
+
Since the output is in tabular format, you can combine it with any `grep`\-like tools to filter the entries. Here is an example to filter only revisions from the “t5-small” model on a Unix-based machine.
|
327 |
+
|
328 |
+
Copied
|
329 |
+
|
330 |
+
➜ eval "huggingface-cli scan-cache -v" | grep "t5-small"
|
331 |
+
t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a
|
332 |
+
t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617
|
333 |
+
t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5
|
334 |
+
|
335 |
+
**Scan cache from Python**
|
336 |
+
|
337 |
+
For a more advanced usage, use [scan\_cache\_dir()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.scan_cache_dir) which is the python utility called by the CLI tool.
|
338 |
+
|
339 |
+
You can use it to get a detailed report structured around 4 dataclasses:
|
340 |
+
|
341 |
+
* [HFCacheInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo): complete report returned by [scan\_cache\_dir()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.scan_cache_dir)
|
342 |
+
* [CachedRepoInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.CachedRepoInfo): information about a cached repo
|
343 |
+
* [CachedRevisionInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.CachedRevisionInfo): information about a cached revision (e.g. “snapshot”) inside a repo
|
344 |
+
* [CachedFileInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.CachedFileInfo): information about a cached file in a snapshot
|
345 |
+
|
346 |
+
Here is a simple usage example. See reference for details.
|
347 |
+
|
348 |
+
Copied
|
349 |
+
|
350 |
+
\>>> from huggingface\_hub import scan\_cache\_dir
|
351 |
+
|
352 |
+
\>>> hf\_cache\_info = scan\_cache\_dir()
|
353 |
+
HFCacheInfo(
|
354 |
+
size\_on\_disk=3398085269,
|
355 |
+
repos=frozenset({
|
356 |
+
CachedRepoInfo(
|
357 |
+
repo\_id='t5-small',
|
358 |
+
repo\_type='model',
|
359 |
+
repo\_path=PosixPath(...),
|
360 |
+
size\_on\_disk=970726914,
|
361 |
+
nb\_files=11,
|
362 |
+
last\_accessed=1662971707.3567169,
|
363 |
+
last\_modified=1662971107.3567169,
|
364 |
+
revisions=frozenset({
|
365 |
+
CachedRevisionInfo(
|
366 |
+
commit\_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5',
|
367 |
+
size\_on\_disk=970726339,
|
368 |
+
snapshot\_path=PosixPath(...),
|
369 |
+
\# No \`last\_accessed\` as blobs are shared among revisions
|
370 |
+
last\_modified=1662971107.3567169,
|
371 |
+
files=frozenset({
|
372 |
+
CachedFileInfo(
|
373 |
+
file\_name='config.json',
|
374 |
+
size\_on\_disk=1197
|
375 |
+
file\_path=PosixPath(...),
|
376 |
+
blob\_path=PosixPath(...),
|
377 |
+
blob\_last\_accessed=1662971707.3567169,
|
378 |
+
blob\_last\_modified=1662971107.3567169,
|
379 |
+
),
|
380 |
+
CachedFileInfo(...),
|
381 |
+
...
|
382 |
+
}),
|
383 |
+
),
|
384 |
+
CachedRevisionInfo(...),
|
385 |
+
...
|
386 |
+
}),
|
387 |
+
),
|
388 |
+
CachedRepoInfo(...),
|
389 |
+
...
|
390 |
+
}),
|
391 |
+
warnings=\[
|
392 |
+
CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."),
|
393 |
+
CorruptedCacheException(...),
|
394 |
+
...
|
395 |
+
\],
|
396 |
+
)
|
397 |
+
|
398 |
+
### [](#clean-your-cache)Clean your cache
|
399 |
+
|
400 |
+
Scanning your cache is interesting but what you really want to do next is usually to delete some portions to free up some space on your drive. This is possible using the `delete-cache` CLI command. One can also programmatically use the [delete\_revisions()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo.delete_revisions) helper from [HFCacheInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo) object returned when scanning the cache.
|
401 |
+
|
402 |
+
**Delete strategy**
|
403 |
+
|
404 |
+
To delete some cache, you need to pass a list of revisions to delete. The tool will define a strategy to free up the space based on this list. It returns a [DeleteCacheStrategy](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.DeleteCacheStrategy) object that describes which files and folders will be deleted. The [DeleteCacheStrategy](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.DeleteCacheStrategy) allows give you how much space is expected to be freed. Once you agree with the deletion, you must execute it to make the deletion effective. In order to avoid discrepancies, you cannot edit a strategy object manually.
|
405 |
+
|
406 |
+
The strategy to delete revisions is the following:
|
407 |
+
|
408 |
+
* the `snapshot` folder containing the revision symlinks is deleted.
|
409 |
+
* blobs files that are targeted only by revisions to be deleted are deleted as well.
|
410 |
+
* if a revision is linked to 1 or more `refs`, references are deleted.
|
411 |
+
* if all revisions from a repo are deleted, the entire cached repository is deleted.
|
412 |
+
|
413 |
+
Revision hashes are unique across all repositories. This means you don’t need to provide any `repo_id` or `repo_type` when removing revisions.
|
414 |
+
|
415 |
+
If a revision is not found in the cache, it will be silently ignored. Besides, if a file or folder cannot be found while trying to delete it, a warning will be logged but no error is thrown. The deletion continues for other paths contained in the [DeleteCacheStrategy](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.DeleteCacheStrategy) object.
|
416 |
+
|
417 |
+
**Clean cache from the terminal**
|
418 |
+
|
419 |
+
The easiest way to delete some revisions from your HF cache-system is to use the `delete-cache` command from `huggingface-cli` tool. The command has two modes. By default, a TUI (Terminal User Interface) is displayed to the user to select which revisions to delete. This TUI is currently in beta as it has not been tested on all platforms. If the TUI doesn’t work on your machine, you can disable it using the `--disable-tui` flag.
|
420 |
+
|
421 |
+
**Using the TUI**
|
422 |
+
|
423 |
+
This is the default mode. To use it, you first need to install extra dependencies by running the following command:
|
424 |
+
|
425 |
+
Copied
|
426 |
+
|
427 |
+
pip install huggingface\_hub\["cli"\]
|
428 |
+
|
429 |
+
Then run the command:
|
430 |
+
|
431 |
+
Copied
|
432 |
+
|
433 |
+
huggingface-cli delete\-cache
|
434 |
+
|
435 |
+
You should now see a list of revisions that you can select/deselect:
|
436 |
+
|
437 |
+

|
438 |
+
|
439 |
+
Instructions:
|
440 |
+
|
441 |
+
* Press keyboard arrow keys `<up>` and `<down>` to move the cursor.
|
442 |
+
* Press `<space>` to toggle (select/unselect) an item.
|
443 |
+
* When a revision is selected, the first line is updated to show you how much space will be freed.
|
444 |
+
* Press `<enter>` to confirm your selection.
|
445 |
+
* If you want to cancel the operation and quit, you can select the first item (“None of the following”). If this item is selected, the delete process will be cancelled, no matter what other items are selected. Otherwise you can also press `<ctrl+c>` to quit the TUI.
|
446 |
+
|
447 |
+
Once you’ve selected the revisions you want to delete and pressed `<enter>`, a last confirmation message will be prompted. Press `<enter>` again and the deletion will be effective. If you want to cancel, enter `n`.
|
448 |
+
|
449 |
+
Copied
|
450 |
+
|
451 |
+
✗ huggingface-cli delete-cache --dir ~/.cache/huggingface/hub
|
452 |
+
? Select revisions to delete: 2 revision(s) selected.
|
453 |
+
? 2 revisions selected counting for 3.1G. Confirm deletion ? Yes
|
454 |
+
Start deletion.
|
455 |
+
Done. Deleted 1 repo(s) and 0 revision(s) for a total of 3.1G.
|
456 |
+
|
457 |
+
**Without TUI**
|
458 |
+
|
459 |
+
As mentioned above, the TUI mode is currently in beta and is optional. It may be the case that it doesn’t work on your machine or that you don’t find it convenient.
|
460 |
+
|
461 |
+
Another approach is to use the `--disable-tui` flag. The process is very similar as you will be asked to manually review the list of revisions to delete. However, this manual step will not take place in the terminal directly but in a temporary file generated on the fly and that you can manually edit.
|
462 |
+
|
463 |
+
This file has all the instructions you need in the header. Open it in your favorite text editor. To select/deselect a revision, simply comment/uncomment it with a `#`. Once the manual review is done and the file is edited, you can save it. Go back to your terminal and press `<enter>`. By default it will compute how much space would be freed with the updated list of revisions. You can continue to edit the file or confirm with `"y"`.
|
464 |
+
|
465 |
+
Copied
|
466 |
+
|
467 |
+
huggingface-cli delete-cache --disable-tui
|
468 |
+
|
469 |
+
Example of command file:
|
470 |
+
|
471 |
+
Copied
|
472 |
+
|
473 |
+
\# INSTRUCTIONS
|
474 |
+
# ------------
|
475 |
+
# This is a temporary file created by running \`huggingface-cli delete-cache\` with the
|
476 |
+
# \`--disable-tui\` option. It contains a set of revisions that can be deleted from your
|
477 |
+
# local cache directory.
|
478 |
+
#
|
479 |
+
# Please manually review the revisions you want to delete:
|
480 |
+
# - Revision hashes can be commented out with '#'.
|
481 |
+
# - Only non-commented revisions in this file will be deleted.
|
482 |
+
# - Revision hashes that are removed from this file are ignored as well.
|
483 |
+
# - If \`CANCEL\_DELETION\` line is uncommented, the all cache deletion is cancelled and
|
484 |
+
# no changes will be applied.
|
485 |
+
#
|
486 |
+
# Once you've manually reviewed this file, please confirm deletion in the terminal. This
|
487 |
+
# file will be automatically removed once done.
|
488 |
+
# ------------
|
489 |
+
|
490 |
+
# KILL SWITCH
|
491 |
+
# ------------
|
492 |
+
# Un-comment following line to completely cancel the deletion process
|
493 |
+
# CANCEL\_DELETION
|
494 |
+
# ------------
|
495 |
+
|
496 |
+
# REVISIONS
|
497 |
+
# ------------
|
498 |
+
# Dataset chrisjay/crowd-speech-africa (761.7M, used 5 days ago)
|
499 |
+
ebedcd8c55c90d39fd27126d29d8484566cd27ca # Refs: main # modified 5 days ago
|
500 |
+
|
501 |
+
# Dataset oscar (3.3M, used 4 days ago)
|
502 |
+
# 916f956518279c5e60c63902ebdf3ddf9fa9d629 # Refs: main # modified 4 days ago
|
503 |
+
|
504 |
+
# Dataset wikiann (804.1K, used 2 weeks ago)
|
505 |
+
89d089624b6323d69dcd9e5eb2def0551887a73a # Refs: main # modified 2 weeks ago
|
506 |
+
|
507 |
+
# Dataset z-uo/male-LJSpeech-italian (5.5G, used 5 days ago)
|
508 |
+
# 9cfa5647b32c0a30d0adfca06bf198d82192a0d1 # Refs: main # modified 5 days ago
|
509 |
+
|
510 |
+
**Clean cache from Python**
|
511 |
+
|
512 |
+
For more flexibility, you can also use the [delete\_revisions()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo.delete_revisions) method programmatically. Here is a simple example. See reference for details.
|
513 |
+
|
514 |
+
Copied
|
515 |
+
|
516 |
+
\>>> from huggingface\_hub import scan\_cache\_dir
|
517 |
+
|
518 |
+
\>>> delete\_strategy = scan\_cache\_dir().delete\_revisions(
|
519 |
+
... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa"
|
520 |
+
... "e2983b237dccf3ab4937c97fa717319a9ca1a96d",
|
521 |
+
... "6c0e6080953db56375760c0471a8c5f2929baf11",
|
522 |
+
... )
|
523 |
+
\>>> print("Will free " + delete\_strategy.expected\_freed\_size\_str)
|
524 |
+
Will free 8.6G
|
525 |
+
|
526 |
+
\>>> delete\_strategy.execute()
|
527 |
+
Cache deletion done. Saved 8.6G.
|
528 |
+
|
529 |
+
[< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/guides/manage-cache.md)
|
530 |
+
|
531 |
+
Create and manage a repository
|
hymm_sp/__init__.py
ADDED
File without changes
|
hymm_sp/config.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from hymm_sp.constants import *
|
3 |
+
import re
|
4 |
+
import collections.abc
|
5 |
+
|
6 |
+
def as_tuple(x):
|
7 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
8 |
+
return tuple(x)
|
9 |
+
if x is None or isinstance(x, (int, float, str)):
|
10 |
+
return (x,)
|
11 |
+
else:
|
12 |
+
raise ValueError(f"Unknown type {type(x)}")
|
13 |
+
|
14 |
+
def parse_args(namespace=None):
|
15 |
+
parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script")
|
16 |
+
parser = add_extra_args(parser)
|
17 |
+
args = parser.parse_args(namespace=namespace)
|
18 |
+
args = sanity_check_args(args)
|
19 |
+
return args
|
20 |
+
|
21 |
+
def add_extra_args(parser: argparse.ArgumentParser):
|
22 |
+
parser = add_network_args(parser)
|
23 |
+
parser = add_extra_models_args(parser)
|
24 |
+
parser = add_denoise_schedule_args(parser)
|
25 |
+
parser = add_evaluation_args(parser)
|
26 |
+
parser = add_test_args(parser)
|
27 |
+
return parser
|
28 |
+
|
29 |
+
def add_test_args(parser: argparse.ArgumentParser):
|
30 |
+
group = parser.add_argument_group(title="Test")
|
31 |
+
|
32 |
+
group.add_argument("--image-start", action="store_true", help="Use one image from video for training")
|
33 |
+
group.add_argument("--use-csv-pose", action="store_true", help="Use one image from video for training")
|
34 |
+
group.add_argument("--add-button", action="store_true", help="Use one image from video for training")
|
35 |
+
group.add_argument("--action-list", type=str, nargs='+', default=None, help="CSV file for evaluation.")
|
36 |
+
group.add_argument("--action-speed-list", type=float, nargs='+', default=None, help="CSV file for evaluation.")
|
37 |
+
group.add_argument("--pose", type=str, default=None, help="CSV file for evaluation.")
|
38 |
+
|
39 |
+
return parser
|
40 |
+
|
41 |
+
|
42 |
+
def add_network_args(parser: argparse.ArgumentParser):
|
43 |
+
group = parser.add_argument_group(title="Network")
|
44 |
+
group.add_argument("--model", type=str, default="HYVideo-T/2",
|
45 |
+
help="Model architecture to use. It it also used to determine the experiment directory.")
|
46 |
+
group.add_argument("--latent-channels", type=str, default=None,
|
47 |
+
help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
|
48 |
+
"it still needs to match the latent channels of the VAE model.")
|
49 |
+
group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
|
50 |
+
return parser
|
51 |
+
|
52 |
+
def add_extra_models_args(parser: argparse.ArgumentParser):
|
53 |
+
group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)")
|
54 |
+
|
55 |
+
# VAE
|
56 |
+
group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.")
|
57 |
+
group.add_argument("--vae-precision", type=str, default="fp16",
|
58 |
+
help="Precision mode for the VAE model.")
|
59 |
+
group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.")
|
60 |
+
group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH),
|
61 |
+
help="Name of the text encoder model.")
|
62 |
+
group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS,
|
63 |
+
help="Precision mode for the text encoder model.")
|
64 |
+
group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.")
|
65 |
+
group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
|
66 |
+
group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH),
|
67 |
+
help="Name of the tokenizer model.")
|
68 |
+
group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"],
|
69 |
+
help="Inference mode for the text encoder model. It should match the text encoder type. T5 and "
|
70 |
+
"CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.")
|
71 |
+
|
72 |
+
group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE,
|
73 |
+
help="Video prompt template for the decoder-only text encoder model.")
|
74 |
+
group.add_argument("--hidden-state-skip-layer", type=int, default=2,
|
75 |
+
help="Skip layer for hidden states.")
|
76 |
+
group.add_argument("--apply-final-norm", action="store_true",
|
77 |
+
help="Apply final normalization to the used text encoder hidden states.")
|
78 |
+
|
79 |
+
# - CLIP
|
80 |
+
group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH),
|
81 |
+
help="Name of the second text encoder model.")
|
82 |
+
group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS,
|
83 |
+
help="Precision mode for the second text encoder model.")
|
84 |
+
group.add_argument("--text-states-dim-2", type=int, default=768,
|
85 |
+
help="Dimension of the second text encoder hidden states.")
|
86 |
+
group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH),
|
87 |
+
help="Name of the second tokenizer model.")
|
88 |
+
group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.")
|
89 |
+
group.set_defaults(use_attention_mask=True)
|
90 |
+
group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION,
|
91 |
+
help="A projection layer for bridging the text encoder hidden states and the diffusion model "
|
92 |
+
"conditions.")
|
93 |
+
return parser
|
94 |
+
|
95 |
+
|
96 |
+
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
|
97 |
+
group = parser.add_argument_group(title="Denoise schedule")
|
98 |
+
group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.")
|
99 |
+
group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.")
|
100 |
+
group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
|
101 |
+
group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching."
|
102 |
+
"Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)")
|
103 |
+
group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.")
|
104 |
+
return parser
|
105 |
+
|
106 |
+
def add_evaluation_args(parser: argparse.ArgumentParser):
|
107 |
+
group = parser.add_argument_group(title="Validation Loss Evaluation")
|
108 |
+
parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS,
|
109 |
+
help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.")
|
110 |
+
parser.add_argument("--reproduce", action="store_true",
|
111 |
+
help="Enable reproducibility by setting random seeds and deterministic algorithms.")
|
112 |
+
parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.")
|
113 |
+
parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"],
|
114 |
+
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.")
|
115 |
+
parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.")
|
116 |
+
group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.")
|
117 |
+
group.add_argument("--video-size", type=int, nargs='+', default=512,
|
118 |
+
help="Video size for training. If a single value is provided, it will be used for both width "
|
119 |
+
"and height. If two values are provided, they will be used for width and height "
|
120 |
+
"respectively.")
|
121 |
+
group.add_argument("--sample-n-frames", type=int, default=33,
|
122 |
+
help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1")
|
123 |
+
group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.")
|
124 |
+
group.add_argument("--val-disable-autocast", action="store_true",
|
125 |
+
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.")
|
126 |
+
group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.")
|
127 |
+
group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.")
|
128 |
+
group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.")
|
129 |
+
group.add_argument("--prompt", type=str, default='', help="Main prompt")
|
130 |
+
group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.")
|
131 |
+
group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.")
|
132 |
+
group.add_argument("--add-pos-prompt", type=str, default='', help="Addition prompt for sampling during evaluation.")
|
133 |
+
group.add_argument("--add-neg-prompt", type=str, default='', help="Addition negative prompt for sampling during evaluation.")
|
134 |
+
group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.")
|
135 |
+
group.add_argument("--image-path", type=str, default="", help="")
|
136 |
+
group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.")
|
137 |
+
group.add_argument("--input", type=str, default=None, help="test data.")
|
138 |
+
group.add_argument("--item-name", type=str, default=None, help="")
|
139 |
+
group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.")
|
140 |
+
group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.")
|
141 |
+
group.add_argument("--use-deepcache", type=int, default=1)
|
142 |
+
group.add_argument("--use-sage", action="store_true", help="Use sage attention for speed up.")
|
143 |
+
|
144 |
+
return parser
|
145 |
+
|
146 |
+
def sanity_check_args(args):
|
147 |
+
# VAE channels
|
148 |
+
vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
|
149 |
+
if not re.match(vae_pattern, args.vae):
|
150 |
+
raise ValueError(
|
151 |
+
f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
|
152 |
+
)
|
153 |
+
vae_channels = int(args.vae.split("-")[1][:-1])
|
154 |
+
if args.latent_channels is None:
|
155 |
+
args.latent_channels = vae_channels
|
156 |
+
if vae_channels != args.latent_channels:
|
157 |
+
raise ValueError(
|
158 |
+
f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
|
159 |
+
)
|
160 |
+
return args
|
hymm_sp/constants.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
"PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE",
|
6 |
+
"PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH",
|
7 |
+
"TEXT_PROJECTION",
|
8 |
+
]
|
9 |
+
|
10 |
+
# =================== Constant Values =====================
|
11 |
+
|
12 |
+
PRECISION_TO_TYPE = {
|
13 |
+
'fp32': torch.float32,
|
14 |
+
'fp16': torch.float16,
|
15 |
+
'bf16': torch.bfloat16,
|
16 |
+
}
|
17 |
+
|
18 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
19 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
20 |
+
"1. The main content and theme of the video."
|
21 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
22 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
23 |
+
"4. background environment, light, style and atmosphere."
|
24 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
25 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
26 |
+
)
|
27 |
+
|
28 |
+
PROMPT_TEMPLATE = {
|
29 |
+
"li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95},
|
30 |
+
}
|
31 |
+
|
32 |
+
# ======================= Model ======================
|
33 |
+
PRECISIONS = {"fp32", "fp16", "bf16"}
|
34 |
+
|
35 |
+
# =================== Model Path =====================
|
36 |
+
MODEL_BASE = os.getenv("MODEL_BASE")
|
37 |
+
|
38 |
+
# 3D VAE
|
39 |
+
VAE_PATH = {
|
40 |
+
"884-16c-hy0801": f"{MODEL_BASE}/vae_3d/hyvae",
|
41 |
+
}
|
42 |
+
|
43 |
+
# Text Encoder
|
44 |
+
TEXT_ENCODER_PATH = {
|
45 |
+
"clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14",
|
46 |
+
"llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1-transformers",
|
47 |
+
}
|
48 |
+
|
49 |
+
# Tokenizer
|
50 |
+
TOKENIZER_PATH = {
|
51 |
+
"clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14",
|
52 |
+
"llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1-transformers",
|
53 |
+
}
|
54 |
+
|
55 |
+
TEXT_PROJECTION = {
|
56 |
+
"linear", # Default, an nn.Linear() layer
|
57 |
+
"single_refiner", # Single TokenRefiner. Refer to LI-DiT
|
58 |
+
}
|
hymm_sp/data_kits/data_tools.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import imageio
|
6 |
+
import torchvision
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
|
10 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
|
11 |
+
"""
|
12 |
+
Saves a batch of videos as a grid animation in GIF or video format.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
videos (torch.Tensor): Input video tensor with shape (batch, channels, time, height, width)
|
16 |
+
path (str): Output file path (e.g., "output/videos.gif")
|
17 |
+
rescale (bool): If True, rescales video values from [-1, 1] to [0, 1]
|
18 |
+
n_rows (int): Number of rows in the grid layout
|
19 |
+
fps (int): Frames per second for the output animation
|
20 |
+
quality (int): Quality parameter for imageio (1-10, higher = better quality)
|
21 |
+
|
22 |
+
Process:
|
23 |
+
1. Rearranges tensor dimensions to (time, batch, channels, height, width)
|
24 |
+
2. For each frame in time:
|
25 |
+
a. Creates a grid of videos using torchvision.utils.make_grid
|
26 |
+
b. Adjusts dimensions to (height, width, channels)
|
27 |
+
c. Rescales values if needed
|
28 |
+
d. Converts to 8-bit uint8 format (0-255)
|
29 |
+
3. Saves frames as an animated GIF/video using imageio
|
30 |
+
"""
|
31 |
+
# Rearrange dimensions to (time, batch, channels, height, width) for frame-wise processing
|
32 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
33 |
+
outputs = [] # Stores processed frames for animation
|
34 |
+
|
35 |
+
for frame in videos:
|
36 |
+
# Create a grid of videos with n_rows rows
|
37 |
+
grid = torchvision.utils.make_grid(frame, nrow=n_rows)
|
38 |
+
|
39 |
+
# Convert from (channels, height, width) to (height, width, channels)
|
40 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
41 |
+
|
42 |
+
# Rescale from [-1, 1] to [0, 1] if needed (common in GAN outputs)
|
43 |
+
if rescale:
|
44 |
+
grid = (grid + 1.0) / 2.0
|
45 |
+
|
46 |
+
# Clamp values to valid range [0, 1] and convert to 8-bit uint8 (0-255)
|
47 |
+
grid = torch.clamp(grid, 0, 1)
|
48 |
+
grid_np = (grid * 255).numpy().astype(np.uint8)
|
49 |
+
|
50 |
+
outputs.append(grid_np)
|
51 |
+
|
52 |
+
# Create output directory if it doesn't exist
|
53 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
54 |
+
|
55 |
+
# Save frames as an animated GIF/video
|
56 |
+
imageio.mimsave(path, outputs, fps=fps, quality=quality)
|
57 |
+
|
58 |
+
|
59 |
+
def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
|
60 |
+
"""
|
61 |
+
Resizes and pads an image to fit a target size while preserving aspect ratio.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
crop_img (np.ndarray): Input image (shape: [height, width, channels])
|
65 |
+
size (tuple): Target size in (width, height) format
|
66 |
+
color (tuple): RGB color for padding (default: white)
|
67 |
+
resize_ratio (float): Scaling factor for resizing before padding (0-1)
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
np.ndarray: Padded image with shape (target_height, target_width, channels)
|
71 |
+
|
72 |
+
Process:
|
73 |
+
1. Calculates scaling factors to fit image within target size
|
74 |
+
2. Resizes image while preserving aspect ratio
|
75 |
+
3. Adds padding to reach exact target size, centering the resized image
|
76 |
+
"""
|
77 |
+
# Get input image dimensions
|
78 |
+
crop_h, crop_w = crop_img.shape[:2]
|
79 |
+
target_w, target_h = size # Target dimensions (width, height)
|
80 |
+
|
81 |
+
# Calculate scaling factors to fit image within target size
|
82 |
+
scale_h = target_h / crop_h # Scale needed to fit height
|
83 |
+
scale_w = target_w / crop_w # Scale needed to fit width
|
84 |
+
|
85 |
+
# Choose the smaller scale to avoid exceeding target dimensions
|
86 |
+
if scale_w > scale_h:
|
87 |
+
# Height is the limiting factor: resize based on height
|
88 |
+
resize_h = int(target_h * resize_ratio)
|
89 |
+
resize_w = int(crop_w / crop_h * resize_h) # Preserve aspect ratio
|
90 |
+
else:
|
91 |
+
# Width is the limiting factor: resize based on width
|
92 |
+
resize_w = int(target_w * resize_ratio)
|
93 |
+
resize_h = int(crop_h / crop_w * resize_w) # Preserve aspect ratio
|
94 |
+
|
95 |
+
# Resize the image using OpenCV
|
96 |
+
resized_img = cv2.resize(crop_img, (resize_w, resize_h))
|
97 |
+
|
98 |
+
# Calculate padding needed to reach target size (centered)
|
99 |
+
pad_left = (target_w - resize_w) // 2
|
100 |
+
pad_top = (target_h - resize_h) // 2
|
101 |
+
pad_right = target_w - resize_w - pad_left # Ensure total width matches target
|
102 |
+
pad_bottom = target_h - resize_h - pad_top # Ensure total height matches target
|
103 |
+
|
104 |
+
# Add padding with the specified color
|
105 |
+
padded_img = cv2.copyMakeBorder(
|
106 |
+
resized_img,
|
107 |
+
top=pad_top,
|
108 |
+
bottom=pad_bottom,
|
109 |
+
left=pad_left,
|
110 |
+
right=pad_right,
|
111 |
+
borderType=cv2.BORDER_CONSTANT,
|
112 |
+
value=color
|
113 |
+
)
|
114 |
+
|
115 |
+
return padded_img
|
hymm_sp/data_kits/video_dataset.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from PIL import Image
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
import csv
|
11 |
+
|
12 |
+
|
13 |
+
def fix_nulls(s):
|
14 |
+
"""
|
15 |
+
Helper generator to remove null characters from input lines.
|
16 |
+
Prevents parsing errors caused by invalid null bytes in CSV/JSON files.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
s: Input iterable containing strings with potential null characters
|
20 |
+
|
21 |
+
Yields:
|
22 |
+
Strings with null characters replaced by spaces
|
23 |
+
"""
|
24 |
+
for line in s:
|
25 |
+
yield line.replace('\0', ' ')
|
26 |
+
|
27 |
+
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
|
28 |
+
"""
|
29 |
+
Find the closest aspect ratio from predefined buckets
|
30 |
+
|
31 |
+
Args:
|
32 |
+
height: Image height
|
33 |
+
width: Image width
|
34 |
+
ratios: List of predefined aspect ratios to match against
|
35 |
+
buckets: List of size tuples corresponding to ratios
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
Tuple containing:
|
39 |
+
- Closest matching size bucket
|
40 |
+
- Closest ratio value
|
41 |
+
"""
|
42 |
+
aspect_ratio = float(height) / float(width)
|
43 |
+
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
44 |
+
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
45 |
+
return buckets[closest_ratio_id], float(closest_ratio)
|
46 |
+
|
47 |
+
|
48 |
+
def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
|
49 |
+
"""
|
50 |
+
Generate valid crop sizes that maintain compatible dimensions with model patches
|
51 |
+
|
52 |
+
Args:
|
53 |
+
base_size: Base dimension for calculating patch count
|
54 |
+
patch_size: Size of model's input patches
|
55 |
+
max_ratio: Maximum allowed aspect ratio (height/width)
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
List of (width, height) tuples representing valid crop sizes
|
59 |
+
"""
|
60 |
+
# Calculate total number of patches from base size
|
61 |
+
num_patches = round((base_size / patch_size) ** 2)
|
62 |
+
assert max_ratio >= 1.0, "Maximum ratio must be at least 1.0"
|
63 |
+
|
64 |
+
crop_size_list = []
|
65 |
+
wp, hp = num_patches, 1 # Initialize with maximum width patches
|
66 |
+
|
67 |
+
# Generate valid patch combinations
|
68 |
+
while wp > 0:
|
69 |
+
# Only add sizes that maintain acceptable aspect ratio
|
70 |
+
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
71 |
+
crop_size_list.append((wp * patch_size, hp * patch_size))
|
72 |
+
|
73 |
+
# Move to next valid patch configuration
|
74 |
+
if (hp + 1) * wp <= num_patches:
|
75 |
+
hp += 1
|
76 |
+
else:
|
77 |
+
wp -= 1
|
78 |
+
return crop_size_list
|
79 |
+
|
80 |
+
|
81 |
+
class VideoCSVDataset(Dataset):
|
82 |
+
"""
|
83 |
+
Dataset class for loading video generation data from CSV files
|
84 |
+
|
85 |
+
Handles:
|
86 |
+
- CSV parsing with null character handling
|
87 |
+
- Loading prompt and metadata
|
88 |
+
- Supporting multiple task types (image-to-video, etc.)
|
89 |
+
"""
|
90 |
+
def __init__(self, csv_path, col_name='prompt', task_type=''):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
csv_path: Path to CSV file containing dataset metadata
|
94 |
+
col_name: Column name containing generation prompts
|
95 |
+
task_type: Type of task (e.g., "i2v" for image-to-video)
|
96 |
+
"""
|
97 |
+
# Read CSV with null character handling
|
98 |
+
with open(csv_path, 'r', newline="\n", encoding='utf-8-sig') as csvfile:
|
99 |
+
self.dataset = list(csv.DictReader(fix_nulls(csvfile), delimiter=';'))
|
100 |
+
|
101 |
+
self.col_name = col_name
|
102 |
+
self.task_type = task_type
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
"""Return total number of samples in dataset"""
|
106 |
+
return len(self.dataset)
|
107 |
+
|
108 |
+
def __getitem__(self, idx):
|
109 |
+
"""
|
110 |
+
Get dataset item by index
|
111 |
+
|
112 |
+
Args:
|
113 |
+
idx: Index of sample to retrieve
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
Dictionary containing:
|
117 |
+
- Prompt and metadata
|
118 |
+
- Paths to auxiliary files (npy, video, poses)
|
119 |
+
- Index for tracking outputs
|
120 |
+
"""
|
121 |
+
example = {}
|
122 |
+
example["prompt"] = self.dataset[idx][self.col_name]
|
123 |
+
example['seed'] = int(self.dataset[idx]['seed'])
|
124 |
+
example['index'] = self.dataset[idx]['index']
|
125 |
+
|
126 |
+
# Add optional auxiliary paths if present in CSV
|
127 |
+
if "npy_path" in self.dataset[idx]:
|
128 |
+
example['npy_path'] = self.dataset[idx]['npy_path']
|
129 |
+
if "video_path" in self.dataset[idx]:
|
130 |
+
example['video_path'] = self.dataset[idx]['video_path']
|
131 |
+
if "monst3r_poses" in self.dataset[idx]:
|
132 |
+
example['monst3r_poses'] = self.dataset[idx]['monst3r_poses']
|
133 |
+
|
134 |
+
# Add image reference path for image-to-video tasks
|
135 |
+
if self.task_type == "i2v":
|
136 |
+
example['ref_image'] = self.dataset[idx]['ref_image_path']
|
137 |
+
|
138 |
+
return example
|
139 |
+
|
140 |
+
|
141 |
+
class JsonDataset(object):
|
142 |
+
"""
|
143 |
+
Dataset class for loading data from JSON files and image sequences
|
144 |
+
|
145 |
+
Handles:
|
146 |
+
- Reading image data from multiple formats
|
147 |
+
- Preprocessing for model compatibility
|
148 |
+
- Generating conditional and unconditional inputs
|
149 |
+
"""
|
150 |
+
def __init__(self, args):
|
151 |
+
"""
|
152 |
+
Args:
|
153 |
+
args: Command-line arguments containing configuration
|
154 |
+
"""
|
155 |
+
self.args = args
|
156 |
+
self.data_list = args.input
|
157 |
+
self.pad_color = (255, 255, 255) # White padding
|
158 |
+
self.llava_size = (336, 336) # Standard size for LLaVA model
|
159 |
+
self.ref_size = (args.video_size[1], args.video_size[0]) # Reference output size
|
160 |
+
|
161 |
+
# Get list of data paths from input list or single file
|
162 |
+
if self.data_list.endswith('.list'):
|
163 |
+
self.data_paths = [line.strip() for line in open(self.data_list, 'r')] if self.data_list else []
|
164 |
+
else:
|
165 |
+
self.data_paths = [self.data_list]
|
166 |
+
|
167 |
+
# Transformation pipeline for LLaVA model input
|
168 |
+
self.llava_transform = transforms.Compose(
|
169 |
+
[
|
170 |
+
transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
|
171 |
+
transforms.ToTensor(),
|
172 |
+
transforms.Normalize(
|
173 |
+
(0.48145466, 0.4578275, 0.4082107),
|
174 |
+
(0.26862954, 0.26130258, 0.27577711)
|
175 |
+
),
|
176 |
+
]
|
177 |
+
)
|
178 |
+
|
179 |
+
def __len__(self):
|
180 |
+
"""Return total number of data items"""
|
181 |
+
return len(self.data_paths)
|
182 |
+
|
183 |
+
def read_image(self, image_path):
|
184 |
+
"""
|
185 |
+
Read image from path with fallback handling
|
186 |
+
|
187 |
+
Args:
|
188 |
+
image_path: Path to image file or dictionary containing path
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
Tuple of (LLaVA-formatted image, reference-sized image)
|
192 |
+
"""
|
193 |
+
# Extract path from dictionary if needed
|
194 |
+
if isinstance(image_path, dict):
|
195 |
+
image_path = image_path['seg_item_image_path']
|
196 |
+
|
197 |
+
try:
|
198 |
+
# Primary method: OpenCV for faster reading
|
199 |
+
face_image_masked = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
|
200 |
+
except:
|
201 |
+
# Fallback: PIL for special formats
|
202 |
+
face_image_masked = Image.open(image_path).convert('RGB')
|
203 |
+
|
204 |
+
# Prepare images for different processing stages
|
205 |
+
cat_face_image = pad_image(face_image_masked.copy(), self.ref_size)
|
206 |
+
llava_face_image = pad_image(face_image_masked.copy(), self.llava_size)
|
207 |
+
return llava_face_image, cat_face_image
|
208 |
+
|
209 |
+
def __getitem__(self, idx):
|
210 |
+
"""
|
211 |
+
Get preprocessed data item by index
|
212 |
+
|
213 |
+
Args:
|
214 |
+
idx: Index of item to retrieve
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
Dictionary containing:
|
218 |
+
- Preprocessed tensors for model input
|
219 |
+
- Metadata (prompt, index, paths)
|
220 |
+
"""
|
221 |
+
data_path = self.data_paths[idx]
|
222 |
+
data_name = os.path.basename(os.path.splitext(data_path)[0])
|
223 |
+
|
224 |
+
# Load data from JSON or use default parameters
|
225 |
+
if data_path.endswith('.json'):
|
226 |
+
data = json.load(open(data_path, 'r'))
|
227 |
+
llava_item_image, cat_item_image = self.read_image(data)
|
228 |
+
item_prompt = data['item_prompt']
|
229 |
+
seed = data['seed']
|
230 |
+
prompt = data['prompt']
|
231 |
+
negative_prompt = data.get('negative_prompt', '') # Default to empty string
|
232 |
+
else:
|
233 |
+
# Handle non-JSON data (direct image files)
|
234 |
+
llava_item_image, cat_item_image = self.read_image(data_path)
|
235 |
+
item_prompt = 'object'
|
236 |
+
seed = self.args.seed
|
237 |
+
prompt = self.args.pos_prompt
|
238 |
+
negative_prompt = self.args.neg_prompt
|
239 |
+
|
240 |
+
# Convert to tensors with appropriate transformations
|
241 |
+
llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
|
242 |
+
cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 # Normalize to [0,1]
|
243 |
+
|
244 |
+
# Create unconditional input (white background)
|
245 |
+
uncond_llava_item_image = np.ones_like(llava_item_image) * 255
|
246 |
+
uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
|
247 |
+
|
248 |
+
# Assemble final batch dictionary
|
249 |
+
return {
|
250 |
+
"pixel_value_llava": llava_item_tensor,
|
251 |
+
"uncond_pixel_value_llava": uncond_llava_item_tensor,
|
252 |
+
"pixel_value_ref": cat_item_tensor,
|
253 |
+
"prompt": prompt,
|
254 |
+
"negative_prompt": negative_prompt,
|
255 |
+
"seed": seed,
|
256 |
+
"name": item_prompt,
|
257 |
+
'data_name': data_name,
|
258 |
+
'index': [idx] # Index for output tracking
|
259 |
+
}
|
hymm_sp/diffusion/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pipelines import HunyuanVideoGamePipeline
|
2 |
+
from .schedulers import FlowMatchDiscreteScheduler
|
3 |
+
|
4 |
+
def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None,
|
5 |
+
device=None, progress_bar_config=None):
|
6 |
+
""" Load the denoising scheduler for inference. """
|
7 |
+
if scheduler is None:
|
8 |
+
scheduler = FlowMatchDiscreteScheduler(
|
9 |
+
shift=args.flow_shift_eval_video,
|
10 |
+
reverse=args.flow_reverse,
|
11 |
+
solver=args.flow_solver,
|
12 |
+
)
|
13 |
+
# Only enable progress bar for rank 0
|
14 |
+
progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0}
|
15 |
+
|
16 |
+
pipeline = HunyuanVideoGamePipeline(vae=vae,
|
17 |
+
text_encoder=text_encoder,
|
18 |
+
text_encoder_2=text_encoder_2,
|
19 |
+
transformer=model,
|
20 |
+
scheduler=scheduler,
|
21 |
+
# safety_checker=None,
|
22 |
+
# feature_extractor=None,
|
23 |
+
# requires_safety_checker=False,
|
24 |
+
progress_bar_config=progress_bar_config,
|
25 |
+
args=args,
|
26 |
+
)
|
27 |
+
if not args.cpu_offload:
|
28 |
+
pipeline = pipeline.to(device)
|
29 |
+
|
30 |
+
return pipeline
|
hymm_sp/diffusion/pipelines/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import the HunyuanVideoGamePipeline class from the current package
|
2 |
+
# This pipeline is specifically designed for handling video game content generation
|
3 |
+
# using the Hunyuan model architecture, providing specialized functionality
|
4 |
+
# for game-related video synthesis, character animation, and environment rendering.
|
5 |
+
from .pipeline_hunyuan_video_game import HunyuanVideoGamePipeline
|
hymm_sp/diffusion/pipelines/pipeline_hunyuan_video_game.py
ADDED
@@ -0,0 +1,1152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
import inspect
|
20 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from packaging import version
|
24 |
+
from diffusers.utils import BaseOutput
|
25 |
+
from dataclasses import dataclass
|
26 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
27 |
+
from diffusers.configuration_utils import FrozenDict
|
28 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
29 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
30 |
+
from diffusers.models import AutoencoderKL, ImageProjection
|
31 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
32 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
33 |
+
from diffusers.utils import (
|
34 |
+
USE_PEFT_BACKEND,
|
35 |
+
deprecate,
|
36 |
+
logging,
|
37 |
+
replace_example_docstring,
|
38 |
+
scale_lora_layers,
|
39 |
+
unscale_lora_layers,
|
40 |
+
)
|
41 |
+
from diffusers.utils.torch_utils import randn_tensor
|
42 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
43 |
+
|
44 |
+
from hymm_sp.constants import PRECISION_TO_TYPE
|
45 |
+
from hymm_sp.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
46 |
+
from hymm_sp.text_encoder import TextEncoder
|
47 |
+
from einops import rearrange
|
48 |
+
from ...modules import HYVideoDiffusionTransformer
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
51 |
+
|
52 |
+
EXAMPLE_DOC_STRING = """"""
|
53 |
+
|
54 |
+
|
55 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
56 |
+
"""
|
57 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
58 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
59 |
+
"""
|
60 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
61 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
62 |
+
# rescale the results from guidance (fixes overexposure)
|
63 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
64 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
65 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
66 |
+
return noise_cfg
|
67 |
+
|
68 |
+
|
69 |
+
def retrieve_timesteps(
|
70 |
+
scheduler,
|
71 |
+
num_inference_steps: Optional[int] = None,
|
72 |
+
device: Optional[Union[str, torch.device]] = None,
|
73 |
+
timesteps: Optional[List[int]] = None,
|
74 |
+
sigmas: Optional[List[float]] = None,
|
75 |
+
**kwargs,
|
76 |
+
):
|
77 |
+
"""
|
78 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
79 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
scheduler (`SchedulerMixin`):
|
83 |
+
The scheduler to get timesteps from.
|
84 |
+
num_inference_steps (`int`):
|
85 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
86 |
+
must be `None`.
|
87 |
+
device (`str` or `torch.device`, *optional*):
|
88 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
89 |
+
timesteps (`List[int]`, *optional*):
|
90 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
91 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
92 |
+
sigmas (`List[float]`, *optional*):
|
93 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
94 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
98 |
+
second element is the number of inference steps.
|
99 |
+
"""
|
100 |
+
if timesteps is not None and sigmas is not None:
|
101 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
102 |
+
if timesteps is not None:
|
103 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
104 |
+
if not accepts_timesteps:
|
105 |
+
raise ValueError(
|
106 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
107 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
108 |
+
)
|
109 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
110 |
+
timesteps = scheduler.timesteps
|
111 |
+
num_inference_steps = len(timesteps)
|
112 |
+
elif sigmas is not None:
|
113 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
114 |
+
if not accept_sigmas:
|
115 |
+
raise ValueError(
|
116 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
117 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
118 |
+
)
|
119 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
120 |
+
timesteps = scheduler.timesteps
|
121 |
+
num_inference_steps = len(timesteps)
|
122 |
+
else:
|
123 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
124 |
+
timesteps = scheduler.timesteps
|
125 |
+
return timesteps, num_inference_steps
|
126 |
+
|
127 |
+
@dataclass
|
128 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
129 |
+
videos: Union[torch.Tensor, np.ndarray]
|
130 |
+
|
131 |
+
|
132 |
+
class HunyuanVideoGamePipeline(DiffusionPipeline):
|
133 |
+
r"""
|
134 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
135 |
+
|
136 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
137 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
138 |
+
|
139 |
+
Args:
|
140 |
+
vae ([`AutoencoderKL`]):
|
141 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
142 |
+
text_encoder ([`TextEncoder`]):
|
143 |
+
Frozen text-encoder.
|
144 |
+
text_encoder_2 ([`TextEncoder`]):
|
145 |
+
Frozen text-encoder_2.
|
146 |
+
transformer ([`HYVideoDiffusionTransformer`]):
|
147 |
+
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
|
148 |
+
scheduler ([`SchedulerMixin`]):
|
149 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
150 |
+
"""
|
151 |
+
|
152 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
153 |
+
_optional_components = ["text_encoder_2"]
|
154 |
+
_exclude_from_cpu_offload = ["transformer"]
|
155 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
156 |
+
|
157 |
+
def __init__(
|
158 |
+
self,
|
159 |
+
vae: AutoencoderKL,
|
160 |
+
text_encoder: TextEncoder,
|
161 |
+
transformer: HYVideoDiffusionTransformer,
|
162 |
+
scheduler: KarrasDiffusionSchedulers,
|
163 |
+
text_encoder_2: Optional[TextEncoder] = None,
|
164 |
+
progress_bar_config: Dict[str, Any] = None,
|
165 |
+
args=None,
|
166 |
+
):
|
167 |
+
super().__init__()
|
168 |
+
|
169 |
+
# ==========================================================================================
|
170 |
+
if progress_bar_config is None:
|
171 |
+
progress_bar_config = {}
|
172 |
+
if not hasattr(self, '_progress_bar_config'):
|
173 |
+
self._progress_bar_config = {}
|
174 |
+
self._progress_bar_config.update(progress_bar_config)
|
175 |
+
|
176 |
+
self.args = args
|
177 |
+
# ==========================================================================================
|
178 |
+
|
179 |
+
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
180 |
+
deprecation_message = (
|
181 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
182 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
183 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
184 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
185 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
186 |
+
" file"
|
187 |
+
)
|
188 |
+
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
189 |
+
new_config = dict(scheduler.config)
|
190 |
+
new_config["steps_offset"] = 1
|
191 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
192 |
+
|
193 |
+
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
194 |
+
deprecation_message = (
|
195 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
196 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
197 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
198 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
199 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
200 |
+
)
|
201 |
+
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
202 |
+
new_config = dict(scheduler.config)
|
203 |
+
new_config["clip_sample"] = False
|
204 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
205 |
+
|
206 |
+
self.register_modules(
|
207 |
+
vae=vae,
|
208 |
+
text_encoder=text_encoder,
|
209 |
+
transformer=transformer,
|
210 |
+
scheduler=scheduler,
|
211 |
+
text_encoder_2=text_encoder_2
|
212 |
+
)
|
213 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
214 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
215 |
+
|
216 |
+
def encode_prompt(
|
217 |
+
self,
|
218 |
+
prompt,
|
219 |
+
device,
|
220 |
+
num_videos_per_prompt,
|
221 |
+
do_classifier_free_guidance,
|
222 |
+
negative_prompt=None,
|
223 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
224 |
+
attention_mask: Optional[torch.Tensor] = None,
|
225 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
226 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
227 |
+
lora_scale: Optional[float] = None,
|
228 |
+
clip_skip: Optional[int] = None,
|
229 |
+
text_encoder: Optional[TextEncoder] = None,
|
230 |
+
data_type: Optional[str] = "image",
|
231 |
+
):
|
232 |
+
r"""
|
233 |
+
Encodes the prompt into text encoder hidden states.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
prompt (`str` or `List[str]`, *optional*):
|
237 |
+
prompt to be encoded
|
238 |
+
device: (`torch.device`):
|
239 |
+
torch device
|
240 |
+
num_videos_per_prompt (`int`):
|
241 |
+
number of images that should be generated per prompt
|
242 |
+
do_classifier_free_guidance (`bool`):
|
243 |
+
whether to use classifier free guidance or not
|
244 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
245 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
246 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
247 |
+
less than `1`).
|
248 |
+
pixel_value_llava (`torch.Tensor`, *optional*):
|
249 |
+
The image tensor for llava.
|
250 |
+
uncond_pixel_value_llava (`torch.Tensor`, *optional*):
|
251 |
+
The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
252 |
+
less than `1`).
|
253 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
254 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
255 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
256 |
+
attention_mask (`torch.Tensor`, *optional*):
|
257 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
258 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
259 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
260 |
+
argument.
|
261 |
+
negative_attention_mask (`torch.Tensor`, *optional*):
|
262 |
+
lora_scale (`float`, *optional*):
|
263 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
264 |
+
clip_skip (`int`, *optional*):
|
265 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
266 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
267 |
+
text_encoder (TextEncoder, *optional*):
|
268 |
+
"""
|
269 |
+
if text_encoder is None:
|
270 |
+
text_encoder = self.text_encoder
|
271 |
+
|
272 |
+
# set lora scale so that monkey patched LoRA
|
273 |
+
# function of text encoder can correctly access it
|
274 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
275 |
+
self._lora_scale = lora_scale
|
276 |
+
|
277 |
+
# dynamically adjust the LoRA scale
|
278 |
+
if not USE_PEFT_BACKEND:
|
279 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
280 |
+
else:
|
281 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
282 |
+
|
283 |
+
if prompt is not None and isinstance(prompt, str):
|
284 |
+
batch_size = 1
|
285 |
+
elif prompt is not None and isinstance(prompt, list):
|
286 |
+
batch_size = len(prompt)
|
287 |
+
else:
|
288 |
+
batch_size = prompt_embeds.shape[0]
|
289 |
+
|
290 |
+
if prompt_embeds is None:
|
291 |
+
# textual inversion: process multi-vector tokens if necessary
|
292 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
293 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
294 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
295 |
+
|
296 |
+
if clip_skip is None:
|
297 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
298 |
+
prompt_embeds = prompt_outputs.hidden_state
|
299 |
+
else:
|
300 |
+
prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
|
301 |
+
# Access the `hidden_states` first, that contains a tuple of
|
302 |
+
# all the hidden states from the encoder layers. Then index into
|
303 |
+
# the tuple to access the hidden states from the desired layer.
|
304 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
305 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
306 |
+
# representations. The `last_hidden_states` that we typically use for
|
307 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
308 |
+
# layer.
|
309 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
|
310 |
+
|
311 |
+
attention_mask = prompt_outputs.attention_mask
|
312 |
+
if attention_mask is not None:
|
313 |
+
attention_mask = attention_mask.to(device)
|
314 |
+
bs_embed, seq_len = attention_mask.shape
|
315 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
316 |
+
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
|
317 |
+
|
318 |
+
|
319 |
+
if text_encoder is not None:
|
320 |
+
prompt_embeds_dtype = text_encoder.dtype
|
321 |
+
elif self.transformer is not None:
|
322 |
+
prompt_embeds_dtype = self.transformer.dtype
|
323 |
+
else:
|
324 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
325 |
+
|
326 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
327 |
+
|
328 |
+
if prompt_embeds.ndim == 2:
|
329 |
+
bs_embed, _ = prompt_embeds.shape
|
330 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
331 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
332 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
333 |
+
else:
|
334 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
335 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
336 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
337 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
338 |
+
|
339 |
+
# get unconditional embeddings for classifier free guidance
|
340 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
341 |
+
uncond_tokens: List[str]
|
342 |
+
if negative_prompt is None:
|
343 |
+
uncond_tokens = [""] * batch_size
|
344 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
345 |
+
raise TypeError(
|
346 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
347 |
+
f" {type(prompt)}."
|
348 |
+
)
|
349 |
+
elif isinstance(negative_prompt, str):
|
350 |
+
uncond_tokens = [negative_prompt]
|
351 |
+
elif batch_size != len(negative_prompt):
|
352 |
+
raise ValueError(
|
353 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
354 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
355 |
+
" the batch size of `prompt`."
|
356 |
+
)
|
357 |
+
else:
|
358 |
+
uncond_tokens = negative_prompt
|
359 |
+
|
360 |
+
# textual inversion: process multi-vector tokens if necessary
|
361 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
362 |
+
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
|
363 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
|
364 |
+
|
365 |
+
negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
|
366 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
367 |
+
|
368 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
369 |
+
if negative_attention_mask is not None:
|
370 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
371 |
+
_, seq_len = negative_attention_mask.shape
|
372 |
+
negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt)
|
373 |
+
negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
|
374 |
+
|
375 |
+
if do_classifier_free_guidance:
|
376 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
377 |
+
seq_len = negative_prompt_embeds.shape[1]
|
378 |
+
|
379 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
380 |
+
|
381 |
+
if negative_prompt_embeds.ndim == 2:
|
382 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt)
|
383 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
|
384 |
+
else:
|
385 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
386 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
387 |
+
|
388 |
+
if text_encoder is not None:
|
389 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
390 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
391 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
392 |
+
|
393 |
+
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
|
394 |
+
|
395 |
+
def decode_latents(self, latents, enable_tiling=True):
|
396 |
+
deprecation_message = \
|
397 |
+
"The decode_latents method is deprecated and will be removed \
|
398 |
+
in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
399 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
400 |
+
|
401 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
402 |
+
if enable_tiling:
|
403 |
+
self.vae.enable_tiling()
|
404 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
405 |
+
self.vae.disable_tiling()
|
406 |
+
else:
|
407 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
408 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
409 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
410 |
+
if image.ndim==4: image = image.cpu().permute(0, 2, 3, 1).float()
|
411 |
+
else: image = image.cpu().float()
|
412 |
+
return image
|
413 |
+
|
414 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
415 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
416 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
417 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
418 |
+
# and should be between [0, 1]
|
419 |
+
extra_step_kwargs = {}
|
420 |
+
|
421 |
+
for k, v in kwargs.items():
|
422 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
423 |
+
if accepts:
|
424 |
+
extra_step_kwargs[k] = v
|
425 |
+
return extra_step_kwargs
|
426 |
+
|
427 |
+
def check_inputs(
|
428 |
+
self,
|
429 |
+
prompt,
|
430 |
+
height,
|
431 |
+
width,
|
432 |
+
frame,
|
433 |
+
callback_steps,
|
434 |
+
negative_prompt=None,
|
435 |
+
prompt_embeds=None,
|
436 |
+
negative_prompt_embeds=None,
|
437 |
+
callback_on_step_end_tensor_inputs=None,
|
438 |
+
vae_ver='88-4c-sd'
|
439 |
+
):
|
440 |
+
if height % 8 != 0 or width % 8 != 0:
|
441 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
442 |
+
|
443 |
+
# if frame is not None:
|
444 |
+
# if '884' in vae_ver:
|
445 |
+
# if frame!=1 and (frame-1)%4!=0:
|
446 |
+
# raise ValueError(f'`frame` has to be 1 or a multiple of 4 but is {frame}.')
|
447 |
+
# elif '888' in vae_ver:
|
448 |
+
# if frame!=1 and (frame-1)%8!=0:
|
449 |
+
# raise ValueError(f'`frame` has to be 1 or a multiple of 8 but is {frame}.')
|
450 |
+
|
451 |
+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
452 |
+
raise ValueError(
|
453 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
454 |
+
f" {type(callback_steps)}."
|
455 |
+
)
|
456 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
457 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
458 |
+
):
|
459 |
+
raise ValueError(
|
460 |
+
f"`callback_on_step_end_tensor_inputs` has to be in \
|
461 |
+
{self._callback_tensor_inputs}, but found \
|
462 |
+
{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
463 |
+
)
|
464 |
+
|
465 |
+
if prompt is not None and prompt_embeds is not None:
|
466 |
+
raise ValueError(
|
467 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
468 |
+
" only forward one of the two."
|
469 |
+
)
|
470 |
+
elif prompt is None and prompt_embeds is None:
|
471 |
+
raise ValueError(
|
472 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
473 |
+
)
|
474 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
475 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
476 |
+
|
477 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
478 |
+
raise ValueError(
|
479 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
480 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
481 |
+
)
|
482 |
+
|
483 |
+
|
484 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
485 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
486 |
+
raise ValueError(
|
487 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
488 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
489 |
+
f" {negative_prompt_embeds.shape}."
|
490 |
+
)
|
491 |
+
|
492 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
493 |
+
# get the original timestep using init_timestep
|
494 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
495 |
+
|
496 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
497 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
498 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
499 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
500 |
+
|
501 |
+
return timesteps.to(device), num_inference_steps - t_start
|
502 |
+
|
503 |
+
def prepare_latents(self, batch_size, num_channels_latents, num_inference_steps,
|
504 |
+
height, width, frame, dtype, device, timesteps,generator,
|
505 |
+
latents=None, gt_latents=None, denoise_strength=1.0,):
|
506 |
+
shape = (
|
507 |
+
batch_size,
|
508 |
+
num_channels_latents,
|
509 |
+
frame,
|
510 |
+
int(height) // self.vae_scale_factor,
|
511 |
+
int(width) // self.vae_scale_factor,
|
512 |
+
)
|
513 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
514 |
+
raise ValueError(
|
515 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
516 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
517 |
+
)
|
518 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
519 |
+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
|
520 |
+
|
521 |
+
if gt_latents.shape[2] == 1:
|
522 |
+
gt_latents = gt_latents.repeat(1, 1, frame, 1, 1)
|
523 |
+
|
524 |
+
# TODO: correct
|
525 |
+
x0 = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
526 |
+
# print("!!!!!!!!!!!!!! RANDOM NOISE !!!!!!!!!!!!!!!!!!")
|
527 |
+
# x0 = randn_tensor(shape, device=device, dtype=dtype)
|
528 |
+
x1 = gt_latents
|
529 |
+
|
530 |
+
t = torch.tensor([0.999]).to(device=device)
|
531 |
+
latents = x0 * t + x1 * (1 - t)
|
532 |
+
latents = torch.randn_like(x1)
|
533 |
+
# print("!!!randn_like", latents.shape)
|
534 |
+
latents = latents.to(dtype=dtype)
|
535 |
+
|
536 |
+
if latents is None:
|
537 |
+
latents = noise
|
538 |
+
original_latents = None
|
539 |
+
else:
|
540 |
+
latents = latents.to(device)
|
541 |
+
|
542 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
543 |
+
latents = latents * self.scheduler.init_noise_sigma
|
544 |
+
|
545 |
+
return latents, timesteps
|
546 |
+
|
547 |
+
# Copied from diffusers.pipelines.latent_consistency_models.
|
548 |
+
# pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
549 |
+
def get_guidance_scale_embedding(
|
550 |
+
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
|
551 |
+
) -> torch.Tensor:
|
552 |
+
"""
|
553 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
554 |
+
|
555 |
+
Args:
|
556 |
+
w (`torch.Tensor`):
|
557 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
558 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
559 |
+
Dimension of the embeddings to generate.
|
560 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
561 |
+
Data type of the generated embeddings.
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
565 |
+
"""
|
566 |
+
assert len(w.shape) == 1
|
567 |
+
w = w * 1000.0
|
568 |
+
|
569 |
+
half_dim = embedding_dim // 2
|
570 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
571 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
572 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
573 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
574 |
+
if embedding_dim % 2 == 1: # zero pad
|
575 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
576 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
577 |
+
return emb
|
578 |
+
|
579 |
+
@property
|
580 |
+
def guidance_scale(self):
|
581 |
+
return self._guidance_scale
|
582 |
+
|
583 |
+
@property
|
584 |
+
def guidance_rescale(self):
|
585 |
+
return self._guidance_rescale
|
586 |
+
|
587 |
+
@property
|
588 |
+
def clip_skip(self):
|
589 |
+
return self._clip_skip
|
590 |
+
|
591 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
592 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
593 |
+
# corresponds to doing no classifier free guidance.
|
594 |
+
@property
|
595 |
+
def do_classifier_free_guidance(self):
|
596 |
+
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
|
597 |
+
return self._guidance_scale > 1
|
598 |
+
|
599 |
+
@property
|
600 |
+
def cross_attention_kwargs(self):
|
601 |
+
return self._cross_attention_kwargs
|
602 |
+
|
603 |
+
@property
|
604 |
+
def num_timesteps(self):
|
605 |
+
return self._num_timesteps
|
606 |
+
|
607 |
+
@property
|
608 |
+
def interrupt(self):
|
609 |
+
return self._interrupt
|
610 |
+
|
611 |
+
@torch.no_grad()
|
612 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
613 |
+
def __call__(
|
614 |
+
self,
|
615 |
+
prompt: Union[str, List[str]],
|
616 |
+
cam_latents: Union[torch.Tensor], # cam_latents
|
617 |
+
last_latents: Union[torch.Tensor],
|
618 |
+
uncond_cam_latents: Union[torch.Tensor],
|
619 |
+
gt_latents: Union[torch.Tensor],
|
620 |
+
height: int,
|
621 |
+
width: int,
|
622 |
+
video_length: int, # frame is called video_len in hunyuan_multimodal/dev_video
|
623 |
+
data_type: str='video',
|
624 |
+
num_inference_steps: int = 50,
|
625 |
+
timesteps: List[int] = None,
|
626 |
+
sigmas: List[float] = None,
|
627 |
+
guidance_scale: float = 7.5,
|
628 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
629 |
+
ref_latents: Optional[torch.Tensor] = None,
|
630 |
+
uncond_ref_latents: Optional[torch.Tensor] = None,
|
631 |
+
ip_cfg_scale: float = 0.0,
|
632 |
+
use_deepcache: int = 1,
|
633 |
+
num_videos_per_prompt: Optional[int] = 1,
|
634 |
+
eta: float = 0.0,
|
635 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
636 |
+
latents: Optional[torch.Tensor] = None,
|
637 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
638 |
+
attention_mask: Optional[torch.Tensor] = None,
|
639 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
640 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
641 |
+
output_type: Optional[str] = "pil",
|
642 |
+
return_dict: bool = True,
|
643 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
644 |
+
guidance_rescale: float = 0.0,
|
645 |
+
clip_skip: Optional[int] = None,
|
646 |
+
callback_on_step_end: Optional[
|
647 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
648 |
+
] = None,
|
649 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
650 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
651 |
+
vae_ver: str='88-4c-sd',
|
652 |
+
enable_tiling: bool=False,
|
653 |
+
n_tokens: Optional[int] = None,
|
654 |
+
video_val_flag: bool=False,
|
655 |
+
denoise_strength: float = 1.0,
|
656 |
+
mask = None,
|
657 |
+
cpu_offload: bool=False,
|
658 |
+
use_sage: bool=False,
|
659 |
+
**kwargs,
|
660 |
+
):
|
661 |
+
r"""
|
662 |
+
The call function to the pipeline for generation.
|
663 |
+
|
664 |
+
Args:
|
665 |
+
prompt (`str` or `List[str]`):
|
666 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
667 |
+
height (`int`):
|
668 |
+
The height in pixels of the generated image.
|
669 |
+
width (`int`):
|
670 |
+
The width in pixels of the generated image.
|
671 |
+
video_length (`int`):
|
672 |
+
The number of frames in the generated video.
|
673 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
674 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
675 |
+
expense of slower inference.
|
676 |
+
timesteps (`List[int]`, *optional*):
|
677 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
678 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
679 |
+
passed will be used. Must be in descending order.
|
680 |
+
sigmas (`List[float]`, *optional*):
|
681 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
682 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
683 |
+
will be used.
|
684 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
685 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
686 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
687 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
688 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
689 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
690 |
+
ref_latents (`torch.Tensor`, *optional*):
|
691 |
+
The image tensor for time-concat.
|
692 |
+
uncond_ref_latents (`torch.Tensor`, *optional*):
|
693 |
+
The image tensor for time-concat. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
694 |
+
less than `1`).
|
695 |
+
pixel_value_llava (`torch.Tensor`, *optional*):
|
696 |
+
The image tensor for llava.
|
697 |
+
uncond_pixel_value_llava (`torch.Tensor`, *optional*):
|
698 |
+
The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
699 |
+
less than `1`).
|
700 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
701 |
+
The number of images to generate per prompt.
|
702 |
+
eta (`float`, *optional*, defaults to 0.0):
|
703 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
704 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
705 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
706 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
707 |
+
generation deterministic.
|
708 |
+
latents (`torch.Tensor`, *optional*):
|
709 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
710 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
711 |
+
tensor is generated by sampling using the supplied random `generator`.
|
712 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
713 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
714 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
715 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
716 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
717 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
718 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
719 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
720 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
721 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
722 |
+
plain tuple.
|
723 |
+
cross_attention_kwargs (`dict`, *optional*):
|
724 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
725 |
+
[`self.processor`]
|
726 |
+
(https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
727 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
728 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
729 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
730 |
+
using zero terminal SNR.
|
731 |
+
clip_skip (`int`, *optional*):
|
732 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
733 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
734 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
735 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
736 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
737 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
738 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
739 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
740 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
741 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
742 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
743 |
+
|
744 |
+
Examples:
|
745 |
+
|
746 |
+
Returns:
|
747 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
748 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
|
749 |
+
otherwise a list with the generated images is returned.
|
750 |
+
"""
|
751 |
+
callback = kwargs.pop("callback", None)
|
752 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
753 |
+
if callback is not None:
|
754 |
+
deprecate(
|
755 |
+
"callback",
|
756 |
+
"1.0.0",
|
757 |
+
"Passing `callback` as an input argument to \
|
758 |
+
`__call__` is deprecated, consider using `callback_on_step_end`",
|
759 |
+
)
|
760 |
+
if callback_steps is not None:
|
761 |
+
deprecate(
|
762 |
+
"callback_steps",
|
763 |
+
"1.0.0",
|
764 |
+
"Passing `callback_steps` as an input argument to \
|
765 |
+
`__call__` is deprecated, consider using `callback_on_step_end`",
|
766 |
+
)
|
767 |
+
|
768 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
769 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
770 |
+
|
771 |
+
# 0. Default height and width to transformer
|
772 |
+
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
773 |
+
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
774 |
+
# to deal with lora scaling and other possible forward hooks
|
775 |
+
|
776 |
+
# 1. Check inputs. Raise error if not correct
|
777 |
+
self.check_inputs(
|
778 |
+
prompt,
|
779 |
+
height,
|
780 |
+
width,
|
781 |
+
video_length,
|
782 |
+
callback_steps,
|
783 |
+
negative_prompt,
|
784 |
+
prompt_embeds,
|
785 |
+
negative_prompt_embeds,
|
786 |
+
callback_on_step_end_tensor_inputs,
|
787 |
+
vae_ver=vae_ver
|
788 |
+
)
|
789 |
+
|
790 |
+
self._guidance_scale = guidance_scale
|
791 |
+
self._guidance_rescale = guidance_rescale
|
792 |
+
self._clip_skip = clip_skip
|
793 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
794 |
+
self._interrupt = False
|
795 |
+
|
796 |
+
# 2. Define call parameters
|
797 |
+
if prompt is not None and isinstance(prompt, str):
|
798 |
+
batch_size = 1
|
799 |
+
elif prompt is not None and isinstance(prompt, list):
|
800 |
+
batch_size = len(prompt)
|
801 |
+
else:
|
802 |
+
batch_size = prompt_embeds.shape[0]
|
803 |
+
|
804 |
+
# device = self._execution_device
|
805 |
+
device = torch.device("cuda")
|
806 |
+
|
807 |
+
# 3. Encode input prompt
|
808 |
+
lora_scale = (
|
809 |
+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
810 |
+
)
|
811 |
+
|
812 |
+
prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask = \
|
813 |
+
self.encode_prompt(
|
814 |
+
prompt,
|
815 |
+
device,
|
816 |
+
num_videos_per_prompt,
|
817 |
+
self.do_classifier_free_guidance,
|
818 |
+
negative_prompt,
|
819 |
+
prompt_embeds=prompt_embeds,
|
820 |
+
attention_mask=attention_mask,
|
821 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
822 |
+
negative_attention_mask=negative_attention_mask,
|
823 |
+
lora_scale=lora_scale,
|
824 |
+
clip_skip=self.clip_skip,
|
825 |
+
data_type=data_type
|
826 |
+
)
|
827 |
+
|
828 |
+
if self.text_encoder_2 is not None:
|
829 |
+
prompt_embeds_2, negative_prompt_embeds_2, prompt_mask_2, negative_prompt_mask_2 = \
|
830 |
+
self.encode_prompt(
|
831 |
+
prompt,
|
832 |
+
device,
|
833 |
+
num_videos_per_prompt,
|
834 |
+
self.do_classifier_free_guidance,
|
835 |
+
negative_prompt,
|
836 |
+
prompt_embeds=None,
|
837 |
+
attention_mask=None,
|
838 |
+
negative_prompt_embeds=None,
|
839 |
+
negative_attention_mask=None,
|
840 |
+
lora_scale=lora_scale,
|
841 |
+
clip_skip=self.clip_skip,
|
842 |
+
text_encoder=self.text_encoder_2,
|
843 |
+
)
|
844 |
+
else:
|
845 |
+
prompt_embeds_2 = None
|
846 |
+
negative_prompt_embeds_2 = None
|
847 |
+
prompt_mask_2 = None
|
848 |
+
negative_prompt_mask_2 = None
|
849 |
+
|
850 |
+
# For classifier free guidance, we need to do two forward passes.
|
851 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
852 |
+
# to avoid doing two forward passes
|
853 |
+
if self.do_classifier_free_guidance:
|
854 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
855 |
+
if prompt_mask is not None:
|
856 |
+
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
|
857 |
+
if prompt_embeds_2 is not None:
|
858 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
859 |
+
if prompt_mask_2 is not None:
|
860 |
+
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
|
861 |
+
|
862 |
+
if self.do_classifier_free_guidance:
|
863 |
+
if ref_latents is not None:
|
864 |
+
ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
|
865 |
+
if prompt_mask[0].sum() > 575:
|
866 |
+
prompt_mask[0] = torch.cat(
|
867 |
+
[torch.ones((1, prompt_mask[0].sum() - 575)).to(prompt_mask),
|
868 |
+
torch.zeros((1, prompt_mask.shape[1] - prompt_mask[0].sum() + 575)).to(prompt_mask)], dim=1)
|
869 |
+
|
870 |
+
if ip_cfg_scale>0:
|
871 |
+
prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1:]])
|
872 |
+
prompt_embeds_2 = torch.cat([prompt_embeds_2, prompt_embeds_2[1:]])
|
873 |
+
prompt_mask = torch.cat([prompt_mask, prompt_mask[1:]], dim=0)
|
874 |
+
ref_latents = torch.cat([uncond_ref_latents, uncond_ref_latents, ref_latents[1:]], dim=0)
|
875 |
+
|
876 |
+
# 4. Prepare timesteps
|
877 |
+
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
|
878 |
+
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
|
879 |
+
)
|
880 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
881 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, **extra_set_timesteps_kwargs,
|
882 |
+
)
|
883 |
+
|
884 |
+
|
885 |
+
if '884' in vae_ver:
|
886 |
+
frame_length = (video_length - 2) // 4 + 2
|
887 |
+
elif '888' in vae_ver:
|
888 |
+
frame_length = (video_length - 1) // 8 + 1
|
889 |
+
else:
|
890 |
+
frame_length = video_length
|
891 |
+
|
892 |
+
# 5. Prepare latent variables
|
893 |
+
num_channels_latents = self.transformer.config.in_channels
|
894 |
+
latents, timesteps = self.prepare_latents(
|
895 |
+
batch_size * num_videos_per_prompt,
|
896 |
+
num_channels_latents,
|
897 |
+
num_inference_steps,
|
898 |
+
height,
|
899 |
+
width,
|
900 |
+
frame_length,
|
901 |
+
prompt_embeds.dtype,
|
902 |
+
device,
|
903 |
+
timesteps,
|
904 |
+
generator,
|
905 |
+
latents,
|
906 |
+
gt_latents,
|
907 |
+
denoise_strength,
|
908 |
+
)
|
909 |
+
|
910 |
+
gt_latents = gt_latents.repeat(1, 1, frame_length, 1, 1)
|
911 |
+
gt_latents_concat = gt_latents.clone()
|
912 |
+
|
913 |
+
if frame_length == 10:
|
914 |
+
gt_latents_concat[:,:,1:,:,:] = 0.0
|
915 |
+
mask_concat = torch.ones(gt_latents.shape[0],
|
916 |
+
1,
|
917 |
+
gt_latents.shape[2],
|
918 |
+
gt_latents.shape[3],
|
919 |
+
gt_latents.shape[4]).to(device=gt_latents.device)
|
920 |
+
mask_concat[:, :, 1:,...] = 0.0
|
921 |
+
else:
|
922 |
+
gt_latents_concat[:,:,gt_latents_concat.shape[2]//2:,:,:] = 0.0
|
923 |
+
mask_zeros = torch.zeros(gt_latents.shape[0],
|
924 |
+
1,
|
925 |
+
gt_latents.shape[2]//2,
|
926 |
+
gt_latents.shape[3],
|
927 |
+
gt_latents.shape[4])
|
928 |
+
mask_ones = torch.ones(gt_latents.shape[0],
|
929 |
+
1,
|
930 |
+
gt_latents.shape[2]//2,
|
931 |
+
gt_latents.shape[3],
|
932 |
+
gt_latents.shape[4])
|
933 |
+
mask_concat = torch.cat([mask_ones, mask_zeros], dim=2).to(device=gt_latents.device)
|
934 |
+
|
935 |
+
# 6. Prepare extra step kwargs.
|
936 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
937 |
+
self.scheduler.step, {"generator": generator, "eta": eta},
|
938 |
+
)
|
939 |
+
|
940 |
+
target_dtype = PRECISION_TO_TYPE[self.args.precision]
|
941 |
+
autocast_enabled = (target_dtype != torch.float32) and not self.args.val_disable_autocast
|
942 |
+
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
|
943 |
+
vae_autocast_enabled = (vae_dtype != torch.float32) and not self.args.val_disable_autocast
|
944 |
+
|
945 |
+
# 7. Denoising loop
|
946 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
947 |
+
self._num_timesteps = len(timesteps)
|
948 |
+
|
949 |
+
start_scale = ip_cfg_scale # 3.0
|
950 |
+
end_scale = 1.0
|
951 |
+
step_scale = (start_scale - end_scale) / (self._num_timesteps - 1 + 1e-3)
|
952 |
+
if cpu_offload: torch.cuda.empty_cache()
|
953 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
954 |
+
for i, t in enumerate(timesteps):
|
955 |
+
if self.interrupt:
|
956 |
+
continue
|
957 |
+
|
958 |
+
if last_latents.shape[2] == 1:
|
959 |
+
latents[:,:,0,:,:] = last_latents[:,:,-1,:,:]
|
960 |
+
else:
|
961 |
+
latents[:,:,:latents.shape[2]//2,:,:] = last_latents
|
962 |
+
gt_latents_concat[:,:,:latents.shape[2]//2,:,:] = last_latents
|
963 |
+
|
964 |
+
# expand the latents if we are doing classifier free guidance
|
965 |
+
latents_concat = torch.concat([latents, gt_latents_concat, mask_concat], dim=1)
|
966 |
+
latent_model_input = torch.cat([latents_concat] * 2) \
|
967 |
+
if self.do_classifier_free_guidance else latents_concat
|
968 |
+
|
969 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
970 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
971 |
+
|
972 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
973 |
+
guidance_expand = None
|
974 |
+
|
975 |
+
cam_latents_ = torch.cat([uncond_cam_latents, cam_latents], dim=0) \
|
976 |
+
if self.do_classifier_free_guidance else cam_latents
|
977 |
+
|
978 |
+
# predict the noise residual
|
979 |
+
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
|
980 |
+
is_cache = False
|
981 |
+
if use_deepcache and num_inference_steps==50:
|
982 |
+
|
983 |
+
no_cache_steps = list(range(0, 10)) + list(range(10, 40, 2)) + list(range(40, 50))
|
984 |
+
if i in no_cache_steps:
|
985 |
+
is_cache = False
|
986 |
+
else:
|
987 |
+
is_cache = True
|
988 |
+
if latent_model_input.shape[-1]*latent_model_input.shape[-2]>64*112 and cpu_offload:
|
989 |
+
if i==0:
|
990 |
+
print(f'cpu_offload={cpu_offload} and \
|
991 |
+
{latent_model_input.shape[-2:]} is large, split infer noise-pred')
|
992 |
+
noise_pred_uncond = self.transformer(latent_model_input[:1],
|
993 |
+
t_expand[:1],
|
994 |
+
text_states=prompt_embeds[:1],
|
995 |
+
text_mask=prompt_mask[:1],
|
996 |
+
text_states_2=prompt_embeds_2[:1],
|
997 |
+
freqs_cos=freqs_cis[0],
|
998 |
+
freqs_sin=freqs_cis[1],
|
999 |
+
guidance=guidance_expand,
|
1000 |
+
return_dict=True,
|
1001 |
+
is_cache=is_cache,
|
1002 |
+
cam_latents=cam_latents_[:1])['x']
|
1003 |
+
torch.cuda.empty_cache()
|
1004 |
+
noise_pred_text = self.transformer(latent_model_input[1:],
|
1005 |
+
t_expand[1:],
|
1006 |
+
text_states=prompt_embeds[1:],
|
1007 |
+
text_mask=prompt_mask[1:],
|
1008 |
+
text_states_2=prompt_embeds_2[1:],
|
1009 |
+
freqs_cos=freqs_cis[0],
|
1010 |
+
freqs_sin=freqs_cis[1],
|
1011 |
+
guidance=guidance_expand,
|
1012 |
+
return_dict=True,
|
1013 |
+
is_cache=is_cache,
|
1014 |
+
cam_latents=cam_latents_[1:])['x']
|
1015 |
+
noise_pred = torch.cat([noise_pred_uncond, noise_pred_text], dim=0)
|
1016 |
+
torch.cuda.empty_cache()
|
1017 |
+
else:
|
1018 |
+
noise_pred = self.transformer( # For an input image (1, 256, 256)
|
1019 |
+
latent_model_input, # [2, 16, 1, 32, 32] #
|
1020 |
+
t_expand, # [2]
|
1021 |
+
text_states=prompt_embeds, # [2, 256, 4096]
|
1022 |
+
text_mask=prompt_mask, # [2, 256]
|
1023 |
+
text_states_2=prompt_embeds_2, # [2, 768]
|
1024 |
+
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
|
1025 |
+
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
|
1026 |
+
guidance=guidance_expand,
|
1027 |
+
return_dict=True,
|
1028 |
+
is_cache=is_cache,
|
1029 |
+
cam_latents=cam_latents_,
|
1030 |
+
use_sage=use_sage,
|
1031 |
+
)['x']
|
1032 |
+
|
1033 |
+
# perform guidance
|
1034 |
+
if self.do_classifier_free_guidance and ip_cfg_scale < 0.1:
|
1035 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1036 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1037 |
+
|
1038 |
+
if ip_cfg_scale > 0:
|
1039 |
+
noise_pred_uncond, noise_pred_text, noise_pred_ip = noise_pred.chunk(3)
|
1040 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * \
|
1041 |
+
(noise_pred_text - noise_pred_uncond) + start_scale * (noise_pred_ip-noise_pred_text)
|
1042 |
+
start_scale -= step_scale
|
1043 |
+
if i==0:
|
1044 |
+
print(f'i={i}, noise_pred shape={noise_pred.shape}')
|
1045 |
+
|
1046 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1047 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1048 |
+
noise_pred = rescale_noise_cfg(noise_pred,
|
1049 |
+
noise_pred_text,
|
1050 |
+
guidance_rescale=self.guidance_rescale)
|
1051 |
+
|
1052 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1053 |
+
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1054 |
+
if last_latents.shape[2] == 1:
|
1055 |
+
latents[:,:,1:,:,:] = self.scheduler.step(noise_pred[:,:,1:,:,:],
|
1056 |
+
t,
|
1057 |
+
latents[:,:,1:,:,:],
|
1058 |
+
**extra_step_kwargs,
|
1059 |
+
return_dict=False)[0]
|
1060 |
+
else:
|
1061 |
+
latents[:,:,noise_pred.shape[2]//2:,:,:] = self.scheduler.step(
|
1062 |
+
noise_pred[:,:,noise_pred.shape[2]//2:,:,:],
|
1063 |
+
t,
|
1064 |
+
latents[:,:,latents.shape[2]//2:,:,:],
|
1065 |
+
**extra_step_kwargs, return_dict=False)[0]
|
1066 |
+
|
1067 |
+
|
1068 |
+
if callback_on_step_end is not None:
|
1069 |
+
callback_kwargs = {}
|
1070 |
+
for k in callback_on_step_end_tensor_inputs:
|
1071 |
+
callback_kwargs[k] = locals()[k]
|
1072 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1073 |
+
|
1074 |
+
latents = callback_outputs.pop("latents", latents)
|
1075 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1076 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1077 |
+
|
1078 |
+
# call the callback, if provided
|
1079 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1080 |
+
if progress_bar is not None:
|
1081 |
+
progress_bar.update()
|
1082 |
+
if callback is not None and i % callback_steps == 0:
|
1083 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1084 |
+
callback(step_idx, t, latents)
|
1085 |
+
|
1086 |
+
if cpu_offload: torch.cuda.empty_cache()
|
1087 |
+
# if mask_latents is not None:
|
1088 |
+
# latents = mask_latents * latents + (1 - mask_latents) * original_latents
|
1089 |
+
if last_latents.shape[2] == 1:
|
1090 |
+
latents = latents[:,:,1:,:,:]
|
1091 |
+
|
1092 |
+
if not output_type == "latent":
|
1093 |
+
expand_temporal_dim = False
|
1094 |
+
if len(latents.shape) == 4:
|
1095 |
+
if isinstance(self.vae, AutoencoderKLCausal3D):
|
1096 |
+
latents = latents.unsqueeze(2)
|
1097 |
+
expand_temporal_dim = True
|
1098 |
+
elif len(latents.shape) == 5:
|
1099 |
+
pass
|
1100 |
+
else:
|
1101 |
+
raise ValueError(
|
1102 |
+
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
|
1103 |
+
|
1104 |
+
if not last_latents.shape[2] == 1:
|
1105 |
+
last_latents = latents[:,:,latents.shape[2]//2:,:,:]
|
1106 |
+
else:
|
1107 |
+
last_latents = latents
|
1108 |
+
latent_decode = last_latents.clone()
|
1109 |
+
latent_decode = latent_decode / self.vae.config.scaling_factor
|
1110 |
+
|
1111 |
+
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
|
1112 |
+
if enable_tiling:
|
1113 |
+
self.vae.enable_tiling()
|
1114 |
+
if cpu_offload:
|
1115 |
+
self.vae.post_quant_conv.to('cuda')
|
1116 |
+
self.vae.decoder.to('cuda')
|
1117 |
+
image = self.vae.decode(latent_decode, return_dict=False, generator=generator)[0]
|
1118 |
+
self.vae.disable_tiling()
|
1119 |
+
if cpu_offload:
|
1120 |
+
self.vae.post_quant_conv.to('cpu')
|
1121 |
+
self.vae.decoder.to('cpu')
|
1122 |
+
torch.cuda.empty_cache()
|
1123 |
+
else:
|
1124 |
+
image = self.vae.decode(latent_decode, return_dict=False, generator=generator)[0]
|
1125 |
+
# if image is None:
|
1126 |
+
# return (None, )
|
1127 |
+
|
1128 |
+
# if expand_temporal_dim or (not video_val_flag and image.shape[2] == 1):
|
1129 |
+
# image = image.squeeze(2)
|
1130 |
+
|
1131 |
+
if image is not None and (expand_temporal_dim or (not video_val_flag and image.shape[2] == 1)):
|
1132 |
+
image = image.squeeze(2)
|
1133 |
+
|
1134 |
+
if image is not None:
|
1135 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
1136 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
1137 |
+
image = image.cpu().float()
|
1138 |
+
|
1139 |
+
# Offload all models
|
1140 |
+
self.maybe_free_model_hooks()
|
1141 |
+
|
1142 |
+
if cpu_offload: torch.cuda.empty_cache()
|
1143 |
+
if not return_dict:
|
1144 |
+
return image
|
1145 |
+
|
1146 |
+
return_latents = kwargs.get("return_latents", False)
|
1147 |
+
|
1148 |
+
if return_latents:
|
1149 |
+
return HunyuanVideoPipelineOutput(videos=image), \
|
1150 |
+
latents, timesteps, last_latents, last_latents[:,:,-1:, ...]
|
1151 |
+
|
1152 |
+
return HunyuanVideoPipelineOutput(videos=image)
|
hymm_sp/diffusion/schedulers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
2 |
+
# from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
#
|
16 |
+
# Modified from diffusers==0.29.2
|
17 |
+
#
|
18 |
+
# ==============================================================================
|
19 |
+
|
20 |
+
from dataclasses import dataclass
|
21 |
+
from typing import Optional, Tuple, Union
|
22 |
+
|
23 |
+
import torch
|
24 |
+
|
25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
26 |
+
from diffusers.utils import BaseOutput, logging
|
27 |
+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass
|
34 |
+
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
|
35 |
+
"""
|
36 |
+
Output class for the scheduler's `step` function output.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
40 |
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
41 |
+
denoising loop.
|
42 |
+
"""
|
43 |
+
|
44 |
+
prev_sample: torch.FloatTensor
|
45 |
+
|
46 |
+
|
47 |
+
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
48 |
+
"""
|
49 |
+
Euler scheduler.
|
50 |
+
|
51 |
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
52 |
+
methods the library implements for all schedulers such as loading and saving.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
num_train_timesteps (`int`, defaults to 1000):
|
56 |
+
The number of diffusion steps to train the model.
|
57 |
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
58 |
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
59 |
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
60 |
+
shift (`float`, defaults to 1.0):
|
61 |
+
The shift value for the timestep schedule.
|
62 |
+
reverse (`bool`, defaults to `True`):
|
63 |
+
Whether to reverse the timestep schedule.
|
64 |
+
"""
|
65 |
+
|
66 |
+
_compatibles = []
|
67 |
+
order = 1
|
68 |
+
|
69 |
+
@register_to_config
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
num_train_timesteps: int = 1000,
|
73 |
+
shift: float = 1.0,
|
74 |
+
reverse: bool = True,
|
75 |
+
solver: str = "euler",
|
76 |
+
n_tokens: Optional[int] = None,
|
77 |
+
):
|
78 |
+
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
|
79 |
+
|
80 |
+
if not reverse:
|
81 |
+
sigmas = sigmas.flip(0)
|
82 |
+
|
83 |
+
self.sigmas = sigmas
|
84 |
+
# the value fed to model
|
85 |
+
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
|
86 |
+
|
87 |
+
self._step_index = None
|
88 |
+
self._begin_index = None
|
89 |
+
|
90 |
+
self.supported_solver = ["euler"]
|
91 |
+
if solver not in self.supported_solver:
|
92 |
+
raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
|
93 |
+
|
94 |
+
@property
|
95 |
+
def step_index(self):
|
96 |
+
"""
|
97 |
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
98 |
+
"""
|
99 |
+
return self._step_index
|
100 |
+
|
101 |
+
@property
|
102 |
+
def begin_index(self):
|
103 |
+
"""
|
104 |
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
105 |
+
"""
|
106 |
+
return self._begin_index
|
107 |
+
|
108 |
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
109 |
+
def set_begin_index(self, begin_index: int = 0):
|
110 |
+
"""
|
111 |
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
begin_index (`int`):
|
115 |
+
The begin index for the scheduler.
|
116 |
+
"""
|
117 |
+
self._begin_index = begin_index
|
118 |
+
|
119 |
+
def _sigma_to_t(self, sigma):
|
120 |
+
return sigma * self.config.num_train_timesteps
|
121 |
+
|
122 |
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
|
123 |
+
n_tokens: int = None):
|
124 |
+
"""
|
125 |
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
126 |
+
|
127 |
+
Args:
|
128 |
+
num_inference_steps (`int`):
|
129 |
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
130 |
+
device (`str` or `torch.device`, *optional*):
|
131 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
132 |
+
n_tokens (`int`, *optional*):
|
133 |
+
Number of tokens in the input sequence.
|
134 |
+
"""
|
135 |
+
self.num_inference_steps = num_inference_steps
|
136 |
+
|
137 |
+
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
|
138 |
+
sigmas = self.sd3_time_shift(sigmas)
|
139 |
+
|
140 |
+
if not self.config.reverse:
|
141 |
+
sigmas = 1 - sigmas
|
142 |
+
|
143 |
+
self.sigmas = sigmas
|
144 |
+
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
|
145 |
+
|
146 |
+
# Reset step index
|
147 |
+
self._step_index = None
|
148 |
+
|
149 |
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
150 |
+
if schedule_timesteps is None:
|
151 |
+
schedule_timesteps = self.timesteps
|
152 |
+
|
153 |
+
indices = (schedule_timesteps == timestep).nonzero()
|
154 |
+
|
155 |
+
# The sigma index that is taken for the **very** first `step`
|
156 |
+
# is always the second index (or the last index if there is only 1)
|
157 |
+
# This way we can ensure we don't accidentally skip a sigma in
|
158 |
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
159 |
+
pos = 1 if len(indices) > 1 else 0
|
160 |
+
|
161 |
+
return indices[pos].item()
|
162 |
+
|
163 |
+
def _init_step_index(self, timestep):
|
164 |
+
if self.begin_index is None:
|
165 |
+
if isinstance(timestep, torch.Tensor):
|
166 |
+
timestep = timestep.to(self.timesteps.device)
|
167 |
+
self._step_index = self.index_for_timestep(timestep)
|
168 |
+
else:
|
169 |
+
self._step_index = self._begin_index
|
170 |
+
|
171 |
+
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
172 |
+
return sample
|
173 |
+
|
174 |
+
def sd3_time_shift(self, t: torch.Tensor):
|
175 |
+
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
|
176 |
+
|
177 |
+
def step(
|
178 |
+
self,
|
179 |
+
model_output: torch.FloatTensor,
|
180 |
+
timestep: Union[float, torch.FloatTensor],
|
181 |
+
sample: torch.FloatTensor,
|
182 |
+
return_dict: bool = True,
|
183 |
+
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
|
184 |
+
"""
|
185 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
186 |
+
process from the learned model outputs (most often the predicted noise).
|
187 |
+
|
188 |
+
Args:
|
189 |
+
model_output (`torch.FloatTensor`):
|
190 |
+
The direct output from learned diffusion model.
|
191 |
+
timestep (`float`):
|
192 |
+
The current discrete timestep in the diffusion chain.
|
193 |
+
sample (`torch.FloatTensor`):
|
194 |
+
A current instance of a sample created by the diffusion process.
|
195 |
+
return_dict (`bool`):
|
196 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
197 |
+
tuple.
|
198 |
+
|
199 |
+
Returns:
|
200 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
201 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
202 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
203 |
+
"""
|
204 |
+
|
205 |
+
if (
|
206 |
+
isinstance(timestep, int)
|
207 |
+
or isinstance(timestep, torch.IntTensor)
|
208 |
+
or isinstance(timestep, torch.LongTensor)
|
209 |
+
):
|
210 |
+
raise ValueError(
|
211 |
+
(
|
212 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
213 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
214 |
+
" one of the `scheduler.timesteps` as a timestep."
|
215 |
+
),
|
216 |
+
)
|
217 |
+
|
218 |
+
if self.step_index is None:
|
219 |
+
self._init_step_index(timestep)
|
220 |
+
|
221 |
+
# Upcast to avoid precision issues when computing prev_sample
|
222 |
+
sample = sample.to(torch.float32)
|
223 |
+
|
224 |
+
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
|
225 |
+
|
226 |
+
if self.config.solver == "euler":
|
227 |
+
prev_sample = sample + model_output.float() * dt
|
228 |
+
else:
|
229 |
+
raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
|
230 |
+
|
231 |
+
# upon completion increase step index by one
|
232 |
+
self._step_index += 1
|
233 |
+
|
234 |
+
if not return_dict:
|
235 |
+
return (prev_sample,)
|
236 |
+
|
237 |
+
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
|
238 |
+
|
239 |
+
def __len__(self):
|
240 |
+
return self.config.num_train_timesteps
|
hymm_sp/helpers.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Union, List
|
3 |
+
from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd
|
4 |
+
|
5 |
+
from itertools import repeat
|
6 |
+
import collections.abc
|
7 |
+
|
8 |
+
|
9 |
+
def _ntuple(n):
|
10 |
+
"""
|
11 |
+
Creates a helper function to convert inputs to tuples of specified length.
|
12 |
+
|
13 |
+
Converts iterable inputs (excluding strings) to tuples of length n,
|
14 |
+
or repeats single values n times to form a tuple. Useful for handling
|
15 |
+
multi-dimensional parameters like sizes and strides.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
n (int): Target length of the tuple
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
function: Parser function that converts inputs to n-length tuples
|
22 |
+
"""
|
23 |
+
def parse(x):
|
24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
25 |
+
x = tuple(x)
|
26 |
+
if len(x) == 1:
|
27 |
+
x = tuple(repeat(x[0], n))
|
28 |
+
return x
|
29 |
+
return tuple(repeat(x, n))
|
30 |
+
return parse
|
31 |
+
|
32 |
+
|
33 |
+
# Create common tuple conversion functions for 1-4 dimensions
|
34 |
+
to_1tuple = _ntuple(1)
|
35 |
+
to_2tuple = _ntuple(2)
|
36 |
+
to_3tuple = _ntuple(3)
|
37 |
+
to_4tuple = _ntuple(4)
|
38 |
+
|
39 |
+
|
40 |
+
def get_rope_freq_from_size(
|
41 |
+
latents_size,
|
42 |
+
ndim,
|
43 |
+
target_ndim,
|
44 |
+
args,
|
45 |
+
rope_theta_rescale_factor: Union[float, List[float]] = 1.0,
|
46 |
+
rope_interpolation_factor: Union[float, List[float]] = 1.0,
|
47 |
+
concat_dict={}
|
48 |
+
):
|
49 |
+
"""
|
50 |
+
Calculates RoPE (Rotary Position Embedding) frequencies based on latent dimensions.
|
51 |
+
|
52 |
+
Converts latent space dimensions to rope-compatible sizes by accounting for
|
53 |
+
patch size, then generates the appropriate frequency embeddings for each dimension.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
latents_size: Dimensions of the latent space tensor
|
57 |
+
ndim (int): Number of dimensions in the latent space
|
58 |
+
target_ndim (int): Target number of dimensions for the embeddings
|
59 |
+
args: Configuration arguments containing model parameters (patch_size, rope_theta, etc.)
|
60 |
+
rope_theta_rescale_factor: Rescaling factor(s) for theta parameter (per dimension)
|
61 |
+
rope_interpolation_factor: Interpolation factor(s) for position embeddings (per dimension)
|
62 |
+
concat_dict: Dictionary for special concatenation modes (e.g., time-based extensions)
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
tuple: Cosine and sine frequency embeddings (freqs_cos, freqs_sin)
|
66 |
+
"""
|
67 |
+
# Calculate rope sizes by dividing latent dimensions by patch size
|
68 |
+
if isinstance(args.patch_size, int):
|
69 |
+
# Validate all latent dimensions are divisible by patch size
|
70 |
+
assert all(s % args.patch_size == 0 for s in latents_size), \
|
71 |
+
f"Latent size (last {ndim} dimensions) must be divisible by patch size ({args.patch_size}), " \
|
72 |
+
f"but got {latents_size}."
|
73 |
+
rope_sizes = [s // args.patch_size for s in latents_size]
|
74 |
+
elif isinstance(args.patch_size, list):
|
75 |
+
# Validate with per-dimension patch sizes
|
76 |
+
assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
|
77 |
+
f"Latent size (last {ndim} dimensions) must be divisible by patch size ({args.patch_size}), " \
|
78 |
+
f"but got {latents_size}."
|
79 |
+
rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)]
|
80 |
+
|
81 |
+
# Add singleton dimensions if needed to match target_ndim (typically for time axis)
|
82 |
+
if len(rope_sizes) != target_ndim:
|
83 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes
|
84 |
+
|
85 |
+
# Calculate head dimension and validate rope dimensions
|
86 |
+
head_dim = args.hidden_size // args.num_heads
|
87 |
+
rope_dim_list = args.rope_dim_list
|
88 |
+
|
89 |
+
# Default: split head dimension equally across target dimensions
|
90 |
+
if rope_dim_list is None:
|
91 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
92 |
+
|
93 |
+
# Ensure rope dimensions sum to head dimension
|
94 |
+
assert sum(rope_dim_list) == head_dim, \
|
95 |
+
"Sum of rope_dim_list must equal attention head dimension (hidden_size // num_heads)"
|
96 |
+
|
97 |
+
# Generate rotary position embeddings
|
98 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(
|
99 |
+
rope_dim_list,
|
100 |
+
rope_sizes,
|
101 |
+
theta=args.rope_theta,
|
102 |
+
use_real=True,
|
103 |
+
theta_rescale_factor=rope_theta_rescale_factor,
|
104 |
+
interpolation_factor=rope_interpolation_factor,
|
105 |
+
concat_dict=concat_dict
|
106 |
+
)
|
107 |
+
return freqs_cos, freqs_sin
|
108 |
+
|
109 |
+
|
110 |
+
def get_nd_rotary_pos_embed_new(
|
111 |
+
rope_dim_list,
|
112 |
+
start,
|
113 |
+
*args,
|
114 |
+
theta=10000.,
|
115 |
+
use_real=False,
|
116 |
+
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
117 |
+
interpolation_factor: Union[float, List[float]] = 1.0,
|
118 |
+
concat_dict={}
|
119 |
+
):
|
120 |
+
"""
|
121 |
+
Generates multi-dimensional Rotary Position Embeddings (RoPE).
|
122 |
+
|
123 |
+
Creates position embeddings for n-dimensional spaces by generating a meshgrid
|
124 |
+
of positions and applying 1D rotary embeddings to each dimension, then combining them.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
rope_dim_list (list): List of embedding dimensions for each axis
|
128 |
+
start: Starting dimensions for generating the meshgrid
|
129 |
+
*args: Additional arguments for meshgrid generation
|
130 |
+
theta (float): Base theta parameter for RoPE frequency calculation
|
131 |
+
use_real (bool): If True, returns separate cosine and sine embeddings
|
132 |
+
theta_rescale_factor: Rescaling factor(s) for theta (per dimension)
|
133 |
+
interpolation_factor: Interpolation factor(s) for position scaling (per dimension)
|
134 |
+
concat_dict: Dictionary for special concatenation modes (e.g., time-based extensions)
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
tuple or tensor: Cosine and sine embeddings if use_real=True, combined embedding otherwise
|
138 |
+
"""
|
139 |
+
# Generate n-dimensional meshgrid of positions (shape: [dim, *sizes])
|
140 |
+
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
|
141 |
+
|
142 |
+
# Handle special concatenation modes (e.g., adding time-based bias)
|
143 |
+
if concat_dict:
|
144 |
+
if concat_dict['mode'] == 'timecat':
|
145 |
+
# Add bias as first element in first dimension
|
146 |
+
bias = grid[:, :1].clone()
|
147 |
+
bias[0] = concat_dict['bias'] * torch.ones_like(bias[0])
|
148 |
+
grid = torch.cat([bias, grid], dim=1)
|
149 |
+
elif concat_dict['mode'] == 'timecat-w':
|
150 |
+
# Add biased first element with spatial offset
|
151 |
+
bias = grid[:, :1].clone()
|
152 |
+
bias[0] = concat_dict['bias'] * torch.ones_like(bias[0])
|
153 |
+
bias[2] += start[-1] # Spatial offset reference: OminiControl implementation
|
154 |
+
grid = torch.cat([bias, grid], dim=1)
|
155 |
+
|
156 |
+
# Normalize theta rescale factors to list format (per dimension)
|
157 |
+
if isinstance(theta_rescale_factor, (int, float)):
|
158 |
+
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
159 |
+
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
160 |
+
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
161 |
+
assert len(theta_rescale_factor) == len(rope_dim_list), \
|
162 |
+
"Length of theta_rescale_factor must match number of dimensions"
|
163 |
+
|
164 |
+
# Normalize interpolation factors to list format (per dimension)
|
165 |
+
if isinstance(interpolation_factor, (int, float)):
|
166 |
+
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
167 |
+
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
168 |
+
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
169 |
+
assert len(interpolation_factor) == len(rope_dim_list), \
|
170 |
+
"Length of interpolation_factor must match number of dimensions"
|
171 |
+
|
172 |
+
# Generate 1D rotary embeddings for each dimension and combine
|
173 |
+
embs = []
|
174 |
+
for i in range(len(rope_dim_list)):
|
175 |
+
# Flatten grid dimension and generate embeddings
|
176 |
+
emb = get_1d_rotary_pos_embed(
|
177 |
+
rope_dim_list[i],
|
178 |
+
grid[i].reshape(-1), # Flatten to 1D positions
|
179 |
+
theta,
|
180 |
+
use_real=use_real,
|
181 |
+
theta_rescale_factor=theta_rescale_factor[i],
|
182 |
+
interpolation_factor=interpolation_factor[i]
|
183 |
+
)
|
184 |
+
embs.append(emb)
|
185 |
+
|
186 |
+
# Combine embeddings from all dimensions
|
187 |
+
if use_real:
|
188 |
+
# Return separate cosine and sine components
|
189 |
+
cos = torch.cat([emb[0] for emb in embs], dim=1)
|
190 |
+
sin = torch.cat([emb[1] for emb in embs], dim=1)
|
191 |
+
return cos, sin
|
192 |
+
else:
|
193 |
+
# Return combined embedding
|
194 |
+
return torch.cat(embs, dim=1)
|
hymm_sp/inference.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
from loguru import logger
|
4 |
+
from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
|
5 |
+
from hymm_sp.vae import load_vae
|
6 |
+
from hymm_sp.modules import load_model
|
7 |
+
from hymm_sp.text_encoder import TextEncoder
|
8 |
+
import torch.distributed
|
9 |
+
from hymm_sp.modules.parallel_states import (
|
10 |
+
initialize_sequence_parallel_state,
|
11 |
+
get_sequence_parallel_state,
|
12 |
+
nccl_info,
|
13 |
+
)
|
14 |
+
from hymm_sp.modules.fp8_optimization import convert_fp8_linear
|
15 |
+
|
16 |
+
|
17 |
+
class Inference(object):
|
18 |
+
def __init__(self,
|
19 |
+
args,
|
20 |
+
vae,
|
21 |
+
vae_kwargs,
|
22 |
+
text_encoder,
|
23 |
+
model,
|
24 |
+
text_encoder_2=None,
|
25 |
+
pipeline=None,
|
26 |
+
cpu_offload=False,
|
27 |
+
device=None,
|
28 |
+
logger=None):
|
29 |
+
self.vae = vae
|
30 |
+
self.vae_kwargs = vae_kwargs
|
31 |
+
|
32 |
+
self.text_encoder = text_encoder
|
33 |
+
self.text_encoder_2 = text_encoder_2
|
34 |
+
|
35 |
+
self.model = model
|
36 |
+
self.pipeline = pipeline
|
37 |
+
self.cpu_offload = cpu_offload
|
38 |
+
|
39 |
+
self.args = args
|
40 |
+
self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
41 |
+
if nccl_info.sp_size > 1:
|
42 |
+
self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
43 |
+
|
44 |
+
self.logger = logger
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def from_pretrained(cls,
|
48 |
+
pretrained_model_path,
|
49 |
+
args,
|
50 |
+
device=None,
|
51 |
+
**kwargs):
|
52 |
+
"""
|
53 |
+
Initialize the Inference pipeline.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
pretrained_model_path (str or pathlib.Path): The model path,
|
57 |
+
including t2v, text encoder and vae checkpoints.
|
58 |
+
device (int): The device for inference. Default is 0.
|
59 |
+
logger (logging.Logger): The logger for the inference pipeline. Default is None.
|
60 |
+
"""
|
61 |
+
# ========================================================================
|
62 |
+
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
|
63 |
+
|
64 |
+
# ======================== Get the args path =============================
|
65 |
+
|
66 |
+
# Set device and disable gradient
|
67 |
+
if device is None:
|
68 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
69 |
+
torch.set_grad_enabled(False)
|
70 |
+
logger.info("Building model...")
|
71 |
+
factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
|
72 |
+
in_channels = args.latent_channels
|
73 |
+
out_channels = args.latent_channels
|
74 |
+
print("="*25, f"build model", "="*25)
|
75 |
+
model = load_model(
|
76 |
+
args,
|
77 |
+
in_channels=in_channels,
|
78 |
+
out_channels=out_channels,
|
79 |
+
factor_kwargs=factor_kwargs
|
80 |
+
)
|
81 |
+
if args.cpu_offload:
|
82 |
+
print(f'='*20, f'load transformer to cpu')
|
83 |
+
model = model.to('cpu')
|
84 |
+
torch.cuda.empty_cache()
|
85 |
+
else:
|
86 |
+
model = model.to(device)
|
87 |
+
model = Inference.load_state_dict(args, model, pretrained_model_path)
|
88 |
+
model.eval()
|
89 |
+
|
90 |
+
if args.use_fp8:
|
91 |
+
convert_fp8_linear(model)
|
92 |
+
|
93 |
+
# ============================= Build extra models ========================
|
94 |
+
# VAE
|
95 |
+
print("="*25, f"load vae", "="*25)
|
96 |
+
vae, _, s_ratio, t_ratio = load_vae(args.vae,
|
97 |
+
args.vae_precision,
|
98 |
+
logger=logger,
|
99 |
+
device='cpu' if args.cpu_offload else device)
|
100 |
+
vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
|
101 |
+
|
102 |
+
# Parallel VAE
|
103 |
+
device_vaes = []
|
104 |
+
device_vaes.append(vae)
|
105 |
+
if nccl_info.sp_size > 1 and nccl_info.rank_within_group == 0:
|
106 |
+
for i in range(1, nccl_info.sp_size):
|
107 |
+
cur_device = torch.device(f"cuda:{i}")
|
108 |
+
# print("!!!!!!!!!! Load vae for ", cur_device)
|
109 |
+
device_vae, _, _, _ = load_vae(args.vae,
|
110 |
+
args.vae_precision,
|
111 |
+
logger=logger,
|
112 |
+
device='cpu' if args.cpu_offload else cur_device)
|
113 |
+
device_vaes.append(device_vae)
|
114 |
+
vae.device_vaes = device_vaes
|
115 |
+
|
116 |
+
# Text encoder
|
117 |
+
if args.prompt_template_video is not None:
|
118 |
+
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
|
119 |
+
else:
|
120 |
+
crop_start = 0
|
121 |
+
max_length = args.text_len + crop_start
|
122 |
+
|
123 |
+
# prompt_template_video
|
124 |
+
prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] \
|
125 |
+
if args.prompt_template_video is not None else None
|
126 |
+
print("="*25, f"load llava", "="*25)
|
127 |
+
text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
|
128 |
+
max_length = max_length,
|
129 |
+
text_encoder_precision = args.text_encoder_precision,
|
130 |
+
tokenizer_type = args.tokenizer,
|
131 |
+
use_attention_mask = args.use_attention_mask,
|
132 |
+
prompt_template_video = prompt_template_video,
|
133 |
+
hidden_state_skip_layer = args.hidden_state_skip_layer,
|
134 |
+
apply_final_norm = args.apply_final_norm,
|
135 |
+
reproduce = args.reproduce,
|
136 |
+
logger = logger,
|
137 |
+
device = 'cpu' if args.cpu_offload else device ,
|
138 |
+
)
|
139 |
+
text_encoder_2 = None
|
140 |
+
if args.text_encoder_2 is not None:
|
141 |
+
text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
|
142 |
+
max_length=args.text_len_2,
|
143 |
+
text_encoder_precision=args.text_encoder_precision_2,
|
144 |
+
tokenizer_type=args.tokenizer_2,
|
145 |
+
use_attention_mask=args.use_attention_mask,
|
146 |
+
reproduce=args.reproduce,
|
147 |
+
logger=logger,
|
148 |
+
device='cpu' if args.cpu_offload else device ,
|
149 |
+
# if not args.use_cpu_offload else 'cpu'
|
150 |
+
)
|
151 |
+
|
152 |
+
return cls(args=args,
|
153 |
+
vae=vae,
|
154 |
+
vae_kwargs=vae_kwargs,
|
155 |
+
text_encoder=text_encoder,
|
156 |
+
model=model,
|
157 |
+
text_encoder_2=text_encoder_2,
|
158 |
+
device=device,
|
159 |
+
logger=logger)
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def load_state_dict(args, model, ckpt_path):
|
163 |
+
load_key = args.load_key
|
164 |
+
ckpt_path = Path(ckpt_path)
|
165 |
+
if ckpt_path.is_dir():
|
166 |
+
ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
|
167 |
+
state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
|
168 |
+
if load_key in state_dict:
|
169 |
+
state_dict = state_dict[load_key]
|
170 |
+
elif load_key == ".":
|
171 |
+
pass
|
172 |
+
else:
|
173 |
+
raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
|
174 |
+
model.load_state_dict(state_dict, strict=False)
|
175 |
+
return model
|
176 |
+
|
177 |
+
def get_exp_dir_and_ckpt_id(self):
|
178 |
+
if self.ckpt is None:
|
179 |
+
raise ValueError("The checkpoint path is not provided.")
|
180 |
+
|
181 |
+
ckpt = Path(self.ckpt)
|
182 |
+
if ckpt.parents[1].name == "checkpoints":
|
183 |
+
# It should be a standard checkpoint path. We use the parent directory as the default save directory.
|
184 |
+
exp_dir = ckpt.parents[2]
|
185 |
+
else:
|
186 |
+
raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
|
187 |
+
f"It seems that the checkpoint path is not standard. Please explicitly provide the "
|
188 |
+
f"save path by --save-path.")
|
189 |
+
return exp_dir, ckpt.parent.name
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def parse_size(size):
|
193 |
+
if isinstance(size, int):
|
194 |
+
size = [size]
|
195 |
+
if not isinstance(size, (list, tuple)):
|
196 |
+
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
|
197 |
+
if len(size) == 1:
|
198 |
+
size = [size[0], size[0]]
|
199 |
+
if len(size) != 2:
|
200 |
+
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
|
201 |
+
return size
|
hymm_sp/modules/__init__.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
|
2 |
+
"""
|
3 |
+
This module provides functionality to load the Hunyuan video diffusion transformer model,
|
4 |
+
which is used for video generation tasks with specific configurations and parameters.
|
5 |
+
"""
|
6 |
+
|
7 |
+
def load_model(args, in_channels, out_channels, factor_kwargs):
|
8 |
+
"""
|
9 |
+
Load and initialize the HYVideoDiffusionTransformer model with specified parameters.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
args: Command-line arguments or configuration object containing model settings.
|
13 |
+
Must include 'model' attribute to select the appropriate configuration.
|
14 |
+
in_channels (int): Number of input channels for the model.
|
15 |
+
out_channels (int): Number of output channels the model should produce.
|
16 |
+
factor_kwargs (dict): Additional keyword arguments for factor adjustments
|
17 |
+
in the model architecture.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
HYVideoDiffusionTransformer: Initialized instance of the video diffusion transformer
|
21 |
+
model with the specified configuration.
|
22 |
+
|
23 |
+
Notes:
|
24 |
+
- Uses the HUNYUAN_VIDEO_CONFIG dictionary to retrieve model-specific configurations
|
25 |
+
based on the model name provided in args.
|
26 |
+
- Sets multitask_mask_training_type to "concat" as a default for this loading setup.
|
27 |
+
"""
|
28 |
+
# Initialize the Hunyuan video diffusion transformer with combined configurations
|
29 |
+
# Merges base config from HUNYUAN_VIDEO_CONFIG and additional factor arguments
|
30 |
+
model = HYVideoDiffusionTransformer(
|
31 |
+
args,
|
32 |
+
in_channels=in_channels,
|
33 |
+
out_channels=out_channels,
|
34 |
+
multitask_mask_training_type="concat",
|
35 |
+
**HUNYUAN_VIDEO_CONFIG[args.model], # Unpack model-specific configuration
|
36 |
+
** factor_kwargs, # Unpack additional factor adjustments
|
37 |
+
)
|
38 |
+
return model
|
hymm_sp/modules/activation_layers.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def get_activation_layer(act_type):
|
5 |
+
"""get activation layer
|
6 |
+
|
7 |
+
Args:
|
8 |
+
act_type (str): the activation type
|
9 |
+
|
10 |
+
Returns:
|
11 |
+
torch.nn.functional: the activation layer
|
12 |
+
"""
|
13 |
+
if act_type == "gelu":
|
14 |
+
return lambda: nn.GELU()
|
15 |
+
elif act_type == "gelu_tanh":
|
16 |
+
# Approximate `tanh` requires torch >= 1.13
|
17 |
+
return lambda: nn.GELU(approximate="tanh")
|
18 |
+
elif act_type == "relu":
|
19 |
+
return nn.ReLU
|
20 |
+
elif act_type == "silu":
|
21 |
+
return nn.SiLU
|
22 |
+
else:
|
23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hymm_sp/modules/attn_layers.py
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib.metadata
|
2 |
+
import math
|
3 |
+
from typing import Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
try:
|
9 |
+
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func
|
10 |
+
from flash_attn.bert_padding import index_first_axis
|
11 |
+
except ImportError:
|
12 |
+
flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
|
13 |
+
index_first_axis = None
|
14 |
+
from packaging import version
|
15 |
+
from transformers.utils.import_utils import _is_package_available
|
16 |
+
|
17 |
+
from .norm_layers import get_norm_layer
|
18 |
+
|
19 |
+
|
20 |
+
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
|
21 |
+
"""
|
22 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
23 |
+
|
24 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
25 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
26 |
+
|
27 |
+
Notes:
|
28 |
+
When using FlashMHAModified, head_first should be False.
|
29 |
+
When using Attention, head_first should be True.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
33 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
34 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
torch.Tensor: Reshaped frequency tensor.
|
38 |
+
|
39 |
+
Raises:
|
40 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
41 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
42 |
+
"""
|
43 |
+
ndim = x.ndim
|
44 |
+
assert 0 <= 1 < ndim
|
45 |
+
|
46 |
+
if isinstance(freqs_cis, tuple):
|
47 |
+
# freqs_cis: (cos, sin) in real space
|
48 |
+
if head_first:
|
49 |
+
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), \
|
50 |
+
f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
51 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
52 |
+
else:
|
53 |
+
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), \
|
54 |
+
f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
|
55 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
56 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
57 |
+
else:
|
58 |
+
# freqs_cis: values in complex space
|
59 |
+
if head_first:
|
60 |
+
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), \
|
61 |
+
f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
62 |
+
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
63 |
+
else:
|
64 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), \
|
65 |
+
f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
|
66 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
67 |
+
return freqs_cis.view(*shape)
|
68 |
+
|
69 |
+
|
70 |
+
def rotate_half(x):
|
71 |
+
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
72 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
73 |
+
|
74 |
+
|
75 |
+
def apply_rotary_emb(
|
76 |
+
xq: torch.Tensor,
|
77 |
+
xk: torch.Tensor,
|
78 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
79 |
+
head_first: bool = False,
|
80 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
81 |
+
"""
|
82 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
83 |
+
|
84 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
85 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
86 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
87 |
+
returned as real tensors.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
91 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
92 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
93 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
97 |
+
|
98 |
+
"""
|
99 |
+
xk_out = None
|
100 |
+
if isinstance(freqs_cis, tuple):
|
101 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
102 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
103 |
+
# real * cos - imag * sin
|
104 |
+
# imag * cos + real * sin
|
105 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
106 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
107 |
+
else:
|
108 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
109 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
110 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
111 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
112 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
113 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
114 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
115 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
116 |
+
|
117 |
+
return xq_out, xk_out
|
118 |
+
|
119 |
+
|
120 |
+
class BasicAttentionLayer(nn.Module):
|
121 |
+
def __init__(self, attn_mode='flash', deterministic=False):
|
122 |
+
super().__init__()
|
123 |
+
self.attn_mode = attn_mode
|
124 |
+
self.deterministic = deterministic
|
125 |
+
|
126 |
+
def set_attn_mode(self, new_mode):
|
127 |
+
self.attn_mode = new_mode
|
128 |
+
|
129 |
+
def enable_deterministic(self):
|
130 |
+
self.deterministic = True
|
131 |
+
|
132 |
+
def disable_deterministic(self):
|
133 |
+
self.deterministic = False
|
134 |
+
|
135 |
+
|
136 |
+
MEMORY_LAYOUT = {
|
137 |
+
"self_flash": (
|
138 |
+
lambda x: x,
|
139 |
+
lambda x: x,
|
140 |
+
),
|
141 |
+
"cross_flash": (
|
142 |
+
lambda x: x,
|
143 |
+
lambda x: x,
|
144 |
+
),
|
145 |
+
"torch": (
|
146 |
+
lambda x: x.transpose(1, 2),
|
147 |
+
lambda x: x.transpose(1, 2),
|
148 |
+
),
|
149 |
+
"vanilla": (
|
150 |
+
lambda x: x.transpose(1, 2),
|
151 |
+
lambda x: x.transpose(1, 2),
|
152 |
+
),
|
153 |
+
}
|
154 |
+
|
155 |
+
|
156 |
+
# Copyed from https://github.com/huggingface/transformers/blob/
|
157 |
+
# b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
|
158 |
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
159 |
+
"""
|
160 |
+
Retrieves indexing data required to repad unpadded (ragged) tensors.
|
161 |
+
|
162 |
+
Arguments:
|
163 |
+
attention_mask (`torch.Tensor`):
|
164 |
+
Boolean or int tensor of shape (batch_size, sequence_length), 1 means
|
165 |
+
valid and 0 means not valid.
|
166 |
+
|
167 |
+
Return:
|
168 |
+
indices (`torch.Tensor):
|
169 |
+
The indices of non-masked tokens from the flattened input sequence.
|
170 |
+
cu_seqlens (`torch.Tensor`):
|
171 |
+
The cumulative sequence lengths, used to index into ragged (unpadded)
|
172 |
+
tensors. `cu_seqlens` shape is (batch_size + 1,).
|
173 |
+
max_seqlen_in_batch (`int`):
|
174 |
+
Maximum sequence length in batch.
|
175 |
+
"""
|
176 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
177 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
178 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
179 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
|
180 |
+
return (
|
181 |
+
indices,
|
182 |
+
cu_seqlens,
|
183 |
+
max_seqlen_in_batch,
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
# Copyed from https://github.com/huggingface/transformers/blob/
|
188 |
+
# b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
|
189 |
+
def is_flash_attn_greater_or_equal(library_version: str):
|
190 |
+
if not _is_package_available("flash_attn"):
|
191 |
+
return False
|
192 |
+
|
193 |
+
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
|
194 |
+
|
195 |
+
|
196 |
+
def get_kv_seqlens_with_mask(attn_mask, k, v):
|
197 |
+
indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
|
198 |
+
b, s1, a, d = k.shape
|
199 |
+
k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
|
200 |
+
v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
|
201 |
+
kv = torch.stack([k, v], dim=1)
|
202 |
+
return cu_seqlens_k, max_seqlen_k, kv
|
203 |
+
|
204 |
+
|
205 |
+
def get_q_seqlens(q):
|
206 |
+
bs, s, a, d = q.shape
|
207 |
+
cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
|
208 |
+
q = q.reshape(bs * s, a, d)
|
209 |
+
return cu_seqlens_q, s, q
|
210 |
+
|
211 |
+
|
212 |
+
def attention(q, k, v, mode, drop_rate=0, attn_mask=None, causal=False, deterministic=False,
|
213 |
+
cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
|
214 |
+
"""
|
215 |
+
Perform QKV self attention.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
219 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
220 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
221 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
222 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
223 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
224 |
+
(default: None)
|
225 |
+
causal (bool): Whether to use causal attention. (default: False)
|
226 |
+
deterministic (bool): Whether to use deterministic attention. (default: False)
|
227 |
+
cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
228 |
+
used to index into q.
|
229 |
+
max_seqlen (int): The maximum sequence length in the batch of q.
|
230 |
+
cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
231 |
+
used to index into kv.
|
232 |
+
max_seqlen_k (int): The maximum sequence length in the batch of k and v.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
236 |
+
"""
|
237 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
238 |
+
q = pre_attn_layout(q)
|
239 |
+
k = pre_attn_layout(k)
|
240 |
+
v = pre_attn_layout(v)
|
241 |
+
|
242 |
+
if mode == 'torch':
|
243 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
244 |
+
attn_mask = attn_mask.to(q.dtype)
|
245 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
246 |
+
|
247 |
+
elif mode == 'vanilla':
|
248 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
249 |
+
|
250 |
+
b, a, s, _ = q.shape
|
251 |
+
s1 = k.size(2)
|
252 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
253 |
+
if causal:
|
254 |
+
# Only applied to self attention
|
255 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
256 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
257 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
258 |
+
attn_bias.to(q.dtype)
|
259 |
+
|
260 |
+
if attn_mask is not None:
|
261 |
+
if attn_mask.dtype == torch.bool:
|
262 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
263 |
+
else:
|
264 |
+
attn_bias += attn_mask
|
265 |
+
|
266 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
267 |
+
attn += attn_bias
|
268 |
+
attn = attn.softmax(dim=-1)
|
269 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
270 |
+
x = attn @ v
|
271 |
+
else:
|
272 |
+
raise NotImplementedError(f'Unsupported attention mode: {mode}')
|
273 |
+
|
274 |
+
x = post_attn_layout(x)
|
275 |
+
b, s, a, d = x.shape
|
276 |
+
out = x.reshape(b, s, -1)
|
277 |
+
return out
|
278 |
+
|
279 |
+
|
280 |
+
class SelfAttentionLayer(BasicAttentionLayer):
|
281 |
+
def __init__(self,
|
282 |
+
dim,
|
283 |
+
num_heads,
|
284 |
+
qkv_bias=True,
|
285 |
+
qk_norm=True,
|
286 |
+
attn_drop=0,
|
287 |
+
proj_drop=0,
|
288 |
+
dtype=None,
|
289 |
+
device=None,
|
290 |
+
norm_type='layer',
|
291 |
+
attn_mode='self_flash',
|
292 |
+
deterministic=False,
|
293 |
+
) -> None:
|
294 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
295 |
+
super().__init__(attn_mode, deterministic)
|
296 |
+
self.dim = dim
|
297 |
+
self.num_heads = num_heads
|
298 |
+
assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
|
299 |
+
self.head_dim = self.dim // num_heads
|
300 |
+
self.attn_drop = attn_drop
|
301 |
+
|
302 |
+
# This assertion is aligned with flash attention
|
303 |
+
assert (
|
304 |
+
self.head_dim % 8 == 0 and self.head_dim <= 128
|
305 |
+
), "Only support head_dim <= 128 and divisible by 8"
|
306 |
+
|
307 |
+
self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
|
308 |
+
|
309 |
+
norm_layer = get_norm_layer(norm_type)
|
310 |
+
self.q_norm = (
|
311 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
312 |
+
if qk_norm
|
313 |
+
else nn.Identity()
|
314 |
+
)
|
315 |
+
self.k_norm = (
|
316 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
317 |
+
if qk_norm
|
318 |
+
else nn.Identity()
|
319 |
+
)
|
320 |
+
|
321 |
+
self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
|
322 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
323 |
+
|
324 |
+
def forward(self, x, freqs_cis=None, attn_mask=None):
|
325 |
+
"""
|
326 |
+
Args:
|
327 |
+
x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
|
328 |
+
freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
|
329 |
+
attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
|
330 |
+
"""
|
331 |
+
b, s, d = x.shape
|
332 |
+
|
333 |
+
# Apply QKV projection
|
334 |
+
qkv = self.Wqkv(x)
|
335 |
+
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
|
336 |
+
q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
|
337 |
+
|
338 |
+
# Apply QK-Norm if needed
|
339 |
+
q = self.q_norm(q)
|
340 |
+
k = self.k_norm(k)
|
341 |
+
|
342 |
+
# Apply RoPE if needed
|
343 |
+
if freqs_cis is not None:
|
344 |
+
qq, kk = apply_rotary_emb(q, k, freqs_cis)
|
345 |
+
assert qq.shape == q.shape and kk.shape == k.shape, \
|
346 |
+
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
|
347 |
+
q, k = qq, kk
|
348 |
+
|
349 |
+
# Apply self attention
|
350 |
+
context = attention(q, k, v,
|
351 |
+
drop_rate=self.attn_drop if self.training else 0,
|
352 |
+
attn_mask=attn_mask,
|
353 |
+
mode=self.attn_mode,
|
354 |
+
deterministic=self.deterministic,
|
355 |
+
)
|
356 |
+
out = self.out_proj(context)
|
357 |
+
out = self.proj_drop(out)
|
358 |
+
|
359 |
+
return out
|
360 |
+
|
361 |
+
|
362 |
+
class CrossAttentionLayer(BasicAttentionLayer):
|
363 |
+
def __init__(self,
|
364 |
+
qdim,
|
365 |
+
kdim,
|
366 |
+
num_heads,
|
367 |
+
qkv_bias=True,
|
368 |
+
qk_norm=True,
|
369 |
+
attn_drop=0,
|
370 |
+
proj_drop=0,
|
371 |
+
dtype=None,
|
372 |
+
device=None,
|
373 |
+
norm_type='layer',
|
374 |
+
attn_mode='cross_flash',
|
375 |
+
deterministic=False,
|
376 |
+
):
|
377 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
378 |
+
super().__init__(attn_mode, deterministic)
|
379 |
+
self.qdim = qdim
|
380 |
+
self.kdim = kdim
|
381 |
+
self.num_heads = num_heads
|
382 |
+
assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
|
383 |
+
self.head_dim = self.qdim // num_heads
|
384 |
+
self.attn_drop = attn_drop
|
385 |
+
|
386 |
+
# This assertion is aligned with flash attention
|
387 |
+
assert (
|
388 |
+
self.head_dim % 8 == 0 and self.head_dim <= 128
|
389 |
+
), "Only support head_dim <= 128 and divisible by 8"
|
390 |
+
|
391 |
+
self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
392 |
+
self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
|
393 |
+
|
394 |
+
norm_layer = get_norm_layer(norm_type)
|
395 |
+
self.q_norm = (
|
396 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
397 |
+
if qk_norm
|
398 |
+
else nn.Identity()
|
399 |
+
)
|
400 |
+
self.k_norm = (
|
401 |
+
norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
402 |
+
if qk_norm
|
403 |
+
else nn.Identity()
|
404 |
+
)
|
405 |
+
|
406 |
+
self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
|
407 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
408 |
+
|
409 |
+
def forward(self, x, y, attn_mask=None):
|
410 |
+
"""
|
411 |
+
Args:
|
412 |
+
x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
|
413 |
+
y (torch.Tensor): (batch, seq_len1, hidden_dim1)
|
414 |
+
attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
|
415 |
+
"""
|
416 |
+
b, s, d = x.shape
|
417 |
+
_, s1, d1 = y.shape
|
418 |
+
|
419 |
+
q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
|
420 |
+
kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
|
421 |
+
k, v = kv.unbind(dim=2)
|
422 |
+
|
423 |
+
# Apply QK-Norm if needed
|
424 |
+
q = self.q_norm(q)
|
425 |
+
k = self.k_norm(k)
|
426 |
+
|
427 |
+
# Apply cross attention
|
428 |
+
context = attention(q, k, v,
|
429 |
+
attn_mask=attn_mask,
|
430 |
+
drop_rate=self.attn_drop if self.training else 0,
|
431 |
+
mode=self.attn_mode,
|
432 |
+
deterministic=self.deterministic,
|
433 |
+
)
|
434 |
+
out = self.out_proj(context)
|
435 |
+
out = self.proj_drop(out)
|
436 |
+
|
437 |
+
return out
|
hymm_sp/modules/cameranet.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import einops
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import collections.abc
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.init as init
|
9 |
+
|
10 |
+
from pathlib import Path
|
11 |
+
from einops import rearrange
|
12 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
14 |
+
from itertools import repeat
|
15 |
+
from .embed_layers import PatchEmbed
|
16 |
+
|
17 |
+
|
18 |
+
def _ntuple(n):
|
19 |
+
"""
|
20 |
+
Creates a helper function to convert inputs to tuples of specified length.
|
21 |
+
|
22 |
+
Functionality:
|
23 |
+
- Converts iterable inputs (excluding strings) to tuples, ensuring length n
|
24 |
+
- Repeats single values n times to form a tuple
|
25 |
+
Useful for handling multi-dimensional parameters like kernel sizes and strides.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
n (int): Target length of the tuple
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
function: A parser function that converts inputs to n-length tuples
|
32 |
+
"""
|
33 |
+
def parse(x):
|
34 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
35 |
+
x = tuple(x)
|
36 |
+
if len(x) == 1:
|
37 |
+
x = tuple(repeat(x[0], n))
|
38 |
+
return x
|
39 |
+
return tuple(repeat(x, n))
|
40 |
+
return parse
|
41 |
+
|
42 |
+
|
43 |
+
# Create common tuple conversion functions
|
44 |
+
to_1tuple = _ntuple(1)
|
45 |
+
to_2tuple = _ntuple(2)
|
46 |
+
to_3tuple = _ntuple(3)
|
47 |
+
to_4tuple = _ntuple(4)
|
48 |
+
|
49 |
+
|
50 |
+
class CameraNet(ModelMixin):
|
51 |
+
"""
|
52 |
+
Camera state encoding network that processes camera parameters into feature embeddings.
|
53 |
+
|
54 |
+
This network converts camera state information into suitable feature representations
|
55 |
+
for video generation models through downsampling, convolutional encoding, and
|
56 |
+
temporal dimension compression. Supports loading from pretrained weights.
|
57 |
+
"""
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
in_channels,
|
61 |
+
downscale_coef,
|
62 |
+
out_channels,
|
63 |
+
patch_size,
|
64 |
+
hidden_size,
|
65 |
+
):
|
66 |
+
super().__init__()
|
67 |
+
# Calculate initial channels: PixelUnshuffle moves spatial info to channel dimension
|
68 |
+
# resulting in channels = in_channels * (downscale_coef^2)
|
69 |
+
start_channels = in_channels * (downscale_coef ** 2)
|
70 |
+
input_channels = [start_channels, start_channels // 2, start_channels // 4]
|
71 |
+
self.input_channels = input_channels
|
72 |
+
self.unshuffle = nn.PixelUnshuffle(downscale_coef)
|
73 |
+
|
74 |
+
self.encode_first = nn.Sequential(
|
75 |
+
nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
|
76 |
+
nn.GroupNorm(2, input_channels[1]),
|
77 |
+
nn.ReLU(),
|
78 |
+
)
|
79 |
+
self._initialize_weights(self.encode_first)
|
80 |
+
self.encode_second = nn.Sequential(
|
81 |
+
nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
|
82 |
+
nn.GroupNorm(2, input_channels[2]),
|
83 |
+
nn.ReLU(),
|
84 |
+
)
|
85 |
+
self._initialize_weights(self.encode_second)
|
86 |
+
|
87 |
+
self.final_proj = nn.Conv2d(input_channels[2], out_channels, kernel_size=1)
|
88 |
+
self.zeros_init_linear(self.final_proj)
|
89 |
+
|
90 |
+
self.scale = nn.Parameter(torch.ones(1))
|
91 |
+
|
92 |
+
self.camera_in = PatchEmbed(patch_size=patch_size, in_chans=out_channels, embed_dim=hidden_size)
|
93 |
+
|
94 |
+
|
95 |
+
def zeros_init_linear(self, linear: nn.Module):
|
96 |
+
"""
|
97 |
+
Zero-initializes weights and biases of linear or convolutional layers.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
linear (nn.Module): Linear or convolutional layer to initialize
|
101 |
+
"""
|
102 |
+
if isinstance(linear, (nn.Linear, nn.Conv2d)):
|
103 |
+
if hasattr(linear, "weight"):
|
104 |
+
nn.init.zeros_(linear.weight)
|
105 |
+
if hasattr(linear, "bias"):
|
106 |
+
nn.init.zeros_(linear.bias)
|
107 |
+
|
108 |
+
def _initialize_weights(self, block):
|
109 |
+
"""
|
110 |
+
Initializes convolutional layer weights using He initialization,
|
111 |
+
with biases initialized to zero.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
block (nn.Sequential): Sequential block containing convolutional layers
|
115 |
+
"""
|
116 |
+
for m in block:
|
117 |
+
if isinstance(m, nn.Conv2d):
|
118 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
|
119 |
+
init.normal_(m.weight, mean=0.0, std=np.sqrt(2.0 / n))
|
120 |
+
if m.bias is not None:
|
121 |
+
init.zeros_(m.bias)
|
122 |
+
|
123 |
+
|
124 |
+
def compress_time(self, x, num_frames):
|
125 |
+
"""
|
126 |
+
Temporal dimension compression: reduces number of frames using average pooling
|
127 |
+
while preserving key temporal information.
|
128 |
+
|
129 |
+
Handling logic:
|
130 |
+
- Special frame counts (66 or 34): split into two segments, keep first frame of each
|
131 |
+
segment then pool remaining frames
|
132 |
+
- Odd frame counts: keep first frame, pool remaining frames
|
133 |
+
- Even frame counts: directly pool all frames
|
134 |
+
|
135 |
+
Args:
|
136 |
+
x (torch.Tensor): Input tensor with shape (b*f, c, h, w)
|
137 |
+
num_frames (int): Number of frames in temporal dimension
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
torch.Tensor: Temporally compressed tensor with shape (b*f', c, h, w) where f' < f
|
141 |
+
"""
|
142 |
+
# Reshape: (b*f, c, h, w) -> (b, f, c, h, w)
|
143 |
+
x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
|
144 |
+
batch_size, frames, channels, height, width = x.shape
|
145 |
+
x = rearrange(x, 'b f c h w -> (b h w) c f')
|
146 |
+
|
147 |
+
# print(x.shape)
|
148 |
+
# raise Exception
|
149 |
+
# Handle special frame counts (66 or 34)
|
150 |
+
if x.shape[-1] == 66 or x.shape[-1] == 34:
|
151 |
+
x_len = x.shape[-1]
|
152 |
+
# Process first segment: keep first frame, pool remaining
|
153 |
+
x_clip1 = x[...,:x_len//2]
|
154 |
+
x_clip1_first, x_clip1_rest = x_clip1[..., 0].unsqueeze(-1), x_clip1[..., 1:]
|
155 |
+
x_clip1_rest = F.avg_pool1d(x_clip1_rest, kernel_size=2, stride=2)
|
156 |
+
|
157 |
+
# Process second segment: keep first frame, pool remaining
|
158 |
+
x_clip2 = x[...,x_len//2:x_len]
|
159 |
+
x_clip2_first, x_clip2_rest = x_clip2[..., 0].unsqueeze(-1), x_clip2[..., 1:]
|
160 |
+
x_clip2_rest = F.avg_pool1d(x_clip2_rest, kernel_size=2, stride=2)
|
161 |
+
|
162 |
+
# Concatenate results from both segments
|
163 |
+
x = torch.cat([x_clip1_first, x_clip1_rest, x_clip2_first, x_clip2_rest], dim=-1)
|
164 |
+
|
165 |
+
elif x.shape[-1] % 2 == 1:
|
166 |
+
x_first, x_rest = x[..., 0], x[..., 1:]
|
167 |
+
if x_rest.shape[-1] > 0:
|
168 |
+
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
169 |
+
|
170 |
+
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
171 |
+
else:
|
172 |
+
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
173 |
+
x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
|
174 |
+
return x
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self,
|
178 |
+
camera_states: torch.Tensor,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Forward pass: encodes camera states into feature embeddings.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
camera_states (torch.Tensor): Camera state tensor with dimensions
|
185 |
+
(batch, frames, channels, height, width)
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
torch.Tensor: Encoded feature embeddings after patch embedding and scaling
|
189 |
+
"""
|
190 |
+
# import pdb;pdb.set_trace()
|
191 |
+
batch_size, num_frames, channels, height, width = camera_states.shape
|
192 |
+
camera_states = rearrange(camera_states, 'b f c h w -> (b f) c h w')
|
193 |
+
camera_states = self.unshuffle(camera_states)
|
194 |
+
camera_states = self.encode_first(camera_states)
|
195 |
+
camera_states = self.compress_time(camera_states, num_frames=num_frames)
|
196 |
+
num_frames = camera_states.shape[0] // batch_size
|
197 |
+
camera_states = self.encode_second(camera_states)
|
198 |
+
camera_states = self.compress_time(camera_states, num_frames=num_frames)
|
199 |
+
# camera_states = rearrange(camera_states, '(b f) c h w -> b f c h w', b=batch_size)
|
200 |
+
camera_states = self.final_proj(camera_states)
|
201 |
+
camera_states = rearrange(camera_states, "(b f) c h w -> b c f h w", b=batch_size)
|
202 |
+
camera_states = self.camera_in(camera_states)
|
203 |
+
return camera_states * self.scale
|
204 |
+
|
205 |
+
@classmethod
|
206 |
+
def from_pretrained(cls, pretrained_model_path):
|
207 |
+
"""
|
208 |
+
Loads model from pretrained weight file.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
pretrained_model_path (str): Path to pretrained weight file
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
CameraNet: Model instance with loaded pretrained weights
|
215 |
+
"""
|
216 |
+
if not Path(pretrained_model_path).exists():
|
217 |
+
print(f"There is no model file in {pretrained_model_path}")
|
218 |
+
print(f"loaded CameraNet's pretrained weights from {pretrained_model_path}.")
|
219 |
+
|
220 |
+
state_dict = torch.load(pretrained_model_path, map_location="cpu")
|
221 |
+
model = CameraNet(in_channels=6, downscale_coef=8, out_channels=16)
|
222 |
+
model.load_state_dict(state_dict, strict=True)
|
223 |
+
return model
|
224 |
+
|
225 |
+
|
226 |
+
if __name__ == "__main__":
|
227 |
+
# Test model initialization and forward pass
|
228 |
+
model = CameraNet(
|
229 |
+
in_channels=6,
|
230 |
+
downscale_coef=8,
|
231 |
+
out_channels=16,
|
232 |
+
patch_size=[1,2,2],
|
233 |
+
hidden_size=3072
|
234 |
+
)
|
235 |
+
print("Model structure:")
|
236 |
+
print(model)
|
237 |
+
|
238 |
+
# Generate test input (batch 1, 33 frames, 6 channels, 704x1280 resolution)
|
239 |
+
num_frames = 33
|
240 |
+
input_tensor = torch.randn(1, num_frames, 6, 704, 1280)
|
241 |
+
|
242 |
+
# Forward pass
|
243 |
+
output_tensor = model(input_tensor)
|
244 |
+
|
245 |
+
# Print results
|
246 |
+
print(f"Output shape: {output_tensor.shape}") # Expected: torch.Size([1, ...])
|
247 |
+
print("Output tensor example:")
|
248 |
+
print(output_tensor)
|
hymm_sp/modules/embed_layers.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from hymm_sp.helpers import to_2tuple
|
5 |
+
|
6 |
+
|
7 |
+
class PatchEmbed(nn.Module):
|
8 |
+
""" 2D Image to Patch Embedding
|
9 |
+
|
10 |
+
Image to Patch Embedding using Conv2d
|
11 |
+
|
12 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
13 |
+
|
14 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
15 |
+
|
16 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
17 |
+
|
18 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
19 |
+
"""
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
patch_size=16,
|
23 |
+
in_chans=3,
|
24 |
+
embed_dim=768,
|
25 |
+
multitask_mask_training_type=None,
|
26 |
+
norm_layer=None,
|
27 |
+
flatten=True,
|
28 |
+
bias=True,
|
29 |
+
dtype=None,
|
30 |
+
device=None
|
31 |
+
):
|
32 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
33 |
+
super().__init__()
|
34 |
+
patch_size = to_2tuple(patch_size)
|
35 |
+
self.patch_size = patch_size
|
36 |
+
self.flatten = flatten
|
37 |
+
|
38 |
+
if multitask_mask_training_type == "concat":
|
39 |
+
orig_in_chans = in_chans
|
40 |
+
in_chans = in_chans * 2 + 1
|
41 |
+
|
42 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias,
|
43 |
+
**factory_kwargs)
|
44 |
+
if multitask_mask_training_type == "concat":
|
45 |
+
nn.init.xavier_uniform_(\
|
46 |
+
self.proj.weight[:, :orig_in_chans].view(self.proj.weight[:, :orig_in_chans].size(0), -1))
|
47 |
+
nn.init.zeros_(self.proj.weight[:, orig_in_chans:].view(self.proj.weight[:, orig_in_chans:].size(0), -1))
|
48 |
+
else:
|
49 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
50 |
+
|
51 |
+
|
52 |
+
if bias:
|
53 |
+
nn.init.zeros_(self.proj.bias)
|
54 |
+
|
55 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.proj(x)
|
59 |
+
if self.flatten:
|
60 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
61 |
+
x = self.norm(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class TextProjection(nn.Module):
|
66 |
+
"""
|
67 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
68 |
+
|
69 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
73 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
74 |
+
super().__init__()
|
75 |
+
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
76 |
+
self.act_1 = act_layer()
|
77 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
78 |
+
|
79 |
+
def forward(self, caption):
|
80 |
+
hidden_states = self.linear_1(caption)
|
81 |
+
hidden_states = self.act_1(hidden_states)
|
82 |
+
hidden_states = self.linear_2(hidden_states)
|
83 |
+
return hidden_states
|
84 |
+
|
85 |
+
|
86 |
+
def timestep_embedding(t, dim, max_period=10000):
|
87 |
+
"""
|
88 |
+
Create sinusoidal timestep embeddings.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
92 |
+
dim (int): the dimension of the output.
|
93 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
97 |
+
|
98 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
99 |
+
"""
|
100 |
+
half = dim // 2
|
101 |
+
freqs = torch.exp(
|
102 |
+
-math.log(max_period)
|
103 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
104 |
+
/ half
|
105 |
+
).to(device=t.device)
|
106 |
+
args = t[:, None].float() * freqs[None]
|
107 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
108 |
+
if dim % 2:
|
109 |
+
embedding = torch.cat(
|
110 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
111 |
+
)
|
112 |
+
return embedding
|
113 |
+
|
114 |
+
|
115 |
+
class TimestepEmbedder(nn.Module):
|
116 |
+
"""
|
117 |
+
Embeds scalar timesteps into vector representations.
|
118 |
+
"""
|
119 |
+
def __init__(self,
|
120 |
+
hidden_size,
|
121 |
+
act_layer,
|
122 |
+
frequency_embedding_size=256,
|
123 |
+
max_period=10000,
|
124 |
+
out_size=None,
|
125 |
+
dtype=None,
|
126 |
+
device=None
|
127 |
+
):
|
128 |
+
factory_kwargs = {'dtype': dtype, 'device': device}
|
129 |
+
super().__init__()
|
130 |
+
self.frequency_embedding_size = frequency_embedding_size
|
131 |
+
self.max_period = max_period
|
132 |
+
if out_size is None:
|
133 |
+
out_size = hidden_size
|
134 |
+
|
135 |
+
self.mlp = nn.Sequential(
|
136 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
137 |
+
act_layer(),
|
138 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
139 |
+
)
|
140 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
141 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
142 |
+
|
143 |
+
def forward(self, t):
|
144 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
145 |
+
t_emb = self.mlp(t_freq)
|
146 |
+
return t_emb
|
hymm_sp/modules/fp8_optimization.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/neuralmagic/AutoFP8/blob/main/auto_fp8/quantize.py
|
2 |
+
import gc
|
3 |
+
from typing import Tuple
|
4 |
+
import copy
|
5 |
+
import torch
|
6 |
+
import tqdm
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def cleanup_memory():
|
13 |
+
gc.collect()
|
14 |
+
torch.cuda.empty_cache()
|
15 |
+
|
16 |
+
|
17 |
+
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
18 |
+
"""Quantize a tensor using per-tensor static scaling factor.
|
19 |
+
Args:
|
20 |
+
tensor: The input tensor.
|
21 |
+
"""
|
22 |
+
finfo = torch.finfo(torch.float8_e4m3fn)
|
23 |
+
# Calculate the scale as dtype max divided by absmax.
|
24 |
+
# Since .abs() creates a new tensor, we use aminmax to get
|
25 |
+
# the min and max first and then calculate the absmax.
|
26 |
+
if tensor.numel() == 0:
|
27 |
+
# Deal with empty tensors (triggered by empty MoE experts)
|
28 |
+
min_val, max_val = (
|
29 |
+
torch.tensor(-16.0, dtype=tensor.dtype),
|
30 |
+
torch.tensor(16.0, dtype=tensor.dtype),
|
31 |
+
)
|
32 |
+
else:
|
33 |
+
min_val, max_val = tensor.aminmax()
|
34 |
+
amax = torch.maximum(min_val.abs(), max_val.abs())
|
35 |
+
scale = finfo.max / amax.clamp(min=1e-12)
|
36 |
+
# scale and clamp the tensor to bring it to
|
37 |
+
# the representative range of float8 data type
|
38 |
+
# (as default cast is unsaturated)
|
39 |
+
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
40 |
+
# Return both float8 data and the inverse scale (as float),
|
41 |
+
# as both required as inputs to torch._scaled_mm
|
42 |
+
qweight = qweight.to(torch.float8_e4m3fn)
|
43 |
+
scale = scale.float().reciprocal()
|
44 |
+
return qweight, scale
|
45 |
+
|
46 |
+
|
47 |
+
fp8_gemm_configs = [
|
48 |
+
triton.Config({'BLOCK_SIZE_M': block_m,
|
49 |
+
'BLOCK_SIZE_N': block_n,
|
50 |
+
'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
51 |
+
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
52 |
+
]
|
53 |
+
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
54 |
+
@triton.jit
|
55 |
+
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
56 |
+
a_scale, b_scale, # 改为单个scale值
|
57 |
+
M, N: tl.constexpr, K: tl.constexpr,
|
58 |
+
BLOCK_SIZE_M: tl.constexpr,
|
59 |
+
BLOCK_SIZE_N: tl.constexpr,
|
60 |
+
BLOCK_SIZE_K: tl.constexpr):
|
61 |
+
"""
|
62 |
+
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
63 |
+
"""
|
64 |
+
pid_m = tl.program_id(axis=0)
|
65 |
+
pid_n = tl.program_id(axis=1)
|
66 |
+
|
67 |
+
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
68 |
+
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
69 |
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
70 |
+
|
71 |
+
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
72 |
+
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
73 |
+
|
74 |
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
75 |
+
|
76 |
+
for i in range(0, K, BLOCK_SIZE_K):
|
77 |
+
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i, other=0.0)
|
78 |
+
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i, other=0.0)
|
79 |
+
|
80 |
+
accumulator += tl.dot(a, b) * a_scale * b_scale
|
81 |
+
|
82 |
+
a_ptrs += BLOCK_SIZE_K
|
83 |
+
b_ptrs += BLOCK_SIZE_K
|
84 |
+
|
85 |
+
c = accumulator.to(c_ptr.dtype.element_ty)
|
86 |
+
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
87 |
+
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
88 |
+
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
|
89 |
+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
|
90 |
+
tl.store(c_ptrs, c, mask=mask)
|
91 |
+
|
92 |
+
|
93 |
+
def triton_fp8_gemm(a: torch.Tensor,
|
94 |
+
b: torch.Tensor,
|
95 |
+
a_scale: float,
|
96 |
+
b_scale: float,
|
97 |
+
out_dtype=torch.bfloat16,
|
98 |
+
bias=None) -> torch.Tensor:
|
99 |
+
"""
|
100 |
+
Perform a matrix multiplication using FP8 precision with per-tensor quantization.
|
101 |
+
"""
|
102 |
+
assert a.is_contiguous() and b.is_contiguous()
|
103 |
+
|
104 |
+
K = a.size(-1)
|
105 |
+
M = a.numel() // K
|
106 |
+
N = b.size(0)
|
107 |
+
c = torch.empty((M, N), dtype=out_dtype, device=a.device)
|
108 |
+
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
109 |
+
if isinstance(a_scale, torch.Tensor):
|
110 |
+
a_scale = a_scale.item()
|
111 |
+
if isinstance(b_scale, torch.Tensor):
|
112 |
+
b_scale = b_scale.item()
|
113 |
+
# import pdb; pdb.set_trace()
|
114 |
+
fp8_gemm_kernel[grid](a, b, c, a_scale, b_scale, M, N, K)
|
115 |
+
if bias is not None:
|
116 |
+
|
117 |
+
c += bias
|
118 |
+
|
119 |
+
return c
|
120 |
+
|
121 |
+
|
122 |
+
def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
|
123 |
+
"""
|
124 |
+
Optimized FP8 GEMM implementation, supports both native FP8 and Triton paths,
|
125 |
+
and automatically handles 3D input and bias.
|
126 |
+
"""
|
127 |
+
if A.numel() == 0:
|
128 |
+
# Handle empty tensor (e.g., when MoE expert is empty)
|
129 |
+
return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
|
130 |
+
|
131 |
+
# Check if reshape is needed (support for 3D input)
|
132 |
+
need_reshape = (A.dim() == 3)
|
133 |
+
batch_size = A.shape[0] if need_reshape else None
|
134 |
+
A_input = A.reshape(-1, A.shape[-1]).contiguous() if need_reshape else A
|
135 |
+
|
136 |
+
if native_fp8_support:
|
137 |
+
# Native FP8 support
|
138 |
+
output = torch._scaled_mm(
|
139 |
+
A_input,
|
140 |
+
B.t(),
|
141 |
+
out_dtype=out_dtype,
|
142 |
+
scale_a=torch.tensor(A_scale) if not isinstance(A_scale, torch.Tensor) else A_scale,
|
143 |
+
scale_b=torch.tensor(B_scale) if not isinstance(B_scale, torch.Tensor) else B_scale,
|
144 |
+
bias=bias.to(out_dtype),
|
145 |
+
)
|
146 |
+
else:
|
147 |
+
# Triton implementation
|
148 |
+
output = triton_fp8_gemm(
|
149 |
+
A_input,
|
150 |
+
B.contiguous(),
|
151 |
+
out_dtype=out_dtype,
|
152 |
+
a_scale=A_scale,
|
153 |
+
b_scale=B_scale,
|
154 |
+
bias=None,
|
155 |
+
)
|
156 |
+
if bias is not None:
|
157 |
+
output += bias
|
158 |
+
|
159 |
+
if need_reshape:
|
160 |
+
# Restore original batch dimension
|
161 |
+
output = output.reshape(batch_size, -1, output.shape[-1])
|
162 |
+
|
163 |
+
return output
|
164 |
+
|
165 |
+
|
166 |
+
# Class responsible for quantizing weights
|
167 |
+
class FP8DynamicLinear(torch.nn.Module):
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
weight: torch.Tensor,
|
171 |
+
weight_scale: torch.Tensor,
|
172 |
+
bias: torch.nn.Parameter,
|
173 |
+
native_fp8_support: bool = False,
|
174 |
+
name: str = ""
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
178 |
+
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
|
179 |
+
self.bias = bias
|
180 |
+
self.native_fp8_support = native_fp8_support
|
181 |
+
self.name = name
|
182 |
+
# @torch.compile
|
183 |
+
def forward(self, x):
|
184 |
+
if x.dtype != torch.float16 and x.dtype != torch.bfloat16:
|
185 |
+
# print(f"Warning: {self.name}'s input is not quantized to float16 or bfloat16")
|
186 |
+
# print(f"input dtype: {x.dtype}")
|
187 |
+
x = x.to(torch.bfloat16)
|
188 |
+
qinput, x_scale = per_tensor_quantize(x)
|
189 |
+
# print("--------------")
|
190 |
+
# print("layer_name:", self.name)
|
191 |
+
# print("A_input.shape:", qinput.shape)
|
192 |
+
# print("B.shape:", self.weight.shape)
|
193 |
+
# print("--------------")
|
194 |
+
output = fp8_gemm(
|
195 |
+
A=qinput,
|
196 |
+
A_scale=x_scale,
|
197 |
+
B=self.weight,
|
198 |
+
B_scale=self.weight_scale,
|
199 |
+
bias=self.bias,
|
200 |
+
out_dtype=x.dtype,
|
201 |
+
native_fp8_support=self.native_fp8_support,
|
202 |
+
)
|
203 |
+
return output
|
204 |
+
|
205 |
+
|
206 |
+
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
|
207 |
+
if "." in name:
|
208 |
+
parent_name = name.rsplit(".", 1)[0]
|
209 |
+
child_name = name[len(parent_name) + 1 :]
|
210 |
+
parent = model.get_submodule(parent_name)
|
211 |
+
else:
|
212 |
+
parent_name = ""
|
213 |
+
parent = model
|
214 |
+
child_name = name
|
215 |
+
setattr(parent, child_name, new_module)
|
216 |
+
|
217 |
+
|
218 |
+
def convert_fp8_linear(model: torch.nn.Module):
|
219 |
+
# native_fp8_support = (
|
220 |
+
# torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
|
221 |
+
# )
|
222 |
+
native_fp8_support = False
|
223 |
+
named_modules = list(model.named_modules())
|
224 |
+
for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
|
225 |
+
if not isinstance(linear, torch.nn.Linear):
|
226 |
+
continue
|
227 |
+
if "mod" in name:
|
228 |
+
print(f"Warning: {name} is a mod module, skipping")
|
229 |
+
continue
|
230 |
+
if "block" not in name:
|
231 |
+
print(f"Warning: {name} is not in a block module, skipping")
|
232 |
+
continue
|
233 |
+
quant_weight, weight_scale = per_tensor_quantize(linear.weight)
|
234 |
+
bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
|
235 |
+
quant_linear = FP8DynamicLinear(
|
236 |
+
weight=quant_weight,
|
237 |
+
weight_scale=weight_scale,
|
238 |
+
bias=bias,
|
239 |
+
native_fp8_support=native_fp8_support,
|
240 |
+
name = name
|
241 |
+
)
|
242 |
+
replace_module(model, name, quant_linear)
|
243 |
+
del linear.weight
|
244 |
+
del linear.bias
|
245 |
+
del linear
|
246 |
+
cleanup_memory()
|
hymm_sp/modules/mlp_layers.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from timm library:
|
2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
3 |
+
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from .modulate_layers import modulate
|
10 |
+
from hymm_sp.helpers import to_2tuple
|
11 |
+
|
12 |
+
|
13 |
+
class MLP(nn.Module):
|
14 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
15 |
+
"""
|
16 |
+
def __init__(self,
|
17 |
+
in_channels,
|
18 |
+
hidden_channels=None,
|
19 |
+
out_features=None,
|
20 |
+
act_layer=nn.GELU,
|
21 |
+
norm_layer=None,
|
22 |
+
bias=True,
|
23 |
+
drop=0.,
|
24 |
+
use_conv=False,
|
25 |
+
device=None,
|
26 |
+
dtype=None
|
27 |
+
):
|
28 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
29 |
+
super().__init__()
|
30 |
+
out_features = out_features or in_channels
|
31 |
+
hidden_channels = hidden_channels or in_channels
|
32 |
+
bias = to_2tuple(bias)
|
33 |
+
drop_probs = to_2tuple(drop)
|
34 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
35 |
+
|
36 |
+
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
|
37 |
+
self.act = act_layer()
|
38 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
39 |
+
self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
|
40 |
+
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
|
41 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
x = self.fc1(x)
|
45 |
+
x = self.act(x)
|
46 |
+
x = self.drop1(x)
|
47 |
+
x = self.norm(x)
|
48 |
+
x = self.fc2(x)
|
49 |
+
x = self.drop2(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
class MLPEmbedder(nn.Module):
|
54 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
55 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
56 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
57 |
+
super().__init__()
|
58 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
59 |
+
self.silu = nn.SiLU()
|
60 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
61 |
+
|
62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
63 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
64 |
+
|
65 |
+
|
66 |
+
class FinalLayer(nn.Module):
|
67 |
+
"""The final layer of DiT."""
|
68 |
+
|
69 |
+
def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
|
70 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
# Just use LayerNorm for the final layer
|
74 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
75 |
+
if isinstance(patch_size, int):
|
76 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs)
|
77 |
+
else:
|
78 |
+
self.linear = nn.Linear(hidden_size,
|
79 |
+
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
80 |
+
bias=True)
|
81 |
+
nn.init.zeros_(self.linear.weight)
|
82 |
+
nn.init.zeros_(self.linear.bias)
|
83 |
+
|
84 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
85 |
+
self.adaLN_modulation = nn.Sequential(
|
86 |
+
act_layer(),
|
87 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
|
88 |
+
)
|
89 |
+
# Zero-initialize the modulation
|
90 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
91 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
92 |
+
|
93 |
+
def forward(self, x, c):
|
94 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
95 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
96 |
+
x = self.linear(x)
|
97 |
+
return x
|
hymm_sp/modules/models.py
ADDED
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Optional, Union, Dict
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
import torch, os
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from diffusers.models import ModelMixin
|
8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
9 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
10 |
+
|
11 |
+
from .activation_layers import get_activation_layer
|
12 |
+
from .norm_layers import get_norm_layer
|
13 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
14 |
+
from .attn_layers import apply_rotary_emb
|
15 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
16 |
+
from .modulate_layers import ModulateDiT, modulate, apply_gate
|
17 |
+
from .token_refiner import SingleTokenRefiner
|
18 |
+
from .cameranet import CameraNet
|
19 |
+
|
20 |
+
from .parallel_states import (
|
21 |
+
nccl_info,
|
22 |
+
get_cu_seqlens,
|
23 |
+
get_sequence_parallel_state,
|
24 |
+
parallel_attention,
|
25 |
+
all_gather,
|
26 |
+
)
|
27 |
+
|
28 |
+
CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
|
29 |
+
DISABLE_SP = int(os.environ.get("DISABLE_SP", 0))
|
30 |
+
print(f'models: cpu_offload={CPU_OFFLOAD}, DISABLE_SP={DISABLE_SP}')
|
31 |
+
|
32 |
+
class DoubleStreamBlock(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
hidden_size: int,
|
36 |
+
num_heads: int,
|
37 |
+
mlp_width_ratio: float,
|
38 |
+
mlp_act_type: str = 'gelu_tanh',
|
39 |
+
qk_norm: bool = True,
|
40 |
+
qk_norm_type: str = 'rms',
|
41 |
+
qkv_bias: bool = False,
|
42 |
+
dtype: Optional[torch.dtype] = None,
|
43 |
+
device: Optional[torch.device] = None,
|
44 |
+
):
|
45 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
self.deterministic = False
|
49 |
+
self.num_heads = num_heads
|
50 |
+
head_dim = hidden_size // num_heads
|
51 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
52 |
+
|
53 |
+
self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
54 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
55 |
+
|
56 |
+
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
57 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
58 |
+
self.img_attn_q_norm = (
|
59 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
60 |
+
if qk_norm
|
61 |
+
else nn.Identity()
|
62 |
+
)
|
63 |
+
self.img_attn_k_norm = (
|
64 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
65 |
+
if qk_norm
|
66 |
+
else nn.Identity()
|
67 |
+
)
|
68 |
+
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
69 |
+
|
70 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
71 |
+
self.img_mlp = MLP(
|
72 |
+
hidden_size,
|
73 |
+
mlp_hidden_dim,
|
74 |
+
act_layer=get_activation_layer(mlp_act_type),
|
75 |
+
bias=True,
|
76 |
+
**factory_kwargs
|
77 |
+
)
|
78 |
+
|
79 |
+
self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
80 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
81 |
+
|
82 |
+
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
83 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
84 |
+
self.txt_attn_q_norm = (
|
85 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
86 |
+
if qk_norm
|
87 |
+
else nn.Identity()
|
88 |
+
)
|
89 |
+
self.txt_attn_k_norm = (
|
90 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
91 |
+
if qk_norm
|
92 |
+
else nn.Identity()
|
93 |
+
)
|
94 |
+
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
95 |
+
|
96 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
97 |
+
self.txt_mlp = MLP(
|
98 |
+
hidden_size,
|
99 |
+
mlp_hidden_dim,
|
100 |
+
act_layer=get_activation_layer(mlp_act_type),
|
101 |
+
bias=True,
|
102 |
+
**factory_kwargs
|
103 |
+
)
|
104 |
+
|
105 |
+
def enable_deterministic(self):
|
106 |
+
self.deterministic = True
|
107 |
+
|
108 |
+
def disable_deterministic(self):
|
109 |
+
self.deterministic = False
|
110 |
+
|
111 |
+
def forward(
|
112 |
+
self,
|
113 |
+
img: torch.Tensor,
|
114 |
+
txt: torch.Tensor,
|
115 |
+
vec: torch.Tensor,
|
116 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
117 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
118 |
+
max_seqlen_q: Optional[int] = None,
|
119 |
+
max_seqlen_kv: Optional[int] = None,
|
120 |
+
freqs_cis: tuple = None,
|
121 |
+
use_sage: bool = True,
|
122 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
123 |
+
img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = (
|
124 |
+
self.img_mod(vec).chunk(6, dim=-1)
|
125 |
+
)
|
126 |
+
txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = (
|
127 |
+
self.txt_mod(vec).chunk(6, dim=-1)
|
128 |
+
)
|
129 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
130 |
+
|
131 |
+
# Prepare image for attention.
|
132 |
+
img_modulated = self.img_norm1(img)
|
133 |
+
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
134 |
+
img_qkv = self.img_attn_qkv(img_modulated)
|
135 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
136 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
137 |
+
# Apply QK-Norm if needed
|
138 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
139 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
140 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
141 |
+
|
142 |
+
# Apply RoPE if needed.
|
143 |
+
if freqs_cis is not None:
|
144 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
145 |
+
assert img_qq.shape == img_q.shape and img_kk.shape == img_k.shape, \
|
146 |
+
f'img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}'
|
147 |
+
img_q, img_k = img_qq, img_kk
|
148 |
+
|
149 |
+
# Prepare txt for attention.
|
150 |
+
txt_modulated = self.txt_norm1(txt)
|
151 |
+
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
152 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
153 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
154 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
155 |
+
# Apply QK-Norm if needed.
|
156 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
157 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
158 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
159 |
+
|
160 |
+
# Run actual attention.
|
161 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
162 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
163 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
164 |
+
|
165 |
+
# Compute attention.
|
166 |
+
if DISABLE_SP:
|
167 |
+
assert cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
168 |
+
|
169 |
+
q, k, v = [
|
170 |
+
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
171 |
+
for x in [q, k, v]
|
172 |
+
]
|
173 |
+
attn = flash_attn_varlen_func(
|
174 |
+
q,
|
175 |
+
k,
|
176 |
+
v,
|
177 |
+
cu_seqlens_q,
|
178 |
+
cu_seqlens_kv,
|
179 |
+
max_seqlen_q,
|
180 |
+
max_seqlen_kv,
|
181 |
+
)
|
182 |
+
attn = attn.view(img_k.shape[0], max_seqlen_q, -1).contiguous()
|
183 |
+
else:
|
184 |
+
attn, _ = parallel_attention(
|
185 |
+
(img_q, txt_q),
|
186 |
+
(img_k, txt_k),
|
187 |
+
(img_v, txt_v),
|
188 |
+
img_q_len=img_q.shape[1],
|
189 |
+
img_kv_len=img_k.shape[1],
|
190 |
+
cu_seqlens_q=cu_seqlens_q,
|
191 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
192 |
+
max_seqlen_q=max_seqlen_q,
|
193 |
+
max_seqlen_kv=max_seqlen_kv,
|
194 |
+
use_sage=use_sage,
|
195 |
+
)
|
196 |
+
img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
|
197 |
+
|
198 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
199 |
+
|
200 |
+
# Calculate the img bloks.
|
201 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
202 |
+
img = img + apply_gate(self.img_mlp(modulate(
|
203 |
+
self.img_norm2(img),
|
204 |
+
shift=img_mod2_shift,
|
205 |
+
scale=img_mod2_scale)), gate=img_mod2_gate)
|
206 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
207 |
+
# Calculate the txt bloks.
|
208 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
209 |
+
txt = txt + apply_gate(self.txt_mlp(modulate(self.txt_norm2(txt),
|
210 |
+
shift=txt_mod2_shift,
|
211 |
+
scale=txt_mod2_scale)), gate=txt_mod2_gate)
|
212 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
213 |
+
return img, txt
|
214 |
+
|
215 |
+
|
216 |
+
class SingleStreamBlock(nn.Module):
|
217 |
+
"""
|
218 |
+
A DiT block with parallel linear layers as described in
|
219 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
hidden_size: int,
|
225 |
+
num_heads: int,
|
226 |
+
mlp_width_ratio: float = 4.0,
|
227 |
+
mlp_act_type: str = 'gelu_tanh',
|
228 |
+
qk_norm: bool = True,
|
229 |
+
qk_norm_type: str = 'rms',
|
230 |
+
qk_scale: float = None,
|
231 |
+
dtype: Optional[torch.dtype] = None,
|
232 |
+
device: Optional[torch.device] = None,
|
233 |
+
):
|
234 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
235 |
+
super().__init__()
|
236 |
+
|
237 |
+
self.deterministic = False
|
238 |
+
self.hidden_size = hidden_size
|
239 |
+
self.num_heads = num_heads
|
240 |
+
head_dim = hidden_size // num_heads
|
241 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
242 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
243 |
+
self.scale = qk_scale or head_dim**-0.5
|
244 |
+
|
245 |
+
# qkv and mlp_in
|
246 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
|
247 |
+
# proj and mlp_out
|
248 |
+
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
|
249 |
+
|
250 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
251 |
+
self.q_norm = (
|
252 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
253 |
+
if qk_norm
|
254 |
+
else nn.Identity()
|
255 |
+
)
|
256 |
+
self.k_norm = (
|
257 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
258 |
+
if qk_norm
|
259 |
+
else nn.Identity()
|
260 |
+
)
|
261 |
+
|
262 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
263 |
+
|
264 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
265 |
+
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
266 |
+
|
267 |
+
def enable_deterministic(self):
|
268 |
+
self.deterministic = True
|
269 |
+
|
270 |
+
def disable_deterministic(self):
|
271 |
+
self.deterministic = False
|
272 |
+
|
273 |
+
def forward(
|
274 |
+
self,
|
275 |
+
x: torch.Tensor,
|
276 |
+
vec: torch.Tensor,
|
277 |
+
txt_len: int,
|
278 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
279 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
280 |
+
max_seqlen_q: Optional[int] = None,
|
281 |
+
max_seqlen_kv: Optional[int] = None,
|
282 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
283 |
+
use_sage: bool = True,
|
284 |
+
) -> torch.Tensor:
|
285 |
+
mod_shift, mod_scale, mod_gate = (
|
286 |
+
self.modulation(vec).chunk(3, dim=-1)
|
287 |
+
)
|
288 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
289 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
290 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
291 |
+
|
292 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
293 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
294 |
+
|
295 |
+
# Apply QK-Norm if needed.
|
296 |
+
q = self.q_norm(q).to(v)
|
297 |
+
k = self.k_norm(k).to(v)
|
298 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
299 |
+
|
300 |
+
# Apply RoPE if needed.
|
301 |
+
if freqs_cis is not None:
|
302 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
303 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
304 |
+
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
305 |
+
assert img_qq.shape == img_q.shape and img_kk.shape == img_k.shape, \
|
306 |
+
f'img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}'
|
307 |
+
img_q, img_k = img_qq, img_kk
|
308 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
309 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
310 |
+
|
311 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
312 |
+
|
313 |
+
# Compute attention.
|
314 |
+
if DISABLE_SP:
|
315 |
+
assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, \
|
316 |
+
f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
|
317 |
+
# [b, s+l, a, d] -> [s+l, b, a, d]
|
318 |
+
q, k, v = [
|
319 |
+
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
320 |
+
for x in [q, k, v]
|
321 |
+
]
|
322 |
+
|
323 |
+
attn = flash_attn_varlen_func(
|
324 |
+
q,
|
325 |
+
k,
|
326 |
+
v,
|
327 |
+
cu_seqlens_q,
|
328 |
+
cu_seqlens_kv,
|
329 |
+
max_seqlen_q,
|
330 |
+
max_seqlen_kv,
|
331 |
+
)
|
332 |
+
attn = attn.view(x.shape[0], max_seqlen_q, -1).contiguous()
|
333 |
+
else:
|
334 |
+
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
|
335 |
+
attn, _ = parallel_attention(
|
336 |
+
(img_q, txt_q),
|
337 |
+
(img_k, txt_k),
|
338 |
+
(img_v, txt_v),
|
339 |
+
img_q_len=img_q.shape[1],
|
340 |
+
img_kv_len=img_k.shape[1],
|
341 |
+
cu_seqlens_q=cu_seqlens_q,
|
342 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
343 |
+
max_seqlen_q=max_seqlen_q,
|
344 |
+
max_seqlen_kv=max_seqlen_kv,
|
345 |
+
use_sage=use_sage,
|
346 |
+
)
|
347 |
+
if CPU_OFFLOAD:
|
348 |
+
torch.cuda.empty_cache()
|
349 |
+
tmp = torch.cat((attn, self.mlp_act(mlp)), 2)
|
350 |
+
torch.cuda.empty_cache()
|
351 |
+
output = self.linear2(tmp)
|
352 |
+
else:
|
353 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
354 |
+
return x + apply_gate(output, gate=mod_gate)
|
355 |
+
|
356 |
+
|
357 |
+
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
|
358 |
+
"""
|
359 |
+
HunyuanVideo Transformer backbone
|
360 |
+
|
361 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
362 |
+
|
363 |
+
Reference:
|
364 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
365 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206,
|
366 |
+
https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
|
367 |
+
|
368 |
+
"""
|
369 |
+
@register_to_config
|
370 |
+
def __init__(
|
371 |
+
self,
|
372 |
+
args,
|
373 |
+
patch_size: list = [1,2,2],
|
374 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
375 |
+
out_channels: int = None,
|
376 |
+
hidden_size: int = 3072,
|
377 |
+
mlp_width_ratio: float = 4.0,
|
378 |
+
mlp_act_type: str = 'gelu_tanh',
|
379 |
+
num_heads: int = 24,
|
380 |
+
depth_double_blocks: int = 19,
|
381 |
+
depth_single_blocks: int = 38,
|
382 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
383 |
+
qkv_bias: bool = True,
|
384 |
+
qk_norm: bool = True,
|
385 |
+
qk_norm_type: str = 'rms',
|
386 |
+
guidance_embed: bool = False, # For modulation.
|
387 |
+
dtype: Optional[torch.dtype] = None,
|
388 |
+
device: Optional[torch.device] = None,
|
389 |
+
multitask_mask_training_type: Optional[str] = None,
|
390 |
+
camera_in_channels: int = 6,
|
391 |
+
camera_down_coef: int = 8,
|
392 |
+
):
|
393 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
394 |
+
super().__init__()
|
395 |
+
|
396 |
+
# Text projection. Default to linear projection.
|
397 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
398 |
+
self.text_projection = args.text_projection
|
399 |
+
self.text_states_dim = args.text_states_dim
|
400 |
+
self.use_attention_mask = args.use_attention_mask
|
401 |
+
self.text_states_dim_2 = args.text_states_dim_2
|
402 |
+
|
403 |
+
# Now we only use above configs from args.
|
404 |
+
self.patch_size = patch_size
|
405 |
+
self.in_channels = in_channels
|
406 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
407 |
+
self.unpatchify_channels = self.out_channels
|
408 |
+
self.guidance_embed = guidance_embed
|
409 |
+
self.rope_dim_list = rope_dim_list
|
410 |
+
self.multitask_mask_training_type = multitask_mask_training_type
|
411 |
+
|
412 |
+
if hidden_size % num_heads != 0:
|
413 |
+
raise ValueError(
|
414 |
+
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
415 |
+
)
|
416 |
+
pe_dim = hidden_size // num_heads
|
417 |
+
if sum(rope_dim_list) != pe_dim:
|
418 |
+
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
|
419 |
+
self.hidden_size = hidden_size
|
420 |
+
self.num_heads = num_heads
|
421 |
+
|
422 |
+
# image projection
|
423 |
+
self.img_in = PatchEmbed(
|
424 |
+
self.patch_size, self.in_channels, self.hidden_size, self.multitask_mask_training_type, **factory_kwargs
|
425 |
+
)
|
426 |
+
|
427 |
+
# text projection
|
428 |
+
if self.text_projection == "linear":
|
429 |
+
self.txt_in = TextProjection(
|
430 |
+
self.text_states_dim,
|
431 |
+
self.hidden_size,
|
432 |
+
get_activation_layer("silu"),
|
433 |
+
**factory_kwargs
|
434 |
+
)
|
435 |
+
elif self.text_projection == "single_refiner":
|
436 |
+
self.txt_in = SingleTokenRefiner(
|
437 |
+
self.text_states_dim, hidden_size, num_heads, depth=2, **factory_kwargs
|
438 |
+
)
|
439 |
+
else:
|
440 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
441 |
+
|
442 |
+
# time modulation
|
443 |
+
self.time_in = TimestepEmbedder(
|
444 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
445 |
+
)
|
446 |
+
|
447 |
+
# text modulation
|
448 |
+
self.vector_in = MLPEmbedder(
|
449 |
+
self.text_states_dim_2, self.hidden_size, **factory_kwargs
|
450 |
+
)
|
451 |
+
|
452 |
+
# guidance modulation
|
453 |
+
self.guidance_in = TimestepEmbedder(
|
454 |
+
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
|
455 |
+
) if guidance_embed else None
|
456 |
+
|
457 |
+
# double blocks
|
458 |
+
self.double_blocks = nn.ModuleList(
|
459 |
+
[
|
460 |
+
DoubleStreamBlock(
|
461 |
+
self.hidden_size,
|
462 |
+
self.num_heads,
|
463 |
+
mlp_width_ratio=mlp_width_ratio,
|
464 |
+
mlp_act_type=mlp_act_type,
|
465 |
+
qk_norm=qk_norm,
|
466 |
+
qk_norm_type=qk_norm_type,
|
467 |
+
qkv_bias=qkv_bias,
|
468 |
+
**factory_kwargs
|
469 |
+
)
|
470 |
+
for _ in range(depth_double_blocks)
|
471 |
+
]
|
472 |
+
)
|
473 |
+
|
474 |
+
# single blocks
|
475 |
+
self.single_blocks = nn.ModuleList(
|
476 |
+
[
|
477 |
+
SingleStreamBlock(
|
478 |
+
self.hidden_size,
|
479 |
+
self.num_heads,
|
480 |
+
mlp_width_ratio=mlp_width_ratio,
|
481 |
+
mlp_act_type=mlp_act_type,
|
482 |
+
qk_norm=qk_norm,
|
483 |
+
qk_norm_type=qk_norm_type,
|
484 |
+
**factory_kwargs
|
485 |
+
)
|
486 |
+
for _ in range(depth_single_blocks)
|
487 |
+
]
|
488 |
+
)
|
489 |
+
|
490 |
+
self.final_layer = FinalLayer(
|
491 |
+
self.hidden_size,
|
492 |
+
self.patch_size,
|
493 |
+
self.out_channels,
|
494 |
+
get_activation_layer("silu"),
|
495 |
+
**factory_kwargs
|
496 |
+
)
|
497 |
+
|
498 |
+
self.camera_net = CameraNet(in_channels=camera_in_channels,
|
499 |
+
out_channels=in_channels,
|
500 |
+
downscale_coef=camera_down_coef,
|
501 |
+
patch_size=self.patch_size,
|
502 |
+
hidden_size=self.hidden_size,
|
503 |
+
)
|
504 |
+
|
505 |
+
def enable_deterministic(self):
|
506 |
+
for block in self.double_blocks:
|
507 |
+
block.enable_deterministic()
|
508 |
+
for block in self.single_blocks:
|
509 |
+
block.enable_deterministic()
|
510 |
+
|
511 |
+
def disable_deterministic(self):
|
512 |
+
for block in self.double_blocks:
|
513 |
+
block.disable_deterministic()
|
514 |
+
for block in self.single_blocks:
|
515 |
+
block.disable_deterministic()
|
516 |
+
|
517 |
+
def forward(
|
518 |
+
self,
|
519 |
+
x: torch.Tensor,
|
520 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
521 |
+
text_states: torch.Tensor = None,
|
522 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
523 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
524 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
525 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
526 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
527 |
+
return_dict: bool = True,
|
528 |
+
is_cache: bool = False,
|
529 |
+
cam_latents = None,
|
530 |
+
use_sage: bool = False,
|
531 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
532 |
+
out = {}
|
533 |
+
img = x
|
534 |
+
txt = text_states
|
535 |
+
_, _, ot, oh, ow = x.shape
|
536 |
+
tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
|
537 |
+
|
538 |
+
# Prepare modulation vectors.
|
539 |
+
vec = self.time_in(t)
|
540 |
+
|
541 |
+
# text modulation
|
542 |
+
vec = vec + self.vector_in(text_states_2)
|
543 |
+
|
544 |
+
# guidance modulation
|
545 |
+
if self.guidance_embed:
|
546 |
+
if guidance is None:
|
547 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
548 |
+
else:
|
549 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
550 |
+
vec = vec + self.guidance_in(guidance)
|
551 |
+
|
552 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
553 |
+
|
554 |
+
camera_condition = cam_latents
|
555 |
+
assert camera_condition is not None, print("plucker_embedding is not provided")
|
556 |
+
latent_len = img.shape[2]
|
557 |
+
|
558 |
+
|
559 |
+
# Embed image and text.
|
560 |
+
img = self.img_in(img)
|
561 |
+
# ref_latents = self.img_in(ref_latents) # off in latent concat
|
562 |
+
if self.text_projection == "linear":
|
563 |
+
txt = self.txt_in(txt)
|
564 |
+
elif self.text_projection == "single_refiner":
|
565 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
566 |
+
else:
|
567 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
568 |
+
|
569 |
+
if camera_condition is not None:
|
570 |
+
|
571 |
+
if latent_len == 18:
|
572 |
+
camera_latents = torch.cat([self.camera_net(torch.zeros_like(camera_condition)), \
|
573 |
+
self.camera_net(camera_condition)], dim=1)
|
574 |
+
elif latent_len == 9:
|
575 |
+
camera_latents = self.camera_net(camera_condition)
|
576 |
+
elif latent_len == 10:
|
577 |
+
camera_latents = torch.cat([self.camera_net(torch.zeros_like(camera_condition)[:,0:4,:,:,:]), \
|
578 |
+
self.camera_net(camera_condition)], dim=1)
|
579 |
+
img = img + camera_latents
|
580 |
+
|
581 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
582 |
+
# ref_length = ref_latents.shape[-2]
|
583 |
+
# img = torch.cat([ref_latents, img], dim=-2) # t c
|
584 |
+
txt_seq_len = txt.shape[1]
|
585 |
+
img_seq_len = img.shape[1]
|
586 |
+
# Compute 'self-attention mask'.
|
587 |
+
|
588 |
+
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
589 |
+
cu_seqlens_kv = cu_seqlens_q
|
590 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
591 |
+
max_seqlen_kv = max_seqlen_q
|
592 |
+
|
593 |
+
if get_sequence_parallel_state():
|
594 |
+
sp_size = nccl_info.sp_size
|
595 |
+
sp_rank = nccl_info.rank_within_group
|
596 |
+
assert img.shape[1] % sp_size == 0, f"Cannot split video sequence into ulysses SP ({sp_size}) parts evenly"
|
597 |
+
img = torch.chunk(img, sp_size, dim=1)[sp_rank]
|
598 |
+
freqs_cos = torch.chunk(freqs_cos, sp_size, dim=0)[sp_rank]
|
599 |
+
freqs_sin = torch.chunk(freqs_sin, sp_size, dim=0)[sp_rank]
|
600 |
+
|
601 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
602 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
603 |
+
# --------------------- Pass through DiT blocks ------------------------
|
604 |
+
if not is_cache:
|
605 |
+
for layer_num, block in enumerate(self.double_blocks):
|
606 |
+
double_block_args = [img, txt, vec, cu_seqlens_q, cu_seqlens_kv, \
|
607 |
+
max_seqlen_q, max_seqlen_kv, freqs_cis, use_sage]
|
608 |
+
img, txt = block(*double_block_args)
|
609 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
610 |
+
|
611 |
+
# Merge txt and img to pass through single stream blocks.
|
612 |
+
x = torch.cat((img, txt), 1)
|
613 |
+
# Compatible with MMDiT.
|
614 |
+
if len(self.single_blocks) > 0:
|
615 |
+
for layer_num, block in enumerate(self.single_blocks):
|
616 |
+
if layer_num == (len(self.single_blocks) - 1):
|
617 |
+
self.cache_out = x
|
618 |
+
single_block_args = [x, vec, txt_seq_len, cu_seqlens_q, \
|
619 |
+
cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, (freqs_cos, freqs_sin), use_sage]
|
620 |
+
x = block(*single_block_args)
|
621 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
622 |
+
else:
|
623 |
+
x = self.cache_out
|
624 |
+
if len(self.single_blocks) > 0:
|
625 |
+
for layer_num, block in enumerate(self.single_blocks):
|
626 |
+
if layer_num < (len(self.single_blocks) - 1):
|
627 |
+
continue
|
628 |
+
single_block_args = [x, vec, txt_seq_len, cu_seqlens_q, \
|
629 |
+
cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, (freqs_cos, freqs_sin), use_sage]
|
630 |
+
x = block(*single_block_args)
|
631 |
+
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
632 |
+
|
633 |
+
img = x[:, :-txt_seq_len, ...]
|
634 |
+
|
635 |
+
if get_sequence_parallel_state():
|
636 |
+
img = all_gather(img, dim=1)
|
637 |
+
|
638 |
+
# img = img[:, ref_length:]
|
639 |
+
# ---------------------------- Final layer ------------------------------
|
640 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
641 |
+
img = self.unpatchify(img, tt, th, tw)
|
642 |
+
|
643 |
+
if return_dict:
|
644 |
+
out['x'] = img
|
645 |
+
return out
|
646 |
+
return img
|
647 |
+
|
648 |
+
def unpatchify(self, x, t, h, w):
|
649 |
+
"""
|
650 |
+
x: (N, T, patch_size**2 * C)
|
651 |
+
imgs: (N, H, W, C)
|
652 |
+
"""
|
653 |
+
c = self.unpatchify_channels
|
654 |
+
pt, ph, pw = self.patch_size
|
655 |
+
assert t * h * w == x.shape[1]
|
656 |
+
|
657 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
658 |
+
x = torch.einsum('nthwcopq->nctohpwq', x)
|
659 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
660 |
+
|
661 |
+
return imgs
|
662 |
+
|
663 |
+
def params_count(self):
|
664 |
+
counts = {
|
665 |
+
"double": sum([
|
666 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters()) +
|
667 |
+
sum(p.numel() for p in block.img_attn_proj.parameters()) +
|
668 |
+
sum(p.numel() for p in block.img_mlp.parameters()) +
|
669 |
+
sum(p.numel() for p in block.txt_attn_qkv.parameters()) +
|
670 |
+
sum(p.numel() for p in block.txt_attn_proj.parameters()) +
|
671 |
+
sum(p.numel() for p in block.txt_mlp.parameters())
|
672 |
+
for block in self.double_blocks
|
673 |
+
]),
|
674 |
+
"single": sum([
|
675 |
+
sum(p.numel() for p in block.linear1.parameters()) +
|
676 |
+
sum(p.numel() for p in block.linear2.parameters())
|
677 |
+
for block in self.single_blocks
|
678 |
+
]),
|
679 |
+
"total": sum(p.numel() for p in self.parameters()),
|
680 |
+
}
|
681 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
682 |
+
return counts
|
683 |
+
|
684 |
+
#################################################################################
|
685 |
+
# HunyuanVideo Configs #
|
686 |
+
#################################################################################
|
687 |
+
|
688 |
+
HUNYUAN_VIDEO_CONFIG = { # Attn+MLP / Total
|
689 |
+
'HYVideo-T/2': { # 9.0B / 12.5B
|
690 |
+
'depth_double_blocks': 20,
|
691 |
+
'depth_single_blocks': 40,
|
692 |
+
'rope_dim_list': [16, 56, 56],
|
693 |
+
'hidden_size': 3072,
|
694 |
+
'num_heads': 24,
|
695 |
+
'mlp_width_ratio': 4,
|
696 |
+
},
|
697 |
+
}
|
hymm_sp/modules/modulate_layers.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class ModulateDiT(nn.Module):
|
8 |
+
"""Modulation layer for DiT."""
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
hidden_size: int,
|
12 |
+
factor: int,
|
13 |
+
act_layer: Callable,
|
14 |
+
dtype=None,
|
15 |
+
device=None,
|
16 |
+
):
|
17 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
18 |
+
super().__init__()
|
19 |
+
self.act = act_layer()
|
20 |
+
self.linear = nn.Linear(
|
21 |
+
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
|
22 |
+
)
|
23 |
+
# Zero-initialize the modulation
|
24 |
+
nn.init.zeros_(self.linear.weight)
|
25 |
+
nn.init.zeros_(self.linear.bias)
|
26 |
+
|
27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
28 |
+
return self.linear(self.act(x))
|
29 |
+
|
30 |
+
|
31 |
+
def modulate(x, shift=None, scale=None):
|
32 |
+
"""modulate by shift and scale
|
33 |
+
|
34 |
+
Args:
|
35 |
+
x (torch.Tensor): input tensor.
|
36 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
37 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: the output tensor after modulate.
|
41 |
+
"""
|
42 |
+
if scale is None and shift is None:
|
43 |
+
return x
|
44 |
+
elif shift is None:
|
45 |
+
return x * (1 + scale.unsqueeze(1))
|
46 |
+
elif scale is None:
|
47 |
+
return x + shift.unsqueeze(1)
|
48 |
+
else:
|
49 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
50 |
+
|
51 |
+
|
52 |
+
def apply_gate(x, gate=None, tanh=False):
|
53 |
+
"""AI is creating summary for apply_gate
|
54 |
+
|
55 |
+
Args:
|
56 |
+
x (torch.Tensor): input tensor.
|
57 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
58 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
torch.Tensor: the output tensor after apply gate.
|
62 |
+
"""
|
63 |
+
if gate is None:
|
64 |
+
return x
|
65 |
+
if tanh:
|
66 |
+
return x * gate.unsqueeze(1).tanh()
|
67 |
+
else:
|
68 |
+
return x * gate.unsqueeze(1)
|
69 |
+
|
70 |
+
|
71 |
+
def ckpt_wrapper(module):
|
72 |
+
def ckpt_forward(*inputs):
|
73 |
+
outputs = module(*inputs)
|
74 |
+
return outputs
|
75 |
+
|
76 |
+
return ckpt_forward
|
hymm_sp/modules/norm_layers.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class RMSNorm(nn.Module):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
dim: int,
|
9 |
+
elementwise_affine=True,
|
10 |
+
eps: float = 1e-6,
|
11 |
+
device=None,
|
12 |
+
dtype=None,
|
13 |
+
):
|
14 |
+
"""
|
15 |
+
Initialize the RMSNorm normalization layer.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
dim (int): The dimension of the input tensor.
|
19 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
20 |
+
|
21 |
+
Attributes:
|
22 |
+
eps (float): A small value added to the denominator for numerical stability.
|
23 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
24 |
+
|
25 |
+
"""
|
26 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
27 |
+
super().__init__()
|
28 |
+
self.eps = eps
|
29 |
+
if elementwise_affine:
|
30 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
31 |
+
|
32 |
+
def _norm(self, x):
|
33 |
+
"""
|
34 |
+
Apply the RMSNorm normalization to the input tensor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
x (torch.Tensor): The input tensor.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
torch.Tensor: The normalized tensor.
|
41 |
+
|
42 |
+
"""
|
43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
"""
|
47 |
+
Forward pass through the RMSNorm layer.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
x (torch.Tensor): The input tensor.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
54 |
+
|
55 |
+
"""
|
56 |
+
output = self._norm(x.float()).type_as(x)
|
57 |
+
if hasattr(self, "weight"):
|
58 |
+
output = output * self.weight
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
def get_norm_layer(norm_layer):
|
63 |
+
"""
|
64 |
+
Get the normalization layer.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
norm_layer (str): The type of normalization layer.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
norm_layer (nn.Module): The normalization layer.
|
71 |
+
"""
|
72 |
+
if norm_layer == "layer":
|
73 |
+
return nn.LayerNorm
|
74 |
+
elif norm_layer == "rms":
|
75 |
+
return RMSNorm
|
76 |
+
else:
|
77 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hymm_sp/modules/parallel_states.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import datetime
|
4 |
+
import torch.distributed as dist
|
5 |
+
from typing import Any, Tuple
|
6 |
+
from torch import Tensor
|
7 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
8 |
+
|
9 |
+
|
10 |
+
class COMM_INFO:
|
11 |
+
def __init__(self):
|
12 |
+
self.group = None
|
13 |
+
self.sp_size = 1
|
14 |
+
self.global_rank = 0
|
15 |
+
self.rank_within_group = 0
|
16 |
+
self.group_id = 0
|
17 |
+
|
18 |
+
|
19 |
+
nccl_info = COMM_INFO()
|
20 |
+
_SEQUENCE_PARALLEL_STATE = False
|
21 |
+
|
22 |
+
|
23 |
+
def get_cu_seqlens(text_mask, img_len):
|
24 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
25 |
+
|
26 |
+
Args:
|
27 |
+
text_mask (torch.Tensor): the mask of text
|
28 |
+
img_len (int): the length of image
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
32 |
+
"""
|
33 |
+
batch_size = text_mask.shape[0]
|
34 |
+
text_len = text_mask.sum(dim=1)
|
35 |
+
max_len = text_mask.shape[1] + img_len
|
36 |
+
|
37 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
38 |
+
|
39 |
+
for i in range(batch_size):
|
40 |
+
s = text_len[i] + img_len
|
41 |
+
s1 = i * max_len + s
|
42 |
+
s2 = (i + 1) * max_len
|
43 |
+
cu_seqlens[2 * i + 1] = s1
|
44 |
+
cu_seqlens[2 * i + 2] = s2
|
45 |
+
|
46 |
+
return cu_seqlens
|
47 |
+
|
48 |
+
def initialize_sequence_parallel_state(sequence_parallel_size):
|
49 |
+
global _SEQUENCE_PARALLEL_STATE
|
50 |
+
if sequence_parallel_size > 1:
|
51 |
+
_SEQUENCE_PARALLEL_STATE = True
|
52 |
+
initialize_sequence_parallel_group(sequence_parallel_size)
|
53 |
+
else:
|
54 |
+
nccl_info.sp_size = 1
|
55 |
+
nccl_info.global_rank = int(os.getenv("RANK", "0"))
|
56 |
+
nccl_info.rank_within_group = 0
|
57 |
+
nccl_info.group_id = int(os.getenv("RANK", "0"))
|
58 |
+
|
59 |
+
def get_sequence_parallel_state():
|
60 |
+
return _SEQUENCE_PARALLEL_STATE
|
61 |
+
|
62 |
+
def initialize_sequence_parallel_group(sequence_parallel_size):
|
63 |
+
"""Initialize the sequence parallel group."""
|
64 |
+
rank = int(os.getenv("RANK", "0"))
|
65 |
+
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
66 |
+
assert (
|
67 |
+
world_size % sequence_parallel_size == 0
|
68 |
+
), "world_size must be divisible by sequence_parallel_size, \
|
69 |
+
but got world_size: {}, sequence_parallel_size: {}".format(
|
70 |
+
world_size, sequence_parallel_size)
|
71 |
+
nccl_info.sp_size = sequence_parallel_size
|
72 |
+
nccl_info.global_rank = rank
|
73 |
+
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
|
74 |
+
for i in range(num_sequence_parallel_groups):
|
75 |
+
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
|
76 |
+
group = dist.new_group(ranks)
|
77 |
+
if rank in ranks:
|
78 |
+
nccl_info.group = group
|
79 |
+
nccl_info.rank_within_group = rank - i * sequence_parallel_size
|
80 |
+
nccl_info.group_id = i
|
81 |
+
|
82 |
+
def initialize_distributed(seed):
|
83 |
+
local_rank = int(os.getenv("RANK", 0))
|
84 |
+
world_size = int(os.getenv("WORLD_SIZE", 1))
|
85 |
+
torch.cuda.set_device(local_rank)
|
86 |
+
dist.init_process_group(backend="nccl",
|
87 |
+
init_method="env://",
|
88 |
+
timeout=datetime.timedelta(seconds=2**31-1),
|
89 |
+
world_size=world_size,
|
90 |
+
rank=local_rank)
|
91 |
+
torch.manual_seed(seed)
|
92 |
+
torch.cuda.manual_seed_all(seed)
|
93 |
+
initialize_sequence_parallel_state(world_size)
|
94 |
+
|
95 |
+
def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
|
96 |
+
"""
|
97 |
+
all-to-all for QKV
|
98 |
+
|
99 |
+
Args:
|
100 |
+
input (torch.tensor): a tensor sharded along dim scatter dim
|
101 |
+
scatter_idx (int): default 1
|
102 |
+
gather_idx (int): default 2
|
103 |
+
group : torch process group
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
|
107 |
+
"""
|
108 |
+
assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
|
109 |
+
|
110 |
+
seq_world_size = dist.get_world_size(group)
|
111 |
+
if scatter_idx == 2 and gather_idx == 1:
|
112 |
+
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
|
113 |
+
bs, shard_seqlen, hc, hs = input.shape
|
114 |
+
seqlen = shard_seqlen * seq_world_size
|
115 |
+
shard_hc = hc // seq_world_size
|
116 |
+
|
117 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
118 |
+
# (bs, seqlen/P, hc, hs) -reshape->
|
119 |
+
# (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)->
|
120 |
+
# (P, seq_len/P, bs, hc/P, hs)
|
121 |
+
input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous())
|
122 |
+
|
123 |
+
output = torch.empty_like(input_t)
|
124 |
+
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
|
125 |
+
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
|
126 |
+
if seq_world_size > 1:
|
127 |
+
dist.all_to_all_single(output, input_t, group=group)
|
128 |
+
torch.cuda.synchronize()
|
129 |
+
else:
|
130 |
+
output = input_t
|
131 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
132 |
+
output = output.reshape(seqlen, bs, shard_hc, hs)
|
133 |
+
|
134 |
+
# (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
|
135 |
+
output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
|
136 |
+
|
137 |
+
return output
|
138 |
+
|
139 |
+
elif scatter_idx == 1 and gather_idx == 2:
|
140 |
+
# input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
|
141 |
+
bs, seqlen, shard_hc, hs = input.shape
|
142 |
+
hc = shard_hc * seq_world_size
|
143 |
+
shard_seqlen = seqlen // seq_world_size
|
144 |
+
seq_world_size = dist.get_world_size(group)
|
145 |
+
|
146 |
+
# transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
|
147 |
+
# (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)->
|
148 |
+
# (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
|
149 |
+
input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc,
|
150 |
+
hs).transpose(0,
|
151 |
+
3).transpose(0,
|
152 |
+
1).contiguous().reshape(seq_world_size, shard_hc,
|
153 |
+
shard_seqlen, bs, hs))
|
154 |
+
|
155 |
+
output = torch.empty_like(input_t)
|
156 |
+
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
|
157 |
+
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
|
158 |
+
if seq_world_size > 1:
|
159 |
+
dist.all_to_all_single(output, input_t, group=group)
|
160 |
+
torch.cuda.synchronize()
|
161 |
+
else:
|
162 |
+
output = input_t
|
163 |
+
|
164 |
+
# if scattering the seq-dim, transpose the heads back to the original dimension
|
165 |
+
output = output.reshape(hc, shard_seqlen, bs, hs)
|
166 |
+
|
167 |
+
# (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
|
168 |
+
output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
|
169 |
+
|
170 |
+
return output
|
171 |
+
else:
|
172 |
+
raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
|
173 |
+
|
174 |
+
|
175 |
+
class SeqAllToAll4D(torch.autograd.Function):
|
176 |
+
@staticmethod
|
177 |
+
def forward(
|
178 |
+
ctx: Any,
|
179 |
+
group: dist.ProcessGroup,
|
180 |
+
input: Tensor,
|
181 |
+
scatter_idx: int,
|
182 |
+
gather_idx: int,
|
183 |
+
) -> Tensor:
|
184 |
+
ctx.group = group
|
185 |
+
ctx.scatter_idx = scatter_idx
|
186 |
+
ctx.gather_idx = gather_idx
|
187 |
+
|
188 |
+
return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
|
189 |
+
|
190 |
+
@staticmethod
|
191 |
+
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
|
192 |
+
return (
|
193 |
+
None,
|
194 |
+
SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
|
195 |
+
None,
|
196 |
+
None,
|
197 |
+
)
|
198 |
+
|
199 |
+
|
200 |
+
def all_to_all_4D(
|
201 |
+
input_: torch.Tensor,
|
202 |
+
scatter_dim: int = 2,
|
203 |
+
gather_dim: int = 1,
|
204 |
+
):
|
205 |
+
return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
|
206 |
+
|
207 |
+
|
208 |
+
def _all_to_all(
|
209 |
+
input_: torch.Tensor,
|
210 |
+
world_size: int,
|
211 |
+
group: dist.ProcessGroup,
|
212 |
+
scatter_dim: int,
|
213 |
+
gather_dim: int,
|
214 |
+
):
|
215 |
+
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
216 |
+
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
217 |
+
dist.all_to_all(output_list, input_list, group=group)
|
218 |
+
return torch.cat(output_list, dim=gather_dim).contiguous()
|
219 |
+
|
220 |
+
|
221 |
+
class _AllToAll(torch.autograd.Function):
|
222 |
+
"""All-to-all communication.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
input_: input matrix
|
226 |
+
process_group: communication group
|
227 |
+
scatter_dim: scatter dimension
|
228 |
+
gather_dim: gather dimension
|
229 |
+
"""
|
230 |
+
|
231 |
+
@staticmethod
|
232 |
+
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
233 |
+
ctx.process_group = process_group
|
234 |
+
ctx.scatter_dim = scatter_dim
|
235 |
+
ctx.gather_dim = gather_dim
|
236 |
+
ctx.world_size = dist.get_world_size(process_group)
|
237 |
+
output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
|
238 |
+
return output
|
239 |
+
|
240 |
+
@staticmethod
|
241 |
+
def backward(ctx, grad_output):
|
242 |
+
grad_output = _all_to_all(
|
243 |
+
grad_output,
|
244 |
+
ctx.world_size,
|
245 |
+
ctx.process_group,
|
246 |
+
ctx.gather_dim,
|
247 |
+
ctx.scatter_dim,
|
248 |
+
)
|
249 |
+
return (
|
250 |
+
grad_output,
|
251 |
+
None,
|
252 |
+
None,
|
253 |
+
None,
|
254 |
+
)
|
255 |
+
|
256 |
+
def all_to_all(
|
257 |
+
input_: torch.Tensor,
|
258 |
+
scatter_dim: int = 2,
|
259 |
+
gather_dim: int = 1,
|
260 |
+
):
|
261 |
+
return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
|
262 |
+
|
263 |
+
|
264 |
+
class _AllGather(torch.autograd.Function):
|
265 |
+
"""All-gather communication with autograd support.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
input_: input tensor
|
269 |
+
dim: dimension along which to concatenate
|
270 |
+
"""
|
271 |
+
|
272 |
+
@staticmethod
|
273 |
+
def forward(ctx, input_, dim):
|
274 |
+
ctx.dim = dim
|
275 |
+
world_size = nccl_info.sp_size
|
276 |
+
group = nccl_info.group
|
277 |
+
input_size = list(input_.size())
|
278 |
+
|
279 |
+
ctx.input_size = input_size[dim]
|
280 |
+
|
281 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
282 |
+
input_ = input_.contiguous()
|
283 |
+
dist.all_gather(tensor_list, input_, group=group)
|
284 |
+
|
285 |
+
output = torch.cat(tensor_list, dim=dim)
|
286 |
+
return output
|
287 |
+
|
288 |
+
@staticmethod
|
289 |
+
def backward(ctx, grad_output):
|
290 |
+
world_size = nccl_info.sp_size
|
291 |
+
rank = nccl_info.rank_within_group
|
292 |
+
dim = ctx.dim
|
293 |
+
input_size = ctx.input_size
|
294 |
+
|
295 |
+
sizes = [input_size] * world_size
|
296 |
+
|
297 |
+
grad_input_list = torch.split(grad_output, sizes, dim=dim)
|
298 |
+
grad_input = grad_input_list[rank]
|
299 |
+
|
300 |
+
return grad_input, None
|
301 |
+
|
302 |
+
|
303 |
+
def all_gather(input_: torch.Tensor, dim: int = 1):
|
304 |
+
"""Performs an all-gather operation on the input tensor along the specified dimension.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
|
308 |
+
dim (int, optional): Dimension along which to concatenate. Defaults to 1.
|
309 |
+
|
310 |
+
Returns:
|
311 |
+
torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
|
312 |
+
"""
|
313 |
+
return _AllGather.apply(input_, dim)
|
314 |
+
|
315 |
+
def parallel_attention(q, k, v,
|
316 |
+
img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv,
|
317 |
+
max_seqlen_q, max_seqlen_kv, use_sage):
|
318 |
+
"""
|
319 |
+
img_q_len,img_kv_len: 32256
|
320 |
+
text_mask: 2x256
|
321 |
+
query: [2, 32256, 24, 128])
|
322 |
+
encoder_query: [2, 256, 24, 128]
|
323 |
+
"""
|
324 |
+
query, encoder_query = q
|
325 |
+
key, encoder_key = k
|
326 |
+
value, encoder_value = v
|
327 |
+
rank = torch.distributed.get_rank()
|
328 |
+
if get_sequence_parallel_state():
|
329 |
+
query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
|
330 |
+
key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
|
331 |
+
value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
|
332 |
+
def shrink_head(encoder_state, dim):
|
333 |
+
local_heads = encoder_state.shape[dim] // nccl_info.sp_size
|
334 |
+
return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
|
335 |
+
encoder_query = shrink_head(encoder_query, dim=2)
|
336 |
+
encoder_key = shrink_head(encoder_key, dim=2)
|
337 |
+
encoder_value = shrink_head(encoder_value, dim=2)
|
338 |
+
|
339 |
+
sequence_length = query.size(1) # 32256
|
340 |
+
encoder_sequence_length = encoder_query.size(1) # 256
|
341 |
+
|
342 |
+
query = torch.cat([query, encoder_query], dim=1)
|
343 |
+
key = torch.cat([key, encoder_key], dim=1)
|
344 |
+
value = torch.cat([value, encoder_value], dim=1)
|
345 |
+
bsz = query.shape[0]
|
346 |
+
head = query.shape[-2]
|
347 |
+
head_dim = query.shape[-1]
|
348 |
+
if use_sage:
|
349 |
+
from sageattention import sageattn
|
350 |
+
hidden_states = sageattn(query, key, value, tensor_layout="NHD")
|
351 |
+
else:
|
352 |
+
query, key, value = [
|
353 |
+
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
354 |
+
for x in [query, key, value]
|
355 |
+
]
|
356 |
+
hidden_states = flash_attn_varlen_func(
|
357 |
+
query,
|
358 |
+
key,
|
359 |
+
value,
|
360 |
+
cu_seqlens_q,
|
361 |
+
cu_seqlens_kv,
|
362 |
+
max_seqlen_q,
|
363 |
+
max_seqlen_kv,
|
364 |
+
)
|
365 |
+
|
366 |
+
# B, S, 3, H, D
|
367 |
+
hidden_states = hidden_states.view(bsz, max_seqlen_q, head, head_dim).contiguous()
|
368 |
+
|
369 |
+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length),
|
370 |
+
dim=1)
|
371 |
+
if get_sequence_parallel_state():
|
372 |
+
hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2)
|
373 |
+
encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous()
|
374 |
+
hidden_states = hidden_states.to(query.dtype)
|
375 |
+
encoder_hidden_states = encoder_hidden_states.to(query.dtype)
|
376 |
+
|
377 |
+
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
378 |
+
|
379 |
+
b, s, _, _= attn.shape
|
380 |
+
attn = attn.reshape(b, s, -1)
|
381 |
+
return attn, None
|
hymm_sp/modules/posemb_layers.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Union, Tuple, List
|
3 |
+
|
4 |
+
|
5 |
+
def _to_tuple(x, dim=2):
|
6 |
+
if isinstance(x, int):
|
7 |
+
return (x,) * dim
|
8 |
+
elif len(x) == dim:
|
9 |
+
return x
|
10 |
+
else:
|
11 |
+
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
12 |
+
|
13 |
+
|
14 |
+
def get_meshgrid_nd(start, *args, dim=2):
|
15 |
+
"""
|
16 |
+
Get n-D meshgrid with start, stop and num.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
20 |
+
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
21 |
+
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
22 |
+
n-tuples.
|
23 |
+
*args: See above.
|
24 |
+
dim (int): Dimension of the meshgrid. Defaults to 2.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
grid (np.ndarray): [dim, ...]
|
28 |
+
"""
|
29 |
+
if len(args) == 0:
|
30 |
+
# start is grid_size
|
31 |
+
num = _to_tuple(start, dim=dim)
|
32 |
+
start = (0,) * dim
|
33 |
+
stop = num
|
34 |
+
elif len(args) == 1:
|
35 |
+
# start is start, args[0] is stop, step is 1
|
36 |
+
start = _to_tuple(start, dim=dim)
|
37 |
+
stop = _to_tuple(args[0], dim=dim)
|
38 |
+
num = [stop[i] - start[i] for i in range(dim)]
|
39 |
+
elif len(args) == 2:
|
40 |
+
# start is start, args[0] is stop, args[1] is num
|
41 |
+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
42 |
+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
43 |
+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
44 |
+
else:
|
45 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
46 |
+
|
47 |
+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
48 |
+
axis_grid = []
|
49 |
+
for i in range(dim):
|
50 |
+
a, b, n = start[i], stop[i], num[i]
|
51 |
+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
52 |
+
axis_grid.append(g)
|
53 |
+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
54 |
+
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
55 |
+
|
56 |
+
return grid
|
57 |
+
|
58 |
+
|
59 |
+
#################################################################################
|
60 |
+
# Rotary Positional Embedding Functions #
|
61 |
+
#################################################################################
|
62 |
+
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
63 |
+
|
64 |
+
|
65 |
+
def get_1d_rotary_pos_embed(dim: int,
|
66 |
+
pos: Union[torch.FloatTensor, int],
|
67 |
+
theta: float = 10000.0,
|
68 |
+
use_real: bool = False,
|
69 |
+
theta_rescale_factor: float = 1.0,
|
70 |
+
interpolation_factor: float = 1.0,
|
71 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
72 |
+
"""
|
73 |
+
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
74 |
+
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
75 |
+
|
76 |
+
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
77 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
78 |
+
The returned tensor contains complex values in complex64 data type.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
dim (int): Dimension of the frequency tensor.
|
82 |
+
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
83 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
84 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
85 |
+
Otherwise, return complex numbers.
|
86 |
+
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
90 |
+
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
91 |
+
"""
|
92 |
+
if isinstance(pos, int):
|
93 |
+
pos = torch.arange(pos).float()
|
94 |
+
|
95 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
96 |
+
# has some connection to NTK literature
|
97 |
+
if theta_rescale_factor != 1.0:
|
98 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
99 |
+
|
100 |
+
freqs = 1.0 / (
|
101 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
102 |
+
) # [D/2]
|
103 |
+
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
104 |
+
if use_real:
|
105 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
106 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
107 |
+
return freqs_cos, freqs_sin
|
108 |
+
else:
|
109 |
+
freqs_cis = torch.polar(
|
110 |
+
torch.ones_like(freqs), freqs
|
111 |
+
) # complex64 # [S, D/2]
|
112 |
+
return freqs_cis
|
hymm_sp/modules/token_refiner.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from einops import rearrange
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from .activation_layers import get_activation_layer
|
8 |
+
from .attn_layers import attention
|
9 |
+
from .norm_layers import get_norm_layer
|
10 |
+
from .embed_layers import TimestepEmbedder, TextProjection
|
11 |
+
from .attn_layers import attention
|
12 |
+
from .mlp_layers import MLP
|
13 |
+
from .modulate_layers import apply_gate
|
14 |
+
|
15 |
+
|
16 |
+
class IndividualTokenRefinerBlock(nn.Module):
|
17 |
+
"""
|
18 |
+
Transformer block for refining individual tokens with adaptive layer normalization.
|
19 |
+
|
20 |
+
Combines self-attention and feed-forward network (FFN) layers with modulation
|
21 |
+
based on conditional inputs (timestep and context embeddings). Supports query-key
|
22 |
+
normalization for improved attention stability.
|
23 |
+
"""
|
24 |
+
def __init__(
|
25 |
+
self,
|
26 |
+
hidden_size,
|
27 |
+
num_heads,
|
28 |
+
mlp_ratio: str = 4.0,
|
29 |
+
mlp_drop_rate: float = 0.0,
|
30 |
+
act_type: str = "silu",
|
31 |
+
qk_norm: bool = False,
|
32 |
+
qk_norm_type: str = "layer",
|
33 |
+
qkv_bias: bool = True,
|
34 |
+
dtype: Optional[torch.dtype] = None,
|
35 |
+
device: Optional[torch.device] = None,
|
36 |
+
):
|
37 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
38 |
+
super().__init__()
|
39 |
+
self.num_heads = num_heads
|
40 |
+
head_dim = hidden_size // num_heads
|
41 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
42 |
+
|
43 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
44 |
+
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
45 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
46 |
+
self.self_attn_q_norm = (
|
47 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
48 |
+
if qk_norm
|
49 |
+
else nn.Identity()
|
50 |
+
)
|
51 |
+
self.self_attn_k_norm = (
|
52 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
53 |
+
if qk_norm
|
54 |
+
else nn.Identity()
|
55 |
+
)
|
56 |
+
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
57 |
+
|
58 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
59 |
+
act_layer = get_activation_layer(act_type)
|
60 |
+
self.mlp = MLP(
|
61 |
+
in_channels=hidden_size,
|
62 |
+
hidden_channels=mlp_hidden_dim,
|
63 |
+
act_layer=act_layer,
|
64 |
+
drop=mlp_drop_rate,
|
65 |
+
**factory_kwargs,
|
66 |
+
)
|
67 |
+
|
68 |
+
self.adaLN_modulation = nn.Sequential(
|
69 |
+
act_layer(),
|
70 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
|
71 |
+
)
|
72 |
+
# Zero-initialize the modulation
|
73 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
74 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
75 |
+
|
76 |
+
def forward(
|
77 |
+
self,
|
78 |
+
x: torch.Tensor,
|
79 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
80 |
+
attn_mask: torch.Tensor = None,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Forward pass of the transformer block.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
x: Input token embeddings (batch_size, seq_len, hidden_size)
|
87 |
+
c: Conditional embeddings (batch_size, hidden_size)
|
88 |
+
attn_mask: Attention mask (batch_size, 1, seq_len, seq_len)
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
Updated token embeddings after self-attention and FFN
|
92 |
+
"""
|
93 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
94 |
+
|
95 |
+
norm_x = self.norm1(x)
|
96 |
+
qkv = self.self_attn_qkv(norm_x)
|
97 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
98 |
+
# Apply QK-Norm if needed
|
99 |
+
q = self.self_attn_q_norm(q).to(v)
|
100 |
+
k = self.self_attn_k_norm(k).to(v)
|
101 |
+
|
102 |
+
# Self-Attention
|
103 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
104 |
+
|
105 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
106 |
+
|
107 |
+
# FFN Layer
|
108 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
109 |
+
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class IndividualTokenRefiner(nn.Module):
|
114 |
+
"""
|
115 |
+
Stack of IndividualTokenRefinerBlocks for sequential token refinement.
|
116 |
+
|
117 |
+
Processes token sequences through multiple transformer blocks with
|
118 |
+
attention masking support for handling variable-length sequences.
|
119 |
+
"""
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
hidden_size,
|
123 |
+
num_heads,
|
124 |
+
depth,
|
125 |
+
mlp_ratio: float = 4.0,
|
126 |
+
mlp_drop_rate: float = 0.0,
|
127 |
+
act_type: str = "silu",
|
128 |
+
qk_norm: bool = False,
|
129 |
+
qk_norm_type: str = "layer",
|
130 |
+
qkv_bias: bool = True,
|
131 |
+
dtype: Optional[torch.dtype] = None,
|
132 |
+
device: Optional[torch.device] = None,
|
133 |
+
):
|
134 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
135 |
+
super().__init__()
|
136 |
+
self.blocks = nn.ModuleList([
|
137 |
+
IndividualTokenRefinerBlock(
|
138 |
+
hidden_size=hidden_size,
|
139 |
+
num_heads=num_heads,
|
140 |
+
mlp_ratio=mlp_ratio,
|
141 |
+
mlp_drop_rate=mlp_drop_rate,
|
142 |
+
act_type=act_type,
|
143 |
+
qk_norm=qk_norm,
|
144 |
+
qk_norm_type=qk_norm_type,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
**factory_kwargs,
|
147 |
+
) for _ in range(depth)
|
148 |
+
])
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
x: torch.Tensor,
|
153 |
+
c: torch.LongTensor,
|
154 |
+
mask: Optional[torch.Tensor] = None,
|
155 |
+
):
|
156 |
+
"""
|
157 |
+
Forward pass through the stack of transformer blocks.
|
158 |
+
|
159 |
+
Args:
|
160 |
+
x: Input token embeddings (batch_size, seq_len, hidden_size)
|
161 |
+
c: Conditional embeddings (batch_size, hidden_size)
|
162 |
+
mask: Sequence mask indicating valid tokens (batch_size, seq_len)
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Refined token embeddings after all blocks
|
166 |
+
"""
|
167 |
+
self_attn_mask = None
|
168 |
+
if mask is not None:
|
169 |
+
batch_size = mask.shape[0]
|
170 |
+
seq_len = mask.shape[1]
|
171 |
+
mask = mask.to(x.device)
|
172 |
+
# batch_size x 1 x seq_len x seq_len
|
173 |
+
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
174 |
+
# batch_size x 1 x seq_len x seq_len
|
175 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
176 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
|
177 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
178 |
+
# avoids self-attention weight being NaN for padding tokens
|
179 |
+
self_attn_mask[:, :, :, 0] = True
|
180 |
+
|
181 |
+
for block in self.blocks:
|
182 |
+
x = block(x, c, self_attn_mask)
|
183 |
+
return x
|
184 |
+
|
185 |
+
|
186 |
+
class SingleTokenRefiner(nn.Module):
|
187 |
+
"""
|
188 |
+
Complete token refinement module with input embedding and conditional modulation.
|
189 |
+
|
190 |
+
Integrates timestep embedding, context projection, and a stack of transformer
|
191 |
+
blocks to refine token sequences based on both input data and conditional inputs.
|
192 |
+
"""
|
193 |
+
def __init__(
|
194 |
+
self,
|
195 |
+
in_channels,
|
196 |
+
hidden_size,
|
197 |
+
num_heads,
|
198 |
+
depth,
|
199 |
+
mlp_ratio: float = 4.0,
|
200 |
+
mlp_drop_rate: float = 0.0,
|
201 |
+
act_type: str = "silu",
|
202 |
+
qk_norm: bool = False,
|
203 |
+
qk_norm_type: str = "layer",
|
204 |
+
qkv_bias: bool = True,
|
205 |
+
dtype: Optional[torch.dtype] = None,
|
206 |
+
device: Optional[torch.device] = None,
|
207 |
+
):
|
208 |
+
factory_kwargs = {'device': device, 'dtype': dtype}
|
209 |
+
super().__init__()
|
210 |
+
|
211 |
+
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
|
212 |
+
|
213 |
+
act_layer = get_activation_layer(act_type)
|
214 |
+
# Build timestep embedding layer
|
215 |
+
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
216 |
+
# Build context embedding layer
|
217 |
+
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
|
218 |
+
|
219 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
220 |
+
hidden_size=hidden_size,
|
221 |
+
num_heads=num_heads,
|
222 |
+
depth=depth,
|
223 |
+
mlp_ratio=mlp_ratio,
|
224 |
+
mlp_drop_rate=mlp_drop_rate,
|
225 |
+
act_type=act_type,
|
226 |
+
qk_norm=qk_norm,
|
227 |
+
qk_norm_type=qk_norm_type,
|
228 |
+
qkv_bias=qkv_bias,
|
229 |
+
**factory_kwargs
|
230 |
+
)
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
x: torch.Tensor,
|
235 |
+
t: torch.LongTensor,
|
236 |
+
mask: Optional[torch.LongTensor] = None,
|
237 |
+
):
|
238 |
+
"""
|
239 |
+
Forward pass of the complete token refiner.
|
240 |
+
|
241 |
+
Args:
|
242 |
+
x: Input features (batch_size, seq_len, in_channels)
|
243 |
+
t: Timestep indices (batch_size,)
|
244 |
+
mask: Sequence mask for variable-length inputs (batch_size, seq_len)
|
245 |
+
|
246 |
+
Returns:
|
247 |
+
Refined token embeddings (batch_size, seq_len, hidden_size)
|
248 |
+
"""
|
249 |
+
timestep_aware_representations = self.t_embedder(t)
|
250 |
+
|
251 |
+
if mask is None:
|
252 |
+
context_aware_representations = x.mean(dim=1)
|
253 |
+
else:
|
254 |
+
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
|
255 |
+
context_aware_representations = (
|
256 |
+
(x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
257 |
+
)
|
258 |
+
context_aware_representations = self.c_embedder(context_aware_representations)
|
259 |
+
c = timestep_aware_representations + context_aware_representations
|
260 |
+
|
261 |
+
x = self.input_embedder(x)
|
262 |
+
|
263 |
+
x = self.individual_token_refiner(x, c, mask)
|
264 |
+
|
265 |
+
return x
|
hymm_sp/sample_batch.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
from loguru import logger
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import torch.distributed
|
7 |
+
import random
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
from PIL import Image
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
from torch.utils.data.distributed import DistributedSampler
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from hymm_sp.config import parse_args
|
15 |
+
from hymm_sp.sample_inference import HunyuanVideoSampler
|
16 |
+
from hymm_sp.data_kits.video_dataset import VideoCSVDataset
|
17 |
+
from hymm_sp.data_kits.data_tools import save_videos_grid
|
18 |
+
from hymm_sp.modules.parallel_states import (
|
19 |
+
initialize_distributed,
|
20 |
+
nccl_info,
|
21 |
+
)
|
22 |
+
|
23 |
+
class CropResize:
|
24 |
+
"""
|
25 |
+
Custom transform to resize and crop images to a target size while preserving aspect ratio.
|
26 |
+
|
27 |
+
Resizes the image to ensure it covers the target dimensions, then center-crops to the exact size.
|
28 |
+
Useful for preparing consistent input dimensions for video generation models.
|
29 |
+
"""
|
30 |
+
def __init__(self, size=(704, 1216)):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
size (tuple): Target dimensions (height, width) for the output image
|
34 |
+
"""
|
35 |
+
self.target_h, self.target_w = size
|
36 |
+
|
37 |
+
def __call__(self, img):
|
38 |
+
"""
|
39 |
+
Apply the transform to an image.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
img (PIL.Image): Input image to transform
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
PIL.Image: Resized and cropped image with target dimensions
|
46 |
+
"""
|
47 |
+
# Get original image dimensions
|
48 |
+
w, h = img.size
|
49 |
+
|
50 |
+
# Calculate scaling factor to ensure image covers target size
|
51 |
+
scale = max(
|
52 |
+
self.target_w / w, # Scale needed to cover target width
|
53 |
+
self.target_h / h # Scale needed to cover target height
|
54 |
+
)
|
55 |
+
|
56 |
+
# Resize image while preserving aspect ratio
|
57 |
+
new_size = (int(h * scale), int(w * scale))
|
58 |
+
resize_transform = transforms.Resize(
|
59 |
+
new_size,
|
60 |
+
interpolation=transforms.InterpolationMode.BILINEAR
|
61 |
+
)
|
62 |
+
resized_img = resize_transform(img)
|
63 |
+
|
64 |
+
# Center-crop to exact target dimensions
|
65 |
+
crop_transform = transforms.CenterCrop((self.target_h, self.target_w))
|
66 |
+
return crop_transform(resized_img)
|
67 |
+
|
68 |
+
|
69 |
+
def main():
|
70 |
+
"""
|
71 |
+
Main function for video generation using the Hunyuan multimodal model.
|
72 |
+
|
73 |
+
Handles argument parsing, distributed setup, model loading, data preparation,
|
74 |
+
and video generation with action-controlled transitions. Supports both image-to-video
|
75 |
+
and video-to-video generation tasks.
|
76 |
+
"""
|
77 |
+
# Parse command-line arguments and configuration
|
78 |
+
args = parse_args()
|
79 |
+
models_root_path = Path(args.ckpt)
|
80 |
+
action_list = args.action_list
|
81 |
+
action_speed_list = args.action_speed_list
|
82 |
+
negative_prompt = args.add_neg_prompt
|
83 |
+
|
84 |
+
# Initialize distributed training/evaluation environment
|
85 |
+
logger.info("*" * 20)
|
86 |
+
initialize_distributed(args.seed)
|
87 |
+
|
88 |
+
# Validate model checkpoint path exists
|
89 |
+
if not models_root_path.exists():
|
90 |
+
raise ValueError(f"Model checkpoint path does not exist: {models_root_path}")
|
91 |
+
logger.info("+" * 20)
|
92 |
+
|
93 |
+
# Set up output directory
|
94 |
+
save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}'
|
95 |
+
os.makedirs(save_path, exist_ok=True)
|
96 |
+
logger.info(f"Generated videos will be saved to: {save_path}")
|
97 |
+
|
98 |
+
# Initialize device configuration for distributed processing
|
99 |
+
rank = 0
|
100 |
+
device = torch.device("cuda")
|
101 |
+
if nccl_info.sp_size > 1:
|
102 |
+
# Use specific GPU based on process rank in distributed setup
|
103 |
+
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
104 |
+
rank = torch.distributed.get_rank()
|
105 |
+
|
106 |
+
# Load the Hunyuan video sampler model from checkpoint
|
107 |
+
logger.info(f"Loading model from checkpoint: {args.ckpt}")
|
108 |
+
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(
|
109 |
+
args.ckpt,
|
110 |
+
args=args,
|
111 |
+
device=device if not args.cpu_offload else torch.device("cpu")
|
112 |
+
)
|
113 |
+
# Update args with model-specific configurations from the checkpoint
|
114 |
+
args = hunyuan_video_sampler.args
|
115 |
+
|
116 |
+
# Enable CPU offloading if specified to reduce GPU memory usage
|
117 |
+
if args.cpu_offload:
|
118 |
+
from diffusers.hooks import apply_group_offloading
|
119 |
+
onload_device = torch.device("cuda")
|
120 |
+
apply_group_offloading(
|
121 |
+
hunyuan_video_sampler.pipeline.transformer,
|
122 |
+
onload_device=onload_device,
|
123 |
+
offload_type="block_level",
|
124 |
+
num_blocks_per_group=1
|
125 |
+
)
|
126 |
+
logger.info("Enabled CPU offloading for transformer blocks")
|
127 |
+
|
128 |
+
# Process each batch in the dataset
|
129 |
+
|
130 |
+
prompt = args.prompt
|
131 |
+
image_paths = [args.image_path]
|
132 |
+
logger.info(f"Prompt: {prompt}, Image Path {args.image_path}")
|
133 |
+
# Generate random seed for reproducibility
|
134 |
+
seed = args.seed if args.seed else random.randint(0, 1_000_000)
|
135 |
+
|
136 |
+
# Define image transformation pipeline for input reference images
|
137 |
+
closest_size = (704, 1216)
|
138 |
+
ref_image_transform = transforms.Compose([
|
139 |
+
CropResize(closest_size),
|
140 |
+
transforms.CenterCrop(closest_size),
|
141 |
+
transforms.ToTensor(),
|
142 |
+
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1] range
|
143 |
+
])
|
144 |
+
|
145 |
+
# Handle image-based generation (start from a single image)
|
146 |
+
if args.image_start:
|
147 |
+
# Load and preprocess reference images
|
148 |
+
raw_ref_images = [Image.open(image_path).convert('RGB') for image_path in image_paths]
|
149 |
+
|
150 |
+
# Apply transformations and prepare tensor for model input
|
151 |
+
ref_images_pixel_values = [ref_image_transform(ref_image) for ref_image in raw_ref_images]
|
152 |
+
ref_images_pixel_values = torch.cat(ref_images_pixel_values).unsqueeze(0).unsqueeze(2).to(device)
|
153 |
+
|
154 |
+
# Encode reference images to latent space using VAE
|
155 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
156 |
+
if args.cpu_offload:
|
157 |
+
# Move VAE components to GPU temporarily for encoding
|
158 |
+
hunyuan_video_sampler.vae.quant_conv.to('cuda')
|
159 |
+
hunyuan_video_sampler.vae.encoder.to('cuda')
|
160 |
+
|
161 |
+
# Enable tiling for VAE to handle large images efficiently
|
162 |
+
hunyuan_video_sampler.pipeline.vae.enable_tiling()
|
163 |
+
|
164 |
+
# Encode image to latents and scale by VAE's scaling factor
|
165 |
+
raw_last_latents = hunyuan_video_sampler.vae.encode(
|
166 |
+
ref_images_pixel_values
|
167 |
+
).latent_dist.sample().to(dtype=torch.float16) # Shape: (B, C, F, H, W)
|
168 |
+
raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
|
169 |
+
raw_ref_latents = raw_last_latents.clone()
|
170 |
+
|
171 |
+
# Clean up
|
172 |
+
hunyuan_video_sampler.pipeline.vae.disable_tiling()
|
173 |
+
if args.cpu_offload:
|
174 |
+
# Move VAE components back to CPU after encoding
|
175 |
+
hunyuan_video_sampler.vae.quant_conv.to('cpu')
|
176 |
+
hunyuan_video_sampler.vae.encoder.to('cpu')
|
177 |
+
|
178 |
+
|
179 |
+
# Handle video-based generation (start from an existing video)
|
180 |
+
else:
|
181 |
+
from decord import VideoReader # Lazy import for video handling
|
182 |
+
|
183 |
+
# Validate video file exists
|
184 |
+
video_path = args.video_path
|
185 |
+
if not os.path.exists(video_path):
|
186 |
+
raise RuntimeError(f"Video file not found: {video_path}")
|
187 |
+
|
188 |
+
# Load reference images from video metadata
|
189 |
+
raw_ref_images = [Image.open(image_path).convert('RGB') for image_path in image_paths]
|
190 |
+
|
191 |
+
# Load video and extract frames
|
192 |
+
ref_video = VideoReader(video_path)
|
193 |
+
ref_frames_length = len(ref_video)
|
194 |
+
logger.info(f"Loaded reference video with {ref_frames_length} frames")
|
195 |
+
|
196 |
+
# Preprocess video frames
|
197 |
+
transformed_images = []
|
198 |
+
for index in range(ref_frames_length):
|
199 |
+
# Convert video frame to PIL image and apply transformations
|
200 |
+
video_image = ref_video[index].numpy()
|
201 |
+
transformed_image = ref_image_transform(Image.fromarray(video_image))
|
202 |
+
transformed_images.append(transformed_image)
|
203 |
+
|
204 |
+
# Prepare tensor for model input
|
205 |
+
transformed_images = torch.stack(transformed_images, dim=1).unsqueeze(0).to(
|
206 |
+
device=hunyuan_video_sampler.device,
|
207 |
+
dtype=torch.float16
|
208 |
+
)
|
209 |
+
|
210 |
+
# Encode video frames to latent space using VAE
|
211 |
+
with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
|
212 |
+
if args.cpu_offload:
|
213 |
+
hunyuan_video_sampler.vae.quant_conv.to('cuda')
|
214 |
+
hunyuan_video_sampler.vae.encoder.to('cuda')
|
215 |
+
|
216 |
+
hunyuan_video_sampler.pipeline.vae.enable_tiling()
|
217 |
+
|
218 |
+
# Encode last 33 frames of video (model-specific requirement)
|
219 |
+
raw_last_latents = hunyuan_video_sampler.vae.encode(
|
220 |
+
transformed_images[:, :, -33:, ...]
|
221 |
+
).latent_dist.sample().to(dtype=torch.float16)
|
222 |
+
raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
|
223 |
+
|
224 |
+
# Encode a single reference frame from the video
|
225 |
+
raw_ref_latents = hunyuan_video_sampler.vae.encode(
|
226 |
+
transformed_images[:, :, -33:-32, ...]
|
227 |
+
).latent_dist.sample().to(dtype=torch.float16)
|
228 |
+
raw_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
|
229 |
+
|
230 |
+
# Clean up
|
231 |
+
hunyuan_video_sampler.pipeline.vae.disable_tiling()
|
232 |
+
if args.cpu_offload:
|
233 |
+
hunyuan_video_sampler.vae.quant_conv.to('cpu')
|
234 |
+
hunyuan_video_sampler.vae.encoder.to('cpu')
|
235 |
+
|
236 |
+
# Store references for generation loop
|
237 |
+
ref_images = raw_ref_images
|
238 |
+
last_latents = raw_last_latents
|
239 |
+
ref_latents = raw_ref_latents
|
240 |
+
|
241 |
+
# Generate video segments for each action in the action list
|
242 |
+
for idx, action_id in enumerate(action_list):
|
243 |
+
# Determine if this is the first action and using image start
|
244 |
+
is_image = (idx == 0 and args.image_start)
|
245 |
+
|
246 |
+
logger.info(f"Generating segment {idx+1}/{len(action_list)} with action ID: {action_id}")
|
247 |
+
# Generate video segment with the current action
|
248 |
+
outputs = hunyuan_video_sampler.predict(
|
249 |
+
prompt=prompt,
|
250 |
+
action_id=action_id,
|
251 |
+
action_speed=action_speed_list[idx],
|
252 |
+
is_image=is_image,
|
253 |
+
size=(704, 1216),
|
254 |
+
seed=seed,
|
255 |
+
last_latents=last_latents, # Previous frame latents for continuity
|
256 |
+
ref_latents=ref_latents, # Reference latents for style consistency
|
257 |
+
video_length=args.sample_n_frames,
|
258 |
+
guidance_scale=args.cfg_scale,
|
259 |
+
num_images_per_prompt=args.num_images,
|
260 |
+
negative_prompt=negative_prompt,
|
261 |
+
infer_steps=args.infer_steps,
|
262 |
+
flow_shift=args.flow_shift_eval_video,
|
263 |
+
use_linear_quadratic_schedule=args.use_linear_quadratic_schedule,
|
264 |
+
linear_schedule_end=args.linear_schedule_end,
|
265 |
+
use_deepcache=args.use_deepcache,
|
266 |
+
cpu_offload=args.cpu_offload,
|
267 |
+
ref_images=ref_images,
|
268 |
+
output_dir=save_path,
|
269 |
+
return_latents=True,
|
270 |
+
use_sage=args.use_sage,
|
271 |
+
)
|
272 |
+
|
273 |
+
# Update latents for next iteration (maintain temporal consistency)
|
274 |
+
ref_latents = outputs["ref_latents"]
|
275 |
+
last_latents = outputs["last_latents"]
|
276 |
+
|
277 |
+
# Save generated video segments if this is the main process (rank 0)
|
278 |
+
if rank == 0:
|
279 |
+
sub_samples = outputs['samples'][0]
|
280 |
+
|
281 |
+
# Initialize or concatenate video segments
|
282 |
+
if idx == 0:
|
283 |
+
if args.image_start:
|
284 |
+
out_cat = sub_samples
|
285 |
+
else:
|
286 |
+
# Combine original video frames with generated frames
|
287 |
+
out_cat = torch.cat([(transformed_images.detach().cpu() + 1) / 2.0, sub_samples], dim=2)
|
288 |
+
else:
|
289 |
+
# Append new segment to existing video
|
290 |
+
out_cat = torch.cat([out_cat, sub_samples], dim=2)
|
291 |
+
|
292 |
+
# Save final combined video
|
293 |
+
save_path_mp4 = f"{save_path}/{os.path.basename(args.image_path).split('.')[0]}.mp4"
|
294 |
+
save_videos_grid(out_cat, save_path_mp4, n_rows=1, fps=24)
|
295 |
+
logger.info(f"Saved generated video to: {save_path_mp4}")
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
main()
|
hymm_sp/sample_inference.py
ADDED
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
from loguru import logger
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib as mpl
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from matplotlib.patches import Patch
|
10 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
11 |
+
|
12 |
+
from hymm_sp.diffusion import load_diffusion_pipeline
|
13 |
+
from hymm_sp.helpers import get_nd_rotary_pos_embed_new
|
14 |
+
from hymm_sp.inference import Inference
|
15 |
+
from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler
|
16 |
+
from packaging import version as pver
|
17 |
+
|
18 |
+
ACTION_DICT = {"w": "forward", "a": "left", "d": "right", "s": "backward"}
|
19 |
+
|
20 |
+
def custom_meshgrid(*args):
|
21 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
22 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
23 |
+
return torch.meshgrid(*args)
|
24 |
+
else:
|
25 |
+
return torch.meshgrid(*args, indexing='ij')
|
26 |
+
|
27 |
+
def get_relative_pose(cam_params):
|
28 |
+
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
|
29 |
+
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
|
30 |
+
source_cam_c2w = abs_c2ws[0]
|
31 |
+
cam_to_origin = 0
|
32 |
+
target_cam_c2w = np.array([
|
33 |
+
[1, 0, 0, 0],
|
34 |
+
[0, 1, 0, -cam_to_origin],
|
35 |
+
[0, 0, 1, 0],
|
36 |
+
[0, 0, 0, 1]
|
37 |
+
])
|
38 |
+
abs2rel = target_cam_c2w @ abs_w2cs[0]
|
39 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
|
40 |
+
for pose in ret_poses:
|
41 |
+
pose[:3, -1:] *= 10
|
42 |
+
ret_poses = np.array(ret_poses, dtype=np.float32)
|
43 |
+
return ret_poses
|
44 |
+
|
45 |
+
def ray_condition(K, c2w, H, W, device, flip_flag=None):
|
46 |
+
# c2w: B, V, 4, 4
|
47 |
+
# K: B, V, 4
|
48 |
+
|
49 |
+
B, V = K.shape[:2]
|
50 |
+
|
51 |
+
j, i = custom_meshgrid(
|
52 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
53 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
54 |
+
)
|
55 |
+
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
|
56 |
+
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
|
57 |
+
|
58 |
+
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
|
59 |
+
if n_flip > 0:
|
60 |
+
j_flip, i_flip = custom_meshgrid(
|
61 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
62 |
+
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
|
63 |
+
)
|
64 |
+
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
|
65 |
+
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
|
66 |
+
i[:, flip_flag, ...] = i_flip
|
67 |
+
j[:, flip_flag, ...] = j_flip
|
68 |
+
|
69 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
70 |
+
|
71 |
+
zs = torch.ones_like(i) # [B, V, HxW]
|
72 |
+
xs = (i - cx) / fx * zs
|
73 |
+
ys = (j - cy) / fy * zs
|
74 |
+
zs = zs.expand_as(ys)
|
75 |
+
|
76 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
77 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
78 |
+
|
79 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
|
80 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
81 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
|
82 |
+
# c2w @ dirctions
|
83 |
+
rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3
|
84 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
85 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
86 |
+
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
87 |
+
return plucker
|
88 |
+
|
89 |
+
def get_c2w(w2cs, transform_matrix, relative_c2w):
|
90 |
+
if relative_c2w:
|
91 |
+
target_cam_c2w = np.array([
|
92 |
+
[1, 0, 0, 0],
|
93 |
+
[0, 1, 0, 0],
|
94 |
+
[0, 0, 1, 0],
|
95 |
+
[0, 0, 0, 1]
|
96 |
+
])
|
97 |
+
abs2rel = target_cam_c2w @ w2cs[0]
|
98 |
+
ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
|
99 |
+
for pose in ret_poses:
|
100 |
+
pose[:3, -1:] *= 2
|
101 |
+
# ret_poses = [poses[:, :3]*2 for poses in ret_poses]
|
102 |
+
# ret_poses[:, :, :3] *= 2
|
103 |
+
else:
|
104 |
+
ret_poses = [np.linalg.inv(w2c) for w2c in w2cs]
|
105 |
+
ret_poses = [transform_matrix @ x for x in ret_poses]
|
106 |
+
return np.array(ret_poses, dtype=np.float32)
|
107 |
+
|
108 |
+
def generate_motion_segment(current_pose,
|
109 |
+
motion_type: str,
|
110 |
+
value: float,
|
111 |
+
duration: int = 30):
|
112 |
+
"""
|
113 |
+
Parameters:
|
114 |
+
motion_type: ('forward', 'backward', 'left', 'right',
|
115 |
+
'rotate_left', 'rotate_right', 'rotate_up', 'rotate_down')
|
116 |
+
value: Translation(m) or Rotation(degree)
|
117 |
+
duration: frames
|
118 |
+
|
119 |
+
Return:
|
120 |
+
positions: [np.array(x,y,z), ...]
|
121 |
+
rotations: [np.array(pitch,yaw,roll), ...]
|
122 |
+
"""
|
123 |
+
positions = []
|
124 |
+
rotations = []
|
125 |
+
|
126 |
+
if motion_type in ['forward', 'backward']:
|
127 |
+
yaw_rad = np.radians(current_pose['rotation'][1])
|
128 |
+
pitch_rad = np.radians(current_pose['rotation'][0])
|
129 |
+
|
130 |
+
forward_vec = np.array([
|
131 |
+
-math.sin(yaw_rad) * math.cos(pitch_rad),
|
132 |
+
math.sin(pitch_rad),
|
133 |
+
-math.cos(yaw_rad) * math.cos(pitch_rad)
|
134 |
+
])
|
135 |
+
|
136 |
+
direction = 1 if motion_type == 'forward' else -1
|
137 |
+
total_move = forward_vec * value * direction
|
138 |
+
step = total_move / duration
|
139 |
+
|
140 |
+
for i in range(1, duration+1):
|
141 |
+
new_pos = current_pose['position'] + step * i
|
142 |
+
positions.append(new_pos.copy())
|
143 |
+
rotations.append(current_pose['rotation'].copy())
|
144 |
+
|
145 |
+
current_pose['position'] = positions[-1]
|
146 |
+
|
147 |
+
elif motion_type in ['left', 'right']:
|
148 |
+
yaw_rad = np.radians(current_pose['rotation'][1])
|
149 |
+
right_vec = np.array([math.cos(yaw_rad), 0, -math.sin(yaw_rad)])
|
150 |
+
|
151 |
+
direction = -1 if motion_type == 'right' else 1
|
152 |
+
total_move = right_vec * value * direction
|
153 |
+
step = total_move / duration
|
154 |
+
|
155 |
+
for i in range(1, duration+1):
|
156 |
+
new_pos = current_pose['position'] + step * i
|
157 |
+
positions.append(new_pos.copy())
|
158 |
+
rotations.append(current_pose['rotation'].copy())
|
159 |
+
|
160 |
+
current_pose['position'] = positions[-1]
|
161 |
+
|
162 |
+
elif motion_type.endswith('rot'):
|
163 |
+
axis = motion_type.split('_')[0]
|
164 |
+
total_rotation = np.zeros(3)
|
165 |
+
|
166 |
+
if axis == 'left':
|
167 |
+
total_rotation[0] = value
|
168 |
+
elif axis == 'right':
|
169 |
+
total_rotation[0] = -value
|
170 |
+
elif axis == 'up':
|
171 |
+
total_rotation[2] = -value
|
172 |
+
elif axis == 'down':
|
173 |
+
total_rotation[2] = value
|
174 |
+
|
175 |
+
step = total_rotation / duration
|
176 |
+
|
177 |
+
for i in range(1, duration+1):
|
178 |
+
positions.append(current_pose['position'].copy())
|
179 |
+
new_rot = current_pose['rotation'] + step * i
|
180 |
+
rotations.append(new_rot.copy())
|
181 |
+
|
182 |
+
current_pose['rotation'] = rotations[-1]
|
183 |
+
|
184 |
+
return positions, rotations, current_pose
|
185 |
+
|
186 |
+
def euler_to_quaternion(angles):
|
187 |
+
pitch, yaw, roll = np.radians(angles)
|
188 |
+
|
189 |
+
cy = math.cos(yaw * 0.5)
|
190 |
+
sy = math.sin(yaw * 0.5)
|
191 |
+
cp = math.cos(pitch * 0.5)
|
192 |
+
sp = math.sin(pitch * 0.5)
|
193 |
+
cr = math.cos(roll * 0.5)
|
194 |
+
sr = math.sin(roll * 0.5)
|
195 |
+
|
196 |
+
qw = cy * cp * cr + sy * sp * sr
|
197 |
+
qx = cy * cp * sr - sy * sp * cr
|
198 |
+
qy = sy * cp * sr + cy * sp * cr
|
199 |
+
qz = sy * cp * cr - cy * sp * sr
|
200 |
+
|
201 |
+
return [qw, qx, qy, qz]
|
202 |
+
|
203 |
+
def quaternion_to_rotation_matrix(q):
|
204 |
+
qw, qx, qy, qz = q
|
205 |
+
return np.array([
|
206 |
+
[1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz), 2*(qx*qz + qw*qy)],
|
207 |
+
[2*(qx*qy + qw*qz), 1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
|
208 |
+
[2*(qx*qz - qw*qy), 2*(qy*qz + qw*qx), 1 - 2*(qx**2 + qy**2)]
|
209 |
+
])
|
210 |
+
|
211 |
+
def ActionToPoseFromID(action_id, value=0.2, duration=33):
|
212 |
+
|
213 |
+
all_positions = []
|
214 |
+
all_rotations = []
|
215 |
+
current_pose = {
|
216 |
+
'position': np.array([0.0, 0.0, 0.0]), # XYZ
|
217 |
+
'rotation': np.array([0.0, 0.0, 0.0]) # (pitch, yaw, roll)
|
218 |
+
}
|
219 |
+
intrinsic = [0.50505, 0.8979, 0.5, 0.5]
|
220 |
+
motion_type = ACTION_DICT[action_id]
|
221 |
+
positions, rotations, current_pose = generate_motion_segment(current_pose, motion_type, value, duration)
|
222 |
+
all_positions.extend(positions)
|
223 |
+
all_rotations.extend(rotations)
|
224 |
+
|
225 |
+
pose_list = []
|
226 |
+
|
227 |
+
row = [0] + intrinsic + [0, 0] + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
|
228 |
+
first_row = " ".join(map(str, row))
|
229 |
+
pose_list.append(first_row)
|
230 |
+
for i, (pos, rot) in enumerate(zip(all_positions, all_rotations)):
|
231 |
+
quat = euler_to_quaternion(rot)
|
232 |
+
R = quaternion_to_rotation_matrix(quat)
|
233 |
+
extrinsic = np.hstack([R, pos.reshape(3, 1)])
|
234 |
+
|
235 |
+
row = [i] + intrinsic + [0, 0] + extrinsic.flatten().tolist()
|
236 |
+
pose_list.append(" ".join(map(str, row)))
|
237 |
+
|
238 |
+
return pose_list
|
239 |
+
|
240 |
+
class Camera(object):
|
241 |
+
def __init__(self, entry):
|
242 |
+
fx, fy, cx, cy = entry[1:5]
|
243 |
+
self.fx = fx
|
244 |
+
self.fy = fy
|
245 |
+
self.cx = cx
|
246 |
+
self.cy = cy
|
247 |
+
w2c_mat = np.array(entry[7:]).reshape(3, 4)
|
248 |
+
w2c_mat_4x4 = np.eye(4)
|
249 |
+
w2c_mat_4x4[:3, :] = w2c_mat
|
250 |
+
self.w2c_mat = w2c_mat_4x4
|
251 |
+
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
|
252 |
+
|
253 |
+
class CameraPoseVisualizer:
|
254 |
+
def __init__(self, xlim, ylim, zlim):
|
255 |
+
self.fig = plt.figure(figsize=(7, 7))
|
256 |
+
self.ax = self.fig.add_subplot(projection='3d')
|
257 |
+
# self.ax.view_init(elev=25, azim=-90)
|
258 |
+
self.plotly_data = None # plotly data traces
|
259 |
+
self.ax.set_aspect("auto")
|
260 |
+
self.ax.set_xlim(xlim)
|
261 |
+
self.ax.set_ylim(ylim)
|
262 |
+
self.ax.set_zlim(zlim)
|
263 |
+
self.ax.set_xlabel('x')
|
264 |
+
self.ax.set_ylabel('y')
|
265 |
+
self.ax.set_zlabel('z')
|
266 |
+
print('initialize camera pose visualizer')
|
267 |
+
|
268 |
+
def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16, base_xval=1, zval=3):
|
269 |
+
vertex_std = np.array([[0, 0, 0, 1],
|
270 |
+
[base_xval, -base_xval * hw_ratio, zval, 1],
|
271 |
+
[base_xval, base_xval * hw_ratio, zval, 1],
|
272 |
+
[-base_xval, base_xval * hw_ratio, zval, 1],
|
273 |
+
[-base_xval, -base_xval * hw_ratio, zval, 1]])
|
274 |
+
vertex_transformed = vertex_std @ extrinsic.T
|
275 |
+
meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
|
276 |
+
[vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
|
277 |
+
[vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
|
278 |
+
[vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
|
279 |
+
[vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1],
|
280 |
+
vertex_transformed[4, :-1]]]
|
281 |
+
|
282 |
+
color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
|
283 |
+
|
284 |
+
self.ax.add_collection3d(
|
285 |
+
Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
|
286 |
+
|
287 |
+
def customize_legend(self, list_label):
|
288 |
+
list_handle = []
|
289 |
+
for idx, label in enumerate(list_label):
|
290 |
+
color = plt.cm.rainbow(idx / len(list_label))
|
291 |
+
patch = Patch(color=color, label=label)
|
292 |
+
list_handle.append(patch)
|
293 |
+
plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)
|
294 |
+
|
295 |
+
def colorbar(self, max_frame_length):
|
296 |
+
cmap = mpl.cm.rainbow
|
297 |
+
norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
|
298 |
+
self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
|
299 |
+
ax=self.ax, orientation='vertical', label='Frame Number')
|
300 |
+
|
301 |
+
def show(self, file_name):
|
302 |
+
plt.title('Extrinsic Parameters')
|
303 |
+
# plt.show()
|
304 |
+
plt.savefig(file_name)
|
305 |
+
|
306 |
+
|
307 |
+
def align_to(value, alignment):
|
308 |
+
return int(math.ceil(value / alignment) * alignment)
|
309 |
+
|
310 |
+
|
311 |
+
def GetPoseEmbedsFromPoses(poses, h, w, target_length, flip=False, start_index=None):
|
312 |
+
|
313 |
+
poses = [pose.split(' ') for pose in poses]
|
314 |
+
|
315 |
+
start_idx = start_index
|
316 |
+
sample_id = [start_idx + i for i in range(target_length)]
|
317 |
+
|
318 |
+
poses = [poses[i] for i in sample_id]
|
319 |
+
|
320 |
+
frame = len(poses)
|
321 |
+
w2cs = [np.asarray([float(p) for p in pose[7:]]).reshape(3, 4) for pose in poses]
|
322 |
+
transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
|
323 |
+
last_row = np.zeros((1, 4))
|
324 |
+
last_row[0, -1] = 1.0
|
325 |
+
w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
|
326 |
+
c2ws = get_c2w(w2cs, transform_matrix, relative_c2w=True)
|
327 |
+
|
328 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
329 |
+
assert len(cam_params) == target_length
|
330 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
331 |
+
|
332 |
+
monst3r_w = cam_params[0].cx * 2
|
333 |
+
monst3r_h = cam_params[0].cy * 2
|
334 |
+
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h
|
335 |
+
intrinsics = np.asarray([[cam_param.fx * ratio_w,
|
336 |
+
cam_param.fy * ratio_h,
|
337 |
+
cam_param.cx * ratio_w,
|
338 |
+
cam_param.cy * ratio_h]
|
339 |
+
for cam_param in cam_params], dtype=np.float32)
|
340 |
+
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
|
341 |
+
relative_pose = True
|
342 |
+
if relative_pose:
|
343 |
+
c2w_poses = get_relative_pose(cam_params)
|
344 |
+
else:
|
345 |
+
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
|
346 |
+
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
|
347 |
+
uncond_c2w = torch.zeros_like(c2w)
|
348 |
+
uncond_c2w[:, :] = torch.eye(4, device=c2w.device)
|
349 |
+
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device)
|
350 |
+
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu',
|
351 |
+
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
|
352 |
+
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu',
|
353 |
+
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
|
354 |
+
|
355 |
+
return plucker_embedding, uncond_plucker_embedding, poses
|
356 |
+
|
357 |
+
def GetPoseEmbedsFromTxt(pose_dir, h, w, target_length, flip=False, start_index=None, step=1):
|
358 |
+
# get camera pose
|
359 |
+
with open(pose_dir, 'r') as f:
|
360 |
+
poses = f.readlines()
|
361 |
+
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
362 |
+
start_idx = start_index
|
363 |
+
sample_id = [start_idx + i*step for i in range(target_length)]
|
364 |
+
poses = [poses[i] for i in sample_id]
|
365 |
+
|
366 |
+
cam_params = [[float(x) for x in pose] for pose in poses]
|
367 |
+
assert len(cam_params) == target_length
|
368 |
+
cam_params = [Camera(cam_param) for cam_param in cam_params]
|
369 |
+
|
370 |
+
monst3r_w = cam_params[0].cx * 2
|
371 |
+
monst3r_h = cam_params[0].cy * 2
|
372 |
+
ratio_w, ratio_h = w/monst3r_w, h/monst3r_h
|
373 |
+
intrinsics = np.asarray([[cam_param.fx * ratio_w,
|
374 |
+
cam_param.fy * ratio_h,
|
375 |
+
cam_param.cx * ratio_w,
|
376 |
+
cam_param.cy * ratio_h]
|
377 |
+
for cam_param in cam_params], dtype=np.float32)
|
378 |
+
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
|
379 |
+
relative_pose = True
|
380 |
+
if relative_pose:
|
381 |
+
c2w_poses = get_relative_pose(cam_params)
|
382 |
+
else:
|
383 |
+
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
|
384 |
+
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
|
385 |
+
uncond_c2w = torch.zeros_like(c2w)
|
386 |
+
uncond_c2w[:, :] = torch.eye(4, device=c2w.device)
|
387 |
+
if flip:
|
388 |
+
flip_flag = torch.ones(target_length, dtype=torch.bool, device=c2w.device)
|
389 |
+
else:
|
390 |
+
flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device)
|
391 |
+
plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu',
|
392 |
+
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
|
393 |
+
uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu',
|
394 |
+
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
|
395 |
+
|
396 |
+
return plucker_embedding, uncond_plucker_embedding, poses
|
397 |
+
|
398 |
+
|
399 |
+
class HunyuanVideoSampler(Inference):
|
400 |
+
def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None,
|
401 |
+
device=0, logger=None):
|
402 |
+
super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2,
|
403 |
+
pipeline=pipeline, device=device, logger=logger)
|
404 |
+
|
405 |
+
self.args = args
|
406 |
+
self.pipeline = load_diffusion_pipeline(
|
407 |
+
args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model,
|
408 |
+
device=self.device)
|
409 |
+
print('load hunyuan model successful... ')
|
410 |
+
|
411 |
+
def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}):
|
412 |
+
target_ndim = 3
|
413 |
+
ndim = 5 - 2
|
414 |
+
if '884' in self.args.vae:
|
415 |
+
latents_size = [(video_length-1)//4+1 , height//8, width//8]
|
416 |
+
else:
|
417 |
+
latents_size = [video_length , height//8, width//8]
|
418 |
+
|
419 |
+
if isinstance(self.model.patch_size, int):
|
420 |
+
assert all(s % self.model.patch_size == 0 for s in latents_size), \
|
421 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
|
422 |
+
f"but got {latents_size}."
|
423 |
+
rope_sizes = [s // self.model.patch_size for s in latents_size]
|
424 |
+
elif isinstance(self.model.patch_size, list):
|
425 |
+
assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
|
426 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
|
427 |
+
f"but got {latents_size}."
|
428 |
+
rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
|
429 |
+
|
430 |
+
if len(rope_sizes) != target_ndim:
|
431 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
432 |
+
head_dim = self.model.hidden_size // self.model.num_heads
|
433 |
+
rope_dim_list = self.model.rope_dim_list
|
434 |
+
if rope_dim_list is None:
|
435 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
436 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
437 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
|
438 |
+
rope_sizes,
|
439 |
+
theta=self.args.rope_theta,
|
440 |
+
use_real=True,
|
441 |
+
theta_rescale_factor=1,
|
442 |
+
concat_dict=concat_dict)
|
443 |
+
return freqs_cos, freqs_sin
|
444 |
+
|
445 |
+
@torch.no_grad()
|
446 |
+
def predict(self,
|
447 |
+
prompt,
|
448 |
+
is_image=True,
|
449 |
+
size=(720, 1280),
|
450 |
+
video_length=129,
|
451 |
+
seed=None,
|
452 |
+
negative_prompt=None,
|
453 |
+
infer_steps=50,
|
454 |
+
guidance_scale=6.0,
|
455 |
+
flow_shift=5.0,
|
456 |
+
batch_size=1,
|
457 |
+
num_videos_per_prompt=1,
|
458 |
+
verbose=1,
|
459 |
+
output_type="pil",
|
460 |
+
**kwargs):
|
461 |
+
"""
|
462 |
+
Predict the image from the given text.
|
463 |
+
|
464 |
+
Args:
|
465 |
+
prompt (str or List[str]): The input text.
|
466 |
+
kwargs:
|
467 |
+
size (int): The (height, width) of the output image/video. Default is (256, 256).
|
468 |
+
video_length (int): The frame number of the output video. Default is 1.
|
469 |
+
seed (int or List[str]): The random seed for the generation. Default is a random integer.
|
470 |
+
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
|
471 |
+
infer_steps (int): The number of inference steps. Default is 100.
|
472 |
+
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
|
473 |
+
num_videos_per_prompt (int): The number of videos per prompt. Default is 1.
|
474 |
+
verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1.
|
475 |
+
output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
476 |
+
Default is 'pil'.
|
477 |
+
"""
|
478 |
+
|
479 |
+
out_dict = dict()
|
480 |
+
|
481 |
+
# ---------------------------------
|
482 |
+
# Prompt
|
483 |
+
# ---------------------------------
|
484 |
+
prompt_embeds = kwargs.get("prompt_embeds", None)
|
485 |
+
attention_mask = kwargs.get("attention_mask", None)
|
486 |
+
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
|
487 |
+
negative_attention_mask = kwargs.get("negative_attention_mask", None)
|
488 |
+
ref_latents = kwargs.get("ref_latents", None)
|
489 |
+
uncond_ref_latents = kwargs.get("uncond_ref_latents", None)
|
490 |
+
return_latents = kwargs.get("return_latents", False)
|
491 |
+
negative_prompt = kwargs.get("negative_prompt", None)
|
492 |
+
|
493 |
+
action_id = kwargs.get("action_id", None)
|
494 |
+
action_speed = kwargs.get("action_speed", None)
|
495 |
+
start_index = kwargs.get("start_index", None)
|
496 |
+
last_latents = kwargs.get("last_latents", None)
|
497 |
+
ref_latents = kwargs.get("ref_latents", None)
|
498 |
+
input_pose = kwargs.get("input_pose", None)
|
499 |
+
step = kwargs.get("step", 1)
|
500 |
+
use_sage = kwargs.get("use_sage", False)
|
501 |
+
|
502 |
+
size = self.parse_size(size)
|
503 |
+
target_height = align_to(size[0], 16)
|
504 |
+
target_width = align_to(size[1], 16)
|
505 |
+
# target_video_length = video_length
|
506 |
+
|
507 |
+
if input_pose is not None:
|
508 |
+
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromTxt(input_pose,
|
509 |
+
target_height,
|
510 |
+
target_width,
|
511 |
+
33,
|
512 |
+
kwargs.get("flip", False),
|
513 |
+
start_index,
|
514 |
+
step)
|
515 |
+
else:
|
516 |
+
pose = ActionToPoseFromID(action_id, value=action_speed)
|
517 |
+
pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromPoses(pose,
|
518 |
+
target_height,
|
519 |
+
target_width,
|
520 |
+
33,
|
521 |
+
kwargs.get("flip", False),
|
522 |
+
0)
|
523 |
+
|
524 |
+
if is_image:
|
525 |
+
target_length = 34
|
526 |
+
else:
|
527 |
+
target_length = 66
|
528 |
+
|
529 |
+
out_dict['frame'] = target_length
|
530 |
+
# print("pose embeds: ", pose_embeds.shape, uncond_pose_embeds.shape)
|
531 |
+
|
532 |
+
pose_embeds = pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda')
|
533 |
+
uncond_pose_embeds = uncond_pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda')
|
534 |
+
|
535 |
+
|
536 |
+
|
537 |
+
cpu_offload = kwargs.get("cpu_offload", 0)
|
538 |
+
use_deepcache = kwargs.get("use_deepcache", 1)
|
539 |
+
denoise_strength = kwargs.get("denoise_strength", 1.0)
|
540 |
+
init_latents = kwargs.get("init_latents", None)
|
541 |
+
mask = kwargs.get("mask", None)
|
542 |
+
if prompt is None:
|
543 |
+
# prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should not be None
|
544 |
+
# pipeline will help to check this
|
545 |
+
prompt = None
|
546 |
+
negative_prompt = None
|
547 |
+
batch_size = prompt_embeds.shape[0]
|
548 |
+
assert prompt_embeds is not None
|
549 |
+
else:
|
550 |
+
# prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should be None
|
551 |
+
# pipeline will help to check this
|
552 |
+
if isinstance(prompt, str):
|
553 |
+
batch_size = 1
|
554 |
+
prompt = [prompt]
|
555 |
+
elif isinstance(prompt, (list, tuple)):
|
556 |
+
batch_size = len(prompt)
|
557 |
+
else:
|
558 |
+
raise ValueError(f"Prompt must be a string or a list of strings, got {prompt}.")
|
559 |
+
|
560 |
+
if negative_prompt is None:
|
561 |
+
negative_prompt = [""] * batch_size
|
562 |
+
if isinstance(negative_prompt, str):
|
563 |
+
negative_prompt = [negative_prompt] * batch_size
|
564 |
+
|
565 |
+
# ---------------------------------
|
566 |
+
# Other arguments
|
567 |
+
# ---------------------------------
|
568 |
+
scheduler = FlowMatchDiscreteScheduler(shift=flow_shift,
|
569 |
+
reverse=self.args.flow_reverse,
|
570 |
+
solver=self.args.flow_solver,
|
571 |
+
)
|
572 |
+
self.pipeline.scheduler = scheduler
|
573 |
+
|
574 |
+
# ---------------------------------
|
575 |
+
# Random seed
|
576 |
+
# ---------------------------------
|
577 |
+
|
578 |
+
if isinstance(seed, torch.Tensor):
|
579 |
+
seed = seed.tolist()
|
580 |
+
if seed is None:
|
581 |
+
seeds = [random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt)]
|
582 |
+
elif isinstance(seed, int):
|
583 |
+
seeds = [seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt)]
|
584 |
+
elif isinstance(seed, (list, tuple)):
|
585 |
+
if len(seed) == batch_size:
|
586 |
+
seeds = [int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt)]
|
587 |
+
elif len(seed) == batch_size * num_videos_per_prompt:
|
588 |
+
seeds = [int(s) for s in seed]
|
589 |
+
else:
|
590 |
+
raise ValueError(
|
591 |
+
f"Length of seed must be equal to number of prompt(batch_size) or "
|
592 |
+
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
|
593 |
+
)
|
594 |
+
else:
|
595 |
+
raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.")
|
596 |
+
generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
|
597 |
+
|
598 |
+
# ---------------------------------
|
599 |
+
# Image/Video size and frame
|
600 |
+
# ---------------------------------
|
601 |
+
|
602 |
+
|
603 |
+
out_dict['size'] = (target_height, target_width)
|
604 |
+
out_dict['video_length'] = target_length
|
605 |
+
out_dict['seeds'] = seeds
|
606 |
+
out_dict['negative_prompt'] = negative_prompt
|
607 |
+
# ---------------------------------
|
608 |
+
# Build RoPE
|
609 |
+
# ---------------------------------
|
610 |
+
|
611 |
+
concat_dict = {'mode': 'timecat', 'bias': -1}
|
612 |
+
if is_image:
|
613 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed(37, target_height, target_width)
|
614 |
+
else:
|
615 |
+
freqs_cos, freqs_sin = self.get_rotary_pos_embed(69, target_height, target_width)
|
616 |
+
|
617 |
+
n_tokens = freqs_cos.shape[0]
|
618 |
+
|
619 |
+
# ---------------------------------
|
620 |
+
# Inference
|
621 |
+
# ---------------------------------
|
622 |
+
output_dir = kwargs.get("output_dir", None)
|
623 |
+
|
624 |
+
if verbose == 1:
|
625 |
+
debug_str = f"""
|
626 |
+
size: {out_dict['size']}
|
627 |
+
video_length: {target_length}
|
628 |
+
prompt: {prompt}
|
629 |
+
neg_prompt: {negative_prompt}
|
630 |
+
seed: {seed}
|
631 |
+
infer_steps: {infer_steps}
|
632 |
+
denoise_strength: {denoise_strength}
|
633 |
+
use_deepcache: {use_deepcache}
|
634 |
+
use_sage: {use_sage}
|
635 |
+
cpu_offload: {cpu_offload}
|
636 |
+
num_images_per_prompt: {num_videos_per_prompt}
|
637 |
+
guidance_scale: {guidance_scale}
|
638 |
+
n_tokens: {n_tokens}
|
639 |
+
flow_shift: {flow_shift}
|
640 |
+
output: {output_dir}"""
|
641 |
+
self.logger.info(debug_str)
|
642 |
+
|
643 |
+
start_time = time.time()
|
644 |
+
samples = self.pipeline(prompt=prompt,
|
645 |
+
last_latents=last_latents,
|
646 |
+
cam_latents=pose_embeds,
|
647 |
+
uncond_cam_latents=uncond_pose_embeds,
|
648 |
+
height=target_height,
|
649 |
+
width=target_width,
|
650 |
+
video_length=target_length,
|
651 |
+
gt_latents = ref_latents,
|
652 |
+
num_inference_steps=infer_steps,
|
653 |
+
guidance_scale=guidance_scale,
|
654 |
+
negative_prompt=negative_prompt,
|
655 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
656 |
+
generator=generator,
|
657 |
+
prompt_embeds=prompt_embeds,
|
658 |
+
ref_latents=ref_latents,
|
659 |
+
latents=init_latents,
|
660 |
+
denoise_strength=denoise_strength,
|
661 |
+
mask=mask,
|
662 |
+
uncond_ref_latents=uncond_ref_latents,
|
663 |
+
ip_cfg_scale=self.args.ip_cfg_scale,
|
664 |
+
use_deepcache=use_deepcache,
|
665 |
+
attention_mask=attention_mask,
|
666 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
667 |
+
negative_attention_mask=negative_attention_mask,
|
668 |
+
output_type=output_type,
|
669 |
+
freqs_cis=(freqs_cos, freqs_sin),
|
670 |
+
n_tokens=n_tokens,
|
671 |
+
data_type='video' if target_length > 1 else 'image',
|
672 |
+
is_progress_bar=True,
|
673 |
+
vae_ver=self.args.vae,
|
674 |
+
enable_tiling=self.args.vae_tiling,
|
675 |
+
cpu_offload=cpu_offload,
|
676 |
+
return_latents=return_latents,
|
677 |
+
use_sage=use_sage,
|
678 |
+
)
|
679 |
+
if samples is None:
|
680 |
+
return None
|
681 |
+
out_dict['samples'] = []
|
682 |
+
out_dict["prompts"] = prompt
|
683 |
+
out_dict['pose'] = poses
|
684 |
+
|
685 |
+
if return_latents:
|
686 |
+
print("return_latents | TRUE")
|
687 |
+
latents, timesteps, last_latents, ref_latents = samples[1], samples[2], samples[3], samples[4]
|
688 |
+
# samples = samples[0][0]
|
689 |
+
if samples[0] is not None and len(samples[0]) > 0:
|
690 |
+
samples = samples[0][0]
|
691 |
+
else:
|
692 |
+
samples = None
|
693 |
+
out_dict["denoised_lantents"] = latents
|
694 |
+
out_dict["timesteps"] = timesteps
|
695 |
+
out_dict["ref_latents"] = ref_latents
|
696 |
+
out_dict["last_latents"] = last_latents
|
697 |
+
|
698 |
+
else:
|
699 |
+
samples = samples[0]
|
700 |
+
|
701 |
+
if samples is not None:
|
702 |
+
for i, sample in enumerate(samples):
|
703 |
+
sample = samples[i].unsqueeze(0)
|
704 |
+
sub_samples = []
|
705 |
+
sub_samples.append(sample)
|
706 |
+
sample_num = len(sub_samples)
|
707 |
+
sub_samples = torch.concat(sub_samples)
|
708 |
+
# only save in tp rank 0
|
709 |
+
out_dict['samples'].append(sub_samples)
|
710 |
+
|
711 |
+
# visualize pose
|
712 |
+
|
713 |
+
gen_time = time.time() - start_time
|
714 |
+
logger.info(f"Success, time: {gen_time}")
|
715 |
+
return out_dict
|
716 |
+
|
hymm_sp/text_encoder/__init__.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import torch, os
|
6 |
+
import torch.nn as nn
|
7 |
+
from transformers import (
|
8 |
+
CLIPTextModel, CLIPTokenizer, LlavaForConditionalGeneration,LlamaModel,
|
9 |
+
LlamaTokenizerFast
|
10 |
+
)
|
11 |
+
from transformers.utils import ModelOutput
|
12 |
+
from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH, PRECISION_TO_TYPE
|
13 |
+
|
14 |
+
CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
|
15 |
+
print(f'text_encoder: cpu_offload={CPU_OFFLOAD}')
|
16 |
+
|
17 |
+
def use_default(value, default):
|
18 |
+
return value if value is not None else default
|
19 |
+
|
20 |
+
def load_text_encoder(text_encoder_type,
|
21 |
+
text_encoder_precision=None,
|
22 |
+
text_encoder_path=None,
|
23 |
+
logger=None,
|
24 |
+
device=None
|
25 |
+
):
|
26 |
+
if text_encoder_path is None:
|
27 |
+
text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
|
28 |
+
if logger is not None:
|
29 |
+
logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
|
30 |
+
|
31 |
+
if text_encoder_type == "clipL":
|
32 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
|
33 |
+
text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
|
34 |
+
elif text_encoder_type == "llava-llama-3-8b":
|
35 |
+
text_encoder = LlavaForConditionalGeneration.from_pretrained(text_encoder_path, low_cpu_mem_usage=True)
|
36 |
+
import transformers
|
37 |
+
transformers_version = transformers.__version__
|
38 |
+
if transformers_version >= "4.53.0":
|
39 |
+
text_encoder.final_layer_norm = text_encoder.language_model.norm
|
40 |
+
else:
|
41 |
+
text_encoder.final_layer_norm = text_encoder.language_model.model.norm
|
42 |
+
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
45 |
+
|
46 |
+
if text_encoder_precision is not None:
|
47 |
+
text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
|
48 |
+
|
49 |
+
text_encoder.requires_grad_(False)
|
50 |
+
|
51 |
+
if logger is not None:
|
52 |
+
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
|
53 |
+
|
54 |
+
if device is not None:
|
55 |
+
text_encoder = text_encoder.to(device)
|
56 |
+
|
57 |
+
return text_encoder, text_encoder_path
|
58 |
+
|
59 |
+
def load_tokenizer(tokenizer_type,
|
60 |
+
tokenizer_path=None,
|
61 |
+
padding_side="right",
|
62 |
+
logger=None
|
63 |
+
):
|
64 |
+
if tokenizer_path is None:
|
65 |
+
tokenizer_path = TOKENIZER_PATH[tokenizer_type]
|
66 |
+
if logger is not None:
|
67 |
+
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
|
68 |
+
|
69 |
+
if tokenizer_type == "clipL":
|
70 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
|
71 |
+
elif tokenizer_type == "llava-llama-3-8b":
|
72 |
+
tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path, padding_side=padding_side)
|
73 |
+
else:
|
74 |
+
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
|
75 |
+
|
76 |
+
return tokenizer, tokenizer_path
|
77 |
+
|
78 |
+
|
79 |
+
@dataclass
|
80 |
+
class TextEncoderModelOutput(ModelOutput):
|
81 |
+
"""
|
82 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
86 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
87 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
88 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
89 |
+
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*,
|
90 |
+
returned when `output_hidden_states=True` is passed):
|
91 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
92 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
93 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
94 |
+
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
|
95 |
+
List of decoded texts.
|
96 |
+
"""
|
97 |
+
|
98 |
+
hidden_state: torch.FloatTensor = None
|
99 |
+
attention_mask: Optional[torch.LongTensor] = None
|
100 |
+
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
|
101 |
+
text_outputs: Optional[list] = None
|
102 |
+
|
103 |
+
|
104 |
+
class TextEncoder(nn.Module):
|
105 |
+
def __init__(self,
|
106 |
+
text_encoder_type: str,
|
107 |
+
max_length: int,
|
108 |
+
text_encoder_precision: Optional[str] = None,
|
109 |
+
text_encoder_path: Optional[str] = None,
|
110 |
+
tokenizer_type: Optional[str] = None,
|
111 |
+
tokenizer_path: Optional[str] = None,
|
112 |
+
output_key: Optional[str] = None,
|
113 |
+
use_attention_mask: bool = True,
|
114 |
+
input_max_length: Optional[int] = None,
|
115 |
+
prompt_template_video: Optional[dict] = None,
|
116 |
+
hidden_state_skip_layer: Optional[int] = None,
|
117 |
+
apply_final_norm: bool = False,
|
118 |
+
reproduce: bool = False,
|
119 |
+
logger=None,
|
120 |
+
device=None,
|
121 |
+
):
|
122 |
+
super().__init__()
|
123 |
+
self.text_encoder_type = text_encoder_type
|
124 |
+
self.max_length = max_length
|
125 |
+
self.precision = text_encoder_precision
|
126 |
+
self.model_path = text_encoder_path
|
127 |
+
self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
|
128 |
+
self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
|
129 |
+
self.use_attention_mask = use_attention_mask
|
130 |
+
if prompt_template_video is not None:
|
131 |
+
assert use_attention_mask is True, "Attention mask is True required when training videos."
|
132 |
+
self.input_max_length = input_max_length if input_max_length is not None else max_length
|
133 |
+
self.prompt_template_video = prompt_template_video
|
134 |
+
self.hidden_state_skip_layer = hidden_state_skip_layer
|
135 |
+
self.apply_final_norm = apply_final_norm
|
136 |
+
self.reproduce = reproduce
|
137 |
+
self.logger = logger
|
138 |
+
|
139 |
+
self.use_video_template = self.prompt_template_video is not None
|
140 |
+
if self.use_video_template:
|
141 |
+
if self.prompt_template_video is not None:
|
142 |
+
assert isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video, (
|
143 |
+
f"`prompt_template_video` must be a dictionary with a key 'template', \
|
144 |
+
got {self.prompt_template_video}"
|
145 |
+
)
|
146 |
+
assert '{}' in str(self.prompt_template_video["template"]), (
|
147 |
+
"`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
|
148 |
+
f"got {self.prompt_template_video['template']}"
|
149 |
+
)
|
150 |
+
|
151 |
+
if "clip" in text_encoder_type:
|
152 |
+
self.output_key = output_key or "pooler_output"
|
153 |
+
elif "llama" in text_encoder_type:
|
154 |
+
self.output_key = output_key or "last_hidden_state"
|
155 |
+
else:
|
156 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
157 |
+
|
158 |
+
self.model, self.model_path = load_text_encoder(
|
159 |
+
text_encoder_type=self.text_encoder_type,
|
160 |
+
text_encoder_precision=self.precision,
|
161 |
+
text_encoder_path=self.model_path,
|
162 |
+
logger=self.logger,
|
163 |
+
device=device
|
164 |
+
)
|
165 |
+
self.dtype = self.model.dtype
|
166 |
+
self.device = self.model.device
|
167 |
+
|
168 |
+
self.tokenizer, self.tokenizer_path = load_tokenizer(
|
169 |
+
tokenizer_type=self.tokenizer_type,
|
170 |
+
tokenizer_path=self.tokenizer_path,
|
171 |
+
padding_side="right",
|
172 |
+
logger=self.logger
|
173 |
+
)
|
174 |
+
|
175 |
+
def __repr__(self):
|
176 |
+
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
|
177 |
+
|
178 |
+
@staticmethod
|
179 |
+
def apply_text_to_template(text, template):
|
180 |
+
"""
|
181 |
+
Apply text to template.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
text (str): Input text.
|
185 |
+
template (str or list): Template string or list of chat conversation.
|
186 |
+
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
|
187 |
+
by adding a space. Defaults to True.
|
188 |
+
"""
|
189 |
+
if isinstance(template, str):
|
190 |
+
# Will send string to tokenizer. Used for llm
|
191 |
+
return template.format(text)
|
192 |
+
else:
|
193 |
+
raise TypeError(f"Unsupported template type: {type(template)}")
|
194 |
+
|
195 |
+
def text2tokens(self, text, data_type='video', name='person'):
|
196 |
+
"""
|
197 |
+
Tokenize the input text.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
text (str or list): Input text.
|
201 |
+
"""
|
202 |
+
tokenize_input_type = 'str'
|
203 |
+
if self.use_video_template:
|
204 |
+
if data_type == 'video':
|
205 |
+
prompt_template = self.prompt_template_video["template"]
|
206 |
+
else:
|
207 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
208 |
+
if isinstance(text, (list, tuple)):
|
209 |
+
text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
|
210 |
+
if isinstance(text[0], list):
|
211 |
+
tokenize_input_type = 'list'
|
212 |
+
elif isinstance(text, str):
|
213 |
+
text = self.apply_text_to_template(text, prompt_template)
|
214 |
+
if isinstance(text, list):
|
215 |
+
tokenize_input_type = 'list'
|
216 |
+
else:
|
217 |
+
raise TypeError(f"Unsupported text type: {type(text)}")
|
218 |
+
|
219 |
+
kwargs = dict(truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
|
220 |
+
if self.text_encoder_type == "llava-llama-3-8b":
|
221 |
+
if isinstance(text, list):
|
222 |
+
for i in range(len(text)):
|
223 |
+
text[i] = text[i] + '\nThe %s looks like<image>' % name
|
224 |
+
elif isinstance(text, str):
|
225 |
+
text = text + '\nThe %s looks like<image>' % name
|
226 |
+
else:
|
227 |
+
raise NotImplementedError
|
228 |
+
|
229 |
+
if tokenize_input_type == 'str':
|
230 |
+
return self.tokenizer(text,
|
231 |
+
return_length=False,
|
232 |
+
return_overflowing_tokens=False,
|
233 |
+
return_attention_mask=True,
|
234 |
+
**kwargs, )
|
235 |
+
elif tokenize_input_type == 'list':
|
236 |
+
return self.tokenizer.apply_chat_template(text,
|
237 |
+
add_generation_prompt=True,
|
238 |
+
tokenize=True,
|
239 |
+
return_dict=True,
|
240 |
+
**kwargs, )
|
241 |
+
else:
|
242 |
+
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
|
243 |
+
|
244 |
+
def encode(self, batch_encoding, use_attention_mask=None, output_hidden_states=False, do_sample=None,
|
245 |
+
hidden_state_skip_layer=None, return_texts=False, data_type='image'):
|
246 |
+
"""
|
247 |
+
Args:
|
248 |
+
batch_encoding (dict): Batch encoding from tokenizer.
|
249 |
+
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
|
250 |
+
Defaults to None.
|
251 |
+
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
|
252 |
+
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
|
253 |
+
output_hidden_states will be set True. Defaults to False.
|
254 |
+
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
|
255 |
+
When self.produce is False, do_sample is set to True by default.
|
256 |
+
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
|
257 |
+
If None, self.output_key will be used. Defaults to None.
|
258 |
+
return_texts (bool): Whether to return the decoded texts. Defaults to False.
|
259 |
+
"""
|
260 |
+
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
|
261 |
+
hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
|
262 |
+
do_sample = use_default(do_sample, not self.reproduce)
|
263 |
+
if CPU_OFFLOAD:
|
264 |
+
self.model.to('cuda')
|
265 |
+
print(f'encode prompt: move text_encoder to cuda')
|
266 |
+
|
267 |
+
attention_mask = batch_encoding["attention_mask"].to(self.model.device) if use_attention_mask else None
|
268 |
+
if 'pixel_value_llava' in batch_encoding:
|
269 |
+
outputs = self.model(
|
270 |
+
input_ids=batch_encoding["input_ids"].to(self.model.device),
|
271 |
+
attention_mask=attention_mask,
|
272 |
+
pixel_values=batch_encoding["pixel_value_llava"].to(self.model.device),
|
273 |
+
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None)
|
274 |
+
else:
|
275 |
+
outputs = self.model(
|
276 |
+
input_ids=batch_encoding["input_ids"].to(self.model.device),
|
277 |
+
attention_mask=attention_mask,
|
278 |
+
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,)
|
279 |
+
if hidden_state_skip_layer is not None:
|
280 |
+
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
281 |
+
# Real last hidden state already has layer norm applied. So here we only apply it
|
282 |
+
# for intermediate layers.
|
283 |
+
if hidden_state_skip_layer > 0 and self.apply_final_norm:
|
284 |
+
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
|
285 |
+
else:
|
286 |
+
last_hidden_state = outputs[self.output_key]
|
287 |
+
|
288 |
+
# Remove hidden states of instruction tokens, only keep prompt tokens.
|
289 |
+
if self.use_video_template:
|
290 |
+
if data_type == 'video':
|
291 |
+
crop_start = self.prompt_template_video.get("crop_start", -1)
|
292 |
+
else:
|
293 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
294 |
+
if crop_start > 0:
|
295 |
+
last_hidden_state = last_hidden_state[:, crop_start:]
|
296 |
+
attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
|
297 |
+
if CPU_OFFLOAD:
|
298 |
+
self.model.to('cpu')
|
299 |
+
torch.cuda.empty_cache()
|
300 |
+
print(f'encode prompt successful: move text_encoder to cpu')
|
301 |
+
if output_hidden_states:
|
302 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
|
303 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask)
|
304 |
+
|
305 |
+
def forward(self, text, use_attention_mask=None, output_hidden_states=False, do_sample=False,
|
306 |
+
hidden_state_skip_layer=None, return_texts=False):
|
307 |
+
batch_encoding = self.text2tokens(text)
|
308 |
+
return self.encode(batch_encoding, use_attention_mask=use_attention_mask,
|
309 |
+
output_hidden_states=output_hidden_states, do_sample=do_sample,
|
310 |
+
hidden_state_skip_layer=hidden_state_skip_layer, return_texts=return_texts)
|
hymm_sp/vae/__init__.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
4 |
+
from ..constants import VAE_PATH, PRECISION_TO_TYPE
|
5 |
+
|
6 |
+
def load_vae(vae_type,
|
7 |
+
vae_precision=None,
|
8 |
+
sample_size=None,
|
9 |
+
vae_path=None,
|
10 |
+
logger=None,
|
11 |
+
device=None
|
12 |
+
):
|
13 |
+
"""
|
14 |
+
Load and configure a Variational Autoencoder (VAE) model.
|
15 |
+
|
16 |
+
This function handles loading 3D causal VAE models, including configuration,
|
17 |
+
weight loading, precision setting, and device placement. It ensures the model
|
18 |
+
is properly initialized for inference.
|
19 |
+
|
20 |
+
Parameters:
|
21 |
+
vae_type (str): Type identifier for the VAE, must follow '???-*' format for 3D VAEs
|
22 |
+
vae_precision (str, optional): Desired precision type (e.g., 'fp16', 'fp32').
|
23 |
+
Uses model's default if not specified.
|
24 |
+
sample_size (tuple, optional): Input sample dimensions to override config defaults
|
25 |
+
vae_path (str, optional): Path to VAE model files. Uses predefined path from
|
26 |
+
VAE_PATH constant if not specified.
|
27 |
+
logger (logging.Logger, optional): Logger instance for progress/debug messages
|
28 |
+
device (torch.device, optional): Target device to place the model (e.g., 'cuda' or 'cpu')
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
tuple: Contains:
|
32 |
+
- vae (AutoencoderKLCausal3D): Loaded and configured VAE model
|
33 |
+
- vae_path (str): Actual path used to load the VAE
|
34 |
+
- spatial_compression_ratio (int): Spatial dimension compression factor
|
35 |
+
- time_compression_ratio (int): Temporal dimension compression factor
|
36 |
+
|
37 |
+
Raises:
|
38 |
+
ValueError: If vae_type does not follow the required 3D VAE format '???-*'
|
39 |
+
"""
|
40 |
+
if vae_path is None:
|
41 |
+
vae_path = VAE_PATH[vae_type]
|
42 |
+
vae_compress_spec, _, _ = vae_type.split("-")
|
43 |
+
length = len(vae_compress_spec)
|
44 |
+
# Process 3D VAE (valid format with 3-character compression spec)
|
45 |
+
if length == 3:
|
46 |
+
if logger is not None:
|
47 |
+
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
|
48 |
+
config = AutoencoderKLCausal3D.load_config(vae_path)
|
49 |
+
if sample_size:
|
50 |
+
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
|
51 |
+
else:
|
52 |
+
vae = AutoencoderKLCausal3D.from_config(config)
|
53 |
+
ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device)
|
54 |
+
if "state_dict" in ckpt:
|
55 |
+
ckpt = ckpt["state_dict"]
|
56 |
+
vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
|
57 |
+
vae.load_state_dict(vae_ckpt)
|
58 |
+
|
59 |
+
spatial_compression_ratio = vae.config.spatial_compression_ratio
|
60 |
+
time_compression_ratio = vae.config.time_compression_ratio
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.")
|
63 |
+
|
64 |
+
if vae_precision is not None:
|
65 |
+
vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
|
66 |
+
|
67 |
+
vae.requires_grad_(False)
|
68 |
+
|
69 |
+
if logger is not None:
|
70 |
+
logger.info(f"VAE to dtype: {vae.dtype}")
|
71 |
+
|
72 |
+
if device is not None:
|
73 |
+
vae = vae.to(device)
|
74 |
+
|
75 |
+
# Ensure model is in evaluation mode (disables dropout/batch norm training behavior)
|
76 |
+
# Note: Even with dropout rate 0, eval mode is recommended for consistent inference
|
77 |
+
vae.eval()
|
78 |
+
|
79 |
+
return vae, vae_path, spatial_compression_ratio, time_compression_ratio
|
hymm_sp/vae/autoencoder_kl_causal_3d.py
ADDED
@@ -0,0 +1,781 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
from typing import Dict, Optional, Tuple, Union
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from torch import distributed as dist
|
6 |
+
import loguru
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.distributed
|
10 |
+
|
11 |
+
from torch import distributed as dist
|
12 |
+
|
13 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
14 |
+
try:
|
15 |
+
# This diffusers is modified and packed in the mirror.
|
16 |
+
from diffusers.loaders import FromOriginalVAEMixin
|
17 |
+
except ImportError:
|
18 |
+
# Use this to be compatible with the original diffusers.
|
19 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
|
20 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
21 |
+
from diffusers.models.attention_processor import (
|
22 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
23 |
+
CROSS_ATTENTION_PROCESSORS,
|
24 |
+
Attention,
|
25 |
+
AttentionProcessor,
|
26 |
+
AttnAddedKVProcessor,
|
27 |
+
AttnProcessor,
|
28 |
+
)
|
29 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
30 |
+
from diffusers.models.modeling_utils import ModelMixin
|
31 |
+
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
|
32 |
+
|
33 |
+
import threading
|
34 |
+
from hymm_sp.modules.parallel_states import (
|
35 |
+
initialize_sequence_parallel_state,
|
36 |
+
get_sequence_parallel_state,
|
37 |
+
nccl_info,
|
38 |
+
)
|
39 |
+
|
40 |
+
def cur_rank():
|
41 |
+
return nccl_info.rank_within_group
|
42 |
+
|
43 |
+
def cur_world_size():
|
44 |
+
return nccl_info.sp_size
|
45 |
+
|
46 |
+
"""
|
47 |
+
use trt need install polygraphy and onnx-graphsurgeon
|
48 |
+
python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
|
49 |
+
"""
|
50 |
+
try:
|
51 |
+
from polygraphy.backend.trt import ( TrtRunner, EngineFromBytes)
|
52 |
+
from polygraphy.backend.common import BytesFromPath
|
53 |
+
except:
|
54 |
+
print("TrtRunner or EngineFromBytes is not available, you can not use trt engine")
|
55 |
+
|
56 |
+
@dataclass
|
57 |
+
class DecoderOutput2(BaseOutput):
|
58 |
+
sample: torch.FloatTensor
|
59 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
60 |
+
|
61 |
+
|
62 |
+
MODEL_OUTPUT_PATH = os.environ.get('MODEL_OUTPUT_PATH')
|
63 |
+
MODEL_BASE = os.environ.get('MODEL_BASE')
|
64 |
+
|
65 |
+
CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
|
66 |
+
DISABLE_SP = int(os.environ.get("DISABLE_SP", 0))
|
67 |
+
|
68 |
+
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
69 |
+
r"""
|
70 |
+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
71 |
+
|
72 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
73 |
+
for all models (such as downloading or saving).
|
74 |
+
|
75 |
+
Parameters:
|
76 |
+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
77 |
+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
78 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
79 |
+
Tuple of downsample block types.
|
80 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
81 |
+
Tuple of upsample block types.
|
82 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
83 |
+
Tuple of block output channels.
|
84 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
85 |
+
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
86 |
+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
87 |
+
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
88 |
+
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
89 |
+
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
90 |
+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
91 |
+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
92 |
+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
93 |
+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
94 |
+
force_upcast (`bool`, *optional*, default to `True`):
|
95 |
+
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
|
96 |
+
can be fine-tuned / trained to a lower range without loosing too much precision in which case
|
97 |
+
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
|
98 |
+
"""
|
99 |
+
|
100 |
+
_supports_gradient_checkpointing = True
|
101 |
+
|
102 |
+
@register_to_config
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
in_channels: int = 3,
|
106 |
+
out_channels: int = 3,
|
107 |
+
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
|
108 |
+
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
|
109 |
+
block_out_channels: Tuple[int] = (64,),
|
110 |
+
layers_per_block: int = 1,
|
111 |
+
act_fn: str = "silu",
|
112 |
+
latent_channels: int = 4,
|
113 |
+
norm_num_groups: int = 32,
|
114 |
+
sample_size: int = 32,
|
115 |
+
sample_tsize: int = 64,
|
116 |
+
scaling_factor: float = 0.18215,
|
117 |
+
force_upcast: float = True,
|
118 |
+
spatial_compression_ratio: int = 8,
|
119 |
+
time_compression_ratio: int = 4,
|
120 |
+
disable_causal_conv: bool = False,
|
121 |
+
mid_block_add_attention: bool = True,
|
122 |
+
mid_block_causal_attn: bool = False,
|
123 |
+
use_trt_engine: bool = False,
|
124 |
+
nccl_gather: bool = True,
|
125 |
+
engine_path: str = f"{MODEL_BASE}/HYVAE_decoder+conv_256x256xT_fp16_H20.engine",
|
126 |
+
):
|
127 |
+
super().__init__()
|
128 |
+
|
129 |
+
self.disable_causal_conv = disable_causal_conv
|
130 |
+
self.time_compression_ratio = time_compression_ratio
|
131 |
+
|
132 |
+
self.encoder = EncoderCausal3D(
|
133 |
+
in_channels=in_channels,
|
134 |
+
out_channels=latent_channels,
|
135 |
+
down_block_types=down_block_types,
|
136 |
+
block_out_channels=block_out_channels,
|
137 |
+
layers_per_block=layers_per_block,
|
138 |
+
act_fn=act_fn,
|
139 |
+
norm_num_groups=norm_num_groups,
|
140 |
+
double_z=True,
|
141 |
+
time_compression_ratio=time_compression_ratio,
|
142 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
143 |
+
disable_causal=disable_causal_conv,
|
144 |
+
mid_block_add_attention=mid_block_add_attention,
|
145 |
+
mid_block_causal_attn=mid_block_causal_attn,
|
146 |
+
)
|
147 |
+
|
148 |
+
self.decoder = DecoderCausal3D(
|
149 |
+
in_channels=latent_channels,
|
150 |
+
out_channels=out_channels,
|
151 |
+
up_block_types=up_block_types,
|
152 |
+
block_out_channels=block_out_channels,
|
153 |
+
layers_per_block=layers_per_block,
|
154 |
+
norm_num_groups=norm_num_groups,
|
155 |
+
act_fn=act_fn,
|
156 |
+
time_compression_ratio=time_compression_ratio,
|
157 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
158 |
+
disable_causal=disable_causal_conv,
|
159 |
+
mid_block_add_attention=mid_block_add_attention,
|
160 |
+
mid_block_causal_attn=mid_block_causal_attn,
|
161 |
+
)
|
162 |
+
|
163 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
164 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
165 |
+
|
166 |
+
self.use_slicing = False
|
167 |
+
self.use_spatial_tiling = False
|
168 |
+
self.use_temporal_tiling = False
|
169 |
+
|
170 |
+
|
171 |
+
# only relevant if vae tiling is enabled
|
172 |
+
self.tile_sample_min_tsize = sample_tsize
|
173 |
+
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
|
174 |
+
|
175 |
+
self.tile_sample_min_size = self.config.sample_size
|
176 |
+
sample_size = (
|
177 |
+
self.config.sample_size[0]
|
178 |
+
if isinstance(self.config.sample_size, (list, tuple))
|
179 |
+
else self.config.sample_size
|
180 |
+
)
|
181 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
182 |
+
self.tile_overlap_factor = 0.25
|
183 |
+
|
184 |
+
# ============= parallism related code ===================
|
185 |
+
world_size = cur_world_size()
|
186 |
+
self.parallel_decode = False if CPU_OFFLOAD else get_sequence_parallel_state()
|
187 |
+
print("WORLD SIZE: ", world_size)
|
188 |
+
|
189 |
+
|
190 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
191 |
+
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
|
192 |
+
module.gradient_checkpointing = value
|
193 |
+
|
194 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
195 |
+
self.use_temporal_tiling = use_tiling
|
196 |
+
|
197 |
+
def disable_temporal_tiling(self):
|
198 |
+
self.enable_temporal_tiling(False)
|
199 |
+
|
200 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
201 |
+
self.use_spatial_tiling = use_tiling
|
202 |
+
|
203 |
+
def disable_spatial_tiling(self):
|
204 |
+
self.enable_spatial_tiling(False)
|
205 |
+
|
206 |
+
def enable_tiling(self, use_tiling: bool = True):
|
207 |
+
r"""
|
208 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
209 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
210 |
+
processing larger images.
|
211 |
+
"""
|
212 |
+
self.enable_spatial_tiling(use_tiling)
|
213 |
+
self.enable_temporal_tiling(use_tiling)
|
214 |
+
|
215 |
+
def disable_tiling(self):
|
216 |
+
r"""
|
217 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
218 |
+
decoding in one step.
|
219 |
+
"""
|
220 |
+
self.disable_spatial_tiling()
|
221 |
+
self.disable_temporal_tiling()
|
222 |
+
|
223 |
+
def enable_slicing(self):
|
224 |
+
r"""
|
225 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
226 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
227 |
+
"""
|
228 |
+
self.use_slicing = True
|
229 |
+
|
230 |
+
def disable_slicing(self):
|
231 |
+
r"""
|
232 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
233 |
+
decoding in one step.
|
234 |
+
"""
|
235 |
+
self.use_slicing = False
|
236 |
+
|
237 |
+
|
238 |
+
def load_trt_decoder(self):
|
239 |
+
self.use_trt_decoder = True
|
240 |
+
self.engine = EngineFromBytes(BytesFromPath(self.engine_path))
|
241 |
+
|
242 |
+
self.trt_decoder_runner = TrtRunner(self.engine)
|
243 |
+
self.activate_trt_decoder()
|
244 |
+
|
245 |
+
def disable_trt_decoder(self):
|
246 |
+
self.use_trt_decoder = False
|
247 |
+
del self.engine
|
248 |
+
|
249 |
+
def activate_trt_decoder(self):
|
250 |
+
self.trt_decoder_runner.activate()
|
251 |
+
|
252 |
+
def deactivate_trt_decoder(self):
|
253 |
+
self.trt_decoder_runner.deactivate()
|
254 |
+
|
255 |
+
@property
|
256 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
257 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
258 |
+
r"""
|
259 |
+
Returns:
|
260 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
261 |
+
indexed by its weight name.
|
262 |
+
"""
|
263 |
+
# set recursively
|
264 |
+
processors = {}
|
265 |
+
|
266 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
267 |
+
if hasattr(module, "get_processor"):
|
268 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
269 |
+
|
270 |
+
for sub_name, child in module.named_children():
|
271 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
272 |
+
|
273 |
+
return processors
|
274 |
+
|
275 |
+
for name, module in self.named_children():
|
276 |
+
fn_recursive_add_processors(name, module, processors)
|
277 |
+
|
278 |
+
return processors
|
279 |
+
|
280 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
281 |
+
def set_attn_processor(
|
282 |
+
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
283 |
+
):
|
284 |
+
r"""
|
285 |
+
Sets the attention processor to use to compute attention.
|
286 |
+
|
287 |
+
Parameters:
|
288 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
289 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
290 |
+
for **all** `Attention` layers.
|
291 |
+
|
292 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
293 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
294 |
+
|
295 |
+
"""
|
296 |
+
count = len(self.attn_processors.keys())
|
297 |
+
|
298 |
+
if isinstance(processor, dict) and len(processor) != count:
|
299 |
+
raise ValueError(
|
300 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
301 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
302 |
+
)
|
303 |
+
|
304 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
305 |
+
if hasattr(module, "set_processor"):
|
306 |
+
if not isinstance(processor, dict):
|
307 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
308 |
+
else:
|
309 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
310 |
+
|
311 |
+
for sub_name, child in module.named_children():
|
312 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
313 |
+
|
314 |
+
for name, module in self.named_children():
|
315 |
+
fn_recursive_attn_processor(name, module, processor)
|
316 |
+
|
317 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
318 |
+
def set_default_attn_processor(self):
|
319 |
+
"""
|
320 |
+
Disables custom attention processors and sets the default attention implementation.
|
321 |
+
"""
|
322 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
323 |
+
processor = AttnAddedKVProcessor()
|
324 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
325 |
+
processor = AttnProcessor()
|
326 |
+
else:
|
327 |
+
raise ValueError(
|
328 |
+
f"Cannot call `set_default_attn_processor` \
|
329 |
+
when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
330 |
+
)
|
331 |
+
|
332 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
333 |
+
|
334 |
+
@apply_forward_hook
|
335 |
+
def encode(
|
336 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
337 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
338 |
+
"""
|
339 |
+
Encode a batch of images into latents.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
x (`torch.FloatTensor`): Input batch of images.
|
343 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
344 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
The latent representations of the encoded images. If `return_dict` is True, a
|
348 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
349 |
+
"""
|
350 |
+
assert len(x.shape) == 5, "The input tensor should have 5 dimensions"
|
351 |
+
|
352 |
+
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
|
353 |
+
return self.temporal_tiled_encode(x, return_dict=return_dict)
|
354 |
+
|
355 |
+
if self.use_spatial_tiling and \
|
356 |
+
(x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
357 |
+
return self.spatial_tiled_encode(x, return_dict=return_dict)
|
358 |
+
|
359 |
+
if self.use_slicing and x.shape[0] > 1:
|
360 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
361 |
+
h = torch.cat(encoded_slices)
|
362 |
+
else:
|
363 |
+
h = self.encoder(x)
|
364 |
+
|
365 |
+
moments = self.quant_conv(h)
|
366 |
+
posterior = DiagonalGaussianDistribution(moments)
|
367 |
+
|
368 |
+
if not return_dict:
|
369 |
+
return (posterior,)
|
370 |
+
|
371 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
372 |
+
|
373 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
374 |
+
assert len(z.shape) == 5, "The input tensor should have 5 dimensions"
|
375 |
+
|
376 |
+
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
|
377 |
+
return self.temporal_tiled_decode(z, return_dict=return_dict)
|
378 |
+
|
379 |
+
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or \
|
380 |
+
z.shape[-2] > self.tile_latent_min_size):
|
381 |
+
return self.spatial_tiled_decode(z, return_dict=return_dict)
|
382 |
+
|
383 |
+
if self.use_trt_decoder:
|
384 |
+
# For unknown reason, `copy_outputs_to_host` must be set to True
|
385 |
+
dec = self.trt_decoder_runner.infer({"input": z.to(RECOMMENDED_DTYPE).contiguous()}, \
|
386 |
+
copy_outputs_to_host=True)["output"].to(device=z.device, dtype=z.dtype)
|
387 |
+
else:
|
388 |
+
z = self.post_quant_conv(z)
|
389 |
+
dec = self.decoder(z)
|
390 |
+
|
391 |
+
if not return_dict:
|
392 |
+
return (dec,)
|
393 |
+
|
394 |
+
return DecoderOutput(sample=dec)
|
395 |
+
|
396 |
+
@apply_forward_hook
|
397 |
+
def decode(
|
398 |
+
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
399 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
400 |
+
"""
|
401 |
+
Decode a batch of images.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
405 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
406 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
407 |
+
|
408 |
+
Returns:
|
409 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
410 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
411 |
+
returned.
|
412 |
+
|
413 |
+
"""
|
414 |
+
|
415 |
+
if self.use_slicing and z.shape[0] > 1:
|
416 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
417 |
+
decoded = torch.cat(decoded_slices)
|
418 |
+
else:
|
419 |
+
decoded = self._decode(z).sample
|
420 |
+
|
421 |
+
if not return_dict:
|
422 |
+
return (decoded,)
|
423 |
+
|
424 |
+
return DecoderOutput(sample=decoded)
|
425 |
+
|
426 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
427 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
428 |
+
if blend_extent == 0:
|
429 |
+
return b
|
430 |
+
|
431 |
+
a_region = a[..., -blend_extent:, :]
|
432 |
+
b_region = b[..., :blend_extent, :]
|
433 |
+
|
434 |
+
weights = torch.arange(blend_extent, device=a.device, dtype=a.dtype) / blend_extent
|
435 |
+
weights = weights.view(1, 1, 1, blend_extent, 1)
|
436 |
+
|
437 |
+
blended = a_region * (1 - weights) + b_region * weights
|
438 |
+
|
439 |
+
b[..., :blend_extent, :] = blended
|
440 |
+
return b
|
441 |
+
|
442 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
443 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
444 |
+
if blend_extent == 0:
|
445 |
+
return b
|
446 |
+
|
447 |
+
a_region = a[..., -blend_extent:]
|
448 |
+
b_region = b[..., :blend_extent]
|
449 |
+
|
450 |
+
weights = torch.arange(blend_extent, device=a.device, dtype=a.dtype) / blend_extent
|
451 |
+
weights = weights.view(1, 1, 1, 1, blend_extent)
|
452 |
+
|
453 |
+
blended = a_region * (1 - weights) + b_region * weights
|
454 |
+
|
455 |
+
b[..., :blend_extent] = blended
|
456 |
+
return b
|
457 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
458 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
459 |
+
if blend_extent == 0:
|
460 |
+
return b
|
461 |
+
|
462 |
+
a_region = a[..., -blend_extent:, :, :]
|
463 |
+
b_region = b[..., :blend_extent, :, :]
|
464 |
+
|
465 |
+
weights = torch.arange(blend_extent, device=a.device, dtype=a.dtype) / blend_extent
|
466 |
+
weights = weights.view(1, 1, blend_extent, 1, 1)
|
467 |
+
|
468 |
+
blended = a_region * (1 - weights) + b_region * weights
|
469 |
+
|
470 |
+
b[..., :blend_extent, :, :] = blended
|
471 |
+
return b
|
472 |
+
|
473 |
+
def spatial_tiled_encode(self,
|
474 |
+
x: torch.FloatTensor,
|
475 |
+
return_dict: bool = True,
|
476 |
+
return_moments: bool = False) -> AutoencoderKLOutput:
|
477 |
+
r"""Encode a batch of images using a tiled encoder.
|
478 |
+
|
479 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
480 |
+
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
481 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
482 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
483 |
+
output, but they should be much less noticeable.
|
484 |
+
|
485 |
+
Args:
|
486 |
+
x (`torch.FloatTensor`): Input batch of images.
|
487 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
488 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
492 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
493 |
+
`tuple` is returned.
|
494 |
+
"""
|
495 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
496 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
497 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
498 |
+
|
499 |
+
# Split video into tiles and encode them separately.
|
500 |
+
rows = []
|
501 |
+
for i in range(0, x.shape[-2], overlap_size):
|
502 |
+
row = []
|
503 |
+
for j in range(0, x.shape[-1], overlap_size):
|
504 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
505 |
+
tile = self.encoder(tile)
|
506 |
+
tile = self.quant_conv(tile)
|
507 |
+
row.append(tile)
|
508 |
+
rows.append(row)
|
509 |
+
result_rows = []
|
510 |
+
for i, row in enumerate(rows):
|
511 |
+
result_row = []
|
512 |
+
for j, tile in enumerate(row):
|
513 |
+
# blend the above tile and the left tile
|
514 |
+
# to the current tile and add the current tile to the result row
|
515 |
+
if i > 0:
|
516 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
517 |
+
if j > 0:
|
518 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
519 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
520 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
521 |
+
|
522 |
+
moments = torch.cat(result_rows, dim=-2)
|
523 |
+
if return_moments:
|
524 |
+
return moments
|
525 |
+
|
526 |
+
posterior = DiagonalGaussianDistribution(moments)
|
527 |
+
if not return_dict:
|
528 |
+
return (posterior,)
|
529 |
+
|
530 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
531 |
+
|
532 |
+
|
533 |
+
def spatial_tiled_decode(self,
|
534 |
+
z: torch.FloatTensor,
|
535 |
+
return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
536 |
+
r"""
|
537 |
+
Decode a batch of images using a tiled decoder.
|
538 |
+
|
539 |
+
Args:
|
540 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
541 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
542 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
543 |
+
|
544 |
+
Returns:
|
545 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
546 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
547 |
+
returned.
|
548 |
+
"""
|
549 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
550 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
551 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
552 |
+
|
553 |
+
# Split z into overlapping tiles and decode them separately.
|
554 |
+
# The tiles have an overlap to avoid seams between tiles.
|
555 |
+
rank = cur_rank()
|
556 |
+
rows = []
|
557 |
+
if self.parallel_decode and rank == 0:
|
558 |
+
rank = cur_rank()
|
559 |
+
#torch.cuda.set_device(rank) # set device for trt_runner
|
560 |
+
world_size = cur_world_size()
|
561 |
+
|
562 |
+
|
563 |
+
cur_device_id = 0
|
564 |
+
device_tasks = []
|
565 |
+
for i in range(world_size):
|
566 |
+
device_tasks.append([])
|
567 |
+
for i in range(0, z.shape[-2], overlap_size):
|
568 |
+
row = []
|
569 |
+
for j in range(0, z.shape[-1], overlap_size):
|
570 |
+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
571 |
+
row.append(None)
|
572 |
+
device_tasks[cur_device_id].append((i // overlap_size, \
|
573 |
+
j // overlap_size, \
|
574 |
+
tile.to("cuda:" + str(cur_device_id))))
|
575 |
+
#device_tasks[cur_device_id].append((i // overlap_size, j // overlap_size, tile))
|
576 |
+
cur_device_id = (cur_device_id + 1) % world_size
|
577 |
+
rows.append(row)
|
578 |
+
|
579 |
+
def thread_run(decoder, device_id, inputs, outputs):
|
580 |
+
for input in inputs:
|
581 |
+
cur_vae = self.device_vaes[device_id]
|
582 |
+
ret = cur_vae.decoder(cur_vae.post_quant_conv(input[2]))
|
583 |
+
outputs[input[0]][input[1]] = ret
|
584 |
+
return
|
585 |
+
|
586 |
+
threads = []
|
587 |
+
for i in range(world_size):
|
588 |
+
cur_thread = threading.Thread(target=thread_run,
|
589 |
+
args=(self, i, device_tasks[i], rows),
|
590 |
+
name="DecoderThread-" + str(i))
|
591 |
+
threads.append(cur_thread)
|
592 |
+
cur_thread.start()
|
593 |
+
|
594 |
+
for cur_thread in threads:
|
595 |
+
cur_thread.join()
|
596 |
+
|
597 |
+
for i in range(len(rows)):
|
598 |
+
for j in range(len(rows[i])):
|
599 |
+
rows[i][j] = rows[i][j].to("cuda:0")
|
600 |
+
|
601 |
+
else:
|
602 |
+
for i in range(0, z.shape[-2], overlap_size):
|
603 |
+
row = []
|
604 |
+
for j in range(0, z.shape[-1], overlap_size):
|
605 |
+
tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
|
606 |
+
tile = self.post_quant_conv(tile)
|
607 |
+
decoded = self.decoder(tile)
|
608 |
+
row.append(decoded)
|
609 |
+
rows.append(row)
|
610 |
+
|
611 |
+
result_rows = []
|
612 |
+
for i, row in enumerate(rows):
|
613 |
+
result_row = []
|
614 |
+
for j, tile in enumerate(row):
|
615 |
+
# blend the above tile and the left tile
|
616 |
+
# to the current tile and add the current tile to the result row
|
617 |
+
if i > 0:
|
618 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
619 |
+
if j > 0:
|
620 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
621 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
622 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
623 |
+
|
624 |
+
if self.parallel_decode and rank != 0:
|
625 |
+
if not return_dict:
|
626 |
+
return (None,)
|
627 |
+
return DecoderOutput(sample=None)
|
628 |
+
|
629 |
+
dec = torch.cat(result_rows, dim=-2)
|
630 |
+
if not return_dict:
|
631 |
+
return (dec,)
|
632 |
+
|
633 |
+
return DecoderOutput(sample=dec)
|
634 |
+
|
635 |
+
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
636 |
+
assert not self.disable_causal_conv, "Temporal tiling is only compatible with causal convolutions."
|
637 |
+
|
638 |
+
B, C, T, H, W = x.shape
|
639 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
|
640 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
|
641 |
+
t_limit = self.tile_latent_min_tsize - blend_extent
|
642 |
+
|
643 |
+
# Split the video into tiles and encode them separately.
|
644 |
+
row = []
|
645 |
+
for i in range(0, T, overlap_size):
|
646 |
+
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
|
647 |
+
if self.use_spatial_tiling and \
|
648 |
+
(tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
|
649 |
+
tile = self.spatial_tiled_encode(tile, return_moments=True)
|
650 |
+
else:
|
651 |
+
tile = self.encoder(tile)
|
652 |
+
tile = self.quant_conv(tile)
|
653 |
+
if i > 0:
|
654 |
+
tile = tile[:, :, 1:, :, :]
|
655 |
+
row.append(tile)
|
656 |
+
result_row = []
|
657 |
+
for i, tile in enumerate(row):
|
658 |
+
if i > 0:
|
659 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
660 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
661 |
+
else:
|
662 |
+
result_row.append(tile[:, :, :t_limit+1, :, :])
|
663 |
+
|
664 |
+
moments = torch.cat(result_row, dim=2)
|
665 |
+
posterior = DiagonalGaussianDistribution(moments)
|
666 |
+
|
667 |
+
if not return_dict:
|
668 |
+
return (posterior,)
|
669 |
+
|
670 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
671 |
+
|
672 |
+
def temporal_tiled_decode(self,
|
673 |
+
z: torch.FloatTensor,
|
674 |
+
return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
675 |
+
# Split z into overlapping tiles and decode them separately.
|
676 |
+
assert not self.disable_causal_conv, "Temporal tiling is only supported with causal convolutions."
|
677 |
+
|
678 |
+
B, C, T, H, W = z.shape
|
679 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
|
680 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
|
681 |
+
t_limit = self.tile_sample_min_tsize - blend_extent
|
682 |
+
rank = 0 if CPU_OFFLOAD or DISABLE_SP else cur_rank()
|
683 |
+
row = []
|
684 |
+
for i in range(0, T, overlap_size):
|
685 |
+
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
|
686 |
+
if self.use_spatial_tiling and \
|
687 |
+
(tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
|
688 |
+
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
|
689 |
+
else:
|
690 |
+
tile = self.post_quant_conv(tile)
|
691 |
+
decoded = self.decoder(tile)
|
692 |
+
if i > 0 and (not self.parallel_decode or rank == 0):
|
693 |
+
decoded = decoded[:, :, 1:, :, :]
|
694 |
+
row.append(decoded)
|
695 |
+
if not CPU_OFFLOAD and not DISABLE_SP and self.parallel_decode and rank != 0:
|
696 |
+
return DecoderOutput(sample=None)
|
697 |
+
result_row = []
|
698 |
+
for i, tile in enumerate(row):
|
699 |
+
if i > 0:
|
700 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
701 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
702 |
+
else:
|
703 |
+
result_row.append(tile[:, :, :t_limit+1, :, :])
|
704 |
+
|
705 |
+
dec = torch.cat(result_row, dim=2)
|
706 |
+
if not return_dict:
|
707 |
+
return (dec,)
|
708 |
+
|
709 |
+
return DecoderOutput(sample=dec)
|
710 |
+
|
711 |
+
def forward(
|
712 |
+
self,
|
713 |
+
sample: torch.FloatTensor,
|
714 |
+
sample_posterior: bool = False,
|
715 |
+
return_dict: bool = True,
|
716 |
+
return_posterior: bool = False,
|
717 |
+
generator: Optional[torch.Generator] = None,
|
718 |
+
) -> Union[DecoderOutput2, torch.FloatTensor]:
|
719 |
+
r"""
|
720 |
+
Args:
|
721 |
+
sample (`torch.FloatTensor`): Input sample.
|
722 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
723 |
+
Whether to sample from the posterior.
|
724 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
725 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
726 |
+
"""
|
727 |
+
x = sample
|
728 |
+
posterior = self.encode(x).latent_dist
|
729 |
+
if sample_posterior:
|
730 |
+
z = posterior.sample(generator=generator)
|
731 |
+
else:
|
732 |
+
z = posterior.mode()
|
733 |
+
dec = self.decode(z).sample
|
734 |
+
|
735 |
+
if not return_dict:
|
736 |
+
if return_posterior:
|
737 |
+
return (dec, posterior)
|
738 |
+
else:
|
739 |
+
return (dec,)
|
740 |
+
if return_posterior:
|
741 |
+
return DecoderOutput2(sample=dec, posterior=posterior)
|
742 |
+
else:
|
743 |
+
return DecoderOutput2(sample=dec)
|
744 |
+
|
745 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
746 |
+
def fuse_qkv_projections(self):
|
747 |
+
"""
|
748 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
749 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
750 |
+
|
751 |
+
<Tip warning={true}>
|
752 |
+
|
753 |
+
This API is 🧪 experimental.
|
754 |
+
|
755 |
+
</Tip>
|
756 |
+
"""
|
757 |
+
self.original_attn_processors = None
|
758 |
+
|
759 |
+
for _, attn_processor in self.attn_processors.items():
|
760 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
761 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
762 |
+
|
763 |
+
self.original_attn_processors = self.attn_processors
|
764 |
+
|
765 |
+
for module in self.modules():
|
766 |
+
if isinstance(module, Attention):
|
767 |
+
module.fuse_projections(fuse=True)
|
768 |
+
|
769 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
770 |
+
def unfuse_qkv_projections(self):
|
771 |
+
"""Disables the fused QKV projection if enabled.
|
772 |
+
|
773 |
+
<Tip warning={true}>
|
774 |
+
|
775 |
+
This API is 🧪 experimental.
|
776 |
+
|
777 |
+
</Tip>
|
778 |
+
|
779 |
+
"""
|
780 |
+
if self.original_attn_processors is not None:
|
781 |
+
self.set_attn_processor(self.original_attn_processors)
|
hymm_sp/vae/unet_causal_3d_blocks.py
ADDED
@@ -0,0 +1,900 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from torch import nn
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
from diffusers.utils import is_torch_version, logging
|
23 |
+
from diffusers.models.activations import get_activation
|
24 |
+
from diffusers.models.attention_processor import SpatialNorm
|
25 |
+
from diffusers.models.attention_processor import Attention
|
26 |
+
from diffusers.models.normalization import AdaGroupNorm
|
27 |
+
from diffusers.models.normalization import RMSNorm
|
28 |
+
|
29 |
+
|
30 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
31 |
+
|
32 |
+
|
33 |
+
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
|
34 |
+
seq_len = n_frame * n_hw
|
35 |
+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
36 |
+
for i in range(seq_len):
|
37 |
+
i_frame = i // n_hw
|
38 |
+
mask[i, : (i_frame + 1) * n_hw] = 0
|
39 |
+
if batch_size is not None:
|
40 |
+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
41 |
+
return mask
|
42 |
+
|
43 |
+
|
44 |
+
class CausalConv3d(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
chan_in,
|
48 |
+
chan_out,
|
49 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
50 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
51 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
52 |
+
pad_mode = 'replicate',
|
53 |
+
disable_causal=False,
|
54 |
+
**kwargs
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
|
58 |
+
self.pad_mode = pad_mode
|
59 |
+
if disable_causal:
|
60 |
+
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2,
|
61 |
+
kernel_size // 2, kernel_size // 2, kernel_size // 2)
|
62 |
+
else:
|
63 |
+
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2,
|
64 |
+
kernel_size // 2, kernel_size - 1, 0) # W, H, T
|
65 |
+
self.time_causal_padding = padding
|
66 |
+
|
67 |
+
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
71 |
+
return self.conv(x)
|
72 |
+
|
73 |
+
class CausalAvgPool3d(nn.Module):
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
77 |
+
stride: Union[int, Tuple[int, int, int]],
|
78 |
+
pad_mode = 'replicate',
|
79 |
+
disable_causal=False,
|
80 |
+
**kwargs
|
81 |
+
):
|
82 |
+
super().__init__()
|
83 |
+
|
84 |
+
self.pad_mode = pad_mode
|
85 |
+
if disable_causal:
|
86 |
+
padding = (0, 0, 0, 0, 0, 0)
|
87 |
+
else:
|
88 |
+
padding = (0, 0, 0, 0, stride - 1, 0) # W, H, T
|
89 |
+
self.time_causal_padding = padding
|
90 |
+
|
91 |
+
self.conv = nn.AvgPool3d(kernel_size, stride=stride, ceil_mode=True, **kwargs)
|
92 |
+
self.pad_mode = pad_mode
|
93 |
+
|
94 |
+
def forward(self, x):
|
95 |
+
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
96 |
+
return self.conv(x)
|
97 |
+
|
98 |
+
class UpsampleCausal3D(nn.Module):
|
99 |
+
"""A 3D upsampling layer with an optional convolution.
|
100 |
+
|
101 |
+
Parameters:
|
102 |
+
channels (`int`):
|
103 |
+
number of channels in the inputs and outputs.
|
104 |
+
use_conv (`bool`, default `False`):
|
105 |
+
option to use a convolution.
|
106 |
+
use_conv_transpose (`bool`, default `False`):
|
107 |
+
option to use a convolution transpose.
|
108 |
+
out_channels (`int`, optional):
|
109 |
+
number of output channels. Defaults to `channels`.
|
110 |
+
name (`str`, default `conv`):
|
111 |
+
name of the upsampling 3D layer.
|
112 |
+
"""
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
channels: int,
|
117 |
+
use_conv: bool = False,
|
118 |
+
use_conv_transpose: bool = False,
|
119 |
+
out_channels: Optional[int] = None,
|
120 |
+
name: str = "conv",
|
121 |
+
kernel_size: Optional[int] = None,
|
122 |
+
padding=1,
|
123 |
+
norm_type=None,
|
124 |
+
eps=None,
|
125 |
+
elementwise_affine=None,
|
126 |
+
bias=True,
|
127 |
+
interpolate=True,
|
128 |
+
upsample_factor=(2, 2, 2),
|
129 |
+
disable_causal=False,
|
130 |
+
):
|
131 |
+
super().__init__()
|
132 |
+
self.channels = channels
|
133 |
+
self.out_channels = out_channels or channels
|
134 |
+
self.use_conv = use_conv
|
135 |
+
self.use_conv_transpose = use_conv_transpose
|
136 |
+
self.name = name
|
137 |
+
self.interpolate = interpolate
|
138 |
+
self.upsample_factor = upsample_factor
|
139 |
+
self.disable_causal = disable_causal
|
140 |
+
|
141 |
+
if norm_type == "ln_norm":
|
142 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
143 |
+
elif norm_type == "rms_norm":
|
144 |
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
145 |
+
elif norm_type is None:
|
146 |
+
self.norm = None
|
147 |
+
else:
|
148 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
149 |
+
|
150 |
+
conv = None
|
151 |
+
if use_conv_transpose:
|
152 |
+
assert False, "Not Implement yet"
|
153 |
+
if kernel_size is None:
|
154 |
+
kernel_size = 4
|
155 |
+
conv = nn.ConvTranspose2d(
|
156 |
+
channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
|
157 |
+
)
|
158 |
+
elif use_conv:
|
159 |
+
if kernel_size is None:
|
160 |
+
kernel_size = 3
|
161 |
+
conv = CausalConv3d(self.channels, self.out_channels,
|
162 |
+
kernel_size=kernel_size, bias=bias, disable_causal=disable_causal)
|
163 |
+
|
164 |
+
if name == "conv":
|
165 |
+
self.conv = conv
|
166 |
+
else:
|
167 |
+
self.Conv2d_0 = conv
|
168 |
+
|
169 |
+
def forward(
|
170 |
+
self,
|
171 |
+
hidden_states: torch.FloatTensor,
|
172 |
+
output_size: Optional[int] = None,
|
173 |
+
scale: float = 1.0,
|
174 |
+
) -> torch.FloatTensor:
|
175 |
+
assert hidden_states.shape[1] == self.channels
|
176 |
+
|
177 |
+
if self.norm is not None:
|
178 |
+
assert False, "Not Implement yet"
|
179 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
180 |
+
|
181 |
+
if self.use_conv_transpose:
|
182 |
+
return self.conv(hidden_states)
|
183 |
+
|
184 |
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
185 |
+
# https://github.com/pytorch/pytorch/issues/86679
|
186 |
+
dtype = hidden_states.dtype
|
187 |
+
if dtype == torch.bfloat16:
|
188 |
+
hidden_states = hidden_states.to(torch.float32)
|
189 |
+
|
190 |
+
# upsample_nearest_nhwc fails with large batch sizes.
|
191 |
+
# see https://github.com/huggingface/diffusers/issues/984
|
192 |
+
if hidden_states.shape[0] >= 64:
|
193 |
+
hidden_states = hidden_states.contiguous()
|
194 |
+
|
195 |
+
# if `output_size` is passed we force the interpolation output
|
196 |
+
# size and do not make use of `scale_factor=2`
|
197 |
+
if self.interpolate:
|
198 |
+
B, C, T, H, W = hidden_states.shape
|
199 |
+
if not self.disable_causal:
|
200 |
+
first_h, other_h = hidden_states.split((1, T-1), dim=2)
|
201 |
+
if output_size is None:
|
202 |
+
if T > 1:
|
203 |
+
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
204 |
+
|
205 |
+
first_h = first_h.squeeze(2)
|
206 |
+
first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
|
207 |
+
first_h = first_h.unsqueeze(2)
|
208 |
+
else:
|
209 |
+
assert False, "Not Implement yet"
|
210 |
+
other_h = F.interpolate(other_h, size=output_size, mode="nearest")
|
211 |
+
|
212 |
+
if T > 1:
|
213 |
+
hidden_states = torch.cat((first_h, other_h), dim=2)
|
214 |
+
else:
|
215 |
+
hidden_states = first_h
|
216 |
+
else:
|
217 |
+
hidden_states = F.interpolate(hidden_states, scale_factor=self.upsample_factor, mode="nearest")
|
218 |
+
|
219 |
+
if dtype == torch.bfloat16:
|
220 |
+
hidden_states = hidden_states.to(dtype)
|
221 |
+
|
222 |
+
if self.use_conv:
|
223 |
+
if self.name == "conv":
|
224 |
+
hidden_states = self.conv(hidden_states)
|
225 |
+
else:
|
226 |
+
hidden_states = self.Conv2d_0(hidden_states)
|
227 |
+
|
228 |
+
return hidden_states
|
229 |
+
|
230 |
+
class DownsampleCausal3D(nn.Module):
|
231 |
+
"""A 3D downsampling layer with an optional convolution.
|
232 |
+
|
233 |
+
Parameters:
|
234 |
+
channels (`int`):
|
235 |
+
number of channels in the inputs and outputs.
|
236 |
+
use_conv (`bool`, default `False`):
|
237 |
+
option to use a convolution.
|
238 |
+
out_channels (`int`, optional):
|
239 |
+
number of output channels. Defaults to `channels`.
|
240 |
+
padding (`int`, default `1`):
|
241 |
+
padding for the convolution.
|
242 |
+
name (`str`, default `conv`):
|
243 |
+
name of the downsampling 3D layer.
|
244 |
+
"""
|
245 |
+
|
246 |
+
def __init__(
|
247 |
+
self,
|
248 |
+
channels: int,
|
249 |
+
use_conv: bool = False,
|
250 |
+
out_channels: Optional[int] = None,
|
251 |
+
padding: int = 1,
|
252 |
+
name: str = "conv",
|
253 |
+
kernel_size=3,
|
254 |
+
norm_type=None,
|
255 |
+
eps=None,
|
256 |
+
elementwise_affine=None,
|
257 |
+
bias=True,
|
258 |
+
stride=2,
|
259 |
+
disable_causal=False,
|
260 |
+
):
|
261 |
+
super().__init__()
|
262 |
+
self.channels = channels
|
263 |
+
self.out_channels = out_channels or channels
|
264 |
+
self.use_conv = use_conv
|
265 |
+
self.padding = padding
|
266 |
+
stride = stride
|
267 |
+
self.name = name
|
268 |
+
|
269 |
+
if norm_type == "ln_norm":
|
270 |
+
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
271 |
+
elif norm_type == "rms_norm":
|
272 |
+
self.norm = RMSNorm(channels, eps, elementwise_affine)
|
273 |
+
elif norm_type is None:
|
274 |
+
self.norm = None
|
275 |
+
else:
|
276 |
+
raise ValueError(f"unknown norm_type: {norm_type}")
|
277 |
+
|
278 |
+
if use_conv:
|
279 |
+
conv = CausalConv3d(
|
280 |
+
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride,
|
281 |
+
disable_causal=disable_causal, bias=bias
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
raise NotImplementedError
|
285 |
+
if name == "conv":
|
286 |
+
self.Conv2d_0 = conv
|
287 |
+
self.conv = conv
|
288 |
+
elif name == "Conv2d_0":
|
289 |
+
self.conv = conv
|
290 |
+
else:
|
291 |
+
self.conv = conv
|
292 |
+
|
293 |
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
294 |
+
assert hidden_states.shape[1] == self.channels
|
295 |
+
|
296 |
+
if self.norm is not None:
|
297 |
+
hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
298 |
+
|
299 |
+
assert hidden_states.shape[1] == self.channels
|
300 |
+
|
301 |
+
hidden_states = self.conv(hidden_states)
|
302 |
+
|
303 |
+
return hidden_states
|
304 |
+
|
305 |
+
class ResnetBlockCausal3D(nn.Module):
|
306 |
+
r"""
|
307 |
+
A Resnet block.
|
308 |
+
|
309 |
+
Parameters:
|
310 |
+
in_channels (`int`): The number of channels in the input.
|
311 |
+
out_channels (`int`, *optional*, default to be `None`):
|
312 |
+
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
313 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
314 |
+
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
|
315 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
316 |
+
groups_out (`int`, *optional*, default to None):
|
317 |
+
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
|
318 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
319 |
+
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
|
320 |
+
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
|
321 |
+
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
|
322 |
+
"ada_group" for a stronger conditioning with scale and shift.
|
323 |
+
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
|
324 |
+
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
|
325 |
+
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
|
326 |
+
use_in_shortcut (`bool`, *optional*, default to `True`):
|
327 |
+
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
|
328 |
+
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
|
329 |
+
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
|
330 |
+
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
|
331 |
+
`conv_shortcut` output.
|
332 |
+
conv_3d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
|
333 |
+
If None, same as `out_channels`.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
*,
|
339 |
+
in_channels: int,
|
340 |
+
out_channels: Optional[int] = None,
|
341 |
+
conv_shortcut: bool = False,
|
342 |
+
dropout: float = 0.0,
|
343 |
+
temb_channels: int = 512,
|
344 |
+
groups: int = 32,
|
345 |
+
groups_out: Optional[int] = None,
|
346 |
+
pre_norm: bool = True,
|
347 |
+
eps: float = 1e-6,
|
348 |
+
non_linearity: str = "swish",
|
349 |
+
skip_time_act: bool = False,
|
350 |
+
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
|
351 |
+
kernel: Optional[torch.FloatTensor] = None,
|
352 |
+
output_scale_factor: float = 1.0,
|
353 |
+
use_in_shortcut: Optional[bool] = None,
|
354 |
+
up: bool = False,
|
355 |
+
down: bool = False,
|
356 |
+
conv_shortcut_bias: bool = True,
|
357 |
+
conv_3d_out_channels: Optional[int] = None,
|
358 |
+
disable_causal: bool = False,
|
359 |
+
):
|
360 |
+
super().__init__()
|
361 |
+
self.pre_norm = pre_norm
|
362 |
+
self.pre_norm = True
|
363 |
+
self.in_channels = in_channels
|
364 |
+
out_channels = in_channels if out_channels is None else out_channels
|
365 |
+
self.out_channels = out_channels
|
366 |
+
self.use_conv_shortcut = conv_shortcut
|
367 |
+
self.up = up
|
368 |
+
self.down = down
|
369 |
+
self.output_scale_factor = output_scale_factor
|
370 |
+
self.time_embedding_norm = time_embedding_norm
|
371 |
+
self.skip_time_act = skip_time_act
|
372 |
+
|
373 |
+
linear_cls = nn.Linear
|
374 |
+
|
375 |
+
if groups_out is None:
|
376 |
+
groups_out = groups
|
377 |
+
|
378 |
+
if self.time_embedding_norm == "ada_group":
|
379 |
+
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
|
380 |
+
elif self.time_embedding_norm == "spatial":
|
381 |
+
self.norm1 = SpatialNorm(in_channels, temb_channels)
|
382 |
+
else:
|
383 |
+
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
384 |
+
|
385 |
+
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, disable_causal=disable_causal)
|
386 |
+
|
387 |
+
if temb_channels is not None:
|
388 |
+
if self.time_embedding_norm == "default":
|
389 |
+
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
390 |
+
elif self.time_embedding_norm == "scale_shift":
|
391 |
+
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
|
392 |
+
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
393 |
+
self.time_emb_proj = None
|
394 |
+
else:
|
395 |
+
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
396 |
+
else:
|
397 |
+
self.time_emb_proj = None
|
398 |
+
|
399 |
+
if self.time_embedding_norm == "ada_group":
|
400 |
+
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
|
401 |
+
elif self.time_embedding_norm == "spatial":
|
402 |
+
self.norm2 = SpatialNorm(out_channels, temb_channels)
|
403 |
+
else:
|
404 |
+
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
405 |
+
|
406 |
+
self.dropout = torch.nn.Dropout(dropout)
|
407 |
+
conv_3d_out_channels = conv_3d_out_channels or out_channels
|
408 |
+
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels,
|
409 |
+
kernel_size=3, stride=1, disable_causal=disable_causal)
|
410 |
+
|
411 |
+
self.nonlinearity = get_activation(non_linearity)
|
412 |
+
|
413 |
+
self.upsample = self.downsample = None
|
414 |
+
if self.up:
|
415 |
+
self.upsample = UpsampleCausal3D(in_channels, use_conv=False, disable_causal=disable_causal)
|
416 |
+
elif self.down:
|
417 |
+
self.downsample = DownsampleCausal3D(in_channels, use_conv=False,
|
418 |
+
disable_causal=disable_causal, name="op")
|
419 |
+
|
420 |
+
self.use_in_shortcut = self.in_channels != conv_3d_out_channels \
|
421 |
+
if use_in_shortcut is None else use_in_shortcut
|
422 |
+
|
423 |
+
self.conv_shortcut = None
|
424 |
+
if self.use_in_shortcut:
|
425 |
+
self.conv_shortcut = CausalConv3d(
|
426 |
+
in_channels,
|
427 |
+
conv_3d_out_channels,
|
428 |
+
kernel_size=1,
|
429 |
+
stride=1,
|
430 |
+
disable_causal=disable_causal,
|
431 |
+
bias=conv_shortcut_bias,
|
432 |
+
)
|
433 |
+
|
434 |
+
def forward(
|
435 |
+
self,
|
436 |
+
input_tensor: torch.FloatTensor,
|
437 |
+
temb: torch.FloatTensor,
|
438 |
+
scale: float = 1.0,
|
439 |
+
) -> torch.FloatTensor:
|
440 |
+
hidden_states = input_tensor
|
441 |
+
|
442 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
443 |
+
hidden_states = self.norm1(hidden_states, temb)
|
444 |
+
else:
|
445 |
+
hidden_states = self.norm1(hidden_states)
|
446 |
+
|
447 |
+
hidden_states = self.nonlinearity(hidden_states)
|
448 |
+
|
449 |
+
if self.upsample is not None:
|
450 |
+
# upsample_nearest_nhwc fails with large batch sizes.
|
451 |
+
# see https://github.com/huggingface/diffusers/issues/984
|
452 |
+
if hidden_states.shape[0] >= 64:
|
453 |
+
input_tensor = input_tensor.contiguous()
|
454 |
+
hidden_states = hidden_states.contiguous()
|
455 |
+
input_tensor = (
|
456 |
+
self.upsample(input_tensor, scale=scale)
|
457 |
+
)
|
458 |
+
hidden_states = (
|
459 |
+
self.upsample(hidden_states, scale=scale)
|
460 |
+
)
|
461 |
+
elif self.downsample is not None:
|
462 |
+
input_tensor = (
|
463 |
+
self.downsample(input_tensor, scale=scale)
|
464 |
+
)
|
465 |
+
hidden_states = (
|
466 |
+
self.downsample(hidden_states, scale=scale)
|
467 |
+
)
|
468 |
+
|
469 |
+
hidden_states = self.conv1(hidden_states)
|
470 |
+
|
471 |
+
if self.time_emb_proj is not None:
|
472 |
+
if not self.skip_time_act:
|
473 |
+
temb = self.nonlinearity(temb)
|
474 |
+
temb = (
|
475 |
+
self.time_emb_proj(temb, scale)[:, :, None, None]
|
476 |
+
)
|
477 |
+
|
478 |
+
if temb is not None and self.time_embedding_norm == "default":
|
479 |
+
hidden_states = hidden_states + temb
|
480 |
+
|
481 |
+
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
482 |
+
hidden_states = self.norm2(hidden_states, temb)
|
483 |
+
else:
|
484 |
+
hidden_states = self.norm2(hidden_states)
|
485 |
+
|
486 |
+
if temb is not None and self.time_embedding_norm == "scale_shift":
|
487 |
+
scale, shift = torch.chunk(temb, 2, dim=1)
|
488 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
489 |
+
|
490 |
+
hidden_states = self.nonlinearity(hidden_states)
|
491 |
+
|
492 |
+
hidden_states = self.dropout(hidden_states)
|
493 |
+
hidden_states = self.conv2(hidden_states)
|
494 |
+
|
495 |
+
if self.conv_shortcut is not None:
|
496 |
+
input_tensor = (
|
497 |
+
self.conv_shortcut(input_tensor)
|
498 |
+
)
|
499 |
+
|
500 |
+
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
501 |
+
|
502 |
+
return output_tensor
|
503 |
+
|
504 |
+
def get_down_block3d(
|
505 |
+
down_block_type: str,
|
506 |
+
num_layers: int,
|
507 |
+
in_channels: int,
|
508 |
+
out_channels: int,
|
509 |
+
temb_channels: int,
|
510 |
+
add_downsample: bool,
|
511 |
+
downsample_stride: int,
|
512 |
+
resnet_eps: float,
|
513 |
+
resnet_act_fn: str,
|
514 |
+
transformer_layers_per_block: int = 1,
|
515 |
+
num_attention_heads: Optional[int] = None,
|
516 |
+
resnet_groups: Optional[int] = None,
|
517 |
+
cross_attention_dim: Optional[int] = None,
|
518 |
+
downsample_padding: Optional[int] = None,
|
519 |
+
dual_cross_attention: bool = False,
|
520 |
+
use_linear_projection: bool = False,
|
521 |
+
only_cross_attention: bool = False,
|
522 |
+
upcast_attention: bool = False,
|
523 |
+
resnet_time_scale_shift: str = "default",
|
524 |
+
attention_type: str = "default",
|
525 |
+
resnet_skip_time_act: bool = False,
|
526 |
+
resnet_out_scale_factor: float = 1.0,
|
527 |
+
cross_attention_norm: Optional[str] = None,
|
528 |
+
attention_head_dim: Optional[int] = None,
|
529 |
+
downsample_type: Optional[str] = None,
|
530 |
+
dropout: float = 0.0,
|
531 |
+
disable_causal: bool = False,
|
532 |
+
):
|
533 |
+
# If attn head dim is not defined, we default it to the number of heads
|
534 |
+
if attention_head_dim is None:
|
535 |
+
logger.warn(
|
536 |
+
f"It is recommended to provide `attention_head_dim` when calling \
|
537 |
+
`get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
|
538 |
+
)
|
539 |
+
attention_head_dim = num_attention_heads
|
540 |
+
|
541 |
+
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
542 |
+
if down_block_type == "DownEncoderBlockCausal3D":
|
543 |
+
return DownEncoderBlockCausal3D(
|
544 |
+
num_layers=num_layers,
|
545 |
+
in_channels=in_channels,
|
546 |
+
out_channels=out_channels,
|
547 |
+
dropout=dropout,
|
548 |
+
add_downsample=add_downsample,
|
549 |
+
downsample_stride=downsample_stride,
|
550 |
+
resnet_eps=resnet_eps,
|
551 |
+
resnet_act_fn=resnet_act_fn,
|
552 |
+
resnet_groups=resnet_groups,
|
553 |
+
downsample_padding=downsample_padding,
|
554 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
555 |
+
disable_causal=disable_causal,
|
556 |
+
)
|
557 |
+
raise ValueError(f"{down_block_type} does not exist.")
|
558 |
+
|
559 |
+
def get_up_block3d(
|
560 |
+
up_block_type: str,
|
561 |
+
num_layers: int,
|
562 |
+
in_channels: int,
|
563 |
+
out_channels: int,
|
564 |
+
prev_output_channel: int,
|
565 |
+
temb_channels: int,
|
566 |
+
add_upsample: bool,
|
567 |
+
upsample_scale_factor: Tuple,
|
568 |
+
resnet_eps: float,
|
569 |
+
resnet_act_fn: str,
|
570 |
+
resolution_idx: Optional[int] = None,
|
571 |
+
transformer_layers_per_block: int = 1,
|
572 |
+
num_attention_heads: Optional[int] = None,
|
573 |
+
resnet_groups: Optional[int] = None,
|
574 |
+
cross_attention_dim: Optional[int] = None,
|
575 |
+
dual_cross_attention: bool = False,
|
576 |
+
use_linear_projection: bool = False,
|
577 |
+
only_cross_attention: bool = False,
|
578 |
+
upcast_attention: bool = False,
|
579 |
+
resnet_time_scale_shift: str = "default",
|
580 |
+
attention_type: str = "default",
|
581 |
+
resnet_skip_time_act: bool = False,
|
582 |
+
resnet_out_scale_factor: float = 1.0,
|
583 |
+
cross_attention_norm: Optional[str] = None,
|
584 |
+
attention_head_dim: Optional[int] = None,
|
585 |
+
upsample_type: Optional[str] = None,
|
586 |
+
dropout: float = 0.0,
|
587 |
+
disable_causal: bool = False,
|
588 |
+
) -> nn.Module:
|
589 |
+
# If attn head dim is not defined, we default it to the number of heads
|
590 |
+
if attention_head_dim is None:
|
591 |
+
logger.warn(
|
592 |
+
f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. \
|
593 |
+
Defaulting `attention_head_dim` to {num_attention_heads}."
|
594 |
+
)
|
595 |
+
attention_head_dim = num_attention_heads
|
596 |
+
|
597 |
+
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
598 |
+
if up_block_type == "UpDecoderBlockCausal3D":
|
599 |
+
return UpDecoderBlockCausal3D(
|
600 |
+
num_layers=num_layers,
|
601 |
+
in_channels=in_channels,
|
602 |
+
out_channels=out_channels,
|
603 |
+
resolution_idx=resolution_idx,
|
604 |
+
dropout=dropout,
|
605 |
+
add_upsample=add_upsample,
|
606 |
+
upsample_scale_factor=upsample_scale_factor,
|
607 |
+
resnet_eps=resnet_eps,
|
608 |
+
resnet_act_fn=resnet_act_fn,
|
609 |
+
resnet_groups=resnet_groups,
|
610 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
611 |
+
temb_channels=temb_channels,
|
612 |
+
disable_causal=disable_causal,
|
613 |
+
)
|
614 |
+
raise ValueError(f"{up_block_type} does not exist.")
|
615 |
+
|
616 |
+
|
617 |
+
class UNetMidBlockCausal3D(nn.Module):
|
618 |
+
"""
|
619 |
+
A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
|
620 |
+
|
621 |
+
Args:
|
622 |
+
in_channels (`int`): The number of input channels.
|
623 |
+
temb_channels (`int`): The number of temporal embedding channels.
|
624 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
625 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
626 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
627 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
|
628 |
+
The type of normalization to apply to the time embeddings. This can help to improve the performance of the
|
629 |
+
model on tasks with long-range temporal dependencies.
|
630 |
+
resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
|
631 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
632 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
633 |
+
attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
|
634 |
+
resnet_pre_norm (`bool`, *optional*, defaults to `True`):
|
635 |
+
Whether to use pre-normalization for the resnet blocks.
|
636 |
+
add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
|
637 |
+
attention_head_dim (`int`, *optional*, defaults to 1):
|
638 |
+
Dimension of a single attention head. The number of attention heads is determined based on this value and
|
639 |
+
the number of input channels.
|
640 |
+
output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
|
641 |
+
|
642 |
+
Returns:
|
643 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
644 |
+
in_channels, height, width)`.
|
645 |
+
|
646 |
+
"""
|
647 |
+
|
648 |
+
def __init__(
|
649 |
+
self,
|
650 |
+
in_channels: int,
|
651 |
+
temb_channels: int,
|
652 |
+
dropout: float = 0.0,
|
653 |
+
num_layers: int = 1,
|
654 |
+
resnet_eps: float = 1e-6,
|
655 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
656 |
+
resnet_act_fn: str = "swish",
|
657 |
+
resnet_groups: int = 32,
|
658 |
+
attn_groups: Optional[int] = None,
|
659 |
+
resnet_pre_norm: bool = True,
|
660 |
+
add_attention: bool = True,
|
661 |
+
attention_head_dim: int = 1,
|
662 |
+
output_scale_factor: float = 1.0,
|
663 |
+
disable_causal: bool = False,
|
664 |
+
causal_attention: bool = False,
|
665 |
+
):
|
666 |
+
super().__init__()
|
667 |
+
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
668 |
+
self.add_attention = add_attention
|
669 |
+
self.causal_attention = causal_attention
|
670 |
+
|
671 |
+
if attn_groups is None:
|
672 |
+
attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
|
673 |
+
|
674 |
+
# there is always at least one resnet
|
675 |
+
resnets = [
|
676 |
+
ResnetBlockCausal3D(
|
677 |
+
in_channels=in_channels,
|
678 |
+
out_channels=in_channels,
|
679 |
+
temb_channels=temb_channels,
|
680 |
+
eps=resnet_eps,
|
681 |
+
groups=resnet_groups,
|
682 |
+
dropout=dropout,
|
683 |
+
time_embedding_norm=resnet_time_scale_shift,
|
684 |
+
non_linearity=resnet_act_fn,
|
685 |
+
output_scale_factor=output_scale_factor,
|
686 |
+
pre_norm=resnet_pre_norm,
|
687 |
+
disable_causal=disable_causal,
|
688 |
+
)
|
689 |
+
]
|
690 |
+
attentions = []
|
691 |
+
|
692 |
+
if attention_head_dim is None:
|
693 |
+
logger.warn(
|
694 |
+
f"It is not recommend to pass `attention_head_dim=None`. \
|
695 |
+
Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
|
696 |
+
)
|
697 |
+
attention_head_dim = in_channels
|
698 |
+
|
699 |
+
for _ in range(num_layers):
|
700 |
+
if self.add_attention:
|
701 |
+
#assert False, "Not implemented yet"
|
702 |
+
attentions.append(
|
703 |
+
Attention(
|
704 |
+
in_channels,
|
705 |
+
heads=in_channels // attention_head_dim,
|
706 |
+
dim_head=attention_head_dim,
|
707 |
+
rescale_output_factor=output_scale_factor,
|
708 |
+
eps=resnet_eps,
|
709 |
+
norm_num_groups=attn_groups,
|
710 |
+
spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
|
711 |
+
residual_connection=True,
|
712 |
+
bias=True,
|
713 |
+
upcast_softmax=True,
|
714 |
+
_from_deprecated_attn_block=True,
|
715 |
+
)
|
716 |
+
)
|
717 |
+
else:
|
718 |
+
attentions.append(None)
|
719 |
+
|
720 |
+
resnets.append(
|
721 |
+
ResnetBlockCausal3D(
|
722 |
+
in_channels=in_channels,
|
723 |
+
out_channels=in_channels,
|
724 |
+
temb_channels=temb_channels,
|
725 |
+
eps=resnet_eps,
|
726 |
+
groups=resnet_groups,
|
727 |
+
dropout=dropout,
|
728 |
+
time_embedding_norm=resnet_time_scale_shift,
|
729 |
+
non_linearity=resnet_act_fn,
|
730 |
+
output_scale_factor=output_scale_factor,
|
731 |
+
pre_norm=resnet_pre_norm,
|
732 |
+
disable_causal=disable_causal,
|
733 |
+
)
|
734 |
+
)
|
735 |
+
|
736 |
+
self.attentions = nn.ModuleList(attentions)
|
737 |
+
self.resnets = nn.ModuleList(resnets)
|
738 |
+
|
739 |
+
def forward(self,
|
740 |
+
hidden_states: torch.FloatTensor,
|
741 |
+
temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
|
742 |
+
hidden_states = self.resnets[0](hidden_states, temb)
|
743 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
744 |
+
if attn is not None:
|
745 |
+
B, C, T, H, W = hidden_states.shape
|
746 |
+
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
747 |
+
if self.causal_attention:
|
748 |
+
attention_mask = prepare_causal_attention_mask(T, H * W,
|
749 |
+
hidden_states.dtype,
|
750 |
+
hidden_states.device,
|
751 |
+
batch_size=B)
|
752 |
+
else:
|
753 |
+
attention_mask = None
|
754 |
+
hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
|
755 |
+
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
756 |
+
hidden_states = resnet(hidden_states, temb)
|
757 |
+
|
758 |
+
return hidden_states
|
759 |
+
|
760 |
+
|
761 |
+
class DownEncoderBlockCausal3D(nn.Module):
|
762 |
+
def __init__(
|
763 |
+
self,
|
764 |
+
in_channels: int,
|
765 |
+
out_channels: int,
|
766 |
+
dropout: float = 0.0,
|
767 |
+
num_layers: int = 1,
|
768 |
+
resnet_eps: float = 1e-6,
|
769 |
+
resnet_time_scale_shift: str = "default",
|
770 |
+
resnet_act_fn: str = "swish",
|
771 |
+
resnet_groups: int = 32,
|
772 |
+
resnet_pre_norm: bool = True,
|
773 |
+
output_scale_factor: float = 1.0,
|
774 |
+
add_downsample: bool = True,
|
775 |
+
downsample_stride: int = 2,
|
776 |
+
downsample_padding: int = 1,
|
777 |
+
disable_causal: bool = False,
|
778 |
+
):
|
779 |
+
super().__init__()
|
780 |
+
resnets = []
|
781 |
+
|
782 |
+
for i in range(num_layers):
|
783 |
+
in_channels = in_channels if i == 0 else out_channels
|
784 |
+
resnets.append(
|
785 |
+
ResnetBlockCausal3D(
|
786 |
+
in_channels=in_channels,
|
787 |
+
out_channels=out_channels,
|
788 |
+
temb_channels=None,
|
789 |
+
eps=resnet_eps,
|
790 |
+
groups=resnet_groups,
|
791 |
+
dropout=dropout,
|
792 |
+
time_embedding_norm=resnet_time_scale_shift,
|
793 |
+
non_linearity=resnet_act_fn,
|
794 |
+
output_scale_factor=output_scale_factor,
|
795 |
+
pre_norm=resnet_pre_norm,
|
796 |
+
disable_causal=disable_causal,
|
797 |
+
)
|
798 |
+
)
|
799 |
+
|
800 |
+
self.resnets = nn.ModuleList(resnets)
|
801 |
+
|
802 |
+
if add_downsample:
|
803 |
+
self.downsamplers = nn.ModuleList(
|
804 |
+
[
|
805 |
+
DownsampleCausal3D(
|
806 |
+
out_channels,
|
807 |
+
use_conv=True,
|
808 |
+
out_channels=out_channels,
|
809 |
+
padding=downsample_padding,
|
810 |
+
name="op",
|
811 |
+
stride=downsample_stride,
|
812 |
+
disable_causal=disable_causal,
|
813 |
+
)
|
814 |
+
]
|
815 |
+
)
|
816 |
+
else:
|
817 |
+
self.downsamplers = None
|
818 |
+
|
819 |
+
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
820 |
+
for resnet in self.resnets:
|
821 |
+
hidden_states = resnet(hidden_states, temb=None, scale=scale)
|
822 |
+
|
823 |
+
if self.downsamplers is not None:
|
824 |
+
for downsampler in self.downsamplers:
|
825 |
+
hidden_states = downsampler(hidden_states, scale)
|
826 |
+
|
827 |
+
return hidden_states
|
828 |
+
|
829 |
+
|
830 |
+
class UpDecoderBlockCausal3D(nn.Module):
|
831 |
+
def __init__(
|
832 |
+
self,
|
833 |
+
in_channels: int,
|
834 |
+
out_channels: int,
|
835 |
+
resolution_idx: Optional[int] = None,
|
836 |
+
dropout: float = 0.0,
|
837 |
+
num_layers: int = 1,
|
838 |
+
resnet_eps: float = 1e-6,
|
839 |
+
resnet_time_scale_shift: str = "default", # default, spatial
|
840 |
+
resnet_act_fn: str = "swish",
|
841 |
+
resnet_groups: int = 32,
|
842 |
+
resnet_pre_norm: bool = True,
|
843 |
+
output_scale_factor: float = 1.0,
|
844 |
+
add_upsample: bool = True,
|
845 |
+
upsample_scale_factor = (2, 2, 2),
|
846 |
+
temb_channels: Optional[int] = None,
|
847 |
+
disable_causal: bool = False,
|
848 |
+
):
|
849 |
+
super().__init__()
|
850 |
+
resnets = []
|
851 |
+
|
852 |
+
for i in range(num_layers):
|
853 |
+
input_channels = in_channels if i == 0 else out_channels
|
854 |
+
|
855 |
+
resnets.append(
|
856 |
+
ResnetBlockCausal3D(
|
857 |
+
in_channels=input_channels,
|
858 |
+
out_channels=out_channels,
|
859 |
+
temb_channels=temb_channels,
|
860 |
+
eps=resnet_eps,
|
861 |
+
groups=resnet_groups,
|
862 |
+
dropout=dropout,
|
863 |
+
time_embedding_norm=resnet_time_scale_shift,
|
864 |
+
non_linearity=resnet_act_fn,
|
865 |
+
output_scale_factor=output_scale_factor,
|
866 |
+
pre_norm=resnet_pre_norm,
|
867 |
+
disable_causal=disable_causal,
|
868 |
+
)
|
869 |
+
)
|
870 |
+
|
871 |
+
self.resnets = nn.ModuleList(resnets)
|
872 |
+
|
873 |
+
if add_upsample:
|
874 |
+
self.upsamplers = nn.ModuleList(
|
875 |
+
[
|
876 |
+
UpsampleCausal3D(
|
877 |
+
out_channels,
|
878 |
+
use_conv=True,
|
879 |
+
out_channels=out_channels,
|
880 |
+
upsample_factor=upsample_scale_factor,
|
881 |
+
disable_causal=disable_causal
|
882 |
+
)
|
883 |
+
]
|
884 |
+
)
|
885 |
+
else:
|
886 |
+
self.upsamplers = None
|
887 |
+
|
888 |
+
self.resolution_idx = resolution_idx
|
889 |
+
|
890 |
+
def forward(
|
891 |
+
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
|
892 |
+
) -> torch.FloatTensor:
|
893 |
+
for resnet in self.resnets:
|
894 |
+
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
895 |
+
|
896 |
+
if self.upsamplers is not None:
|
897 |
+
for upsampler in self.upsamplers:
|
898 |
+
hidden_states = upsampler(hidden_states)
|
899 |
+
|
900 |
+
return hidden_states
|
hymm_sp/vae/vae.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
9 |
+
from diffusers.utils.torch_utils import randn_tensor
|
10 |
+
from diffusers.models.attention_processor import SpatialNorm
|
11 |
+
from .unet_causal_3d_blocks import (
|
12 |
+
CausalConv3d,
|
13 |
+
UNetMidBlockCausal3D,
|
14 |
+
get_down_block3d,
|
15 |
+
get_up_block3d,
|
16 |
+
)
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class DecoderOutput(BaseOutput):
|
20 |
+
r"""
|
21 |
+
Output of decoding method.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
25 |
+
The decoded output sample from the last layer of the model.
|
26 |
+
"""
|
27 |
+
|
28 |
+
sample: torch.FloatTensor
|
29 |
+
|
30 |
+
|
31 |
+
class EncoderCausal3D(nn.Module):
|
32 |
+
r"""
|
33 |
+
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
in_channels (`int`, *optional*, defaults to 3):
|
37 |
+
The number of input channels.
|
38 |
+
out_channels (`int`, *optional*, defaults to 3):
|
39 |
+
The number of output channels.
|
40 |
+
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
41 |
+
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
|
42 |
+
options.
|
43 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
44 |
+
The number of output channels for each block.
|
45 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
46 |
+
The number of layers per block.
|
47 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
48 |
+
The number of groups for normalization.
|
49 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
50 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
51 |
+
double_z (`bool`, *optional*, defaults to `True`):
|
52 |
+
Whether to double the number of output channels for the last block.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_channels: int = 3,
|
58 |
+
out_channels: int = 3,
|
59 |
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
|
60 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
61 |
+
layers_per_block: int = 2,
|
62 |
+
norm_num_groups: int = 32,
|
63 |
+
act_fn: str = "silu",
|
64 |
+
double_z: bool = True,
|
65 |
+
mid_block_add_attention=True,
|
66 |
+
time_compression_ratio: int = 4,
|
67 |
+
spatial_compression_ratio: int = 8,
|
68 |
+
disable_causal: bool = False,
|
69 |
+
mid_block_causal_attn: bool = False,
|
70 |
+
):
|
71 |
+
super().__init__()
|
72 |
+
self.layers_per_block = layers_per_block
|
73 |
+
|
74 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[0],
|
75 |
+
kernel_size=3, stride=1, disable_causal=disable_causal)
|
76 |
+
self.mid_block = None
|
77 |
+
self.down_blocks = nn.ModuleList([])
|
78 |
+
|
79 |
+
# down
|
80 |
+
output_channel = block_out_channels[0]
|
81 |
+
for i, down_block_type in enumerate(down_block_types):
|
82 |
+
input_channel = output_channel
|
83 |
+
output_channel = block_out_channels[i]
|
84 |
+
is_final_block = i == len(block_out_channels) - 1
|
85 |
+
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
86 |
+
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
87 |
+
|
88 |
+
if time_compression_ratio == 4:
|
89 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
90 |
+
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - \
|
91 |
+
num_time_downsample_layers) and not is_final_block)
|
92 |
+
elif time_compression_ratio == 8:
|
93 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
94 |
+
add_time_downsample = bool(i < num_time_downsample_layers)
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
|
97 |
+
|
98 |
+
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
99 |
+
downsample_stride_T = (2, ) if add_time_downsample else (1, )
|
100 |
+
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
101 |
+
down_block = get_down_block3d(
|
102 |
+
down_block_type,
|
103 |
+
num_layers=self.layers_per_block,
|
104 |
+
in_channels=input_channel,
|
105 |
+
out_channels=output_channel,
|
106 |
+
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
107 |
+
downsample_stride=downsample_stride,
|
108 |
+
resnet_eps=1e-6,
|
109 |
+
downsample_padding=0,
|
110 |
+
resnet_act_fn=act_fn,
|
111 |
+
resnet_groups=norm_num_groups,
|
112 |
+
attention_head_dim=output_channel,
|
113 |
+
temb_channels=None,
|
114 |
+
disable_causal=disable_causal,
|
115 |
+
)
|
116 |
+
self.down_blocks.append(down_block)
|
117 |
+
|
118 |
+
# mid
|
119 |
+
self.mid_block = UNetMidBlockCausal3D(
|
120 |
+
in_channels=block_out_channels[-1],
|
121 |
+
resnet_eps=1e-6,
|
122 |
+
resnet_act_fn=act_fn,
|
123 |
+
output_scale_factor=1,
|
124 |
+
resnet_time_scale_shift="default",
|
125 |
+
attention_head_dim=block_out_channels[-1],
|
126 |
+
resnet_groups=norm_num_groups,
|
127 |
+
temb_channels=None,
|
128 |
+
add_attention=mid_block_add_attention,
|
129 |
+
disable_causal=disable_causal,
|
130 |
+
causal_attention=mid_block_causal_attn,
|
131 |
+
)
|
132 |
+
|
133 |
+
# out
|
134 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
135 |
+
self.conv_act = nn.SiLU()
|
136 |
+
|
137 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
138 |
+
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels,
|
139 |
+
kernel_size=3, disable_causal=disable_causal)
|
140 |
+
|
141 |
+
self.gradient_checkpointing = False
|
142 |
+
|
143 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
144 |
+
r"""The forward method of the `EncoderCausal3D` class."""
|
145 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
146 |
+
|
147 |
+
sample = self.conv_in(sample)
|
148 |
+
|
149 |
+
if self.training and self.gradient_checkpointing:
|
150 |
+
|
151 |
+
def create_custom_forward(module):
|
152 |
+
def custom_forward(*inputs):
|
153 |
+
return module(*inputs)
|
154 |
+
|
155 |
+
return custom_forward
|
156 |
+
|
157 |
+
# down
|
158 |
+
if is_torch_version(">=", "1.11.0"):
|
159 |
+
for down_block in self.down_blocks:
|
160 |
+
sample = torch.utils.checkpoint.checkpoint(
|
161 |
+
create_custom_forward(down_block), sample, use_reentrant=False
|
162 |
+
)
|
163 |
+
# middle
|
164 |
+
sample = torch.utils.checkpoint.checkpoint(
|
165 |
+
create_custom_forward(self.mid_block), sample, use_reentrant=False
|
166 |
+
)
|
167 |
+
else:
|
168 |
+
for down_block in self.down_blocks:
|
169 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
170 |
+
# middle
|
171 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
172 |
+
|
173 |
+
else:
|
174 |
+
# down
|
175 |
+
for down_block in self.down_blocks:
|
176 |
+
sample = down_block(sample)
|
177 |
+
|
178 |
+
# middle
|
179 |
+
sample = self.mid_block(sample)
|
180 |
+
|
181 |
+
# post-process
|
182 |
+
sample = self.conv_norm_out(sample)
|
183 |
+
sample = self.conv_act(sample)
|
184 |
+
sample = self.conv_out(sample)
|
185 |
+
|
186 |
+
return sample
|
187 |
+
|
188 |
+
|
189 |
+
class DecoderCausal3D(nn.Module):
|
190 |
+
r"""
|
191 |
+
The `DecoderCausal3D` layer of a variational autoencoder that decodes its
|
192 |
+
latent representation into an output sample.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
in_channels (`int`, *optional*, defaults to 3):
|
196 |
+
The number of input channels.
|
197 |
+
out_channels (`int`, *optional*, defaults to 3):
|
198 |
+
The number of output channels.
|
199 |
+
up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
200 |
+
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
201 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
202 |
+
The number of output channels for each block.
|
203 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
204 |
+
The number of layers per block.
|
205 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
206 |
+
The number of groups for normalization.
|
207 |
+
act_fn (`str`, *optional*, defaults to `"silu"`):
|
208 |
+
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
209 |
+
norm_type (`str`, *optional*, defaults to `"group"`):
|
210 |
+
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
211 |
+
"""
|
212 |
+
|
213 |
+
def __init__(
|
214 |
+
self,
|
215 |
+
in_channels: int = 3,
|
216 |
+
out_channels: int = 3,
|
217 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
|
218 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
219 |
+
layers_per_block: int = 2,
|
220 |
+
norm_num_groups: int = 32,
|
221 |
+
act_fn: str = "silu",
|
222 |
+
norm_type: str = "group", # group, spatial
|
223 |
+
mid_block_add_attention=True,
|
224 |
+
time_compression_ratio: int = 4,
|
225 |
+
spatial_compression_ratio: int = 8,
|
226 |
+
disable_causal: bool = False,
|
227 |
+
mid_block_causal_attn: bool = False,
|
228 |
+
):
|
229 |
+
super().__init__()
|
230 |
+
self.layers_per_block = layers_per_block
|
231 |
+
|
232 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3,
|
233 |
+
stride=1, disable_causal=disable_causal)
|
234 |
+
self.mid_block = None
|
235 |
+
self.up_blocks = nn.ModuleList([])
|
236 |
+
|
237 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
238 |
+
|
239 |
+
# mid
|
240 |
+
self.mid_block = UNetMidBlockCausal3D(
|
241 |
+
in_channels=block_out_channels[-1],
|
242 |
+
resnet_eps=1e-6,
|
243 |
+
resnet_act_fn=act_fn,
|
244 |
+
output_scale_factor=1,
|
245 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
246 |
+
attention_head_dim=block_out_channels[-1],
|
247 |
+
resnet_groups=norm_num_groups,
|
248 |
+
temb_channels=temb_channels,
|
249 |
+
add_attention=mid_block_add_attention,
|
250 |
+
disable_causal=disable_causal,
|
251 |
+
causal_attention=mid_block_causal_attn,
|
252 |
+
)
|
253 |
+
|
254 |
+
# up
|
255 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
256 |
+
output_channel = reversed_block_out_channels[0]
|
257 |
+
for i, up_block_type in enumerate(up_block_types):
|
258 |
+
prev_output_channel = output_channel
|
259 |
+
output_channel = reversed_block_out_channels[i]
|
260 |
+
is_final_block = i == len(block_out_channels) - 1
|
261 |
+
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
262 |
+
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
263 |
+
|
264 |
+
if time_compression_ratio == 4:
|
265 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
266 |
+
add_time_upsample = bool(i >= len(block_out_channels) - 1 - \
|
267 |
+
num_time_upsample_layers and not is_final_block)
|
268 |
+
elif time_compression_ratio == 8:
|
269 |
+
add_spatial_upsample = bool(i >= len(block_out_channels) - num_spatial_upsample_layers)
|
270 |
+
add_time_upsample = bool(i >= len(block_out_channels) - num_time_upsample_layers)
|
271 |
+
else:
|
272 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
|
273 |
+
|
274 |
+
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
275 |
+
upsample_scale_factor_T = (2, ) if add_time_upsample else (1, )
|
276 |
+
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
277 |
+
up_block = get_up_block3d(
|
278 |
+
up_block_type,
|
279 |
+
num_layers=self.layers_per_block + 1,
|
280 |
+
in_channels=prev_output_channel,
|
281 |
+
out_channels=output_channel,
|
282 |
+
prev_output_channel=None,
|
283 |
+
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
284 |
+
upsample_scale_factor=upsample_scale_factor,
|
285 |
+
resnet_eps=1e-6,
|
286 |
+
resnet_act_fn=act_fn,
|
287 |
+
resnet_groups=norm_num_groups,
|
288 |
+
attention_head_dim=output_channel,
|
289 |
+
temb_channels=temb_channels,
|
290 |
+
resnet_time_scale_shift=norm_type,
|
291 |
+
disable_causal=disable_causal,
|
292 |
+
)
|
293 |
+
self.up_blocks.append(up_block)
|
294 |
+
prev_output_channel = output_channel
|
295 |
+
|
296 |
+
# out
|
297 |
+
if norm_type == "spatial":
|
298 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
299 |
+
else:
|
300 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
301 |
+
self.conv_act = nn.SiLU()
|
302 |
+
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, disable_causal=disable_causal)
|
303 |
+
|
304 |
+
self.gradient_checkpointing = False
|
305 |
+
|
306 |
+
def forward(
|
307 |
+
self,
|
308 |
+
sample: torch.FloatTensor,
|
309 |
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
310 |
+
) -> torch.FloatTensor:
|
311 |
+
r"""The forward method of the `DecoderCausal3D` class."""
|
312 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
313 |
+
|
314 |
+
sample = self.conv_in(sample)
|
315 |
+
|
316 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
317 |
+
if self.training and self.gradient_checkpointing:
|
318 |
+
|
319 |
+
def create_custom_forward(module):
|
320 |
+
def custom_forward(*inputs):
|
321 |
+
return module(*inputs)
|
322 |
+
|
323 |
+
return custom_forward
|
324 |
+
|
325 |
+
if is_torch_version(">=", "1.11.0"):
|
326 |
+
# middle
|
327 |
+
sample = torch.utils.checkpoint.checkpoint(
|
328 |
+
create_custom_forward(self.mid_block),
|
329 |
+
sample,
|
330 |
+
latent_embeds,
|
331 |
+
use_reentrant=False,
|
332 |
+
)
|
333 |
+
sample = sample.to(upscale_dtype)
|
334 |
+
|
335 |
+
# up
|
336 |
+
for up_block in self.up_blocks:
|
337 |
+
sample = torch.utils.checkpoint.checkpoint(
|
338 |
+
create_custom_forward(up_block),
|
339 |
+
sample,
|
340 |
+
latent_embeds,
|
341 |
+
use_reentrant=False,
|
342 |
+
)
|
343 |
+
else:
|
344 |
+
# middle
|
345 |
+
sample = torch.utils.checkpoint.checkpoint(
|
346 |
+
create_custom_forward(self.mid_block), sample, latent_embeds
|
347 |
+
)
|
348 |
+
sample = sample.to(upscale_dtype)
|
349 |
+
|
350 |
+
# up
|
351 |
+
for up_block in self.up_blocks:
|
352 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
353 |
+
else:
|
354 |
+
# middle
|
355 |
+
sample = self.mid_block(sample, latent_embeds)
|
356 |
+
sample = sample.to(upscale_dtype)
|
357 |
+
|
358 |
+
# up
|
359 |
+
for up_block in self.up_blocks:
|
360 |
+
sample = up_block(sample, latent_embeds)
|
361 |
+
|
362 |
+
# post-process
|
363 |
+
if latent_embeds is None:
|
364 |
+
sample = self.conv_norm_out(sample)
|
365 |
+
else:
|
366 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
367 |
+
sample = self.conv_act(sample)
|
368 |
+
sample = self.conv_out(sample)
|
369 |
+
|
370 |
+
return sample
|
371 |
+
|
372 |
+
|
373 |
+
class DiagonalGaussianDistribution(object):
|
374 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
375 |
+
if parameters.ndim == 3:
|
376 |
+
dim = 2 # (B, L, C)
|
377 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
378 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
379 |
+
else:
|
380 |
+
raise NotImplementedError
|
381 |
+
self.parameters = parameters
|
382 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
383 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
384 |
+
self.deterministic = deterministic
|
385 |
+
self.std = torch.exp(0.5 * self.logvar)
|
386 |
+
self.var = torch.exp(self.logvar)
|
387 |
+
if self.deterministic:
|
388 |
+
self.var = self.std = torch.zeros_like(
|
389 |
+
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
|
390 |
+
)
|
391 |
+
|
392 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
393 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
394 |
+
sample = randn_tensor(
|
395 |
+
self.mean.shape,
|
396 |
+
generator=generator,
|
397 |
+
device=self.parameters.device,
|
398 |
+
dtype=self.parameters.dtype,
|
399 |
+
)
|
400 |
+
x = self.mean + self.std * sample
|
401 |
+
return x
|
402 |
+
|
403 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
404 |
+
if self.deterministic:
|
405 |
+
return torch.Tensor([0.0])
|
406 |
+
else:
|
407 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
408 |
+
if other is None:
|
409 |
+
return 0.5 * torch.sum(
|
410 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
411 |
+
dim=reduce_dim,
|
412 |
+
)
|
413 |
+
else:
|
414 |
+
return 0.5 * torch.sum(
|
415 |
+
torch.pow(self.mean - other.mean, 2) / other.var
|
416 |
+
+ self.var / other.var
|
417 |
+
- 1.0
|
418 |
+
- self.logvar
|
419 |
+
+ other.logvar,
|
420 |
+
dim=reduce_dim,
|
421 |
+
)
|
422 |
+
|
423 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
424 |
+
if self.deterministic:
|
425 |
+
return torch.Tensor([0.0])
|
426 |
+
logtwopi = np.log(2.0 * np.pi)
|
427 |
+
return 0.5 * torch.sum(
|
428 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
429 |
+
dim=dims,
|
430 |
+
)
|
431 |
+
|
432 |
+
def mode(self) -> torch.Tensor:
|
433 |
+
return self.mean
|
requirements.txt
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.9.0
|
2 |
+
av==15.0.0
|
3 |
+
certifi==2025.8.3
|
4 |
+
charset-normalizer==3.4.2
|
5 |
+
contourpy==1.3.3
|
6 |
+
cycler==0.12.1
|
7 |
+
decord==0.6.0
|
8 |
+
diffusers==0.34.0
|
9 |
+
einops==0.8.1
|
10 |
+
filelock==3.13.1
|
11 |
+
fonttools==4.59.0
|
12 |
+
fsspec==2024.6.1
|
13 |
+
hf-xet==1.1.5
|
14 |
+
huggingface-hub==0.34.3
|
15 |
+
idna==3.10
|
16 |
+
imageio==2.37.0
|
17 |
+
imageio-ffmpeg==0.6.0
|
18 |
+
importlib_metadata==8.7.0
|
19 |
+
Jinja2==3.1.4
|
20 |
+
kiwisolver==1.4.8
|
21 |
+
loguru==0.7.3
|
22 |
+
MarkupSafe==2.1.5
|
23 |
+
matplotlib==3.10.5
|
24 |
+
mpmath==1.3.0
|
25 |
+
networkx==3.3
|
26 |
+
ninja==1.11.1.4
|
27 |
+
numpy==2.1.2
|
28 |
+
nvidia-ml-py==12.575.51
|
29 |
+
nvidia-nccl-cu12==2.21.5
|
30 |
+
nvidia-nvjitlink-cu12==12.4.127
|
31 |
+
nvidia-nvtx-cu12==12.4.127
|
32 |
+
nvitop==1.5.2
|
33 |
+
opencv-python-headless==4.12.0.88
|
34 |
+
packaging==25.0
|
35 |
+
pandas==2.3.1
|
36 |
+
pillow==11.0.0
|
37 |
+
protobuf==6.31.1
|
38 |
+
psutil==7.0.0
|
39 |
+
pyparsing==3.2.3
|
40 |
+
python-dateutil==2.9.0.post0
|
41 |
+
pytz==2025.2
|
42 |
+
PyYAML==6.0.2
|
43 |
+
regex==2025.7.34
|
44 |
+
requests==2.32.4
|
45 |
+
safetensors==0.5.3
|
46 |
+
sentencepiece==0.2.0
|
47 |
+
setuptools==78.1.1
|
48 |
+
six==1.17.0
|
49 |
+
sympy==1.13.1
|
50 |
+
tokenizers==0.21.4
|
51 |
+
tqdm==4.67.1
|
52 |
+
transformers==4.54.1
|
53 |
+
triton==3.1.0
|
54 |
+
typing_extensions==4.12.2
|
55 |
+
tzdata==2025.2
|
56 |
+
urllib3==2.5.0
|
57 |
+
wheel==0.45.1
|
58 |
+
zipp==3.23.0
|
59 |
+
gradio==5.42.0
|
60 |
+
sageattention==1.0.6
|
scripts/run_sample_batch_4090.sh
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
3 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
4 |
+
export MODEL_BASE="/path/to/models"
|
5 |
+
checkpoint_path="/path/to/ckpts"
|
6 |
+
|
7 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
8 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
9 |
+
|
10 |
+
# disable sp and enable cpu offload
|
11 |
+
export DISABLE_SP=1
|
12 |
+
export CPU_OFFLOAD=1
|
13 |
+
export NUM_GPU=1
|
14 |
+
|
15 |
+
# # enable both sp and cpu offload
|
16 |
+
# export DISABLE_SP=0
|
17 |
+
# export CPU_OFFLOAD=1
|
18 |
+
# export NUM_GPU=8
|
19 |
+
|
20 |
+
torchrun --nnodes=1 --nproc_per_node=${NUM_GPU} --master_port 29605 hymm_sp/sample_batch.py \
|
21 |
+
--image-path "asset/village.png" \
|
22 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
23 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
24 |
+
--ckpt ${checkpoint_path} \
|
25 |
+
--video-size 704 1216 \
|
26 |
+
--cfg-scale 2.0 \
|
27 |
+
--image-start \
|
28 |
+
--action-list w s d a \
|
29 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
30 |
+
--seed 250160 \
|
31 |
+
--infer-steps 50 \
|
32 |
+
--flow-shift-eval-video 5.0 \
|
33 |
+
--cpu-offload \
|
34 |
+
--use-fp8 \
|
35 |
+
--save-path './results/'
|
scripts/run_sample_batch_distill.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
3 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
4 |
+
export MODEL_BASE="weights/stdmodels"
|
5 |
+
checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
|
6 |
+
|
7 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
8 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
9 |
+
|
10 |
+
torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
|
11 |
+
--image-path "asset/village.png" \
|
12 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
13 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
14 |
+
--ckpt ${checkpoint_path} \
|
15 |
+
--video-size 704 1216 \
|
16 |
+
--cfg-scale 1.0 \
|
17 |
+
--image-start \
|
18 |
+
--action-list w s d a \
|
19 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
20 |
+
--seed 250160 \
|
21 |
+
--infer-steps 8 \
|
22 |
+
--use-fp8 \
|
23 |
+
--flow-shift-eval-video 5.0 \
|
24 |
+
--save-path './results_distill/'
|
scripts/run_sample_batch_sp.sh
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
JOBS_DIR=$(dirname $(dirname "$0"))
|
3 |
+
export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
|
4 |
+
export MODEL_BASE="weights/stdmodels"
|
5 |
+
checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states.pt"
|
6 |
+
|
7 |
+
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
|
8 |
+
modelname='Tencent_hunyuanGameCraft_720P'
|
9 |
+
|
10 |
+
torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
|
11 |
+
--image-path "asset/village.png" \
|
12 |
+
--prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
|
13 |
+
--add-pos-prompt "Realistic, High-quality." \
|
14 |
+
--add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
|
15 |
+
--ckpt ${checkpoint_path} \
|
16 |
+
--video-size 704 1216 \
|
17 |
+
--cfg-scale 2.0 \
|
18 |
+
--image-start \
|
19 |
+
--action-list w s d a \
|
20 |
+
--action-speed-list 0.2 0.2 0.2 0.2 \
|
21 |
+
--seed 250160 \
|
22 |
+
--infer-steps 50 \
|
23 |
+
--flow-shift-eval-video 5.0 \
|
24 |
+
--save-path './results/'
|