jbilcke-hf HF Staff commited on
Commit
01c0e76
·
0 Parent(s):

Initial commit with LFS-tracked binary files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .claude/settings.local.json +19 -0
  2. .gitattributes +26 -0
  3. .gitignore +1 -0
  4. CLAUDE.md +137 -0
  5. LICENSE +77 -0
  6. Notice.txt +100 -0
  7. README.md +296 -0
  8. app.py +360 -0
  9. asset/method.png +3 -0
  10. asset/teaser.png +3 -0
  11. asset/village.png +3 -0
  12. docs_for_ai_coding_bots/.DS_Store +0 -0
  13. docs_for_ai_coding_bots/huggingface_hub/Downloading-model-from-hub.md +174 -0
  14. docs_for_ai_coding_bots/huggingface_hub/Using-the-cache-in-hf-hub-library.md +531 -0
  15. hymm_sp/__init__.py +0 -0
  16. hymm_sp/config.py +160 -0
  17. hymm_sp/constants.py +58 -0
  18. hymm_sp/data_kits/data_tools.py +115 -0
  19. hymm_sp/data_kits/video_dataset.py +259 -0
  20. hymm_sp/diffusion/__init__.py +30 -0
  21. hymm_sp/diffusion/pipelines/__init__.py +5 -0
  22. hymm_sp/diffusion/pipelines/pipeline_hunyuan_video_game.py +1152 -0
  23. hymm_sp/diffusion/schedulers/__init__.py +2 -0
  24. hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py +240 -0
  25. hymm_sp/helpers.py +194 -0
  26. hymm_sp/inference.py +201 -0
  27. hymm_sp/modules/__init__.py +38 -0
  28. hymm_sp/modules/activation_layers.py +23 -0
  29. hymm_sp/modules/attn_layers.py +437 -0
  30. hymm_sp/modules/cameranet.py +248 -0
  31. hymm_sp/modules/embed_layers.py +146 -0
  32. hymm_sp/modules/fp8_optimization.py +246 -0
  33. hymm_sp/modules/mlp_layers.py +97 -0
  34. hymm_sp/modules/models.py +697 -0
  35. hymm_sp/modules/modulate_layers.py +76 -0
  36. hymm_sp/modules/norm_layers.py +77 -0
  37. hymm_sp/modules/parallel_states.py +381 -0
  38. hymm_sp/modules/posemb_layers.py +112 -0
  39. hymm_sp/modules/token_refiner.py +265 -0
  40. hymm_sp/sample_batch.py +298 -0
  41. hymm_sp/sample_inference.py +716 -0
  42. hymm_sp/text_encoder/__init__.py +310 -0
  43. hymm_sp/vae/__init__.py +79 -0
  44. hymm_sp/vae/autoencoder_kl_causal_3d.py +781 -0
  45. hymm_sp/vae/unet_causal_3d_blocks.py +900 -0
  46. hymm_sp/vae/vae.py +433 -0
  47. requirements.txt +60 -0
  48. scripts/run_sample_batch_4090.sh +35 -0
  49. scripts/run_sample_batch_distill.sh +24 -0
  50. scripts/run_sample_batch_sp.sh +24 -0
.claude/settings.local.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(git remote set-url:*)",
5
+ "Bash(git lfs track:*)",
6
+ "Bash(git add:*)",
7
+ "Bash(git commit:*)",
8
+ "Bash(git push:*)",
9
+ "Bash(git rm:*)",
10
+ "Bash(git lfs:*)",
11
+ "Bash(git gc:*)",
12
+ "Bash(GIT_TRACE=1 git push origin main -f)",
13
+ "Bash(git rev-list:*)",
14
+ "Bash(git checkout:*)"
15
+ ],
16
+ "deny": [],
17
+ "ask": []
18
+ }
19
+ }
.gitattributes ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.npy filter=lfs diff=lfs merge=lfs -text
2
+ *.h5 filter=lfs diff=lfs merge=lfs -text
3
+ *.tar filter=lfs diff=lfs merge=lfs -text
4
+ *.tgz filter=lfs diff=lfs merge=lfs -text
5
+ *.avi filter=lfs diff=lfs merge=lfs -text
6
+ *.pt filter=lfs diff=lfs merge=lfs -text
7
+ *.jpg filter=lfs diff=lfs merge=lfs -text
8
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
9
+ *.gif filter=lfs diff=lfs merge=lfs -text
10
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
11
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
12
+ *.hdf5 filter=lfs diff=lfs merge=lfs -text
13
+ *.png filter=lfs diff=lfs merge=lfs -text
14
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
15
+ *.mov filter=lfs diff=lfs merge=lfs -text
16
+ *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *.bin filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.npz filter=lfs diff=lfs merge=lfs -text
21
+ *.bmp filter=lfs diff=lfs merge=lfs -text
22
+ *.zip filter=lfs diff=lfs merge=lfs -text
23
+ *.tar.gz filter=lfs diff=lfs merge=lfs -text
24
+ *.webm filter=lfs diff=lfs merge=lfs -text
25
+ asset/teaser.png filter=lfs diff=lfs merge=lfs -text
26
+ asset/*.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .DS_Store
CLAUDE.md ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ Hunyuan-GameCraft is a high-dynamic interactive game video generation system that creates gameplay videos with controllable camera movements and actions. The system uses diffusion models and action-controlled generation to synthesize realistic game footage from reference images and keyboard/mouse input controls.
8
+
9
+ ## Key Commands
10
+
11
+ ### Installation
12
+ ```bash
13
+ # Create and activate conda environment
14
+ conda create -n HYGameCraft python==3.10
15
+ conda activate HYGameCraft
16
+
17
+ # Install PyTorch and dependencies
18
+ conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
19
+
20
+ # Install requirements
21
+ python -m pip install -r requirements.txt
22
+
23
+ # Install flash attention (optional, for acceleration)
24
+ python -m pip install ninja
25
+ python -m pip install git+https://github.com/Dao-AILab/[email protected]
26
+ ```
27
+
28
+ ### Download Models
29
+ ```bash
30
+ cd weights
31
+ huggingface-cli download tencent/Hunyuan-GameCraft-1.0 --local-dir ./
32
+ ```
33
+
34
+ ### Run Inference
35
+
36
+ **Multi-GPU (8 GPUs) - Standard Model:**
37
+ ```bash
38
+ torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
39
+ --image-path "asset/village.png" \
40
+ --prompt "YOUR_PROMPT" \
41
+ --ckpt weights/gamecraft_models/mp_rank_00_model_states.pt \
42
+ --video-size 704 1216 \
43
+ --cfg-scale 2.0 \
44
+ --image-start \
45
+ --action-list w s d a \
46
+ --action-speed-list 0.2 0.2 0.2 0.2 \
47
+ --seed 250160 \
48
+ --infer-steps 50 \
49
+ --save-path './results/'
50
+ ```
51
+
52
+ **Single GPU with Low VRAM (24GB minimum):**
53
+ ```bash
54
+ export DISABLE_SP=1
55
+ export CPU_OFFLOAD=1
56
+ torchrun --nnodes=1 --nproc_per_node=1 --master_port 29605 hymm_sp/sample_batch.py \
57
+ --ckpt weights/gamecraft_models/mp_rank_00_model_states.pt \
58
+ --cpu-offload \
59
+ --use-fp8 \
60
+ [other parameters...]
61
+ ```
62
+
63
+ **Distilled Model (faster, 8 inference steps):**
64
+ ```bash
65
+ torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
66
+ --ckpt weights/gamecraft_models/mp_rank_00_model_states_distill.pt \
67
+ --cfg-scale 1.0 \
68
+ --infer-steps 8 \
69
+ --use-fp8 \
70
+ [other parameters...]
71
+ ```
72
+
73
+ ## Architecture Overview
74
+
75
+ ### Core Components
76
+
77
+ 1. **Main Entry Points**
78
+ - `hymm_sp/sample_batch.py`: Main script for batch video generation with distributed processing
79
+ - `hymm_sp/sample_inference.py`: Core inference logic and model sampling
80
+ - `hymm_sp/config.py`: Configuration parsing and argument handling
81
+
82
+ 2. **Model Architecture (`hymm_sp/modules/`)**
83
+ - `models.py`: Core diffusion model implementation
84
+ - `cameranet.py`: Camera control and action encoding for game interactions
85
+ - `token_refiner.py`: Text token refinement for prompt conditioning
86
+ - `parallel_states.py`: Distributed training/inference state management
87
+ - `fp8_optimization.py`: FP8 quantization for memory/speed optimization
88
+
89
+ 3. **VAE Module (`hymm_sp/vae/`)**
90
+ - `autoencoder_kl_causal_3d.py`: 3D causal VAE for video encoding/decoding
91
+ - Handles latent space conversion for video frames
92
+
93
+ 4. **Diffusion Pipeline (`hymm_sp/diffusion/`)**
94
+ - `pipeline_hunyuan_video_game.py`: Custom pipeline for game video generation
95
+ - `scheduling_flow_match_discrete.py`: Flow matching scheduler for denoising
96
+
97
+ 5. **Data Processing (`hymm_sp/data_kits/`)**
98
+ - `video_dataset.py`: Dataset handling for video inputs
99
+ - `data_tools.py`: Video saving and processing utilities
100
+
101
+ ### Key Features
102
+
103
+ - **Action Control**: Maps keyboard inputs (w/a/s/d) to continuous camera space for smooth transitions
104
+ - **Hybrid History Conditioning**: Extends video sequences autoregressively while preserving scene context
105
+ - **Model Distillation**: Accelerated inference model (8 steps vs 50 steps)
106
+ - **Memory Optimization**: FP8 quantization, CPU offloading, and SageAttention support
107
+ - **Distributed Processing**: Multi-GPU support with sequence parallelism
108
+
109
+ ### Important Parameters
110
+
111
+ - `--action-list`: Sequence of keyboard actions (w/a/s/d)
112
+ - `--action-speed-list`: Movement speed for each action (0.0-3.0)
113
+ - `--video-size`: Output resolution (height width)
114
+ - `--cfg-scale`: Classifier-free guidance scale (1.0 for distilled, 2.0 for standard)
115
+ - `--infer-steps`: Denoising steps (8 for distilled, 50 for standard)
116
+ - `--use-fp8`: Enable FP8 optimization for memory reduction
117
+ - `--cpu-offload`: Offload model to CPU for low VRAM scenarios
118
+
119
+ ### Model Weights Structure
120
+ ```
121
+ weights/
122
+ ├── gamecraft_models/
123
+ │ ├── mp_rank_00_model_states.pt # Standard model
124
+ │ └── mp_rank_00_model_states_distill.pt # Distilled model
125
+ └── stdmodels/
126
+ ├── vae_3d/ # 3D VAE model
127
+ ├── llava-llama-3-8b-v1_1-transformers/ # Text encoder
128
+ └── openai_clip-vit-large-patch14/ # CLIP encoder
129
+ ```
130
+
131
+ ## Development Notes
132
+
133
+ - Environment variable `MODEL_BASE` should point to `weights/stdmodels`
134
+ - Use `export DISABLE_SP=1` and `export CPU_OFFLOAD=1` for single GPU inference
135
+ - Minimum GPU memory: 24GB (very slow), Recommended: 80GB per GPU
136
+ - Action length determines video duration (1 action = 33 frames at 25 FPS)
137
+ - SageAttention can be installed for additional acceleration
LICENSE ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hunyuan-GameCraft Release Date: August 14, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan-GameCraft released at [https://github.com/Tencent-Hunyuan/Hunyuan-GameCraft-1.0].
16
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ 5. RULES OF USE.
32
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35
+ 6. INTELLECTUAL PROPERTY.
36
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 8. SURVIVAL AND TERMINATION.
45
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47
+ 9. GOVERNING LAW AND JURISDICTION.
48
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50
+
51
+ EXHIBIT A
52
+ ACCEPTABLE USE POLICY
53
+
54
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
55
+ Last modified: November 5, 2024
56
+
57
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58
+ 1. Outside the Territory;
59
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60
+ 3. To harm Yourself or others;
61
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66
+ 9. To intentionally defame, disparage or otherwise harass others;
67
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70
+ 13. To impersonate another individual without consent, authorization, or legal right;
71
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76
+ 19. For military purposes;
77
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
Notice.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Usage and Legal Notices:
2
+
3
+ Tencent is pleased to support the open source community by making Tencent Hunyuan-GameCraft available.
4
+
5
+ Copyright (C) 2025 Tencent. All rights reserved. The below softwares in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
6
+ Tencent Hunyuan-GameCraft is licensed under Tencent Hunyuan Community License Agreement, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent Hunyuan-GameCraft does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
7
+
8
+
9
+ Other dependencies and licenses:
10
+
11
+
12
+ Open Source Software Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT and Other Licenses of the Third-Party Components therein:
13
+ The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
14
+ --------------------------------------------------------------------
15
+ 1. HunyuanVideo
16
+ Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
17
+
18
+
19
+ Terms of the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT:
20
+ --------------------------------------------------------------------
21
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
22
+ Tencent HunyuanVideo Release Date: December 3, 2024
23
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
24
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
25
+ 1. DEFINITIONS.
26
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
27
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
28
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
29
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
30
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
31
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
32
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
33
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
34
+ i. “Tencent,” “We” or “Us” shall mean THL A29 Limited.
35
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent HunyuanVideo released at [https://github.com/Tencent/HunyuanVideo].
36
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
37
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
38
+ m. “Third Party�� or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
39
+ n. “including” shall mean including but not limited to.
40
+ 2. GRANT OF RIGHTS.
41
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
42
+ 3. DISTRIBUTION.
43
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
44
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
45
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
46
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
47
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
48
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
49
+ 4. ADDITIONAL COMMERCIAL TERMS.
50
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
51
+ 5. RULES OF USE.
52
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
53
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
54
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
55
+ 6. INTELLECTUAL PROPERTY.
56
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
57
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
58
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
59
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
60
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
61
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
62
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
63
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
64
+ 8. SURVIVAL AND TERMINATION.
65
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
66
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
67
+ 9. GOVERNING LAW AND JURISDICTION.
68
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
69
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
70
+
71
+ EXHIBIT A
72
+ ACCEPTABLE USE POLICY
73
+
74
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
75
+ Last modified: November 5, 2024
76
+
77
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
78
+ 1. Outside the Territory;
79
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
80
+ 3. To harm Yourself or others;
81
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
82
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
83
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
84
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
85
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
86
+ 9. To intentionally defame, disparage or otherwise harass others;
87
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
88
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
89
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
90
+ 13. To impersonate another individual without consent, authorization, or legal right;
91
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
92
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
93
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
94
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
95
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
96
+ 19. For military purposes;
97
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
98
+
99
+ For the license of other third party components, please refer to the following URL:
100
+ https://github.com/Tencent-Hunyuan/HunyuanVideo/blob/ff2dd59277b3177785d8279d4170968afa3b1d55/Notice
README.md ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Hunyuan-GameCraft
3
+ emoji: 🎮
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.42.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ short_description: Interactive Game Video Generation
12
+ ---
13
+
14
+ <!-- ## **Hunyuan-GameCraft** -->
15
+
16
+ <!-- <p align="center">
17
+ <img src="assets/material/logo.png" height=100>
18
+ </p> -->
19
+
20
+ # **Hunyuan-GameCraft** 🎮
21
+
22
+ <div align="center">
23
+ <a href="https://github.com/Tencent-Hunyuan/Hunyuan-GameCraft-1.0"><img src="https://img.shields.io/static/v1?label=Code&message=Github&color=blue"></a> &ensp;
24
+ <a href="https://hunyuan-gamecraft.github.io/"><img src="https://img.shields.io/static/v1?label=Project%20Page&message=Web&color=green"></a> &ensp;
25
+ <a href="https://arxiv.org/abs/2506.17201"><img src="https://img.shields.io/badge/ArXiv-2506.17201-red"></a> &ensp;
26
+ <a href="https://huggingface.co/tencent/Hunyuan-GameCraft-1.0"><img src="https://img.shields.io/static/v1?label=Huggingface&message=Hunyuan-GameCraft-1.0&color=yellow"></a> &ensp;
27
+ </div>
28
+
29
+ ![image](asset/teaser.png)
30
+
31
+ > [**Hunyuan-GameCraft: High-dynamic Interactive Game Video Generation with Hybrid History Condition**](https://arxiv.org/abs/2506.17201) <be>
32
+
33
+
34
+
35
+ ## 🔥🔥🔥 News!!
36
+ * Aug 14, 2025: 👋 We release the inference code and model weights of Hunyuan-GameCraft. [Download](weights/README.md).
37
+
38
+
39
+ ## 📑 Open-source Plan
40
+
41
+ - Hunyuan-GameCraft
42
+ - [x] Inference
43
+ - [x] Checkpoints
44
+ - [ ] Gradio & Huggingface Demo
45
+
46
+ ## Contents
47
+ - [**Hunyuan-GameCraft** 🌅](#Hunyuan-GameCraft-)
48
+ - [🔥🔥🔥 News!!](#-news)
49
+ - [📑 Open-source Plan](#-open-source-plan)
50
+ - [Contents](#contents)
51
+ - [**Abstract**](#abstract)
52
+ - [**Overall Architecture**](#-overall-architecture)
53
+ - [📜 Requirements](#-requirements)
54
+ - [🛠️ Dependencies and Installation](#️-dependencies-and-installation)
55
+ - [Installation Guide for Linux](#installation-guide-for-linux)
56
+ - [🧱 Download Pretrained Models](#-download-pretrained-models)
57
+ - [🚀 Parallel Inference on Multiple GPUs](#-parallel-inference-on-multiple-gpus)
58
+ - [🔑 Single-gpu Inference](#-single-gpu-inference)
59
+ - [Run with very low VRAM](#run-with-very-low-vram)
60
+ - [Run a Gradio Server](#run-a-gradio-server)
61
+ - [🔗 BibTeX](#-bibtex)
62
+ - [Acknowledgements](#acknowledgements)
63
+ ---
64
+
65
+ ## **Abstract**
66
+
67
+ Recent advances in diffusion-based and controllable video generation have enabled high-quality and temporally coherent video synthesis, laying the groundwork for immersive interactive gaming experiences. However, current methods face limitations in **dynamics**, **physically realistic**, **long-term consistency**, and **efficiency**, which limit the ability to create various gameplay videos. To address these gaps, we introduce Hunyuan-GameCraft, a novel framework for high-dynamic interactive video generation in game environments. To achieve fine-grained action control, we unify standard keyboard and mouse inputs into a **shared camera representation space**, facilitating smooth interpolation between various camera and movement operations. Then we propose a **hybrid history-conditioned training strategy** that extends video sequences autoregressively while preserving game scene information. Additionally, to enhance inference efficiency and playability, we achieve **model distillation** to reduce computational overhead while maintaining consistency across long temporal sequences, making it suitable for real-time deployment in complex interactive environments. The model is trained on a large-scale dataset comprising over one million gameplay recordings across over 100 AAA games, ensuring broad coverage and diversity, then fine-tuned on a carefully annotated synthetic dataset to enhance precision and control. The curated game scene data significantly improves the visual fidelity, realism and action controllability. Extensive experiments demonstrate that Hunyuan-GameCraft significantly outperforms existing models, advancing the realism and playability of interactive game video generation.
68
+
69
+ ## **Overall Architecture**
70
+
71
+ ![image](asset/method.png)
72
+
73
+ Given a reference image and the corresponding prompt, the keyboard or mouse signal, we transform these options to the continuous camera space. Then we design a light-weight action encoder to encode the input camera trajectory. The action and image features are added after patchify. For long video extension, we design a variable mask indicator, where 1 and 0 indicate history frames and predicted frames, respectively.
74
+
75
+
76
+ ## 📜 Requirements
77
+
78
+ * An NVIDIA GPU with CUDA support is required.
79
+ * The model is tested on a machine with 8*H20/H800GPUs.
80
+ * **Minimum**: The minimum GPU memory required is 24GB but very slow.
81
+ * **Recommended**: We recommend using a GPU with 80GB of memory for better generation quality.
82
+ * Tested operating system: Linux
83
+
84
+
85
+ ## 🛠️ Dependencies and Installation
86
+
87
+ Begin by cloning the repository:
88
+ ```shell
89
+ git clone https://github.com/Tencent-Hunyuan/Hunyuan-GameCraft-1.0.git
90
+ cd Hunyuan-GameCraft-1.0
91
+ ```
92
+
93
+ ### Installation Guide for Linux
94
+
95
+ We recommend CUDA versions 12.4 for the manual installation.
96
+
97
+ Conda's installation instructions are available [here](https://docs.anaconda.com/free/miniconda/index.html).
98
+
99
+ ```shell
100
+ # 1. Create conda environment
101
+ conda create -n HYGameCraft python==3.10
102
+
103
+ # 2. Activate the environment
104
+ conda activate HYGameCraft
105
+
106
+ # 3. Install PyTorch and other dependencies using conda
107
+ conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
108
+
109
+ # 4. Install pip dependencies
110
+ python -m pip install -r requirements.txt
111
+ # 5. Install flash attention v2 for acceleration (requires CUDA 11.8 or above)
112
+ python -m pip install ninja
113
+ python -m pip install git+https://github.com/Dao-AILab/[email protected]
114
+ ```
115
+
116
+ Additionally, you can also use HunyuanVideo Docker image. Use the following command to pull and run the docker image.
117
+
118
+ ```shell
119
+ # For CUDA 12.4 (updated to avoid float point exception)
120
+ docker pull hunyuanvideo/hunyuanvideo:cuda_12
121
+ docker run -itd --gpus all --init --net=host --uts=host --ipc=host --name hunyuanvideo --security-opt=seccomp=unconfined --ulimit=stack=67108864 --ulimit=memlock=-1 --privileged hunyuanvideo/hunyuanvideo:cuda_12
122
+ pip install diffusers==0.34.0 transformers==4.54.1
123
+
124
+ ```
125
+
126
+
127
+ ## 🧱 Download Pretrained Models
128
+
129
+ The details of download pretrained models are shown [here](weights/README.md).
130
+
131
+ ## 🚀 Parallel Inference on Multiple GPUs
132
+
133
+ For example, to generate a video using 8 GPUs, you can use the following command, where `--action-list w s d a` simulate keyboard manipulation signals to help you generate a video of the corresponding content. `--action-speed-list 0.2 0.2 0.2 0.2` represents the displacement distance and can be replaced with any value between 0 and 3.
134
+
135
+ You can try any combination and any length of the action list (one action per 33 frames, 25FPS) to generate a long video, and make sure the length of `--action-speed-list` must be the same as `--action-list`. It should be noticed that the inference time is linearly related to the action length:
136
+
137
+ ```bash
138
+ #!/bin/bash
139
+ JOBS_DIR=$(dirname $(dirname "$0"))
140
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
141
+ export MODEL_BASE="weights/stdmodels"
142
+ checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states.pt"
143
+
144
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
145
+ modelname='Tencent_hunyuanGameCraft_720P'
146
+
147
+ torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
148
+ --image-path "asset/village.png" \
149
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
150
+ --add-pos-prompt "Realistic, High-quality." \
151
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
152
+ --ckpt ${checkpoint_path} \
153
+ --video-size 704 1216 \
154
+ --cfg-scale 2.0 \
155
+ --image-start \
156
+ --action-list w s d a \
157
+ --action-speed-list 0.2 0.2 0.2 0.2 \
158
+ --seed 250160 \
159
+ --infer-steps 50 \
160
+ --flow-shift-eval-video 5.0 \
161
+ --save-path './results/'
162
+
163
+ ```
164
+
165
+
166
+ Additionally, we support FP8 optimization and [SageAttn](https://github.com/thu-ml/SageAttention). To enable FP8, simply add the `--use-fp8` to your command.
167
+ And install SageAttention with:
168
+ ```bash
169
+ git clone https://github.com/thu-ml/SageAttention.git
170
+ cd SageAttention
171
+ python setup.py install # or pip install -e .
172
+ ```
173
+
174
+ We also provide an accelerated model, you can use the following command:
175
+ ```bash
176
+ #!/bin/bash
177
+ JOBS_DIR=$(dirname $(dirname "$0"))
178
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
179
+ export MODEL_BASE="weights/stdmodels"
180
+ checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
181
+
182
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
183
+ modelname='Tencent_hunyuanGameCraft_720P'
184
+
185
+ torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
186
+ --image-path "asset/village.png" \
187
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
188
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
189
+ --ckpt ${checkpoint_path} \
190
+ --video-size 704 1216 \
191
+ --cfg-scale 1.0 \
192
+ --image-start \
193
+ --action-list w s d a \
194
+ --action-speed-list 0.2 0.2 0.2 0.2 \
195
+ --seed 250160 \
196
+ --infer-steps 8 \
197
+ --use-fp8 \
198
+ --flow-shift-eval-video 5.0 \
199
+ --save-path './results_distill/'
200
+ ```
201
+
202
+
203
+ ## 🔑 Single-gpu with Low-VRAM Inference
204
+
205
+ For example, to generate a video with 1 GPU with Low-VRAM (minimum GPU memory required is 24GB for 704px1216p but very slow), you can use the following command:
206
+
207
+ ```bash
208
+ #!/bin/bash
209
+ JOBS_DIR=$(dirname $(dirname "$0"))
210
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
211
+ export MODEL_BASE="weights/stdmodels"
212
+ checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states.pt"
213
+
214
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
215
+ modelname='Tencent_hunyuanGameCraft_720P'
216
+
217
+ # disable sp and cpu offload
218
+ export DISABLE_SP=1
219
+ export CPU_OFFLOAD=1
220
+
221
+ torchrun --nnodes=1 --nproc_per_node=1 --master_port 29605 hymm_sp/sample_batch.py \
222
+ --image-path "asset/village.png" \
223
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
224
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
225
+ --ckpt ${checkpoint_path} \
226
+ --video-size 704 1216 \
227
+ --cfg-scale 2.0 \
228
+ --image-start \
229
+ --action-list w a d s \
230
+ --action-speed-list 0.2 0.2 0.2 0.2 \
231
+ --seed 250160 \
232
+ --sample-n-frames 33 \
233
+ --infer-steps 50 \
234
+ --flow-shift-eval-video 5.0 \
235
+ --cpu-offload \
236
+ --use-fp8 \
237
+ --save-path './results_poor/'
238
+
239
+ ```
240
+
241
+ As for using the accelerated model, you can use the following command:
242
+
243
+ ```bash
244
+ #!/bin/bash
245
+ JOBS_DIR=$(dirname $(dirname "$0"))
246
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
247
+ export MODEL_BASE="weights/stdmodels"
248
+ checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
249
+
250
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
251
+ modelname='Tencent_hunyuanGameCraft_720P'
252
+
253
+ # disable sp and cpu offload
254
+ export DISABLE_SP=1
255
+ export CPU_OFFLOAD=1
256
+
257
+ torchrun --nnodes=1 --nproc_per_node=1 --master_port 29605 hymm_sp/sample_batch.py \
258
+ --image-path "asset/village.png" \
259
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
260
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
261
+ --ckpt ${checkpoint_path} \
262
+ --video-size 704 1216 \
263
+ --cfg-scale 1.0 \
264
+ --image-start \
265
+ --action-list w a d s \
266
+ --action-speed-list 0.2 0.2 0.2 0.2 \
267
+ --seed 250160 \
268
+ --sample-n-frames 33 \
269
+ --infer-steps 8 \
270
+ --flow-shift-eval-video 5.0 \
271
+ --cpu-offload \
272
+ --use-fp8 \
273
+ --save-path './results_distill_poor/'
274
+ ```
275
+
276
+
277
+
278
+ ## 🔗 BibTeX
279
+
280
+ If you find [Hunyuan-GameCraft](https://arxiv.org/abs/2506.17201) useful for your research and applications, please cite using this BibTeX:
281
+
282
+ ```BibTeX
283
+ @misc{li2025hunyuangamecrafthighdynamicinteractivegame,
284
+ title={Hunyuan-GameCraft: High-dynamic Interactive Game Video Generation with Hybrid History Condition},
285
+ author={Jiaqi Li and Junshu Tang and Zhiyong Xu and Longhuang Wu and Yuan Zhou and Shuai Shao and Tianbao Yu and Zhiguo Cao and Qinglin Lu},
286
+ year={2025},
287
+ eprint={2506.17201},
288
+ archivePrefix={arXiv},
289
+ primaryClass={cs.CV},
290
+ url={https://arxiv.org/abs/2506.17201},
291
+ }
292
+ ```
293
+
294
+ ## Acknowledgements
295
+
296
+ We would like to thank the contributors to the [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), [HunyuanVideo-Avatar](https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar),[SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium), [FLUX](https://github.com/black-forest-labs/flux), [Llama](https://github.com/meta-llama/llama), [LLaVA](https://github.com/haotian-liu/LLaVA), [Xtuner](https://github.com/InternLM/xtuner), [diffusers](https://github.com/huggingface/diffusers) and [HuggingFace](https://huggingface.co) repositories, for their open research and exploration.
app.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import numpy as np
5
+ import random
6
+ from pathlib import Path
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+ from loguru import logger
10
+ from huggingface_hub import hf_hub_download
11
+ import tempfile
12
+
13
+ from hymm_sp.sample_inference import HunyuanVideoSampler
14
+ from hymm_sp.data_kits.data_tools import save_videos_grid
15
+ from hymm_sp.config import parse_args
16
+ import argparse
17
+
18
+ os.environ["MODEL_BASE"] = "weights/stdmodels"
19
+ os.environ["DISABLE_SP"] = "1"
20
+ os.environ["CPU_OFFLOAD"] = "1"
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+
24
+ class CropResize:
25
+ def __init__(self, size=(704, 1216)):
26
+ self.target_h, self.target_w = size
27
+
28
+ def __call__(self, img):
29
+ w, h = img.size
30
+ scale = max(
31
+ self.target_w / w,
32
+ self.target_h / h
33
+ )
34
+ new_size = (int(h * scale), int(w * scale))
35
+ resize_transform = transforms.Resize(
36
+ new_size,
37
+ interpolation=transforms.InterpolationMode.BILINEAR
38
+ )
39
+ resized_img = resize_transform(img)
40
+ crop_transform = transforms.CenterCrop((self.target_h, self.target_w))
41
+ return crop_transform(resized_img)
42
+
43
+ def create_args():
44
+ args = argparse.Namespace()
45
+ args.ckpt = "weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
46
+ args.video_size = [704, 1216]
47
+ args.cfg_scale = 1.0
48
+ args.image_start = True
49
+ args.seed = None
50
+ args.infer_steps = 8
51
+ args.use_fp8 = True
52
+ args.flow_shift_eval_video = 5.0
53
+ args.sample_n_frames = 33
54
+ args.num_images = 1
55
+ args.use_linear_quadratic_schedule = False
56
+ args.linear_schedule_end = 0.25
57
+ args.use_deepcache = False
58
+ args.cpu_offload = True
59
+ args.use_sage = True
60
+ args.save_path = './results/'
61
+ args.save_path_suffix = ''
62
+ args.add_pos_prompt = "Realistic, High-quality."
63
+ args.add_neg_prompt = "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border."
64
+ return args
65
+
66
+ logger.info("Initializing Hunyuan-GameCraft model...")
67
+
68
+ if not os.path.exists("weights/gamecraft_models/mp_rank_00_model_states_distill.pt"):
69
+ logger.info("Downloading model weights from Hugging Face...")
70
+ os.makedirs("weights/gamecraft_models", exist_ok=True)
71
+ hf_hub_download(
72
+ repo_id="tencent/Hunyuan-GameCraft-1.0",
73
+ filename="gamecraft_models/mp_rank_00_model_states_distill.pt",
74
+ local_dir="weights/",
75
+ local_dir_use_symlinks=False
76
+ )
77
+
78
+ args = create_args()
79
+ hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(
80
+ args.ckpt,
81
+ args=args,
82
+ device=torch.device("cpu")
83
+ )
84
+ args = hunyuan_video_sampler.args
85
+
86
+ if args.cpu_offload:
87
+ from diffusers.hooks import apply_group_offloading
88
+ onload_device = torch.device("cuda")
89
+ apply_group_offloading(
90
+ hunyuan_video_sampler.pipeline.transformer,
91
+ onload_device=onload_device,
92
+ offload_type="block_level",
93
+ num_blocks_per_group=1
94
+ )
95
+ logger.info("Enabled CPU offloading for transformer blocks")
96
+
97
+ logger.info("Model loaded successfully!")
98
+
99
+ def generate_video(
100
+ input_image,
101
+ prompt,
102
+ action_sequence,
103
+ action_speeds,
104
+ negative_prompt,
105
+ seed,
106
+ cfg_scale,
107
+ num_inference_steps,
108
+ progress=gr.Progress(track_tqdm=True)
109
+ ):
110
+ try:
111
+ progress(0, desc="Initializing...")
112
+
113
+ if input_image is None:
114
+ return None, "Please upload an image first!"
115
+
116
+ action_list = action_sequence.lower().replace(" ", "").split(",") if action_sequence else ["w"]
117
+ speed_list = [float(s.strip()) for s in action_speeds.split(",")] if action_speeds else [0.2]
118
+
119
+ if len(speed_list) != len(action_list):
120
+ if len(speed_list) == 1:
121
+ speed_list = speed_list * len(action_list)
122
+ else:
123
+ return None, f"Error: Number of speeds ({len(speed_list)}) must match number of actions ({len(action_list)})"
124
+
125
+ for action in action_list:
126
+ if action not in ['w', 'a', 's', 'd']:
127
+ return None, f"Error: Invalid action '{action}'. Use only w, a, s, d"
128
+
129
+ for speed in speed_list:
130
+ if not 0.0 <= speed <= 3.0:
131
+ return None, f"Error: Speed {speed} out of range. Use values between 0.0 and 3.0"
132
+
133
+ progress(0.1, desc="Processing image...")
134
+
135
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_file:
136
+ input_image.save(tmp_file.name)
137
+ image_path = tmp_file.name
138
+
139
+ closest_size = (704, 1216)
140
+ ref_image_transform = transforms.Compose([
141
+ CropResize(closest_size),
142
+ transforms.CenterCrop(closest_size),
143
+ transforms.ToTensor(),
144
+ transforms.Normalize([0.5], [0.5])
145
+ ])
146
+
147
+ raw_ref_image = Image.open(image_path).convert('RGB')
148
+ ref_image_pixel_values = ref_image_transform(raw_ref_image)
149
+ ref_image_pixel_values = ref_image_pixel_values.unsqueeze(0).unsqueeze(2).to(device)
150
+
151
+ progress(0.2, desc="Encoding image...")
152
+
153
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
154
+ if args.cpu_offload:
155
+ hunyuan_video_sampler.vae.quant_conv.to('cuda')
156
+ hunyuan_video_sampler.vae.encoder.to('cuda')
157
+
158
+ hunyuan_video_sampler.pipeline.vae.enable_tiling()
159
+
160
+ raw_last_latents = hunyuan_video_sampler.vae.encode(
161
+ ref_image_pixel_values
162
+ ).latent_dist.sample().to(dtype=torch.float16)
163
+ raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
164
+ raw_ref_latents = raw_last_latents.clone()
165
+
166
+ hunyuan_video_sampler.pipeline.vae.disable_tiling()
167
+ if args.cpu_offload:
168
+ hunyuan_video_sampler.vae.quant_conv.to('cpu')
169
+ hunyuan_video_sampler.vae.encoder.to('cpu')
170
+
171
+ ref_images = [raw_ref_image]
172
+ last_latents = raw_last_latents
173
+ ref_latents = raw_ref_latents
174
+
175
+ progress(0.3, desc="Starting video generation...")
176
+
177
+ if seed is None or seed == -1:
178
+ seed = random.randint(0, 1_000_000)
179
+
180
+ all_samples = []
181
+
182
+ for idx, (action_id, action_speed) in enumerate(zip(action_list, speed_list)):
183
+ is_image = (idx == 0)
184
+
185
+ progress(0.3 + (0.6 * idx / len(action_list)),
186
+ desc=f"Generating segment {idx+1}/{len(action_list)} (action: {action_id})")
187
+
188
+ outputs = hunyuan_video_sampler.predict(
189
+ prompt=prompt,
190
+ action_id=action_id,
191
+ action_speed=action_speed,
192
+ is_image=is_image,
193
+ size=(704, 1216),
194
+ seed=seed,
195
+ last_latents=last_latents,
196
+ ref_latents=ref_latents,
197
+ video_length=args.sample_n_frames,
198
+ guidance_scale=cfg_scale,
199
+ num_images_per_prompt=1,
200
+ negative_prompt=negative_prompt,
201
+ infer_steps=num_inference_steps,
202
+ flow_shift=args.flow_shift_eval_video,
203
+ use_linear_quadratic_schedule=args.use_linear_quadratic_schedule,
204
+ linear_schedule_end=args.linear_schedule_end,
205
+ use_deepcache=args.use_deepcache,
206
+ cpu_offload=args.cpu_offload,
207
+ ref_images=ref_images,
208
+ output_dir=None,
209
+ return_latents=True,
210
+ use_sage=args.use_sage,
211
+ )
212
+
213
+ ref_latents = outputs["ref_latents"]
214
+ last_latents = outputs["last_latents"]
215
+
216
+ sub_samples = outputs['samples'][0]
217
+ all_samples.append(sub_samples)
218
+
219
+ progress(0.9, desc="Finalizing video...")
220
+
221
+ if len(all_samples) > 0:
222
+ out_cat = torch.cat(all_samples, dim=2)
223
+ else:
224
+ out_cat = all_samples[0]
225
+
226
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
227
+ output_path = tmp_video.name
228
+
229
+ save_videos_grid(out_cat, output_path, n_rows=1, fps=25)
230
+
231
+ if os.path.exists(image_path):
232
+ os.remove(image_path)
233
+
234
+ progress(1.0, desc="Complete!")
235
+ return output_path, "Video generated successfully!"
236
+
237
+ except Exception as e:
238
+ logger.error(f"Error generating video: {e}")
239
+ return None, f"Error: {str(e)}"
240
+
241
+ with gr.Blocks(title="Hunyuan-GameCraft") as demo:
242
+ gr.Markdown("""
243
+ # 🎮 Hunyuan-GameCraft Video Generation
244
+
245
+ Generate interactive game-style videos from a single image using keyboard actions (W/A/S/D).
246
+ Using the **distilled model** for faster generation (8 inference steps).
247
+ """)
248
+
249
+ with gr.Row():
250
+ with gr.Column(scale=1):
251
+ input_image = gr.Image(
252
+ label="Input Image",
253
+ type="pil",
254
+ height=400
255
+ )
256
+
257
+ prompt = gr.Textbox(
258
+ label="Prompt",
259
+ placeholder="Describe the scene...",
260
+ value="A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky.",
261
+ lines=3
262
+ )
263
+
264
+ with gr.Accordion("Action Controls", open=True):
265
+ action_sequence = gr.Textbox(
266
+ label="Action Sequence (comma-separated)",
267
+ placeholder="w, a, s, d",
268
+ value="w, s, d, a",
269
+ info="Use w (forward), a (left), s (backward), d (right)"
270
+ )
271
+
272
+ action_speeds = gr.Textbox(
273
+ label="Action Speeds (comma-separated)",
274
+ placeholder="0.2, 0.2, 0.2, 0.2",
275
+ value="0.2, 0.2, 0.2, 0.2",
276
+ info="Speed for each action (0.0 to 3.0). Single value applies to all."
277
+ )
278
+
279
+ with gr.Accordion("Advanced Settings", open=False):
280
+ negative_prompt = gr.Textbox(
281
+ label="Negative Prompt",
282
+ value="overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border.",
283
+ lines=2
284
+ )
285
+
286
+ seed = gr.Number(
287
+ label="Seed",
288
+ value=-1,
289
+ precision=0,
290
+ info="Set to -1 for random seed"
291
+ )
292
+
293
+ cfg_scale = gr.Slider(
294
+ label="CFG Scale",
295
+ minimum=0.5,
296
+ maximum=3.0,
297
+ value=1.0,
298
+ step=0.1,
299
+ info="Classifier-free guidance scale (1.0 for distilled model)"
300
+ )
301
+
302
+ num_inference_steps = gr.Slider(
303
+ label="Inference Steps",
304
+ minimum=4,
305
+ maximum=20,
306
+ value=8,
307
+ step=1,
308
+ info="Number of denoising steps (8 for distilled model)"
309
+ )
310
+
311
+ generate_btn = gr.Button("Generate Video", variant="primary")
312
+
313
+ with gr.Column(scale=1):
314
+ output_video = gr.Video(
315
+ label="Generated Video",
316
+ height=400
317
+ )
318
+ status_text = gr.Textbox(
319
+ label="Status",
320
+ interactive=False
321
+ )
322
+
323
+ gr.Markdown("""
324
+ ### Tips:
325
+ - Each action generates 33 frames (1.3 seconds at 25 FPS)
326
+ - The distilled model is optimized for speed with 8 inference steps
327
+ - Use FP8 optimization for better memory efficiency
328
+ - Minimum GPU memory: 24GB VRAM
329
+ """)
330
+
331
+ generate_btn.click(
332
+ fn=generate_video,
333
+ inputs=[
334
+ input_image,
335
+ prompt,
336
+ action_sequence,
337
+ action_speeds,
338
+ negative_prompt,
339
+ seed,
340
+ cfg_scale,
341
+ num_inference_steps
342
+ ],
343
+ outputs=[output_video, status_text]
344
+ )
345
+
346
+ gr.Examples(
347
+ examples=[
348
+ [
349
+ "asset/village.png",
350
+ "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky.",
351
+ "w, a, d, s",
352
+ "0.2, 0.2, 0.2, 0.2"
353
+ ]
354
+ ],
355
+ inputs=[input_image, prompt, action_sequence, action_speeds],
356
+ label="Example"
357
+ )
358
+
359
+ if __name__ == "__main__":
360
+ demo.launch(share=True)
asset/method.png ADDED

Git LFS Details

  • SHA256: e9d0546830d54f90e96392614405472ab06eb700ad184e4a6689bd84ff436890
  • Pointer size: 132 Bytes
  • Size of remote file: 2.86 MB
asset/teaser.png ADDED

Git LFS Details

  • SHA256: 5272120a5f85af9ee44c5f9714d6d0d99ba186ef2a66181b4bbd1b718c399555
  • Pointer size: 133 Bytes
  • Size of remote file: 20.9 MB
asset/village.png ADDED

Git LFS Details

  • SHA256: 5a5e986bd3100537653cd82c280ad1b3a1f2b0edec7abfd6926fae440bb92cd1
  • Pointer size: 132 Bytes
  • Size of remote file: 2.61 MB
docs_for_ai_coding_bots/.DS_Store ADDED
Binary file (6.15 kB). View file
 
docs_for_ai_coding_bots/huggingface_hub/Downloading-model-from-hub.md ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [](#download-files-from-the-hub)Download files from the Hub
2
+ ===========================================================
3
+
4
+ The `huggingface_hub` library provides functions to download files from the repositories stored on the Hub. You can use these functions independently or integrate them into your own library, making it more convenient for your users to interact with the Hub. This guide will show you how to:
5
+
6
+ * Download and cache a single file.
7
+ * Download and cache an entire repository.
8
+ * Download files to a local folder.
9
+
10
+ [](#download-a-single-file)Download a single file
11
+ -------------------------------------------------
12
+
13
+ The [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) function is the main function for downloading files from the Hub. It downloads the remote file, caches it on disk (in a version-aware way), and returns its local file path.
14
+
15
+ The returned filepath is a pointer to the HF local cache. Therefore, it is important to not modify the file to avoid having a corrupted cache. If you are interested in getting to know more about how files are cached, please refer to our [caching guide](./manage-cache).
16
+
17
+ ### [](#from-latest-version)From latest version
18
+
19
+ Select the file to download using the `repo_id`, `repo_type` and `filename` parameters. By default, the file will be considered as being part of a `model` repo.
20
+
21
+ Copied
22
+
23
+ \>>> from huggingface\_hub import hf\_hub\_download
24
+ \>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json")
25
+ '/root/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade/config.json'
26
+
27
+ \# Download from a dataset
28
+ \>>> hf\_hub\_download(repo\_id="google/fleurs", filename="fleurs.py", repo\_type="dataset")
29
+ '/root/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34/fleurs.py'
30
+
31
+ ### [](#from-specific-version)From specific version
32
+
33
+ By default, the latest version from the `main` branch is downloaded. However, in some cases you want to download a file at a particular version (e.g. from a specific branch, a PR, a tag or a commit hash). To do so, use the `revision` parameter:
34
+
35
+ Copied
36
+
37
+ \# Download from the \`v1.0\` tag
38
+ \>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="v1.0")
39
+
40
+ \# Download from the \`test-branch\` branch
41
+ \>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="test-branch")
42
+
43
+ \# Download from Pull Request #3
44
+ \>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="refs/pr/3")
45
+
46
+ \# Download from a specific commit hash
47
+ \>>> hf\_hub\_download(repo\_id="lysandre/arxiv-nlp", filename="config.json", revision="877b84a8f93f2d619faa2a6e514a32beef88ab0a")
48
+
49
+ **Note:** When using the commit hash, it must be the full-length hash instead of a 7-character commit hash.
50
+
51
+ ### [](#construct-a-download-url)Construct a download URL
52
+
53
+ In case you want to construct the URL used to download a file from a repo, you can use [hf\_hub\_url()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_url) which returns a URL. Note that it is used internally by [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download).
54
+
55
+ [](#download-an-entire-repository)Download an entire repository
56
+ ---------------------------------------------------------------
57
+
58
+ [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) downloads an entire repository at a given revision. It uses internally [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) which means all downloaded files are also cached on your local disk. Downloads are made concurrently to speed-up the process.
59
+
60
+ To download a whole repository, just pass the `repo_id` and `repo_type`:
61
+
62
+ Copied
63
+
64
+ \>>> from huggingface\_hub import snapshot\_download
65
+ \>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp")
66
+ '/home/lysandre/.cache/huggingface/hub/models--lysandre--arxiv-nlp/snapshots/894a9adde21d9a3e3843e6d5aeaaf01875c7fade'
67
+
68
+ \# Or from a dataset
69
+ \>>> snapshot\_download(repo\_id="google/fleurs", repo\_type="dataset")
70
+ '/home/lysandre/.cache/huggingface/hub/datasets--google--fleurs/snapshots/199e4ae37915137c555b1765c01477c216287d34'
71
+
72
+ [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) downloads the latest revision by default. If you want a specific repository revision, use the `revision` parameter:
73
+
74
+ Copied
75
+
76
+ \>>> from huggingface\_hub import snapshot\_download
77
+ \>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp", revision="refs/pr/1")
78
+
79
+ ### [](#filter-files-to-download)Filter files to download
80
+
81
+ [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) provides an easy way to download a repository. However, you don’t always want to download the entire content of a repository. For example, you might want to prevent downloading all `.bin` files if you know you’ll only use the `.safetensors` weights. You can do that using `allow_patterns` and `ignore_patterns` parameters.
82
+
83
+ These parameters accept either a single pattern or a list of patterns. Patterns are Standard Wildcards (globbing patterns) as documented [here](https://tldp.org/LDP/GNU-Linux-Tools-Summary/html/x11655.htm). The pattern matching is based on [`fnmatch`](https://docs.python.org/3/library/fnmatch.html).
84
+
85
+ For example, you can use `allow_patterns` to only download JSON configuration files:
86
+
87
+ Copied
88
+
89
+ \>>> from huggingface\_hub import snapshot\_download
90
+ \>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp", allow\_patterns="\*.json")
91
+
92
+ On the other hand, `ignore_patterns` can exclude certain files from being downloaded. The following example ignores the `.msgpack` and `.h5` file extensions:
93
+
94
+ Copied
95
+
96
+ \>>> from huggingface\_hub import snapshot\_download
97
+ \>>> snapshot\_download(repo\_id="lysandre/arxiv-nlp", ignore\_patterns=\["\*.msgpack", "\*.h5"\])
98
+
99
+ Finally, you can combine both to precisely filter your download. Here is an example to download all json and markdown files except `vocab.json`.
100
+
101
+ Copied
102
+
103
+ \>>> from huggingface\_hub import snapshot\_download
104
+ \>>> snapshot\_download(repo\_id="gpt2", allow\_patterns=\["\*.md", "\*.json"\], ignore\_patterns="vocab.json")
105
+
106
+ [](#download-files-to-a-local-folder)Download file(s) to a local folder
107
+ -----------------------------------------------------------------------
108
+
109
+ By default, we recommend using the [cache system](./manage-cache) to download files from the Hub. You can specify a custom cache location using the `cache_dir` parameter in [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) and [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download), or by setting the [`HF_HOME`](../package_reference/environment_variables#hf_home) environment variable.
110
+
111
+ However, if you need to download files to a specific folder, you can pass a `local_dir` parameter to the download function. This is useful to get a workflow closer to what the `git` command offers. The downloaded files will maintain their original file structure within the specified folder. For example, if `filename="data/train.csv"` and `local_dir="path/to/folder"`, the resulting filepath will be `"path/to/folder/data/train.csv"`.
112
+
113
+ A `.cache/huggingface/` folder is created at the root of your local directory containing metadata about the downloaded files. This prevents re-downloading files if they’re already up-to-date. If the metadata has changed, then the new file version is downloaded. This makes the `local_dir` optimized for pulling only the latest changes.
114
+
115
+ After completing the download, you can safely remove the `.cache/huggingface/` folder if you no longer need it. However, be aware that re-running your script without this folder may result in longer recovery times, as metadata will be lost. Rest assured that your local data will remain intact and unaffected.
116
+
117
+ Don’t worry about the `.cache/huggingface/` folder when committing changes to the Hub! This folder is automatically ignored by both `git` and [upload\_folder()](/docs/huggingface_hub/v0.32.2/en/package_reference/hf_api#huggingface_hub.HfApi.upload_folder).
118
+
119
+ [](#download-from-the-cli)Download from the CLI
120
+ -----------------------------------------------
121
+
122
+ You can use the `huggingface-cli download` command from the terminal to directly download files from the Hub. Internally, it uses the same [hf\_hub\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.hf_hub_download) and [snapshot\_download()](/docs/huggingface_hub/v0.32.2/en/package_reference/file_download#huggingface_hub.snapshot_download) helpers described above and prints the returned path to the terminal.
123
+
124
+ Copied
125
+
126
+ \>>> huggingface-cli download gpt2 config.json
127
+ /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10/config.json
128
+
129
+ You can download multiple files at once which displays a progress bar and returns the snapshot path in which the files are located:
130
+
131
+ Copied
132
+
133
+ \>>> huggingface-cli download gpt2 config.json model.safetensors
134
+ Fetching 2 files: 100%|████████████████████████████████████████████| 2/2 \[00:00<00:00, 23831.27it/s\]
135
+ /home/wauplin/.cache/huggingface/hub/models--gpt2/snapshots/11c5a3d5811f50298f278a704980280950aedb10
136
+
137
+ For more details about the CLI download command, please refer to the [CLI guide](./cli#huggingface-cli-download).
138
+
139
+ [](#faster-downloads)Faster downloads
140
+ -------------------------------------
141
+
142
+ There are two options to speed up downloads. Both involve installing a Python package written in Rust.
143
+
144
+ * `hf_xet` is newer and uses the Xet storage backend for upload/download. It is available in production, but is in the process of being rolled out to all users, so join the [waitlist](https://huggingface.co/join/xet) to get onboarded soon!
145
+ * `hf_transfer` is a power-tool to download and upload to our LFS storage backend (note: this is less future-proof than Xet). It is thoroughly tested and has been in production for a long time, but it has some limitations.
146
+
147
+ ### [](#hfxet)hf\_xet
148
+
149
+ Take advantage of faster downloads through `hf_xet`, the Python binding to the [`xet-core`](https://github.com/huggingface/xet-core) library that enables chunk-based deduplication for faster downloads and uploads. `hf_xet` integrates seamlessly with `huggingface_hub`, but uses the Rust `xet-core` library and Xet storage instead of LFS.
150
+
151
+ `hf_xet` uses the Xet storage system, which breaks files down into immutable chunks, storing collections of these chunks (called blocks or xorbs) remotely and retrieving them to reassemble the file when requested. When downloading, after confirming the user is authorized to access the files, `hf_xet` will query the Xet content-addressable service (CAS) with the LFS SHA256 hash for this file to receive the reconstruction metadata (ranges within xorbs) to assemble these files, along with presigned URLs to download the xorbs directly. Then `hf_xet` will efficiently download the xorb ranges necessary and will write out the files on disk. `hf_xet` uses a local disk cache to only download chunks once, learn more in the [Chunk-based caching(Xet)](./manage-cache#chunk-based-caching-xet) section.
152
+
153
+ To enable it, specify the `hf_xet` package when installing `huggingface_hub`:
154
+
155
+ Copied
156
+
157
+ pip install -U "huggingface\_hub\[hf\_xet\]"
158
+
159
+ Note: `hf_xet` will only be utilized when the files being downloaded are being stored with Xet Storage.
160
+
161
+ All other `huggingface_hub` APIs will continue to work without any modification. To learn more about the benefits of Xet storage and `hf_xet`, refer to this [section](https://huggingface.co/docs/hub/storage-backends).
162
+
163
+ ### [](#hftransfer)hf\_transfer
164
+
165
+ If you are running on a machine with high bandwidth, you can increase your download speed with [`hf_transfer`](https://github.com/huggingface/hf_transfer), a Rust-based library developed to speed up file transfers with the Hub. To enable it:
166
+
167
+ 1. Specify the `hf_transfer` extra when installing `huggingface_hub` (e.g. `pip install huggingface_hub[hf_transfer]`).
168
+ 2. Set `HF_HUB_ENABLE_HF_TRANSFER=1` as an environment variable.
169
+
170
+ `hf_transfer` is a power user tool! It is tested and production-ready, but it lacks user-friendly features like advanced error handling or proxies. For more details, please take a look at this [section](https://huggingface.co/docs/huggingface_hub/hf_transfer).
171
+
172
+ [< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/guides/download.md)
173
+
174
+ Command Line Interface (CLI)
docs_for_ai_coding_bots/huggingface_hub/Using-the-cache-in-hf-hub-library.md ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [](#understand-caching)Understand caching
2
+ =========================================
3
+
4
+ `huggingface_hub` utilizes the local disk as two caches, which avoid re-downloading items again. The first cache is a file-based cache, which caches individual files downloaded from the Hub and ensures that the same file is not downloaded again when a repo gets updated. The second cache is a chunk cache, where each chunk represents a byte range from a file and ensures that chunks that are shared across files are only downloaded once.
5
+
6
+ [](#file-based-caching)File-based caching
7
+ -----------------------------------------
8
+
9
+ The Hugging Face Hub cache-system is designed to be the central cache shared across libraries that depend on the Hub. It has been updated in v0.8.0 to prevent re-downloading same files between revisions.
10
+
11
+ The caching system is designed as follows:
12
+
13
+ Copied
14
+
15
+ <CACHE\_DIR\>
16
+ ├─ <MODELS\>
17
+ ├─ <DATASETS\>
18
+ ├─ <SPACES\>
19
+
20
+ The default `<CACHE_DIR>` is `~/.cache/huggingface/hub`. However, it is customizable with the `cache_dir` argument on all methods, or by specifying either `HF_HOME` or `HF_HUB_CACHE` environment variable.
21
+
22
+ Models, datasets and spaces share a common root. Each of these repositories contains the repository type, the namespace (organization or username) if it exists and the repository name:
23
+
24
+ Copied
25
+
26
+ <CACHE\_DIR\>
27
+ ├─ models\--julien\-c\--EsperBERTo\-small
28
+ ├─ models\--lysandrejik\--arxiv\-nlp
29
+ ├─ models\--bert\-base\-cased
30
+ ├─ datasets\--glue
31
+ ├─ datasets\--huggingface\--DataMeasurementsFiles
32
+ ├─ spaces\--dalle\-mini\--dalle\-mini
33
+
34
+ It is within these folders that all files will now be downloaded from the Hub. Caching ensures that a file isn’t downloaded twice if it already exists and wasn’t updated; but if it was updated, and you’re asking for the latest file, then it will download the latest file (while keeping the previous file intact in case you need it again).
35
+
36
+ In order to achieve this, all folders contain the same skeleton:
37
+
38
+ Copied
39
+
40
+ <CACHE\_DIR>
41
+ ├─ datasets\--glue
42
+ │ ├─ refs
43
+ │ ├─ blobs
44
+ │ ├─ snapshots
45
+ ...
46
+
47
+ Each folder is designed to contain the following:
48
+
49
+ ### [](#refs)Refs
50
+
51
+ The `refs` folder contains files which indicates the latest revision of the given reference. For example, if we have previously fetched a file from the `main` branch of a repository, the `refs` folder will contain a file named `main`, which will itself contain the commit identifier of the current head.
52
+
53
+ If the latest commit of `main` has `aaaaaa` as identifier, then it will contain `aaaaaa`.
54
+
55
+ If that same branch gets updated with a new commit, that has `bbbbbb` as an identifier, then re-downloading a file from that reference will update the `refs/main` file to contain `bbbbbb`.
56
+
57
+ ### [](#blobs)Blobs
58
+
59
+ The `blobs` folder contains the actual files that we have downloaded. The name of each file is their hash.
60
+
61
+ ### [](#snapshots)Snapshots
62
+
63
+ The `snapshots` folder contains symlinks to the blobs mentioned above. It is itself made up of several folders: one per known revision!
64
+
65
+ In the explanation above, we had initially fetched a file from the `aaaaaa` revision, before fetching a file from the `bbbbbb` revision. In this situation, we would now have two folders in the `snapshots` folder: `aaaaaa` and `bbbbbb`.
66
+
67
+ In each of these folders, live symlinks that have the names of the files that we have downloaded. For example, if we had downloaded the `README.md` file at revision `aaaaaa`, we would have the following path:
68
+
69
+ Copied
70
+
71
+ <CACHE\_DIR>/<REPO\_NAME>/snapshots/aaaaaa/README.md
72
+
73
+ That `README.md` file is actually a symlink linking to the blob that has the hash of the file.
74
+
75
+ By creating the skeleton this way we open the mechanism to file sharing: if the same file was fetched in revision `bbbbbb`, it would have the same hash and the file would not need to be re-downloaded.
76
+
77
+ ### [](#noexist-advanced).no\_exist (advanced)
78
+
79
+ In addition to the `blobs`, `refs` and `snapshots` folders, you might also find a `.no_exist` folder in your cache. This folder keeps track of files that you’ve tried to download once but don’t exist on the Hub. Its structure is the same as the `snapshots` folder with 1 subfolder per known revision:
80
+
81
+ Copied
82
+
83
+ <CACHE\_DIR>/<REPO\_NAME>/.no\_exist/aaaaaa/config\_that\_does\_not\_exist.json
84
+
85
+ Unlike the `snapshots` folder, files are simple empty files (no symlinks). In this example, the file `"config_that_does_not_exist.json"` does not exist on the Hub for the revision `"aaaaaa"`. As it only stores empty files, this folder is neglectable in term of disk usage.
86
+
87
+ So now you might wonder, why is this information even relevant? In some cases, a framework tries to load optional files for a model. Saving the non-existence of optional files makes it faster to load a model as it saves 1 HTTP call per possible optional file. This is for example the case in `transformers` where each tokenizer can support additional files. The first time you load the tokenizer on your machine, it will cache which optional files exist (and which doesn’t) to make the loading time faster for the next initializations.
88
+
89
+ To test if a file is cached locally (without making any HTTP request), you can use the [try\_to\_load\_from\_cache()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.try_to_load_from_cache) helper. It will either return the filepath (if exists and cached), the object `_CACHED_NO_EXIST` (if non-existence is cached) or `None` (if we don’t know).
90
+
91
+ Copied
92
+
93
+ from huggingface\_hub import try\_to\_load\_from\_cache, \_CACHED\_NO\_EXIST
94
+
95
+ filepath = try\_to\_load\_from\_cache()
96
+ if isinstance(filepath, str):
97
+ \# file exists and is cached
98
+ ...
99
+ elif filepath is \_CACHED\_NO\_EXIST:
100
+ \# non-existence of file is cached
101
+ ...
102
+ else:
103
+ \# file is not cached
104
+ ...
105
+
106
+ ### [](#in-practice)In practice
107
+
108
+ In practice, your cache should look like the following tree:
109
+
110
+ Copied
111
+
112
+ \[ 96\] .
113
+ └── \[ 160\] models--julien-c--EsperBERTo-small
114
+ ├── \[ 160\] blobs
115
+ │ ├── \[321M\] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
116
+ │ ├── \[ 398\] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
117
+ │ └── \[1.4K\] d7edf6bd2a681fb0175f7735299831ee1b22b812
118
+ ├── \[ 96\] refs
119
+ │ └── \[ 40\] main
120
+ └── \[ 128\] snapshots
121
+ ├── \[ 128\] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
122
+ │ ├── \[ 52\] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
123
+ │ └── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
124
+ └── \[ 128\] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
125
+ ├── \[ 52\] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
126
+ └── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
127
+
128
+ ### [](#limitations)Limitations
129
+
130
+ In order to have an efficient cache-system, `huggingface-hub` uses symlinks. However, symlinks are not supported on all machines. This is a known limitation especially on Windows. When this is the case, `huggingface_hub` do not use the `blobs/` directory but directly stores the files in the `snapshots/` directory instead. This workaround allows users to download and cache files from the Hub exactly the same way. Tools to inspect and delete the cache (see below) are also supported. However, the cache-system is less efficient as a single file might be downloaded several times if multiple revisions of the same repo is downloaded.
131
+
132
+ If you want to benefit from the symlink-based cache-system on a Windows machine, you either need to [activate Developer Mode](https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development) or to run Python as an administrator.
133
+
134
+ When symlinks are not supported, a warning message is displayed to the user to alert them they are using a degraded version of the cache-system. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable to true.
135
+
136
+ [](#chunk-based-caching-xet)Chunk-based caching (Xet)
137
+ -----------------------------------------------------
138
+
139
+ To provide more efficient file transfers, `hf_xet` adds a `xet` directory to the existing `huggingface_hub` cache, creating additional caching layer to enable chunk-based deduplication. This cache holds chunks, which are immutable byte ranges from files (up to 64KB) that are created using content-defined chunking. For more information on the Xet Storage system, see this [section](https://huggingface.co/docs/hub/storage-backends).
140
+
141
+ The `xet` directory, located at `~/.cache/huggingface/xet` by default, contains two caches, utilized for uploads and downloads with the following structure
142
+
143
+ Copied
144
+
145
+ <CACHE\_DIR>
146
+ ├─ chunk\_cache
147
+ ├─ shard\_cache
148
+
149
+ The `xet` cache, like the rest of `hf_xet` is fully integrated with `huggingface_hub`. If you use the existing APIs for interacting with cached assets, there is no need to update your workflow. The `xet` cache is built as an optimization layer on top of the existing `hf_xet` chunk-based deduplication and `huggingface_hub` cache system.
150
+
151
+ The `chunk-cache` directory contains cached data chunks that are used to speed up downloads while the `shard-cache` directory contains cached shards that are utilized on the upload path.
152
+
153
+ ### [](#chunkcache)chunk\_cache
154
+
155
+ This cache is used on the download path. The cache directory structure is based on a base-64 encoded hash from the content-addressed store (CAS) that backs each Xet-enabled repository. A CAS hash serves as the key to lookup the offsets of where the data is stored.
156
+
157
+ At the topmost level, the first two letters of the base 64 encoded CAS hash are used to create a subdirectory in the `chunk_cache` (keys that share these first two letters are grouped here). The inner levels are comprised of subdirectories with the full key as the directory name. At the base are the cache items which are ranges of blocks that contain the cached chunks.
158
+
159
+ Copied
160
+
161
+ <CACHE\_DIR>
162
+ ├─ xet
163
+ │ ├─ chunk\_cache
164
+ │ │ ├─ A1
165
+ │ │ │ ├─ A1GerURLUcISVivdseeoY1PnYifYkOaCCJ7V5Q9fjgxkZWZhdWx0
166
+ │ │ │ │ ├─ AAAAAAEAAAA5DQAAAAAAAIhRLjDI3SS5jYs4ysNKZiJy9XFI8CN7Ww0UyEA9KPD9
167
+ │ │ │ │ ├─ AQAAAAIAAABzngAAAAAAAPNqPjd5Zby5aBvabF7Z1itCx0ryMwoCnuQcDwq79jlB
168
+
169
+ When requesting a file, the first thing `hf_xet` does is communicate with Xet storage’s content addressed store (CAS) for reconstruction information. The reconstruction information contains information about the CAS keys required to download the file in its entirety.
170
+
171
+ Before executing the requests for the CAS keys, the `chunk_cache` is consulted. If a key in the cache matches a CAS key, then there is no reason to issue a request for that content. `hf_xet` uses the chunks stored in the directory instead.
172
+
173
+ As the `chunk_cache` is purely an optimization, not a guarantee, `hf_xet` utilizes a computationally efficient eviction policy. When the `chunk_cache` is full (see `Limits and Limitations` below), `hf_xet` implements a random eviction policy when selecting an eviction candidate. This significantly reduces the overhead of managing a robust caching system (e.g., LRU) while still providing most of the benefits of caching chunks.
174
+
175
+ ### [](#shardcache)shard\_cache
176
+
177
+ This cache is used when uploading content to the Hub. The directory is flat, comprising only of shard files, each using an ID for the shard name.
178
+
179
+ Copied
180
+
181
+ <CACHE\_DIR>
182
+ ├─ xet
183
+ │ ├─ shard\_cache
184
+ │ │ ├─ 1fe4ffd5cf0c3375f1ef9aec5016cf773ccc5ca294293d3f92d92771dacfc15d.mdb
185
+ │ │ ├─ 906ee184dc1cd0615164a89ed64e8147b3fdccd1163d80d794c66814b3b09992.mdb
186
+ │ │ ├─ ceeeb7ea4cf6c0a8d395a2cf9c08871211fbbd17b9b5dc1005811845307e6b8f.mdb
187
+ │ │ ├─ e8535155b1b11ebd894c908e91a1e14e3461dddd1392695ddc90ae54a548d8b2.mdb
188
+
189
+ The `shard_cache` contains shards that are:
190
+
191
+ * Locally generated and successfully uploaded to the CAS
192
+ * Downloaded from CAS as part of the global deduplication algorithm
193
+
194
+ Shards provide a mapping between files and chunks. During uploads, each file is chunked and the hash of the chunk is saved. Every shard in the cache is then consulted. If a shard contains a chunk hash that is present in the local file being uploaded, then that chunk can be discarded as it is already stored in CAS.
195
+
196
+ All shards have an expiration date of 3-4 weeks from when they are downloaded. Shards that are expired are not loaded during upload and are deleted one week after expiration.
197
+
198
+ ### [](#limits-and-limitations)Limits and Limitations
199
+
200
+ The `chunk_cache` is limited to 10GB in size while the `shard_cache` is technically without limits (in practice, the size and use of shards are such that limiting the cache is unnecessary).
201
+
202
+ By design, both caches are without high-level APIs. These caches are used primarily to facilitate the reconstruction (download) or upload of a file. To interact with the assets themselves, it’s recommended that you use the [`huggingface_hub` cache system APIs](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).
203
+
204
+ If you need to reclaim the space utilized by either cache or need to debug any potential cache-related issues, simply remove the `xet` cache entirely by running `rm -rf ~/<cache_dir>/xet` where `<cache_dir>` is the location of your Hugging Face cache, typically `~/.cache/huggingface`
205
+
206
+ Example full `xet`cache directory tree:
207
+
208
+ Copied
209
+
210
+ <CACHE\_DIR>
211
+ ├─ xet
212
+ │ ├─ chunk\_cache
213
+ │ │ ├─ L1
214
+ │ │ │ ├─ L1GerURLUcISVivdseeoY1PnYifYkOaCCJ7V5Q9fjgxkZWZhdWx0
215
+ │ │ │ │ ├─ AAAAAAEAAAA5DQAAAAAAAIhRLjDI3SS5jYs4ysNKZiJy9XFI8CN7Ww0UyEA9KPD9
216
+ │ │ │ │ ├─ AQAAAAIAAABzngAAAAAAAPNqPjd5Zby5aBvabF7Z1itCx0ryMwoCnuQcDwq79jlB
217
+ │ ├─ shard\_cache
218
+ │ │ ├─ 1fe4ffd5cf0c3375f1ef9aec5016cf773ccc5ca294293d3f92d92771dacfc15d.mdb
219
+ │ │ ├─ 906ee184dc1cd0615164a89ed64e8147b3fdccd1163d80d794c66814b3b09992.mdb
220
+ │ │ ├─ ceeeb7ea4cf6c0a8d395a2cf9c08871211fbbd17b9b5dc1005811845307e6b8f.mdb
221
+ │ │ ├─ e8535155b1b11ebd894c908e91a1e14e3461dddd1392695ddc90ae54a548d8b2.mdb
222
+
223
+ To learn more about Xet Storage, see this [section](https://huggingface.co/docs/hub/storage-backends).
224
+
225
+ [](#caching-assets)Caching assets
226
+ ---------------------------------
227
+
228
+ In addition to caching files from the Hub, downstream libraries often requires to cache other files related to HF but not handled directly by `huggingface_hub` (example: file downloaded from GitHub, preprocessed data, logs,…). In order to cache those files, called `assets`, one can use [cached\_assets\_path()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.cached_assets_path). This small helper generates paths in the HF cache in a unified way based on the name of the library requesting it and optionally on a namespace and a subfolder name. The goal is to let every downstream libraries manage its assets its own way (e.g. no rule on the structure) as long as it stays in the right assets folder. Those libraries can then leverage tools from `huggingface_hub` to manage the cache, in particular scanning and deleting parts of the assets from a CLI command.
229
+
230
+ Copied
231
+
232
+ from huggingface\_hub import cached\_assets\_path
233
+
234
+ assets\_path = cached\_assets\_path(library\_name="datasets", namespace="SQuAD", subfolder="download")
235
+ something\_path = assets\_path / "something.json" \# Do anything you like in your assets folder !
236
+
237
+ [cached\_assets\_path()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.cached_assets_path) is the recommended way to store assets but is not mandatory. If your library already uses its own cache, feel free to use it!
238
+
239
+ ### [](#assets-in-practice)Assets in practice
240
+
241
+ In practice, your assets cache should look like the following tree:
242
+
243
+ Copied
244
+
245
+ assets/
246
+ └── datasets/
247
+ │ ├── SQuAD/
248
+ │ │ ├── downloaded/
249
+ │ │ ├── extracted/
250
+ │ │ └── processed/
251
+ │ ├── Helsinki-NLP--tatoeba\_mt/
252
+ │ ├── downloaded/
253
+ │ ├── extracted/
254
+ │ └── processed/
255
+ └── transformers/
256
+ ├── default/
257
+ │ ├── something/
258
+ ├── bert-base-cased/
259
+ │ ├── default/
260
+ │ └── training/
261
+ hub/
262
+ └── models--julien-c--EsperBERTo-small/
263
+ ├── blobs/
264
+ │ ├── (...)
265
+ │ ├── (...)
266
+ ├── refs/
267
+ │ └── (...)
268
+ └── \[ 128\] snapshots/
269
+ ├── 2439f60ef33a0d46d85da5001d52aeda5b00ce9f/
270
+ │ ├── (...)
271
+ └── bbc77c8132af1cc5cf678da3f1ddf2de43606d48/
272
+ └── (...)
273
+
274
+ [](#manage-your-file-based-cache)Manage your file-based cache
275
+ -------------------------------------------------------------
276
+
277
+ ### [](#scan-your-cache)Scan your cache
278
+
279
+ At the moment, cached files are never deleted from your local directory: when you download a new revision of a branch, previous files are kept in case you need them again. Therefore it can be useful to scan your cache directory in order to know which repos and revisions are taking the most disk space. `huggingface_hub` provides an helper to do so that can be used via `huggingface-cli` or in a python script.
280
+
281
+ **Scan cache from the terminal**
282
+
283
+ The easiest way to scan your HF cache-system is to use the `scan-cache` command from `huggingface-cli` tool. This command scans the cache and prints a report with information like repo id, repo type, disk usage, refs and full local path.
284
+
285
+ The snippet below shows a scan report in a folder in which 4 models and 2 datasets are cached.
286
+
287
+ Copied
288
+
289
+ ➜ huggingface-cli scan-cache
290
+ REPO ID REPO TYPE SIZE ON DISK NB FILES LAST\_ACCESSED LAST\_MODIFIED REFS LOCAL PATH
291
+ --------------------------- --------- ------------ -------- ------------- ------------- ------------------- -------------------------------------------------------------------------
292
+ glue dataset 116.3K 15 4 days ago 4 days ago 2.4.0, main, 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue
293
+ google/fleurs dataset 64.9M 6 1 week ago 1 week ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs
294
+ Jean-Baptiste/camembert-ner model 441.0M 7 2 weeks ago 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner
295
+ bert-base-cased model 1.9G 13 1 week ago 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased
296
+ t5-base model 10.1K 3 3 months ago 3 months ago main /home/wauplin/.cache/huggingface/hub/models--t5-base
297
+ t5-small model 970.7M 11 3 days ago 3 days ago refs/pr/1, main /home/wauplin/.cache/huggingface/hub/models--t5-small
298
+
299
+ Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
300
+ Got 1 warning(s) while scanning. Use -vvv to print details.
301
+
302
+ To get a more detailed report, use the `--verbose` option. For each repo, you get a list of all revisions that have been downloaded. As explained above, the files that don’t change between 2 revisions are shared thanks to the symlinks. This means that the size of the repo on disk is expected to be less than the sum of the size of each of its revisions. For example, here `bert-base-cased` has 2 revisions of 1.4G and 1.5G but the total disk usage is only 1.9G.
303
+
304
+ Copied
305
+
306
+ ➜ huggingface-cli scan-cache -v
307
+ REPO ID REPO TYPE REVISION SIZE ON DISK NB FILES LAST\_MODIFIED REFS LOCAL PATH
308
+ --------------------------- --------- ---------------------------------------- ------------ -------- ------------- ----------- ----------------------------------------------------------------------------------------------------------------------------
309
+ glue dataset 9338f7b671827df886678df2bdd7cc7b4f36dffd 97.7K 14 4 days ago main, 2.4.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/9338f7b671827df886678df2bdd7cc7b4f36dffd
310
+ glue dataset f021ae41c879fcabcf823648ec685e3fead91fe7 97.8K 14 1 week ago 1.17.0 /home/wauplin/.cache/huggingface/hub/datasets--glue/snapshots/f021ae41c879fcabcf823648ec685e3fead91fe7
311
+ google/fleurs dataset 129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8 25.4K 3 2 weeks ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/129b6e96cf1967cd5d2b9b6aec75ce6cce7c89e8
312
+ google/fleurs dataset 24f85a01eb955224ca3946e70050869c56446805 64.9M 4 1 week ago main /home/wauplin/.cache/huggingface/hub/datasets--google--fleurs/snapshots/24f85a01eb955224ca3946e70050869c56446805
313
+ Jean-Baptiste/camembert-ner model dbec8489a1c44ecad9da8a9185115bccabd799fe 441.0M 7 16 hours ago main /home/wauplin/.cache/huggingface/hub/models--Jean-Baptiste--camembert-ner/snapshots/dbec8489a1c44ecad9da8a9185115bccabd799fe
314
+ bert-base-cased model 378aa1bda6387fd00e824948ebe3488630ad8565 1.5G 9 2 years ago /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/378aa1bda6387fd00e824948ebe3488630ad8565
315
+ bert-base-cased model a8d257ba9925ef39f3036bfc338acf5283c512d9 1.4G 9 3 days ago main /home/wauplin/.cache/huggingface/hub/models--bert-base-cased/snapshots/a8d257ba9925ef39f3036bfc338acf5283c512d9
316
+ t5-base model 23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9 10.1K 3 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-base/snapshots/23aa4f41cb7c08d4b05c8f327b22bfa0eb8c7ad9
317
+ t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a
318
+ t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617
319
+ t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5
320
+
321
+ Done in 0.0s. Scanned 6 repo(s) for a total of 3.4G.
322
+ Got 1 warning(s) while scanning. Use -vvv to print details.
323
+
324
+ **Grep example**
325
+
326
+ Since the output is in tabular format, you can combine it with any `grep`\-like tools to filter the entries. Here is an example to filter only revisions from the “t5-small” model on a Unix-based machine.
327
+
328
+ Copied
329
+
330
+ ➜ eval "huggingface-cli scan-cache -v" | grep "t5-small"
331
+ t5-small model 98ffebbb27340ec1b1abd7c45da12c253ee1882a 726.2M 6 1 week ago refs/pr/1 /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/98ffebbb27340ec1b1abd7c45da12c253ee1882a
332
+ t5-small model d0a119eedb3718e34c648e594394474cf95e0617 485.8M 6 4 weeks ago /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d0a119eedb3718e34c648e594394474cf95e0617
333
+ t5-small model d78aea13fa7ecd06c29e3e46195d6341255065d5 970.7M 9 1 week ago main /home/wauplin/.cache/huggingface/hub/models--t5-small/snapshots/d78aea13fa7ecd06c29e3e46195d6341255065d5
334
+
335
+ **Scan cache from Python**
336
+
337
+ For a more advanced usage, use [scan\_cache\_dir()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.scan_cache_dir) which is the python utility called by the CLI tool.
338
+
339
+ You can use it to get a detailed report structured around 4 dataclasses:
340
+
341
+ * [HFCacheInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo): complete report returned by [scan\_cache\_dir()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.scan_cache_dir)
342
+ * [CachedRepoInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.CachedRepoInfo): information about a cached repo
343
+ * [CachedRevisionInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.CachedRevisionInfo): information about a cached revision (e.g. “snapshot”) inside a repo
344
+ * [CachedFileInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.CachedFileInfo): information about a cached file in a snapshot
345
+
346
+ Here is a simple usage example. See reference for details.
347
+
348
+ Copied
349
+
350
+ \>>> from huggingface\_hub import scan\_cache\_dir
351
+
352
+ \>>> hf\_cache\_info = scan\_cache\_dir()
353
+ HFCacheInfo(
354
+ size\_on\_disk=3398085269,
355
+ repos=frozenset({
356
+ CachedRepoInfo(
357
+ repo\_id='t5-small',
358
+ repo\_type='model',
359
+ repo\_path=PosixPath(...),
360
+ size\_on\_disk=970726914,
361
+ nb\_files=11,
362
+ last\_accessed=1662971707.3567169,
363
+ last\_modified=1662971107.3567169,
364
+ revisions=frozenset({
365
+ CachedRevisionInfo(
366
+ commit\_hash='d78aea13fa7ecd06c29e3e46195d6341255065d5',
367
+ size\_on\_disk=970726339,
368
+ snapshot\_path=PosixPath(...),
369
+ \# No \`last\_accessed\` as blobs are shared among revisions
370
+ last\_modified=1662971107.3567169,
371
+ files=frozenset({
372
+ CachedFileInfo(
373
+ file\_name='config.json',
374
+ size\_on\_disk=1197
375
+ file\_path=PosixPath(...),
376
+ blob\_path=PosixPath(...),
377
+ blob\_last\_accessed=1662971707.3567169,
378
+ blob\_last\_modified=1662971107.3567169,
379
+ ),
380
+ CachedFileInfo(...),
381
+ ...
382
+ }),
383
+ ),
384
+ CachedRevisionInfo(...),
385
+ ...
386
+ }),
387
+ ),
388
+ CachedRepoInfo(...),
389
+ ...
390
+ }),
391
+ warnings=\[
392
+ CorruptedCacheException("Snapshots dir doesn't exist in cached repo: ..."),
393
+ CorruptedCacheException(...),
394
+ ...
395
+ \],
396
+ )
397
+
398
+ ### [](#clean-your-cache)Clean your cache
399
+
400
+ Scanning your cache is interesting but what you really want to do next is usually to delete some portions to free up some space on your drive. This is possible using the `delete-cache` CLI command. One can also programmatically use the [delete\_revisions()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo.delete_revisions) helper from [HFCacheInfo](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo) object returned when scanning the cache.
401
+
402
+ **Delete strategy**
403
+
404
+ To delete some cache, you need to pass a list of revisions to delete. The tool will define a strategy to free up the space based on this list. It returns a [DeleteCacheStrategy](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.DeleteCacheStrategy) object that describes which files and folders will be deleted. The [DeleteCacheStrategy](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.DeleteCacheStrategy) allows give you how much space is expected to be freed. Once you agree with the deletion, you must execute it to make the deletion effective. In order to avoid discrepancies, you cannot edit a strategy object manually.
405
+
406
+ The strategy to delete revisions is the following:
407
+
408
+ * the `snapshot` folder containing the revision symlinks is deleted.
409
+ * blobs files that are targeted only by revisions to be deleted are deleted as well.
410
+ * if a revision is linked to 1 or more `refs`, references are deleted.
411
+ * if all revisions from a repo are deleted, the entire cached repository is deleted.
412
+
413
+ Revision hashes are unique across all repositories. This means you don’t need to provide any `repo_id` or `repo_type` when removing revisions.
414
+
415
+ If a revision is not found in the cache, it will be silently ignored. Besides, if a file or folder cannot be found while trying to delete it, a warning will be logged but no error is thrown. The deletion continues for other paths contained in the [DeleteCacheStrategy](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.DeleteCacheStrategy) object.
416
+
417
+ **Clean cache from the terminal**
418
+
419
+ The easiest way to delete some revisions from your HF cache-system is to use the `delete-cache` command from `huggingface-cli` tool. The command has two modes. By default, a TUI (Terminal User Interface) is displayed to the user to select which revisions to delete. This TUI is currently in beta as it has not been tested on all platforms. If the TUI doesn’t work on your machine, you can disable it using the `--disable-tui` flag.
420
+
421
+ **Using the TUI**
422
+
423
+ This is the default mode. To use it, you first need to install extra dependencies by running the following command:
424
+
425
+ Copied
426
+
427
+ pip install huggingface\_hub\["cli"\]
428
+
429
+ Then run the command:
430
+
431
+ Copied
432
+
433
+ huggingface-cli delete\-cache
434
+
435
+ You should now see a list of revisions that you can select/deselect:
436
+
437
+ ![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/hub/delete-cache-tui.png)
438
+
439
+ Instructions:
440
+
441
+ * Press keyboard arrow keys `<up>` and `<down>` to move the cursor.
442
+ * Press `<space>` to toggle (select/unselect) an item.
443
+ * When a revision is selected, the first line is updated to show you how much space will be freed.
444
+ * Press `<enter>` to confirm your selection.
445
+ * If you want to cancel the operation and quit, you can select the first item (“None of the following”). If this item is selected, the delete process will be cancelled, no matter what other items are selected. Otherwise you can also press `<ctrl+c>` to quit the TUI.
446
+
447
+ Once you’ve selected the revisions you want to delete and pressed `<enter>`, a last confirmation message will be prompted. Press `<enter>` again and the deletion will be effective. If you want to cancel, enter `n`.
448
+
449
+ Copied
450
+
451
+ ✗ huggingface-cli delete-cache --dir ~/.cache/huggingface/hub
452
+ ? Select revisions to delete: 2 revision(s) selected.
453
+ ? 2 revisions selected counting for 3.1G. Confirm deletion ? Yes
454
+ Start deletion.
455
+ Done. Deleted 1 repo(s) and 0 revision(s) for a total of 3.1G.
456
+
457
+ **Without TUI**
458
+
459
+ As mentioned above, the TUI mode is currently in beta and is optional. It may be the case that it doesn’t work on your machine or that you don’t find it convenient.
460
+
461
+ Another approach is to use the `--disable-tui` flag. The process is very similar as you will be asked to manually review the list of revisions to delete. However, this manual step will not take place in the terminal directly but in a temporary file generated on the fly and that you can manually edit.
462
+
463
+ This file has all the instructions you need in the header. Open it in your favorite text editor. To select/deselect a revision, simply comment/uncomment it with a `#`. Once the manual review is done and the file is edited, you can save it. Go back to your terminal and press `<enter>`. By default it will compute how much space would be freed with the updated list of revisions. You can continue to edit the file or confirm with `"y"`.
464
+
465
+ Copied
466
+
467
+ huggingface-cli delete-cache --disable-tui
468
+
469
+ Example of command file:
470
+
471
+ Copied
472
+
473
+ \# INSTRUCTIONS
474
+ # ------------
475
+ # This is a temporary file created by running \`huggingface-cli delete-cache\` with the
476
+ # \`--disable-tui\` option. It contains a set of revisions that can be deleted from your
477
+ # local cache directory.
478
+ #
479
+ # Please manually review the revisions you want to delete:
480
+ # - Revision hashes can be commented out with '#'.
481
+ # - Only non-commented revisions in this file will be deleted.
482
+ # - Revision hashes that are removed from this file are ignored as well.
483
+ # - If \`CANCEL\_DELETION\` line is uncommented, the all cache deletion is cancelled and
484
+ # no changes will be applied.
485
+ #
486
+ # Once you've manually reviewed this file, please confirm deletion in the terminal. This
487
+ # file will be automatically removed once done.
488
+ # ------------
489
+
490
+ # KILL SWITCH
491
+ # ------------
492
+ # Un-comment following line to completely cancel the deletion process
493
+ # CANCEL\_DELETION
494
+ # ------------
495
+
496
+ # REVISIONS
497
+ # ------------
498
+ # Dataset chrisjay/crowd-speech-africa (761.7M, used 5 days ago)
499
+ ebedcd8c55c90d39fd27126d29d8484566cd27ca # Refs: main # modified 5 days ago
500
+
501
+ # Dataset oscar (3.3M, used 4 days ago)
502
+ # 916f956518279c5e60c63902ebdf3ddf9fa9d629 # Refs: main # modified 4 days ago
503
+
504
+ # Dataset wikiann (804.1K, used 2 weeks ago)
505
+ 89d089624b6323d69dcd9e5eb2def0551887a73a # Refs: main # modified 2 weeks ago
506
+
507
+ # Dataset z-uo/male-LJSpeech-italian (5.5G, used 5 days ago)
508
+ # 9cfa5647b32c0a30d0adfca06bf198d82192a0d1 # Refs: main # modified 5 days ago
509
+
510
+ **Clean cache from Python**
511
+
512
+ For more flexibility, you can also use the [delete\_revisions()](/docs/huggingface_hub/v0.32.2/en/package_reference/cache#huggingface_hub.HFCacheInfo.delete_revisions) method programmatically. Here is a simple example. See reference for details.
513
+
514
+ Copied
515
+
516
+ \>>> from huggingface\_hub import scan\_cache\_dir
517
+
518
+ \>>> delete\_strategy = scan\_cache\_dir().delete\_revisions(
519
+ ... "81fd1d6e7847c99f5862c9fb81387956d99ec7aa"
520
+ ... "e2983b237dccf3ab4937c97fa717319a9ca1a96d",
521
+ ... "6c0e6080953db56375760c0471a8c5f2929baf11",
522
+ ... )
523
+ \>>> print("Will free " + delete\_strategy.expected\_freed\_size\_str)
524
+ Will free 8.6G
525
+
526
+ \>>> delete\_strategy.execute()
527
+ Cache deletion done. Saved 8.6G.
528
+
529
+ [< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/guides/manage-cache.md)
530
+
531
+ Create and manage a repository
hymm_sp/__init__.py ADDED
File without changes
hymm_sp/config.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from hymm_sp.constants import *
3
+ import re
4
+ import collections.abc
5
+
6
+ def as_tuple(x):
7
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
8
+ return tuple(x)
9
+ if x is None or isinstance(x, (int, float, str)):
10
+ return (x,)
11
+ else:
12
+ raise ValueError(f"Unknown type {type(x)}")
13
+
14
+ def parse_args(namespace=None):
15
+ parser = argparse.ArgumentParser(description="Hunyuan Multimodal training/inference script")
16
+ parser = add_extra_args(parser)
17
+ args = parser.parse_args(namespace=namespace)
18
+ args = sanity_check_args(args)
19
+ return args
20
+
21
+ def add_extra_args(parser: argparse.ArgumentParser):
22
+ parser = add_network_args(parser)
23
+ parser = add_extra_models_args(parser)
24
+ parser = add_denoise_schedule_args(parser)
25
+ parser = add_evaluation_args(parser)
26
+ parser = add_test_args(parser)
27
+ return parser
28
+
29
+ def add_test_args(parser: argparse.ArgumentParser):
30
+ group = parser.add_argument_group(title="Test")
31
+
32
+ group.add_argument("--image-start", action="store_true", help="Use one image from video for training")
33
+ group.add_argument("--use-csv-pose", action="store_true", help="Use one image from video for training")
34
+ group.add_argument("--add-button", action="store_true", help="Use one image from video for training")
35
+ group.add_argument("--action-list", type=str, nargs='+', default=None, help="CSV file for evaluation.")
36
+ group.add_argument("--action-speed-list", type=float, nargs='+', default=None, help="CSV file for evaluation.")
37
+ group.add_argument("--pose", type=str, default=None, help="CSV file for evaluation.")
38
+
39
+ return parser
40
+
41
+
42
+ def add_network_args(parser: argparse.ArgumentParser):
43
+ group = parser.add_argument_group(title="Network")
44
+ group.add_argument("--model", type=str, default="HYVideo-T/2",
45
+ help="Model architecture to use. It it also used to determine the experiment directory.")
46
+ group.add_argument("--latent-channels", type=str, default=None,
47
+ help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
48
+ "it still needs to match the latent channels of the VAE model.")
49
+ group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
50
+ return parser
51
+
52
+ def add_extra_models_args(parser: argparse.ArgumentParser):
53
+ group = parser.add_argument_group(title="Extra Models (VAE, Text Encoder, Tokenizer)")
54
+
55
+ # VAE
56
+ group.add_argument("--vae", type=str, default="884-16c-hy0801", help="Name of the VAE model.")
57
+ group.add_argument("--vae-precision", type=str, default="fp16",
58
+ help="Precision mode for the VAE model.")
59
+ group.add_argument("--vae-tiling", action="store_true", default=True, help="Enable tiling for the VAE model.")
60
+ group.add_argument("--text-encoder", type=str, default="llava-llama-3-8b", choices=list(TEXT_ENCODER_PATH),
61
+ help="Name of the text encoder model.")
62
+ group.add_argument("--text-encoder-precision", type=str, default="fp16", choices=PRECISIONS,
63
+ help="Precision mode for the text encoder model.")
64
+ group.add_argument("--text-states-dim", type=int, default=4096, help="Dimension of the text encoder hidden states.")
65
+ group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
66
+ group.add_argument("--tokenizer", type=str, default="llava-llama-3-8b", choices=list(TOKENIZER_PATH),
67
+ help="Name of the tokenizer model.")
68
+ group.add_argument("--text-encoder-infer-mode", type=str, default="encoder", choices=["encoder", "decoder"],
69
+ help="Inference mode for the text encoder model. It should match the text encoder type. T5 and "
70
+ "CLIP can only work in 'encoder' mode, while Llava/GLM can work in both modes.")
71
+
72
+ group.add_argument("--prompt-template-video", type=str, default='li-dit-encode-video', choices=PROMPT_TEMPLATE,
73
+ help="Video prompt template for the decoder-only text encoder model.")
74
+ group.add_argument("--hidden-state-skip-layer", type=int, default=2,
75
+ help="Skip layer for hidden states.")
76
+ group.add_argument("--apply-final-norm", action="store_true",
77
+ help="Apply final normalization to the used text encoder hidden states.")
78
+
79
+ # - CLIP
80
+ group.add_argument("--text-encoder-2", type=str, default='clipL', choices=list(TEXT_ENCODER_PATH),
81
+ help="Name of the second text encoder model.")
82
+ group.add_argument("--text-encoder-precision-2", type=str, default="fp16", choices=PRECISIONS,
83
+ help="Precision mode for the second text encoder model.")
84
+ group.add_argument("--text-states-dim-2", type=int, default=768,
85
+ help="Dimension of the second text encoder hidden states.")
86
+ group.add_argument("--tokenizer-2", type=str, default='clipL', choices=list(TOKENIZER_PATH),
87
+ help="Name of the second tokenizer model.")
88
+ group.add_argument("--text-len-2", type=int, default=77, help="Maximum length of the second text input.")
89
+ group.set_defaults(use_attention_mask=True)
90
+ group.add_argument("--text-projection", type=str, default="single_refiner", choices=TEXT_PROJECTION,
91
+ help="A projection layer for bridging the text encoder hidden states and the diffusion model "
92
+ "conditions.")
93
+ return parser
94
+
95
+
96
+ def add_denoise_schedule_args(parser: argparse.ArgumentParser):
97
+ group = parser.add_argument_group(title="Denoise schedule")
98
+ group.add_argument("--flow-shift-eval-video", type=float, default=None, help="Shift factor for flow matching schedulers when using video data.")
99
+ group.add_argument("--flow-reverse", action="store_true", default=True, help="If reverse, learning/sampling from t=1 -> t=0.")
100
+ group.add_argument("--flow-solver", type=str, default="euler", help="Solver for flow matching.")
101
+ group.add_argument("--use-linear-quadratic-schedule", action="store_true", help="Use linear quadratic schedule for flow matching."
102
+ "Follow MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)")
103
+ group.add_argument("--linear-schedule-end", type=int, default=25, help="End step for linear quadratic schedule for flow matching.")
104
+ return parser
105
+
106
+ def add_evaluation_args(parser: argparse.ArgumentParser):
107
+ group = parser.add_argument_group(title="Validation Loss Evaluation")
108
+ parser.add_argument("--precision", type=str, default="bf16", choices=PRECISIONS,
109
+ help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.")
110
+ parser.add_argument("--reproduce", action="store_true",
111
+ help="Enable reproducibility by setting random seeds and deterministic algorithms.")
112
+ parser.add_argument("--ckpt", type=str, help="Path to the checkpoint to evaluate.")
113
+ parser.add_argument("--load-key", type=str, default="module", choices=["module", "ema"],
114
+ help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.")
115
+ parser.add_argument("--cpu-offload", action="store_true", help="Use CPU offload for the model load.")
116
+ group.add_argument( "--use-fp8", action="store_true", help="Enable use fp8 for inference acceleration.")
117
+ group.add_argument("--video-size", type=int, nargs='+', default=512,
118
+ help="Video size for training. If a single value is provided, it will be used for both width "
119
+ "and height. If two values are provided, they will be used for width and height "
120
+ "respectively.")
121
+ group.add_argument("--sample-n-frames", type=int, default=33,
122
+ help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1")
123
+ group.add_argument("--infer-steps", type=int, default=100, help="Number of denoising steps for inference.")
124
+ group.add_argument("--val-disable-autocast", action="store_true",
125
+ help="Disable autocast for denoising loop and vae decoding in pipeline sampling.")
126
+ group.add_argument("--num-images", type=int, default=1, help="Number of images to generate for each prompt.")
127
+ group.add_argument("--seed", type=int, default=1024, help="Seed for evaluation.")
128
+ group.add_argument("--save-path-suffix", type=str, default="", help="Suffix for the directory of saved samples.")
129
+ group.add_argument("--prompt", type=str, default='', help="Main prompt")
130
+ group.add_argument("--pos-prompt", type=str, default='', help="Prompt for sampling during evaluation.")
131
+ group.add_argument("--neg-prompt", type=str, default='', help="Negative prompt for sampling during evaluation.")
132
+ group.add_argument("--add-pos-prompt", type=str, default='', help="Addition prompt for sampling during evaluation.")
133
+ group.add_argument("--add-neg-prompt", type=str, default='', help="Addition negative prompt for sampling during evaluation.")
134
+ group.add_argument("--pad-face-size", type=float, default=0.7, help="Pad bbox for face align.")
135
+ group.add_argument("--image-path", type=str, default="", help="")
136
+ group.add_argument("--save-path", type=str, default=None, help="Path to save the generated samples.")
137
+ group.add_argument("--input", type=str, default=None, help="test data.")
138
+ group.add_argument("--item-name", type=str, default=None, help="")
139
+ group.add_argument("--cfg-scale", type=float, default=7.5, help="Classifier free guidance scale.")
140
+ group.add_argument("--ip-cfg-scale", type=float, default=0, help="Classifier free guidance scale.")
141
+ group.add_argument("--use-deepcache", type=int, default=1)
142
+ group.add_argument("--use-sage", action="store_true", help="Use sage attention for speed up.")
143
+
144
+ return parser
145
+
146
+ def sanity_check_args(args):
147
+ # VAE channels
148
+ vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
149
+ if not re.match(vae_pattern, args.vae):
150
+ raise ValueError(
151
+ f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
152
+ )
153
+ vae_channels = int(args.vae.split("-")[1][:-1])
154
+ if args.latent_channels is None:
155
+ args.latent_channels = vae_channels
156
+ if vae_channels != args.latent_channels:
157
+ raise ValueError(
158
+ f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
159
+ )
160
+ return args
hymm_sp/constants.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ __all__ = [
5
+ "PROMPT_TEMPLATE", "MODEL_BASE", "PRECISION_TO_TYPE",
6
+ "PRECISIONS", "VAE_PATH", "TEXT_ENCODER_PATH", "TOKENIZER_PATH",
7
+ "TEXT_PROJECTION",
8
+ ]
9
+
10
+ # =================== Constant Values =====================
11
+
12
+ PRECISION_TO_TYPE = {
13
+ 'fp32': torch.float32,
14
+ 'fp16': torch.float16,
15
+ 'bf16': torch.bfloat16,
16
+ }
17
+
18
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
19
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
20
+ "1. The main content and theme of the video."
21
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
22
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
23
+ "4. background environment, light, style and atmosphere."
24
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
25
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
26
+ )
27
+
28
+ PROMPT_TEMPLATE = {
29
+ "li-dit-encode-video": {"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "crop_start": 95},
30
+ }
31
+
32
+ # ======================= Model ======================
33
+ PRECISIONS = {"fp32", "fp16", "bf16"}
34
+
35
+ # =================== Model Path =====================
36
+ MODEL_BASE = os.getenv("MODEL_BASE")
37
+
38
+ # 3D VAE
39
+ VAE_PATH = {
40
+ "884-16c-hy0801": f"{MODEL_BASE}/vae_3d/hyvae",
41
+ }
42
+
43
+ # Text Encoder
44
+ TEXT_ENCODER_PATH = {
45
+ "clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14",
46
+ "llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1-transformers",
47
+ }
48
+
49
+ # Tokenizer
50
+ TOKENIZER_PATH = {
51
+ "clipL": f"{MODEL_BASE}/openai_clip-vit-large-patch14",
52
+ "llava-llama-3-8b": f"{MODEL_BASE}/llava-llama-3-8b-v1_1-transformers",
53
+ }
54
+
55
+ TEXT_PROJECTION = {
56
+ "linear", # Default, an nn.Linear() layer
57
+ "single_refiner", # Single TokenRefiner. Refer to LI-DiT
58
+ }
hymm_sp/data_kits/data_tools.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import imageio
6
+ import torchvision
7
+ from einops import rearrange
8
+
9
+
10
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
11
+ """
12
+ Saves a batch of videos as a grid animation in GIF or video format.
13
+
14
+ Args:
15
+ videos (torch.Tensor): Input video tensor with shape (batch, channels, time, height, width)
16
+ path (str): Output file path (e.g., "output/videos.gif")
17
+ rescale (bool): If True, rescales video values from [-1, 1] to [0, 1]
18
+ n_rows (int): Number of rows in the grid layout
19
+ fps (int): Frames per second for the output animation
20
+ quality (int): Quality parameter for imageio (1-10, higher = better quality)
21
+
22
+ Process:
23
+ 1. Rearranges tensor dimensions to (time, batch, channels, height, width)
24
+ 2. For each frame in time:
25
+ a. Creates a grid of videos using torchvision.utils.make_grid
26
+ b. Adjusts dimensions to (height, width, channels)
27
+ c. Rescales values if needed
28
+ d. Converts to 8-bit uint8 format (0-255)
29
+ 3. Saves frames as an animated GIF/video using imageio
30
+ """
31
+ # Rearrange dimensions to (time, batch, channels, height, width) for frame-wise processing
32
+ videos = rearrange(videos, "b c t h w -> t b c h w")
33
+ outputs = [] # Stores processed frames for animation
34
+
35
+ for frame in videos:
36
+ # Create a grid of videos with n_rows rows
37
+ grid = torchvision.utils.make_grid(frame, nrow=n_rows)
38
+
39
+ # Convert from (channels, height, width) to (height, width, channels)
40
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
41
+
42
+ # Rescale from [-1, 1] to [0, 1] if needed (common in GAN outputs)
43
+ if rescale:
44
+ grid = (grid + 1.0) / 2.0
45
+
46
+ # Clamp values to valid range [0, 1] and convert to 8-bit uint8 (0-255)
47
+ grid = torch.clamp(grid, 0, 1)
48
+ grid_np = (grid * 255).numpy().astype(np.uint8)
49
+
50
+ outputs.append(grid_np)
51
+
52
+ # Create output directory if it doesn't exist
53
+ os.makedirs(os.path.dirname(path), exist_ok=True)
54
+
55
+ # Save frames as an animated GIF/video
56
+ imageio.mimsave(path, outputs, fps=fps, quality=quality)
57
+
58
+
59
+ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
60
+ """
61
+ Resizes and pads an image to fit a target size while preserving aspect ratio.
62
+
63
+ Args:
64
+ crop_img (np.ndarray): Input image (shape: [height, width, channels])
65
+ size (tuple): Target size in (width, height) format
66
+ color (tuple): RGB color for padding (default: white)
67
+ resize_ratio (float): Scaling factor for resizing before padding (0-1)
68
+
69
+ Returns:
70
+ np.ndarray: Padded image with shape (target_height, target_width, channels)
71
+
72
+ Process:
73
+ 1. Calculates scaling factors to fit image within target size
74
+ 2. Resizes image while preserving aspect ratio
75
+ 3. Adds padding to reach exact target size, centering the resized image
76
+ """
77
+ # Get input image dimensions
78
+ crop_h, crop_w = crop_img.shape[:2]
79
+ target_w, target_h = size # Target dimensions (width, height)
80
+
81
+ # Calculate scaling factors to fit image within target size
82
+ scale_h = target_h / crop_h # Scale needed to fit height
83
+ scale_w = target_w / crop_w # Scale needed to fit width
84
+
85
+ # Choose the smaller scale to avoid exceeding target dimensions
86
+ if scale_w > scale_h:
87
+ # Height is the limiting factor: resize based on height
88
+ resize_h = int(target_h * resize_ratio)
89
+ resize_w = int(crop_w / crop_h * resize_h) # Preserve aspect ratio
90
+ else:
91
+ # Width is the limiting factor: resize based on width
92
+ resize_w = int(target_w * resize_ratio)
93
+ resize_h = int(crop_h / crop_w * resize_w) # Preserve aspect ratio
94
+
95
+ # Resize the image using OpenCV
96
+ resized_img = cv2.resize(crop_img, (resize_w, resize_h))
97
+
98
+ # Calculate padding needed to reach target size (centered)
99
+ pad_left = (target_w - resize_w) // 2
100
+ pad_top = (target_h - resize_h) // 2
101
+ pad_right = target_w - resize_w - pad_left # Ensure total width matches target
102
+ pad_bottom = target_h - resize_h - pad_top # Ensure total height matches target
103
+
104
+ # Add padding with the specified color
105
+ padded_img = cv2.copyMakeBorder(
106
+ resized_img,
107
+ top=pad_top,
108
+ bottom=pad_bottom,
109
+ left=pad_left,
110
+ right=pad_right,
111
+ borderType=cv2.BORDER_CONSTANT,
112
+ value=color
113
+ )
114
+
115
+ return padded_img
hymm_sp/data_kits/video_dataset.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import json
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+ import torchvision.transforms as transforms
9
+ from torch.utils.data import Dataset
10
+ import csv
11
+
12
+
13
+ def fix_nulls(s):
14
+ """
15
+ Helper generator to remove null characters from input lines.
16
+ Prevents parsing errors caused by invalid null bytes in CSV/JSON files.
17
+
18
+ Args:
19
+ s: Input iterable containing strings with potential null characters
20
+
21
+ Yields:
22
+ Strings with null characters replaced by spaces
23
+ """
24
+ for line in s:
25
+ yield line.replace('\0', ' ')
26
+
27
+ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
28
+ """
29
+ Find the closest aspect ratio from predefined buckets
30
+
31
+ Args:
32
+ height: Image height
33
+ width: Image width
34
+ ratios: List of predefined aspect ratios to match against
35
+ buckets: List of size tuples corresponding to ratios
36
+
37
+ Returns:
38
+ Tuple containing:
39
+ - Closest matching size bucket
40
+ - Closest ratio value
41
+ """
42
+ aspect_ratio = float(height) / float(width)
43
+ closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
44
+ closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
45
+ return buckets[closest_ratio_id], float(closest_ratio)
46
+
47
+
48
+ def generate_crop_size_list(base_size=256, patch_size=16, max_ratio=4.0):
49
+ """
50
+ Generate valid crop sizes that maintain compatible dimensions with model patches
51
+
52
+ Args:
53
+ base_size: Base dimension for calculating patch count
54
+ patch_size: Size of model's input patches
55
+ max_ratio: Maximum allowed aspect ratio (height/width)
56
+
57
+ Returns:
58
+ List of (width, height) tuples representing valid crop sizes
59
+ """
60
+ # Calculate total number of patches from base size
61
+ num_patches = round((base_size / patch_size) ** 2)
62
+ assert max_ratio >= 1.0, "Maximum ratio must be at least 1.0"
63
+
64
+ crop_size_list = []
65
+ wp, hp = num_patches, 1 # Initialize with maximum width patches
66
+
67
+ # Generate valid patch combinations
68
+ while wp > 0:
69
+ # Only add sizes that maintain acceptable aspect ratio
70
+ if max(wp, hp) / min(wp, hp) <= max_ratio:
71
+ crop_size_list.append((wp * patch_size, hp * patch_size))
72
+
73
+ # Move to next valid patch configuration
74
+ if (hp + 1) * wp <= num_patches:
75
+ hp += 1
76
+ else:
77
+ wp -= 1
78
+ return crop_size_list
79
+
80
+
81
+ class VideoCSVDataset(Dataset):
82
+ """
83
+ Dataset class for loading video generation data from CSV files
84
+
85
+ Handles:
86
+ - CSV parsing with null character handling
87
+ - Loading prompt and metadata
88
+ - Supporting multiple task types (image-to-video, etc.)
89
+ """
90
+ def __init__(self, csv_path, col_name='prompt', task_type=''):
91
+ """
92
+ Args:
93
+ csv_path: Path to CSV file containing dataset metadata
94
+ col_name: Column name containing generation prompts
95
+ task_type: Type of task (e.g., "i2v" for image-to-video)
96
+ """
97
+ # Read CSV with null character handling
98
+ with open(csv_path, 'r', newline="\n", encoding='utf-8-sig') as csvfile:
99
+ self.dataset = list(csv.DictReader(fix_nulls(csvfile), delimiter=';'))
100
+
101
+ self.col_name = col_name
102
+ self.task_type = task_type
103
+
104
+ def __len__(self):
105
+ """Return total number of samples in dataset"""
106
+ return len(self.dataset)
107
+
108
+ def __getitem__(self, idx):
109
+ """
110
+ Get dataset item by index
111
+
112
+ Args:
113
+ idx: Index of sample to retrieve
114
+
115
+ Returns:
116
+ Dictionary containing:
117
+ - Prompt and metadata
118
+ - Paths to auxiliary files (npy, video, poses)
119
+ - Index for tracking outputs
120
+ """
121
+ example = {}
122
+ example["prompt"] = self.dataset[idx][self.col_name]
123
+ example['seed'] = int(self.dataset[idx]['seed'])
124
+ example['index'] = self.dataset[idx]['index']
125
+
126
+ # Add optional auxiliary paths if present in CSV
127
+ if "npy_path" in self.dataset[idx]:
128
+ example['npy_path'] = self.dataset[idx]['npy_path']
129
+ if "video_path" in self.dataset[idx]:
130
+ example['video_path'] = self.dataset[idx]['video_path']
131
+ if "monst3r_poses" in self.dataset[idx]:
132
+ example['monst3r_poses'] = self.dataset[idx]['monst3r_poses']
133
+
134
+ # Add image reference path for image-to-video tasks
135
+ if self.task_type == "i2v":
136
+ example['ref_image'] = self.dataset[idx]['ref_image_path']
137
+
138
+ return example
139
+
140
+
141
+ class JsonDataset(object):
142
+ """
143
+ Dataset class for loading data from JSON files and image sequences
144
+
145
+ Handles:
146
+ - Reading image data from multiple formats
147
+ - Preprocessing for model compatibility
148
+ - Generating conditional and unconditional inputs
149
+ """
150
+ def __init__(self, args):
151
+ """
152
+ Args:
153
+ args: Command-line arguments containing configuration
154
+ """
155
+ self.args = args
156
+ self.data_list = args.input
157
+ self.pad_color = (255, 255, 255) # White padding
158
+ self.llava_size = (336, 336) # Standard size for LLaVA model
159
+ self.ref_size = (args.video_size[1], args.video_size[0]) # Reference output size
160
+
161
+ # Get list of data paths from input list or single file
162
+ if self.data_list.endswith('.list'):
163
+ self.data_paths = [line.strip() for line in open(self.data_list, 'r')] if self.data_list else []
164
+ else:
165
+ self.data_paths = [self.data_list]
166
+
167
+ # Transformation pipeline for LLaVA model input
168
+ self.llava_transform = transforms.Compose(
169
+ [
170
+ transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
171
+ transforms.ToTensor(),
172
+ transforms.Normalize(
173
+ (0.48145466, 0.4578275, 0.4082107),
174
+ (0.26862954, 0.26130258, 0.27577711)
175
+ ),
176
+ ]
177
+ )
178
+
179
+ def __len__(self):
180
+ """Return total number of data items"""
181
+ return len(self.data_paths)
182
+
183
+ def read_image(self, image_path):
184
+ """
185
+ Read image from path with fallback handling
186
+
187
+ Args:
188
+ image_path: Path to image file or dictionary containing path
189
+
190
+ Returns:
191
+ Tuple of (LLaVA-formatted image, reference-sized image)
192
+ """
193
+ # Extract path from dictionary if needed
194
+ if isinstance(image_path, dict):
195
+ image_path = image_path['seg_item_image_path']
196
+
197
+ try:
198
+ # Primary method: OpenCV for faster reading
199
+ face_image_masked = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
200
+ except:
201
+ # Fallback: PIL for special formats
202
+ face_image_masked = Image.open(image_path).convert('RGB')
203
+
204
+ # Prepare images for different processing stages
205
+ cat_face_image = pad_image(face_image_masked.copy(), self.ref_size)
206
+ llava_face_image = pad_image(face_image_masked.copy(), self.llava_size)
207
+ return llava_face_image, cat_face_image
208
+
209
+ def __getitem__(self, idx):
210
+ """
211
+ Get preprocessed data item by index
212
+
213
+ Args:
214
+ idx: Index of item to retrieve
215
+
216
+ Returns:
217
+ Dictionary containing:
218
+ - Preprocessed tensors for model input
219
+ - Metadata (prompt, index, paths)
220
+ """
221
+ data_path = self.data_paths[idx]
222
+ data_name = os.path.basename(os.path.splitext(data_path)[0])
223
+
224
+ # Load data from JSON or use default parameters
225
+ if data_path.endswith('.json'):
226
+ data = json.load(open(data_path, 'r'))
227
+ llava_item_image, cat_item_image = self.read_image(data)
228
+ item_prompt = data['item_prompt']
229
+ seed = data['seed']
230
+ prompt = data['prompt']
231
+ negative_prompt = data.get('negative_prompt', '') # Default to empty string
232
+ else:
233
+ # Handle non-JSON data (direct image files)
234
+ llava_item_image, cat_item_image = self.read_image(data_path)
235
+ item_prompt = 'object'
236
+ seed = self.args.seed
237
+ prompt = self.args.pos_prompt
238
+ negative_prompt = self.args.neg_prompt
239
+
240
+ # Convert to tensors with appropriate transformations
241
+ llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
242
+ cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0 # Normalize to [0,1]
243
+
244
+ # Create unconditional input (white background)
245
+ uncond_llava_item_image = np.ones_like(llava_item_image) * 255
246
+ uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
247
+
248
+ # Assemble final batch dictionary
249
+ return {
250
+ "pixel_value_llava": llava_item_tensor,
251
+ "uncond_pixel_value_llava": uncond_llava_item_tensor,
252
+ "pixel_value_ref": cat_item_tensor,
253
+ "prompt": prompt,
254
+ "negative_prompt": negative_prompt,
255
+ "seed": seed,
256
+ "name": item_prompt,
257
+ 'data_name': data_name,
258
+ 'index': [idx] # Index for output tracking
259
+ }
hymm_sp/diffusion/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pipelines import HunyuanVideoGamePipeline
2
+ from .schedulers import FlowMatchDiscreteScheduler
3
+
4
+ def load_diffusion_pipeline(args, rank, vae, text_encoder, text_encoder_2, model, scheduler=None,
5
+ device=None, progress_bar_config=None):
6
+ """ Load the denoising scheduler for inference. """
7
+ if scheduler is None:
8
+ scheduler = FlowMatchDiscreteScheduler(
9
+ shift=args.flow_shift_eval_video,
10
+ reverse=args.flow_reverse,
11
+ solver=args.flow_solver,
12
+ )
13
+ # Only enable progress bar for rank 0
14
+ progress_bar_config = progress_bar_config or {'leave': True, 'disable': rank != 0}
15
+
16
+ pipeline = HunyuanVideoGamePipeline(vae=vae,
17
+ text_encoder=text_encoder,
18
+ text_encoder_2=text_encoder_2,
19
+ transformer=model,
20
+ scheduler=scheduler,
21
+ # safety_checker=None,
22
+ # feature_extractor=None,
23
+ # requires_safety_checker=False,
24
+ progress_bar_config=progress_bar_config,
25
+ args=args,
26
+ )
27
+ if not args.cpu_offload:
28
+ pipeline = pipeline.to(device)
29
+
30
+ return pipeline
hymm_sp/diffusion/pipelines/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Import the HunyuanVideoGamePipeline class from the current package
2
+ # This pipeline is specifically designed for handling video game content generation
3
+ # using the Hunyuan model architecture, providing specialized functionality
4
+ # for game-related video synthesis, character animation, and environment rendering.
5
+ from .pipeline_hunyuan_video_game import HunyuanVideoGamePipeline
hymm_sp/diffusion/pipelines/pipeline_hunyuan_video_game.py ADDED
@@ -0,0 +1,1152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import numpy as np
22
+ import torch
23
+ from packaging import version
24
+ from diffusers.utils import BaseOutput
25
+ from dataclasses import dataclass
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.configuration_utils import FrozenDict
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
30
+ from diffusers.models import AutoencoderKL, ImageProjection
31
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ deprecate,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+
44
+ from hymm_sp.constants import PRECISION_TO_TYPE
45
+ from hymm_sp.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
46
+ from hymm_sp.text_encoder import TextEncoder
47
+ from einops import rearrange
48
+ from ...modules import HYVideoDiffusionTransformer
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """"""
53
+
54
+
55
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
56
+ """
57
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
58
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
59
+ """
60
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
61
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
62
+ # rescale the results from guidance (fixes overexposure)
63
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
64
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
65
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
66
+ return noise_cfg
67
+
68
+
69
+ def retrieve_timesteps(
70
+ scheduler,
71
+ num_inference_steps: Optional[int] = None,
72
+ device: Optional[Union[str, torch.device]] = None,
73
+ timesteps: Optional[List[int]] = None,
74
+ sigmas: Optional[List[float]] = None,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
79
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
80
+
81
+ Args:
82
+ scheduler (`SchedulerMixin`):
83
+ The scheduler to get timesteps from.
84
+ num_inference_steps (`int`):
85
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
86
+ must be `None`.
87
+ device (`str` or `torch.device`, *optional*):
88
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
89
+ timesteps (`List[int]`, *optional*):
90
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
91
+ `num_inference_steps` and `sigmas` must be `None`.
92
+ sigmas (`List[float]`, *optional*):
93
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
94
+ `num_inference_steps` and `timesteps` must be `None`.
95
+
96
+ Returns:
97
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
98
+ second element is the number of inference steps.
99
+ """
100
+ if timesteps is not None and sigmas is not None:
101
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
102
+ if timesteps is not None:
103
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104
+ if not accepts_timesteps:
105
+ raise ValueError(
106
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
107
+ f" timestep schedules. Please check whether you are using the correct scheduler."
108
+ )
109
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
110
+ timesteps = scheduler.timesteps
111
+ num_inference_steps = len(timesteps)
112
+ elif sigmas is not None:
113
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114
+ if not accept_sigmas:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
118
+ )
119
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
120
+ timesteps = scheduler.timesteps
121
+ num_inference_steps = len(timesteps)
122
+ else:
123
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
124
+ timesteps = scheduler.timesteps
125
+ return timesteps, num_inference_steps
126
+
127
+ @dataclass
128
+ class HunyuanVideoPipelineOutput(BaseOutput):
129
+ videos: Union[torch.Tensor, np.ndarray]
130
+
131
+
132
+ class HunyuanVideoGamePipeline(DiffusionPipeline):
133
+ r"""
134
+ Pipeline for text-to-video generation using HunyuanVideo.
135
+
136
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
137
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
138
+
139
+ Args:
140
+ vae ([`AutoencoderKL`]):
141
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
142
+ text_encoder ([`TextEncoder`]):
143
+ Frozen text-encoder.
144
+ text_encoder_2 ([`TextEncoder`]):
145
+ Frozen text-encoder_2.
146
+ transformer ([`HYVideoDiffusionTransformer`]):
147
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
148
+ scheduler ([`SchedulerMixin`]):
149
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
150
+ """
151
+
152
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
153
+ _optional_components = ["text_encoder_2"]
154
+ _exclude_from_cpu_offload = ["transformer"]
155
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
156
+
157
+ def __init__(
158
+ self,
159
+ vae: AutoencoderKL,
160
+ text_encoder: TextEncoder,
161
+ transformer: HYVideoDiffusionTransformer,
162
+ scheduler: KarrasDiffusionSchedulers,
163
+ text_encoder_2: Optional[TextEncoder] = None,
164
+ progress_bar_config: Dict[str, Any] = None,
165
+ args=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ # ==========================================================================================
170
+ if progress_bar_config is None:
171
+ progress_bar_config = {}
172
+ if not hasattr(self, '_progress_bar_config'):
173
+ self._progress_bar_config = {}
174
+ self._progress_bar_config.update(progress_bar_config)
175
+
176
+ self.args = args
177
+ # ==========================================================================================
178
+
179
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
180
+ deprecation_message = (
181
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
182
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
183
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
184
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
185
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
186
+ " file"
187
+ )
188
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
189
+ new_config = dict(scheduler.config)
190
+ new_config["steps_offset"] = 1
191
+ scheduler._internal_dict = FrozenDict(new_config)
192
+
193
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
194
+ deprecation_message = (
195
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
196
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
197
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
198
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
199
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
200
+ )
201
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
202
+ new_config = dict(scheduler.config)
203
+ new_config["clip_sample"] = False
204
+ scheduler._internal_dict = FrozenDict(new_config)
205
+
206
+ self.register_modules(
207
+ vae=vae,
208
+ text_encoder=text_encoder,
209
+ transformer=transformer,
210
+ scheduler=scheduler,
211
+ text_encoder_2=text_encoder_2
212
+ )
213
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
215
+
216
+ def encode_prompt(
217
+ self,
218
+ prompt,
219
+ device,
220
+ num_videos_per_prompt,
221
+ do_classifier_free_guidance,
222
+ negative_prompt=None,
223
+ prompt_embeds: Optional[torch.Tensor] = None,
224
+ attention_mask: Optional[torch.Tensor] = None,
225
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
226
+ negative_attention_mask: Optional[torch.Tensor] = None,
227
+ lora_scale: Optional[float] = None,
228
+ clip_skip: Optional[int] = None,
229
+ text_encoder: Optional[TextEncoder] = None,
230
+ data_type: Optional[str] = "image",
231
+ ):
232
+ r"""
233
+ Encodes the prompt into text encoder hidden states.
234
+
235
+ Args:
236
+ prompt (`str` or `List[str]`, *optional*):
237
+ prompt to be encoded
238
+ device: (`torch.device`):
239
+ torch device
240
+ num_videos_per_prompt (`int`):
241
+ number of images that should be generated per prompt
242
+ do_classifier_free_guidance (`bool`):
243
+ whether to use classifier free guidance or not
244
+ negative_prompt (`str` or `List[str]`, *optional*):
245
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
246
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
247
+ less than `1`).
248
+ pixel_value_llava (`torch.Tensor`, *optional*):
249
+ The image tensor for llava.
250
+ uncond_pixel_value_llava (`torch.Tensor`, *optional*):
251
+ The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
252
+ less than `1`).
253
+ prompt_embeds (`torch.Tensor`, *optional*):
254
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
255
+ provided, text embeddings will be generated from `prompt` input argument.
256
+ attention_mask (`torch.Tensor`, *optional*):
257
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
258
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
259
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
260
+ argument.
261
+ negative_attention_mask (`torch.Tensor`, *optional*):
262
+ lora_scale (`float`, *optional*):
263
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
264
+ clip_skip (`int`, *optional*):
265
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
266
+ the output of the pre-final layer will be used for computing the prompt embeddings.
267
+ text_encoder (TextEncoder, *optional*):
268
+ """
269
+ if text_encoder is None:
270
+ text_encoder = self.text_encoder
271
+
272
+ # set lora scale so that monkey patched LoRA
273
+ # function of text encoder can correctly access it
274
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
275
+ self._lora_scale = lora_scale
276
+
277
+ # dynamically adjust the LoRA scale
278
+ if not USE_PEFT_BACKEND:
279
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
280
+ else:
281
+ scale_lora_layers(text_encoder.model, lora_scale)
282
+
283
+ if prompt is not None and isinstance(prompt, str):
284
+ batch_size = 1
285
+ elif prompt is not None and isinstance(prompt, list):
286
+ batch_size = len(prompt)
287
+ else:
288
+ batch_size = prompt_embeds.shape[0]
289
+
290
+ if prompt_embeds is None:
291
+ # textual inversion: process multi-vector tokens if necessary
292
+ if isinstance(self, TextualInversionLoaderMixin):
293
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
294
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
295
+
296
+ if clip_skip is None:
297
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
298
+ prompt_embeds = prompt_outputs.hidden_state
299
+ else:
300
+ prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
301
+ # Access the `hidden_states` first, that contains a tuple of
302
+ # all the hidden states from the encoder layers. Then index into
303
+ # the tuple to access the hidden states from the desired layer.
304
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
305
+ # We also need to apply the final LayerNorm here to not mess with the
306
+ # representations. The `last_hidden_states` that we typically use for
307
+ # obtaining the final prompt representations passes through the LayerNorm
308
+ # layer.
309
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
310
+
311
+ attention_mask = prompt_outputs.attention_mask
312
+ if attention_mask is not None:
313
+ attention_mask = attention_mask.to(device)
314
+ bs_embed, seq_len = attention_mask.shape
315
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
316
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
317
+
318
+
319
+ if text_encoder is not None:
320
+ prompt_embeds_dtype = text_encoder.dtype
321
+ elif self.transformer is not None:
322
+ prompt_embeds_dtype = self.transformer.dtype
323
+ else:
324
+ prompt_embeds_dtype = prompt_embeds.dtype
325
+
326
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
327
+
328
+ if prompt_embeds.ndim == 2:
329
+ bs_embed, _ = prompt_embeds.shape
330
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
331
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
332
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
333
+ else:
334
+ bs_embed, seq_len, _ = prompt_embeds.shape
335
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
336
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
337
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
338
+
339
+ # get unconditional embeddings for classifier free guidance
340
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
341
+ uncond_tokens: List[str]
342
+ if negative_prompt is None:
343
+ uncond_tokens = [""] * batch_size
344
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
345
+ raise TypeError(
346
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
347
+ f" {type(prompt)}."
348
+ )
349
+ elif isinstance(negative_prompt, str):
350
+ uncond_tokens = [negative_prompt]
351
+ elif batch_size != len(negative_prompt):
352
+ raise ValueError(
353
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
354
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
355
+ " the batch size of `prompt`."
356
+ )
357
+ else:
358
+ uncond_tokens = negative_prompt
359
+
360
+ # textual inversion: process multi-vector tokens if necessary
361
+ if isinstance(self, TextualInversionLoaderMixin):
362
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
363
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
364
+
365
+ negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
366
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
367
+
368
+ negative_attention_mask = negative_prompt_outputs.attention_mask
369
+ if negative_attention_mask is not None:
370
+ negative_attention_mask = negative_attention_mask.to(device)
371
+ _, seq_len = negative_attention_mask.shape
372
+ negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt)
373
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
374
+
375
+ if do_classifier_free_guidance:
376
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
377
+ seq_len = negative_prompt_embeds.shape[1]
378
+
379
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
380
+
381
+ if negative_prompt_embeds.ndim == 2:
382
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt)
383
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
384
+ else:
385
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
386
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
387
+
388
+ if text_encoder is not None:
389
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
390
+ # Retrieve the original scale by scaling back the LoRA layers
391
+ unscale_lora_layers(text_encoder.model, lora_scale)
392
+
393
+ return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
394
+
395
+ def decode_latents(self, latents, enable_tiling=True):
396
+ deprecation_message = \
397
+ "The decode_latents method is deprecated and will be removed \
398
+ in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
399
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
400
+
401
+ latents = 1 / self.vae.config.scaling_factor * latents
402
+ if enable_tiling:
403
+ self.vae.enable_tiling()
404
+ image = self.vae.decode(latents, return_dict=False)[0]
405
+ self.vae.disable_tiling()
406
+ else:
407
+ image = self.vae.decode(latents, return_dict=False)[0]
408
+ image = (image / 2 + 0.5).clamp(0, 1)
409
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
410
+ if image.ndim==4: image = image.cpu().permute(0, 2, 3, 1).float()
411
+ else: image = image.cpu().float()
412
+ return image
413
+
414
+ def prepare_extra_func_kwargs(self, func, kwargs):
415
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
416
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
417
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
418
+ # and should be between [0, 1]
419
+ extra_step_kwargs = {}
420
+
421
+ for k, v in kwargs.items():
422
+ accepts = k in set(inspect.signature(func).parameters.keys())
423
+ if accepts:
424
+ extra_step_kwargs[k] = v
425
+ return extra_step_kwargs
426
+
427
+ def check_inputs(
428
+ self,
429
+ prompt,
430
+ height,
431
+ width,
432
+ frame,
433
+ callback_steps,
434
+ negative_prompt=None,
435
+ prompt_embeds=None,
436
+ negative_prompt_embeds=None,
437
+ callback_on_step_end_tensor_inputs=None,
438
+ vae_ver='88-4c-sd'
439
+ ):
440
+ if height % 8 != 0 or width % 8 != 0:
441
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
442
+
443
+ # if frame is not None:
444
+ # if '884' in vae_ver:
445
+ # if frame!=1 and (frame-1)%4!=0:
446
+ # raise ValueError(f'`frame` has to be 1 or a multiple of 4 but is {frame}.')
447
+ # elif '888' in vae_ver:
448
+ # if frame!=1 and (frame-1)%8!=0:
449
+ # raise ValueError(f'`frame` has to be 1 or a multiple of 8 but is {frame}.')
450
+
451
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
452
+ raise ValueError(
453
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
454
+ f" {type(callback_steps)}."
455
+ )
456
+ if callback_on_step_end_tensor_inputs is not None and not all(
457
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
458
+ ):
459
+ raise ValueError(
460
+ f"`callback_on_step_end_tensor_inputs` has to be in \
461
+ {self._callback_tensor_inputs}, but found \
462
+ {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
463
+ )
464
+
465
+ if prompt is not None and prompt_embeds is not None:
466
+ raise ValueError(
467
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
468
+ " only forward one of the two."
469
+ )
470
+ elif prompt is None and prompt_embeds is None:
471
+ raise ValueError(
472
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
473
+ )
474
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
475
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
476
+
477
+ if negative_prompt is not None and negative_prompt_embeds is not None:
478
+ raise ValueError(
479
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
480
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
481
+ )
482
+
483
+
484
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
485
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
486
+ raise ValueError(
487
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
488
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
489
+ f" {negative_prompt_embeds.shape}."
490
+ )
491
+
492
+ def get_timesteps(self, num_inference_steps, strength, device):
493
+ # get the original timestep using init_timestep
494
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
495
+
496
+ t_start = max(num_inference_steps - init_timestep, 0)
497
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
498
+ if hasattr(self.scheduler, "set_begin_index"):
499
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
500
+
501
+ return timesteps.to(device), num_inference_steps - t_start
502
+
503
+ def prepare_latents(self, batch_size, num_channels_latents, num_inference_steps,
504
+ height, width, frame, dtype, device, timesteps,generator,
505
+ latents=None, gt_latents=None, denoise_strength=1.0,):
506
+ shape = (
507
+ batch_size,
508
+ num_channels_latents,
509
+ frame,
510
+ int(height) // self.vae_scale_factor,
511
+ int(width) // self.vae_scale_factor,
512
+ )
513
+ if isinstance(generator, list) and len(generator) != batch_size:
514
+ raise ValueError(
515
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
516
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
517
+ )
518
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
519
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
520
+
521
+ if gt_latents.shape[2] == 1:
522
+ gt_latents = gt_latents.repeat(1, 1, frame, 1, 1)
523
+
524
+ # TODO: correct
525
+ x0 = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
526
+ # print("!!!!!!!!!!!!!! RANDOM NOISE !!!!!!!!!!!!!!!!!!")
527
+ # x0 = randn_tensor(shape, device=device, dtype=dtype)
528
+ x1 = gt_latents
529
+
530
+ t = torch.tensor([0.999]).to(device=device)
531
+ latents = x0 * t + x1 * (1 - t)
532
+ latents = torch.randn_like(x1)
533
+ # print("!!!randn_like", latents.shape)
534
+ latents = latents.to(dtype=dtype)
535
+
536
+ if latents is None:
537
+ latents = noise
538
+ original_latents = None
539
+ else:
540
+ latents = latents.to(device)
541
+
542
+ if hasattr(self.scheduler, "init_noise_sigma"):
543
+ latents = latents * self.scheduler.init_noise_sigma
544
+
545
+ return latents, timesteps
546
+
547
+ # Copied from diffusers.pipelines.latent_consistency_models.
548
+ # pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
549
+ def get_guidance_scale_embedding(
550
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
551
+ ) -> torch.Tensor:
552
+ """
553
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
554
+
555
+ Args:
556
+ w (`torch.Tensor`):
557
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
558
+ embedding_dim (`int`, *optional*, defaults to 512):
559
+ Dimension of the embeddings to generate.
560
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
561
+ Data type of the generated embeddings.
562
+
563
+ Returns:
564
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
565
+ """
566
+ assert len(w.shape) == 1
567
+ w = w * 1000.0
568
+
569
+ half_dim = embedding_dim // 2
570
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
571
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
572
+ emb = w.to(dtype)[:, None] * emb[None, :]
573
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
574
+ if embedding_dim % 2 == 1: # zero pad
575
+ emb = torch.nn.functional.pad(emb, (0, 1))
576
+ assert emb.shape == (w.shape[0], embedding_dim)
577
+ return emb
578
+
579
+ @property
580
+ def guidance_scale(self):
581
+ return self._guidance_scale
582
+
583
+ @property
584
+ def guidance_rescale(self):
585
+ return self._guidance_rescale
586
+
587
+ @property
588
+ def clip_skip(self):
589
+ return self._clip_skip
590
+
591
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
592
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
593
+ # corresponds to doing no classifier free guidance.
594
+ @property
595
+ def do_classifier_free_guidance(self):
596
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
597
+ return self._guidance_scale > 1
598
+
599
+ @property
600
+ def cross_attention_kwargs(self):
601
+ return self._cross_attention_kwargs
602
+
603
+ @property
604
+ def num_timesteps(self):
605
+ return self._num_timesteps
606
+
607
+ @property
608
+ def interrupt(self):
609
+ return self._interrupt
610
+
611
+ @torch.no_grad()
612
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
613
+ def __call__(
614
+ self,
615
+ prompt: Union[str, List[str]],
616
+ cam_latents: Union[torch.Tensor], # cam_latents
617
+ last_latents: Union[torch.Tensor],
618
+ uncond_cam_latents: Union[torch.Tensor],
619
+ gt_latents: Union[torch.Tensor],
620
+ height: int,
621
+ width: int,
622
+ video_length: int, # frame is called video_len in hunyuan_multimodal/dev_video
623
+ data_type: str='video',
624
+ num_inference_steps: int = 50,
625
+ timesteps: List[int] = None,
626
+ sigmas: List[float] = None,
627
+ guidance_scale: float = 7.5,
628
+ negative_prompt: Optional[Union[str, List[str]]] = None,
629
+ ref_latents: Optional[torch.Tensor] = None,
630
+ uncond_ref_latents: Optional[torch.Tensor] = None,
631
+ ip_cfg_scale: float = 0.0,
632
+ use_deepcache: int = 1,
633
+ num_videos_per_prompt: Optional[int] = 1,
634
+ eta: float = 0.0,
635
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
636
+ latents: Optional[torch.Tensor] = None,
637
+ prompt_embeds: Optional[torch.Tensor] = None,
638
+ attention_mask: Optional[torch.Tensor] = None,
639
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
640
+ negative_attention_mask: Optional[torch.Tensor] = None,
641
+ output_type: Optional[str] = "pil",
642
+ return_dict: bool = True,
643
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
644
+ guidance_rescale: float = 0.0,
645
+ clip_skip: Optional[int] = None,
646
+ callback_on_step_end: Optional[
647
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
648
+ ] = None,
649
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
650
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
651
+ vae_ver: str='88-4c-sd',
652
+ enable_tiling: bool=False,
653
+ n_tokens: Optional[int] = None,
654
+ video_val_flag: bool=False,
655
+ denoise_strength: float = 1.0,
656
+ mask = None,
657
+ cpu_offload: bool=False,
658
+ use_sage: bool=False,
659
+ **kwargs,
660
+ ):
661
+ r"""
662
+ The call function to the pipeline for generation.
663
+
664
+ Args:
665
+ prompt (`str` or `List[str]`):
666
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
667
+ height (`int`):
668
+ The height in pixels of the generated image.
669
+ width (`int`):
670
+ The width in pixels of the generated image.
671
+ video_length (`int`):
672
+ The number of frames in the generated video.
673
+ num_inference_steps (`int`, *optional*, defaults to 50):
674
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
675
+ expense of slower inference.
676
+ timesteps (`List[int]`, *optional*):
677
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
678
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
679
+ passed will be used. Must be in descending order.
680
+ sigmas (`List[float]`, *optional*):
681
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
682
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
683
+ will be used.
684
+ guidance_scale (`float`, *optional*, defaults to 7.5):
685
+ A higher guidance scale value encourages the model to generate images closely linked to the text
686
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
687
+ negative_prompt (`str` or `List[str]`, *optional*):
688
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
689
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
690
+ ref_latents (`torch.Tensor`, *optional*):
691
+ The image tensor for time-concat.
692
+ uncond_ref_latents (`torch.Tensor`, *optional*):
693
+ The image tensor for time-concat. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
694
+ less than `1`).
695
+ pixel_value_llava (`torch.Tensor`, *optional*):
696
+ The image tensor for llava.
697
+ uncond_pixel_value_llava (`torch.Tensor`, *optional*):
698
+ The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
699
+ less than `1`).
700
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
701
+ The number of images to generate per prompt.
702
+ eta (`float`, *optional*, defaults to 0.0):
703
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
704
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
705
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
706
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
707
+ generation deterministic.
708
+ latents (`torch.Tensor`, *optional*):
709
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
710
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
711
+ tensor is generated by sampling using the supplied random `generator`.
712
+ prompt_embeds (`torch.Tensor`, *optional*):
713
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
714
+ provided, text embeddings are generated from the `prompt` input argument.
715
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
716
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
717
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
718
+ output_type (`str`, *optional*, defaults to `"pil"`):
719
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
720
+ return_dict (`bool`, *optional*, defaults to `True`):
721
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
722
+ plain tuple.
723
+ cross_attention_kwargs (`dict`, *optional*):
724
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
725
+ [`self.processor`]
726
+ (https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
727
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
728
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
729
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
730
+ using zero terminal SNR.
731
+ clip_skip (`int`, *optional*):
732
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
733
+ the output of the pre-final layer will be used for computing the prompt embeddings.
734
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
735
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
736
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
737
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
738
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
739
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
740
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
741
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
742
+ `._callback_tensor_inputs` attribute of your pipeline class.
743
+
744
+ Examples:
745
+
746
+ Returns:
747
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
748
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
749
+ otherwise a list with the generated images is returned.
750
+ """
751
+ callback = kwargs.pop("callback", None)
752
+ callback_steps = kwargs.pop("callback_steps", None)
753
+ if callback is not None:
754
+ deprecate(
755
+ "callback",
756
+ "1.0.0",
757
+ "Passing `callback` as an input argument to \
758
+ `__call__` is deprecated, consider using `callback_on_step_end`",
759
+ )
760
+ if callback_steps is not None:
761
+ deprecate(
762
+ "callback_steps",
763
+ "1.0.0",
764
+ "Passing `callback_steps` as an input argument to \
765
+ `__call__` is deprecated, consider using `callback_on_step_end`",
766
+ )
767
+
768
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
769
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
770
+
771
+ # 0. Default height and width to transformer
772
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
773
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
774
+ # to deal with lora scaling and other possible forward hooks
775
+
776
+ # 1. Check inputs. Raise error if not correct
777
+ self.check_inputs(
778
+ prompt,
779
+ height,
780
+ width,
781
+ video_length,
782
+ callback_steps,
783
+ negative_prompt,
784
+ prompt_embeds,
785
+ negative_prompt_embeds,
786
+ callback_on_step_end_tensor_inputs,
787
+ vae_ver=vae_ver
788
+ )
789
+
790
+ self._guidance_scale = guidance_scale
791
+ self._guidance_rescale = guidance_rescale
792
+ self._clip_skip = clip_skip
793
+ self._cross_attention_kwargs = cross_attention_kwargs
794
+ self._interrupt = False
795
+
796
+ # 2. Define call parameters
797
+ if prompt is not None and isinstance(prompt, str):
798
+ batch_size = 1
799
+ elif prompt is not None and isinstance(prompt, list):
800
+ batch_size = len(prompt)
801
+ else:
802
+ batch_size = prompt_embeds.shape[0]
803
+
804
+ # device = self._execution_device
805
+ device = torch.device("cuda")
806
+
807
+ # 3. Encode input prompt
808
+ lora_scale = (
809
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
810
+ )
811
+
812
+ prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask = \
813
+ self.encode_prompt(
814
+ prompt,
815
+ device,
816
+ num_videos_per_prompt,
817
+ self.do_classifier_free_guidance,
818
+ negative_prompt,
819
+ prompt_embeds=prompt_embeds,
820
+ attention_mask=attention_mask,
821
+ negative_prompt_embeds=negative_prompt_embeds,
822
+ negative_attention_mask=negative_attention_mask,
823
+ lora_scale=lora_scale,
824
+ clip_skip=self.clip_skip,
825
+ data_type=data_type
826
+ )
827
+
828
+ if self.text_encoder_2 is not None:
829
+ prompt_embeds_2, negative_prompt_embeds_2, prompt_mask_2, negative_prompt_mask_2 = \
830
+ self.encode_prompt(
831
+ prompt,
832
+ device,
833
+ num_videos_per_prompt,
834
+ self.do_classifier_free_guidance,
835
+ negative_prompt,
836
+ prompt_embeds=None,
837
+ attention_mask=None,
838
+ negative_prompt_embeds=None,
839
+ negative_attention_mask=None,
840
+ lora_scale=lora_scale,
841
+ clip_skip=self.clip_skip,
842
+ text_encoder=self.text_encoder_2,
843
+ )
844
+ else:
845
+ prompt_embeds_2 = None
846
+ negative_prompt_embeds_2 = None
847
+ prompt_mask_2 = None
848
+ negative_prompt_mask_2 = None
849
+
850
+ # For classifier free guidance, we need to do two forward passes.
851
+ # Here we concatenate the unconditional and text embeddings into a single batch
852
+ # to avoid doing two forward passes
853
+ if self.do_classifier_free_guidance:
854
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
855
+ if prompt_mask is not None:
856
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
857
+ if prompt_embeds_2 is not None:
858
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
859
+ if prompt_mask_2 is not None:
860
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
861
+
862
+ if self.do_classifier_free_guidance:
863
+ if ref_latents is not None:
864
+ ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
865
+ if prompt_mask[0].sum() > 575:
866
+ prompt_mask[0] = torch.cat(
867
+ [torch.ones((1, prompt_mask[0].sum() - 575)).to(prompt_mask),
868
+ torch.zeros((1, prompt_mask.shape[1] - prompt_mask[0].sum() + 575)).to(prompt_mask)], dim=1)
869
+
870
+ if ip_cfg_scale>0:
871
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1:]])
872
+ prompt_embeds_2 = torch.cat([prompt_embeds_2, prompt_embeds_2[1:]])
873
+ prompt_mask = torch.cat([prompt_mask, prompt_mask[1:]], dim=0)
874
+ ref_latents = torch.cat([uncond_ref_latents, uncond_ref_latents, ref_latents[1:]], dim=0)
875
+
876
+ # 4. Prepare timesteps
877
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
878
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
879
+ )
880
+ timesteps, num_inference_steps = retrieve_timesteps(
881
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, **extra_set_timesteps_kwargs,
882
+ )
883
+
884
+
885
+ if '884' in vae_ver:
886
+ frame_length = (video_length - 2) // 4 + 2
887
+ elif '888' in vae_ver:
888
+ frame_length = (video_length - 1) // 8 + 1
889
+ else:
890
+ frame_length = video_length
891
+
892
+ # 5. Prepare latent variables
893
+ num_channels_latents = self.transformer.config.in_channels
894
+ latents, timesteps = self.prepare_latents(
895
+ batch_size * num_videos_per_prompt,
896
+ num_channels_latents,
897
+ num_inference_steps,
898
+ height,
899
+ width,
900
+ frame_length,
901
+ prompt_embeds.dtype,
902
+ device,
903
+ timesteps,
904
+ generator,
905
+ latents,
906
+ gt_latents,
907
+ denoise_strength,
908
+ )
909
+
910
+ gt_latents = gt_latents.repeat(1, 1, frame_length, 1, 1)
911
+ gt_latents_concat = gt_latents.clone()
912
+
913
+ if frame_length == 10:
914
+ gt_latents_concat[:,:,1:,:,:] = 0.0
915
+ mask_concat = torch.ones(gt_latents.shape[0],
916
+ 1,
917
+ gt_latents.shape[2],
918
+ gt_latents.shape[3],
919
+ gt_latents.shape[4]).to(device=gt_latents.device)
920
+ mask_concat[:, :, 1:,...] = 0.0
921
+ else:
922
+ gt_latents_concat[:,:,gt_latents_concat.shape[2]//2:,:,:] = 0.0
923
+ mask_zeros = torch.zeros(gt_latents.shape[0],
924
+ 1,
925
+ gt_latents.shape[2]//2,
926
+ gt_latents.shape[3],
927
+ gt_latents.shape[4])
928
+ mask_ones = torch.ones(gt_latents.shape[0],
929
+ 1,
930
+ gt_latents.shape[2]//2,
931
+ gt_latents.shape[3],
932
+ gt_latents.shape[4])
933
+ mask_concat = torch.cat([mask_ones, mask_zeros], dim=2).to(device=gt_latents.device)
934
+
935
+ # 6. Prepare extra step kwargs.
936
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
937
+ self.scheduler.step, {"generator": generator, "eta": eta},
938
+ )
939
+
940
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
941
+ autocast_enabled = (target_dtype != torch.float32) and not self.args.val_disable_autocast
942
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
943
+ vae_autocast_enabled = (vae_dtype != torch.float32) and not self.args.val_disable_autocast
944
+
945
+ # 7. Denoising loop
946
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
947
+ self._num_timesteps = len(timesteps)
948
+
949
+ start_scale = ip_cfg_scale # 3.0
950
+ end_scale = 1.0
951
+ step_scale = (start_scale - end_scale) / (self._num_timesteps - 1 + 1e-3)
952
+ if cpu_offload: torch.cuda.empty_cache()
953
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
954
+ for i, t in enumerate(timesteps):
955
+ if self.interrupt:
956
+ continue
957
+
958
+ if last_latents.shape[2] == 1:
959
+ latents[:,:,0,:,:] = last_latents[:,:,-1,:,:]
960
+ else:
961
+ latents[:,:,:latents.shape[2]//2,:,:] = last_latents
962
+ gt_latents_concat[:,:,:latents.shape[2]//2,:,:] = last_latents
963
+
964
+ # expand the latents if we are doing classifier free guidance
965
+ latents_concat = torch.concat([latents, gt_latents_concat, mask_concat], dim=1)
966
+ latent_model_input = torch.cat([latents_concat] * 2) \
967
+ if self.do_classifier_free_guidance else latents_concat
968
+
969
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
970
+ t_expand = t.repeat(latent_model_input.shape[0])
971
+
972
+ t_expand = t.repeat(latent_model_input.shape[0])
973
+ guidance_expand = None
974
+
975
+ cam_latents_ = torch.cat([uncond_cam_latents, cam_latents], dim=0) \
976
+ if self.do_classifier_free_guidance else cam_latents
977
+
978
+ # predict the noise residual
979
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
980
+ is_cache = False
981
+ if use_deepcache and num_inference_steps==50:
982
+
983
+ no_cache_steps = list(range(0, 10)) + list(range(10, 40, 2)) + list(range(40, 50))
984
+ if i in no_cache_steps:
985
+ is_cache = False
986
+ else:
987
+ is_cache = True
988
+ if latent_model_input.shape[-1]*latent_model_input.shape[-2]>64*112 and cpu_offload:
989
+ if i==0:
990
+ print(f'cpu_offload={cpu_offload} and \
991
+ {latent_model_input.shape[-2:]} is large, split infer noise-pred')
992
+ noise_pred_uncond = self.transformer(latent_model_input[:1],
993
+ t_expand[:1],
994
+ text_states=prompt_embeds[:1],
995
+ text_mask=prompt_mask[:1],
996
+ text_states_2=prompt_embeds_2[:1],
997
+ freqs_cos=freqs_cis[0],
998
+ freqs_sin=freqs_cis[1],
999
+ guidance=guidance_expand,
1000
+ return_dict=True,
1001
+ is_cache=is_cache,
1002
+ cam_latents=cam_latents_[:1])['x']
1003
+ torch.cuda.empty_cache()
1004
+ noise_pred_text = self.transformer(latent_model_input[1:],
1005
+ t_expand[1:],
1006
+ text_states=prompt_embeds[1:],
1007
+ text_mask=prompt_mask[1:],
1008
+ text_states_2=prompt_embeds_2[1:],
1009
+ freqs_cos=freqs_cis[0],
1010
+ freqs_sin=freqs_cis[1],
1011
+ guidance=guidance_expand,
1012
+ return_dict=True,
1013
+ is_cache=is_cache,
1014
+ cam_latents=cam_latents_[1:])['x']
1015
+ noise_pred = torch.cat([noise_pred_uncond, noise_pred_text], dim=0)
1016
+ torch.cuda.empty_cache()
1017
+ else:
1018
+ noise_pred = self.transformer( # For an input image (1, 256, 256)
1019
+ latent_model_input, # [2, 16, 1, 32, 32] #
1020
+ t_expand, # [2]
1021
+ text_states=prompt_embeds, # [2, 256, 4096]
1022
+ text_mask=prompt_mask, # [2, 256]
1023
+ text_states_2=prompt_embeds_2, # [2, 768]
1024
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
1025
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
1026
+ guidance=guidance_expand,
1027
+ return_dict=True,
1028
+ is_cache=is_cache,
1029
+ cam_latents=cam_latents_,
1030
+ use_sage=use_sage,
1031
+ )['x']
1032
+
1033
+ # perform guidance
1034
+ if self.do_classifier_free_guidance and ip_cfg_scale < 0.1:
1035
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1036
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1037
+
1038
+ if ip_cfg_scale > 0:
1039
+ noise_pred_uncond, noise_pred_text, noise_pred_ip = noise_pred.chunk(3)
1040
+ noise_pred = noise_pred_uncond + self.guidance_scale * \
1041
+ (noise_pred_text - noise_pred_uncond) + start_scale * (noise_pred_ip-noise_pred_text)
1042
+ start_scale -= step_scale
1043
+ if i==0:
1044
+ print(f'i={i}, noise_pred shape={noise_pred.shape}')
1045
+
1046
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1047
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1048
+ noise_pred = rescale_noise_cfg(noise_pred,
1049
+ noise_pred_text,
1050
+ guidance_rescale=self.guidance_rescale)
1051
+
1052
+ # compute the previous noisy sample x_t -> x_t-1
1053
+ # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1054
+ if last_latents.shape[2] == 1:
1055
+ latents[:,:,1:,:,:] = self.scheduler.step(noise_pred[:,:,1:,:,:],
1056
+ t,
1057
+ latents[:,:,1:,:,:],
1058
+ **extra_step_kwargs,
1059
+ return_dict=False)[0]
1060
+ else:
1061
+ latents[:,:,noise_pred.shape[2]//2:,:,:] = self.scheduler.step(
1062
+ noise_pred[:,:,noise_pred.shape[2]//2:,:,:],
1063
+ t,
1064
+ latents[:,:,latents.shape[2]//2:,:,:],
1065
+ **extra_step_kwargs, return_dict=False)[0]
1066
+
1067
+
1068
+ if callback_on_step_end is not None:
1069
+ callback_kwargs = {}
1070
+ for k in callback_on_step_end_tensor_inputs:
1071
+ callback_kwargs[k] = locals()[k]
1072
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1073
+
1074
+ latents = callback_outputs.pop("latents", latents)
1075
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1076
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1077
+
1078
+ # call the callback, if provided
1079
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1080
+ if progress_bar is not None:
1081
+ progress_bar.update()
1082
+ if callback is not None and i % callback_steps == 0:
1083
+ step_idx = i // getattr(self.scheduler, "order", 1)
1084
+ callback(step_idx, t, latents)
1085
+
1086
+ if cpu_offload: torch.cuda.empty_cache()
1087
+ # if mask_latents is not None:
1088
+ # latents = mask_latents * latents + (1 - mask_latents) * original_latents
1089
+ if last_latents.shape[2] == 1:
1090
+ latents = latents[:,:,1:,:,:]
1091
+
1092
+ if not output_type == "latent":
1093
+ expand_temporal_dim = False
1094
+ if len(latents.shape) == 4:
1095
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1096
+ latents = latents.unsqueeze(2)
1097
+ expand_temporal_dim = True
1098
+ elif len(latents.shape) == 5:
1099
+ pass
1100
+ else:
1101
+ raise ValueError(
1102
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
1103
+
1104
+ if not last_latents.shape[2] == 1:
1105
+ last_latents = latents[:,:,latents.shape[2]//2:,:,:]
1106
+ else:
1107
+ last_latents = latents
1108
+ latent_decode = last_latents.clone()
1109
+ latent_decode = latent_decode / self.vae.config.scaling_factor
1110
+
1111
+ with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
1112
+ if enable_tiling:
1113
+ self.vae.enable_tiling()
1114
+ if cpu_offload:
1115
+ self.vae.post_quant_conv.to('cuda')
1116
+ self.vae.decoder.to('cuda')
1117
+ image = self.vae.decode(latent_decode, return_dict=False, generator=generator)[0]
1118
+ self.vae.disable_tiling()
1119
+ if cpu_offload:
1120
+ self.vae.post_quant_conv.to('cpu')
1121
+ self.vae.decoder.to('cpu')
1122
+ torch.cuda.empty_cache()
1123
+ else:
1124
+ image = self.vae.decode(latent_decode, return_dict=False, generator=generator)[0]
1125
+ # if image is None:
1126
+ # return (None, )
1127
+
1128
+ # if expand_temporal_dim or (not video_val_flag and image.shape[2] == 1):
1129
+ # image = image.squeeze(2)
1130
+
1131
+ if image is not None and (expand_temporal_dim or (not video_val_flag and image.shape[2] == 1)):
1132
+ image = image.squeeze(2)
1133
+
1134
+ if image is not None:
1135
+ image = (image / 2 + 0.5).clamp(0, 1)
1136
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1137
+ image = image.cpu().float()
1138
+
1139
+ # Offload all models
1140
+ self.maybe_free_model_hooks()
1141
+
1142
+ if cpu_offload: torch.cuda.empty_cache()
1143
+ if not return_dict:
1144
+ return image
1145
+
1146
+ return_latents = kwargs.get("return_latents", False)
1147
+
1148
+ if return_latents:
1149
+ return HunyuanVideoPipelineOutput(videos=image), \
1150
+ latents, timesteps, last_latents, last_latents[:,:,-1:, ...]
1151
+
1152
+ return HunyuanVideoPipelineOutput(videos=image)
hymm_sp/diffusion/schedulers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
2
+ # from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
hymm_sp/diffusion/schedulers/scheduling_flow_match_discrete.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ @dataclass
34
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
35
+ """
36
+ Output class for the scheduler's `step` function output.
37
+
38
+ Args:
39
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
40
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
41
+ denoising loop.
42
+ """
43
+
44
+ prev_sample: torch.FloatTensor
45
+
46
+
47
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
48
+ """
49
+ Euler scheduler.
50
+
51
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
52
+ methods the library implements for all schedulers such as loading and saving.
53
+
54
+ Args:
55
+ num_train_timesteps (`int`, defaults to 1000):
56
+ The number of diffusion steps to train the model.
57
+ timestep_spacing (`str`, defaults to `"linspace"`):
58
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
59
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
60
+ shift (`float`, defaults to 1.0):
61
+ The shift value for the timestep schedule.
62
+ reverse (`bool`, defaults to `True`):
63
+ Whether to reverse the timestep schedule.
64
+ """
65
+
66
+ _compatibles = []
67
+ order = 1
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ num_train_timesteps: int = 1000,
73
+ shift: float = 1.0,
74
+ reverse: bool = True,
75
+ solver: str = "euler",
76
+ n_tokens: Optional[int] = None,
77
+ ):
78
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
79
+
80
+ if not reverse:
81
+ sigmas = sigmas.flip(0)
82
+
83
+ self.sigmas = sigmas
84
+ # the value fed to model
85
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
86
+
87
+ self._step_index = None
88
+ self._begin_index = None
89
+
90
+ self.supported_solver = ["euler"]
91
+ if solver not in self.supported_solver:
92
+ raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
93
+
94
+ @property
95
+ def step_index(self):
96
+ """
97
+ The index counter for current timestep. It will increase 1 after each scheduler step.
98
+ """
99
+ return self._step_index
100
+
101
+ @property
102
+ def begin_index(self):
103
+ """
104
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
105
+ """
106
+ return self._begin_index
107
+
108
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
109
+ def set_begin_index(self, begin_index: int = 0):
110
+ """
111
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
112
+
113
+ Args:
114
+ begin_index (`int`):
115
+ The begin index for the scheduler.
116
+ """
117
+ self._begin_index = begin_index
118
+
119
+ def _sigma_to_t(self, sigma):
120
+ return sigma * self.config.num_train_timesteps
121
+
122
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None,
123
+ n_tokens: int = None):
124
+ """
125
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
126
+
127
+ Args:
128
+ num_inference_steps (`int`):
129
+ The number of diffusion steps used when generating samples with a pre-trained model.
130
+ device (`str` or `torch.device`, *optional*):
131
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
132
+ n_tokens (`int`, *optional*):
133
+ Number of tokens in the input sequence.
134
+ """
135
+ self.num_inference_steps = num_inference_steps
136
+
137
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
138
+ sigmas = self.sd3_time_shift(sigmas)
139
+
140
+ if not self.config.reverse:
141
+ sigmas = 1 - sigmas
142
+
143
+ self.sigmas = sigmas
144
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
145
+
146
+ # Reset step index
147
+ self._step_index = None
148
+
149
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
150
+ if schedule_timesteps is None:
151
+ schedule_timesteps = self.timesteps
152
+
153
+ indices = (schedule_timesteps == timestep).nonzero()
154
+
155
+ # The sigma index that is taken for the **very** first `step`
156
+ # is always the second index (or the last index if there is only 1)
157
+ # This way we can ensure we don't accidentally skip a sigma in
158
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
159
+ pos = 1 if len(indices) > 1 else 0
160
+
161
+ return indices[pos].item()
162
+
163
+ def _init_step_index(self, timestep):
164
+ if self.begin_index is None:
165
+ if isinstance(timestep, torch.Tensor):
166
+ timestep = timestep.to(self.timesteps.device)
167
+ self._step_index = self.index_for_timestep(timestep)
168
+ else:
169
+ self._step_index = self._begin_index
170
+
171
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
172
+ return sample
173
+
174
+ def sd3_time_shift(self, t: torch.Tensor):
175
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
176
+
177
+ def step(
178
+ self,
179
+ model_output: torch.FloatTensor,
180
+ timestep: Union[float, torch.FloatTensor],
181
+ sample: torch.FloatTensor,
182
+ return_dict: bool = True,
183
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
184
+ """
185
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
186
+ process from the learned model outputs (most often the predicted noise).
187
+
188
+ Args:
189
+ model_output (`torch.FloatTensor`):
190
+ The direct output from learned diffusion model.
191
+ timestep (`float`):
192
+ The current discrete timestep in the diffusion chain.
193
+ sample (`torch.FloatTensor`):
194
+ A current instance of a sample created by the diffusion process.
195
+ return_dict (`bool`):
196
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
197
+ tuple.
198
+
199
+ Returns:
200
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
201
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
202
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
203
+ """
204
+
205
+ if (
206
+ isinstance(timestep, int)
207
+ or isinstance(timestep, torch.IntTensor)
208
+ or isinstance(timestep, torch.LongTensor)
209
+ ):
210
+ raise ValueError(
211
+ (
212
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
213
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
214
+ " one of the `scheduler.timesteps` as a timestep."
215
+ ),
216
+ )
217
+
218
+ if self.step_index is None:
219
+ self._init_step_index(timestep)
220
+
221
+ # Upcast to avoid precision issues when computing prev_sample
222
+ sample = sample.to(torch.float32)
223
+
224
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
225
+
226
+ if self.config.solver == "euler":
227
+ prev_sample = sample + model_output.float() * dt
228
+ else:
229
+ raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
230
+
231
+ # upon completion increase step index by one
232
+ self._step_index += 1
233
+
234
+ if not return_dict:
235
+ return (prev_sample,)
236
+
237
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
238
+
239
+ def __len__(self):
240
+ return self.config.num_train_timesteps
hymm_sp/helpers.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, List
3
+ from hymm_sp.modules.posemb_layers import get_1d_rotary_pos_embed, get_meshgrid_nd
4
+
5
+ from itertools import repeat
6
+ import collections.abc
7
+
8
+
9
+ def _ntuple(n):
10
+ """
11
+ Creates a helper function to convert inputs to tuples of specified length.
12
+
13
+ Converts iterable inputs (excluding strings) to tuples of length n,
14
+ or repeats single values n times to form a tuple. Useful for handling
15
+ multi-dimensional parameters like sizes and strides.
16
+
17
+ Args:
18
+ n (int): Target length of the tuple
19
+
20
+ Returns:
21
+ function: Parser function that converts inputs to n-length tuples
22
+ """
23
+ def parse(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ x = tuple(x)
26
+ if len(x) == 1:
27
+ x = tuple(repeat(x[0], n))
28
+ return x
29
+ return tuple(repeat(x, n))
30
+ return parse
31
+
32
+
33
+ # Create common tuple conversion functions for 1-4 dimensions
34
+ to_1tuple = _ntuple(1)
35
+ to_2tuple = _ntuple(2)
36
+ to_3tuple = _ntuple(3)
37
+ to_4tuple = _ntuple(4)
38
+
39
+
40
+ def get_rope_freq_from_size(
41
+ latents_size,
42
+ ndim,
43
+ target_ndim,
44
+ args,
45
+ rope_theta_rescale_factor: Union[float, List[float]] = 1.0,
46
+ rope_interpolation_factor: Union[float, List[float]] = 1.0,
47
+ concat_dict={}
48
+ ):
49
+ """
50
+ Calculates RoPE (Rotary Position Embedding) frequencies based on latent dimensions.
51
+
52
+ Converts latent space dimensions to rope-compatible sizes by accounting for
53
+ patch size, then generates the appropriate frequency embeddings for each dimension.
54
+
55
+ Args:
56
+ latents_size: Dimensions of the latent space tensor
57
+ ndim (int): Number of dimensions in the latent space
58
+ target_ndim (int): Target number of dimensions for the embeddings
59
+ args: Configuration arguments containing model parameters (patch_size, rope_theta, etc.)
60
+ rope_theta_rescale_factor: Rescaling factor(s) for theta parameter (per dimension)
61
+ rope_interpolation_factor: Interpolation factor(s) for position embeddings (per dimension)
62
+ concat_dict: Dictionary for special concatenation modes (e.g., time-based extensions)
63
+
64
+ Returns:
65
+ tuple: Cosine and sine frequency embeddings (freqs_cos, freqs_sin)
66
+ """
67
+ # Calculate rope sizes by dividing latent dimensions by patch size
68
+ if isinstance(args.patch_size, int):
69
+ # Validate all latent dimensions are divisible by patch size
70
+ assert all(s % args.patch_size == 0 for s in latents_size), \
71
+ f"Latent size (last {ndim} dimensions) must be divisible by patch size ({args.patch_size}), " \
72
+ f"but got {latents_size}."
73
+ rope_sizes = [s // args.patch_size for s in latents_size]
74
+ elif isinstance(args.patch_size, list):
75
+ # Validate with per-dimension patch sizes
76
+ assert all(s % args.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
77
+ f"Latent size (last {ndim} dimensions) must be divisible by patch size ({args.patch_size}), " \
78
+ f"but got {latents_size}."
79
+ rope_sizes = [s // args.patch_size[idx] for idx, s in enumerate(latents_size)]
80
+
81
+ # Add singleton dimensions if needed to match target_ndim (typically for time axis)
82
+ if len(rope_sizes) != target_ndim:
83
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes
84
+
85
+ # Calculate head dimension and validate rope dimensions
86
+ head_dim = args.hidden_size // args.num_heads
87
+ rope_dim_list = args.rope_dim_list
88
+
89
+ # Default: split head dimension equally across target dimensions
90
+ if rope_dim_list is None:
91
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
92
+
93
+ # Ensure rope dimensions sum to head dimension
94
+ assert sum(rope_dim_list) == head_dim, \
95
+ "Sum of rope_dim_list must equal attention head dimension (hidden_size // num_heads)"
96
+
97
+ # Generate rotary position embeddings
98
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(
99
+ rope_dim_list,
100
+ rope_sizes,
101
+ theta=args.rope_theta,
102
+ use_real=True,
103
+ theta_rescale_factor=rope_theta_rescale_factor,
104
+ interpolation_factor=rope_interpolation_factor,
105
+ concat_dict=concat_dict
106
+ )
107
+ return freqs_cos, freqs_sin
108
+
109
+
110
+ def get_nd_rotary_pos_embed_new(
111
+ rope_dim_list,
112
+ start,
113
+ *args,
114
+ theta=10000.,
115
+ use_real=False,
116
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
117
+ interpolation_factor: Union[float, List[float]] = 1.0,
118
+ concat_dict={}
119
+ ):
120
+ """
121
+ Generates multi-dimensional Rotary Position Embeddings (RoPE).
122
+
123
+ Creates position embeddings for n-dimensional spaces by generating a meshgrid
124
+ of positions and applying 1D rotary embeddings to each dimension, then combining them.
125
+
126
+ Args:
127
+ rope_dim_list (list): List of embedding dimensions for each axis
128
+ start: Starting dimensions for generating the meshgrid
129
+ *args: Additional arguments for meshgrid generation
130
+ theta (float): Base theta parameter for RoPE frequency calculation
131
+ use_real (bool): If True, returns separate cosine and sine embeddings
132
+ theta_rescale_factor: Rescaling factor(s) for theta (per dimension)
133
+ interpolation_factor: Interpolation factor(s) for position scaling (per dimension)
134
+ concat_dict: Dictionary for special concatenation modes (e.g., time-based extensions)
135
+
136
+ Returns:
137
+ tuple or tensor: Cosine and sine embeddings if use_real=True, combined embedding otherwise
138
+ """
139
+ # Generate n-dimensional meshgrid of positions (shape: [dim, *sizes])
140
+ grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list))
141
+
142
+ # Handle special concatenation modes (e.g., adding time-based bias)
143
+ if concat_dict:
144
+ if concat_dict['mode'] == 'timecat':
145
+ # Add bias as first element in first dimension
146
+ bias = grid[:, :1].clone()
147
+ bias[0] = concat_dict['bias'] * torch.ones_like(bias[0])
148
+ grid = torch.cat([bias, grid], dim=1)
149
+ elif concat_dict['mode'] == 'timecat-w':
150
+ # Add biased first element with spatial offset
151
+ bias = grid[:, :1].clone()
152
+ bias[0] = concat_dict['bias'] * torch.ones_like(bias[0])
153
+ bias[2] += start[-1] # Spatial offset reference: OminiControl implementation
154
+ grid = torch.cat([bias, grid], dim=1)
155
+
156
+ # Normalize theta rescale factors to list format (per dimension)
157
+ if isinstance(theta_rescale_factor, (int, float)):
158
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
159
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
160
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
161
+ assert len(theta_rescale_factor) == len(rope_dim_list), \
162
+ "Length of theta_rescale_factor must match number of dimensions"
163
+
164
+ # Normalize interpolation factors to list format (per dimension)
165
+ if isinstance(interpolation_factor, (int, float)):
166
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
167
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
168
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
169
+ assert len(interpolation_factor) == len(rope_dim_list), \
170
+ "Length of interpolation_factor must match number of dimensions"
171
+
172
+ # Generate 1D rotary embeddings for each dimension and combine
173
+ embs = []
174
+ for i in range(len(rope_dim_list)):
175
+ # Flatten grid dimension and generate embeddings
176
+ emb = get_1d_rotary_pos_embed(
177
+ rope_dim_list[i],
178
+ grid[i].reshape(-1), # Flatten to 1D positions
179
+ theta,
180
+ use_real=use_real,
181
+ theta_rescale_factor=theta_rescale_factor[i],
182
+ interpolation_factor=interpolation_factor[i]
183
+ )
184
+ embs.append(emb)
185
+
186
+ # Combine embeddings from all dimensions
187
+ if use_real:
188
+ # Return separate cosine and sine components
189
+ cos = torch.cat([emb[0] for emb in embs], dim=1)
190
+ sin = torch.cat([emb[1] for emb in embs], dim=1)
191
+ return cos, sin
192
+ else:
193
+ # Return combined embedding
194
+ return torch.cat(embs, dim=1)
hymm_sp/inference.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from loguru import logger
4
+ from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE
5
+ from hymm_sp.vae import load_vae
6
+ from hymm_sp.modules import load_model
7
+ from hymm_sp.text_encoder import TextEncoder
8
+ import torch.distributed
9
+ from hymm_sp.modules.parallel_states import (
10
+ initialize_sequence_parallel_state,
11
+ get_sequence_parallel_state,
12
+ nccl_info,
13
+ )
14
+ from hymm_sp.modules.fp8_optimization import convert_fp8_linear
15
+
16
+
17
+ class Inference(object):
18
+ def __init__(self,
19
+ args,
20
+ vae,
21
+ vae_kwargs,
22
+ text_encoder,
23
+ model,
24
+ text_encoder_2=None,
25
+ pipeline=None,
26
+ cpu_offload=False,
27
+ device=None,
28
+ logger=None):
29
+ self.vae = vae
30
+ self.vae_kwargs = vae_kwargs
31
+
32
+ self.text_encoder = text_encoder
33
+ self.text_encoder_2 = text_encoder_2
34
+
35
+ self.model = model
36
+ self.pipeline = pipeline
37
+ self.cpu_offload = cpu_offload
38
+
39
+ self.args = args
40
+ self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu"
41
+ if nccl_info.sp_size > 1:
42
+ self.device = torch.device(f"cuda:{torch.distributed.get_rank()}")
43
+
44
+ self.logger = logger
45
+
46
+ @classmethod
47
+ def from_pretrained(cls,
48
+ pretrained_model_path,
49
+ args,
50
+ device=None,
51
+ **kwargs):
52
+ """
53
+ Initialize the Inference pipeline.
54
+
55
+ Args:
56
+ pretrained_model_path (str or pathlib.Path): The model path,
57
+ including t2v, text encoder and vae checkpoints.
58
+ device (int): The device for inference. Default is 0.
59
+ logger (logging.Logger): The logger for the inference pipeline. Default is None.
60
+ """
61
+ # ========================================================================
62
+ logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
63
+
64
+ # ======================== Get the args path =============================
65
+
66
+ # Set device and disable gradient
67
+ if device is None:
68
+ device = "cuda" if torch.cuda.is_available() else "cpu"
69
+ torch.set_grad_enabled(False)
70
+ logger.info("Building model...")
71
+ factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]}
72
+ in_channels = args.latent_channels
73
+ out_channels = args.latent_channels
74
+ print("="*25, f"build model", "="*25)
75
+ model = load_model(
76
+ args,
77
+ in_channels=in_channels,
78
+ out_channels=out_channels,
79
+ factor_kwargs=factor_kwargs
80
+ )
81
+ if args.cpu_offload:
82
+ print(f'='*20, f'load transformer to cpu')
83
+ model = model.to('cpu')
84
+ torch.cuda.empty_cache()
85
+ else:
86
+ model = model.to(device)
87
+ model = Inference.load_state_dict(args, model, pretrained_model_path)
88
+ model.eval()
89
+
90
+ if args.use_fp8:
91
+ convert_fp8_linear(model)
92
+
93
+ # ============================= Build extra models ========================
94
+ # VAE
95
+ print("="*25, f"load vae", "="*25)
96
+ vae, _, s_ratio, t_ratio = load_vae(args.vae,
97
+ args.vae_precision,
98
+ logger=logger,
99
+ device='cpu' if args.cpu_offload else device)
100
+ vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio}
101
+
102
+ # Parallel VAE
103
+ device_vaes = []
104
+ device_vaes.append(vae)
105
+ if nccl_info.sp_size > 1 and nccl_info.rank_within_group == 0:
106
+ for i in range(1, nccl_info.sp_size):
107
+ cur_device = torch.device(f"cuda:{i}")
108
+ # print("!!!!!!!!!! Load vae for ", cur_device)
109
+ device_vae, _, _, _ = load_vae(args.vae,
110
+ args.vae_precision,
111
+ logger=logger,
112
+ device='cpu' if args.cpu_offload else cur_device)
113
+ device_vaes.append(device_vae)
114
+ vae.device_vaes = device_vaes
115
+
116
+ # Text encoder
117
+ if args.prompt_template_video is not None:
118
+ crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
119
+ else:
120
+ crop_start = 0
121
+ max_length = args.text_len + crop_start
122
+
123
+ # prompt_template_video
124
+ prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] \
125
+ if args.prompt_template_video is not None else None
126
+ print("="*25, f"load llava", "="*25)
127
+ text_encoder = TextEncoder(text_encoder_type = args.text_encoder,
128
+ max_length = max_length,
129
+ text_encoder_precision = args.text_encoder_precision,
130
+ tokenizer_type = args.tokenizer,
131
+ use_attention_mask = args.use_attention_mask,
132
+ prompt_template_video = prompt_template_video,
133
+ hidden_state_skip_layer = args.hidden_state_skip_layer,
134
+ apply_final_norm = args.apply_final_norm,
135
+ reproduce = args.reproduce,
136
+ logger = logger,
137
+ device = 'cpu' if args.cpu_offload else device ,
138
+ )
139
+ text_encoder_2 = None
140
+ if args.text_encoder_2 is not None:
141
+ text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2,
142
+ max_length=args.text_len_2,
143
+ text_encoder_precision=args.text_encoder_precision_2,
144
+ tokenizer_type=args.tokenizer_2,
145
+ use_attention_mask=args.use_attention_mask,
146
+ reproduce=args.reproduce,
147
+ logger=logger,
148
+ device='cpu' if args.cpu_offload else device ,
149
+ # if not args.use_cpu_offload else 'cpu'
150
+ )
151
+
152
+ return cls(args=args,
153
+ vae=vae,
154
+ vae_kwargs=vae_kwargs,
155
+ text_encoder=text_encoder,
156
+ model=model,
157
+ text_encoder_2=text_encoder_2,
158
+ device=device,
159
+ logger=logger)
160
+
161
+ @staticmethod
162
+ def load_state_dict(args, model, ckpt_path):
163
+ load_key = args.load_key
164
+ ckpt_path = Path(ckpt_path)
165
+ if ckpt_path.is_dir():
166
+ ckpt_path = next(ckpt_path.glob("*_model_states.pt"))
167
+ state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
168
+ if load_key in state_dict:
169
+ state_dict = state_dict[load_key]
170
+ elif load_key == ".":
171
+ pass
172
+ else:
173
+ raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}")
174
+ model.load_state_dict(state_dict, strict=False)
175
+ return model
176
+
177
+ def get_exp_dir_and_ckpt_id(self):
178
+ if self.ckpt is None:
179
+ raise ValueError("The checkpoint path is not provided.")
180
+
181
+ ckpt = Path(self.ckpt)
182
+ if ckpt.parents[1].name == "checkpoints":
183
+ # It should be a standard checkpoint path. We use the parent directory as the default save directory.
184
+ exp_dir = ckpt.parents[2]
185
+ else:
186
+ raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. "
187
+ f"It seems that the checkpoint path is not standard. Please explicitly provide the "
188
+ f"save path by --save-path.")
189
+ return exp_dir, ckpt.parent.name
190
+
191
+ @staticmethod
192
+ def parse_size(size):
193
+ if isinstance(size, int):
194
+ size = [size]
195
+ if not isinstance(size, (list, tuple)):
196
+ raise ValueError(f"Size must be an integer or (height, width), got {size}.")
197
+ if len(size) == 1:
198
+ size = [size[0], size[0]]
199
+ if len(size) != 2:
200
+ raise ValueError(f"Size must be an integer or (height, width), got {size}.")
201
+ return size
hymm_sp/modules/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2
+ """
3
+ This module provides functionality to load the Hunyuan video diffusion transformer model,
4
+ which is used for video generation tasks with specific configurations and parameters.
5
+ """
6
+
7
+ def load_model(args, in_channels, out_channels, factor_kwargs):
8
+ """
9
+ Load and initialize the HYVideoDiffusionTransformer model with specified parameters.
10
+
11
+ Args:
12
+ args: Command-line arguments or configuration object containing model settings.
13
+ Must include 'model' attribute to select the appropriate configuration.
14
+ in_channels (int): Number of input channels for the model.
15
+ out_channels (int): Number of output channels the model should produce.
16
+ factor_kwargs (dict): Additional keyword arguments for factor adjustments
17
+ in the model architecture.
18
+
19
+ Returns:
20
+ HYVideoDiffusionTransformer: Initialized instance of the video diffusion transformer
21
+ model with the specified configuration.
22
+
23
+ Notes:
24
+ - Uses the HUNYUAN_VIDEO_CONFIG dictionary to retrieve model-specific configurations
25
+ based on the model name provided in args.
26
+ - Sets multitask_mask_training_type to "concat" as a default for this loading setup.
27
+ """
28
+ # Initialize the Hunyuan video diffusion transformer with combined configurations
29
+ # Merges base config from HUNYUAN_VIDEO_CONFIG and additional factor arguments
30
+ model = HYVideoDiffusionTransformer(
31
+ args,
32
+ in_channels=in_channels,
33
+ out_channels=out_channels,
34
+ multitask_mask_training_type="concat",
35
+ **HUNYUAN_VIDEO_CONFIG[args.model], # Unpack model-specific configuration
36
+ ** factor_kwargs, # Unpack additional factor adjustments
37
+ )
38
+ return model
hymm_sp/modules/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hymm_sp/modules/attn_layers.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+ from typing import Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ try:
9
+ from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func
10
+ from flash_attn.bert_padding import index_first_axis
11
+ except ImportError:
12
+ flash_attn_qkvpacked_func, flash_attn_kvpacked_func, flash_attn_varlen_kvpacked_func = None, None, None
13
+ index_first_axis = None
14
+ from packaging import version
15
+ from transformers.utils.import_utils import _is_package_available
16
+
17
+ from .norm_layers import get_norm_layer
18
+
19
+
20
+ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
21
+ """
22
+ Reshape frequency tensor for broadcasting it with another tensor.
23
+
24
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
25
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
26
+
27
+ Notes:
28
+ When using FlashMHAModified, head_first should be False.
29
+ When using Attention, head_first should be True.
30
+
31
+ Args:
32
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
33
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
34
+ head_first (bool): head dimension first (except batch dim) or not.
35
+
36
+ Returns:
37
+ torch.Tensor: Reshaped frequency tensor.
38
+
39
+ Raises:
40
+ AssertionError: If the frequency tensor doesn't match the expected shape.
41
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
42
+ """
43
+ ndim = x.ndim
44
+ assert 0 <= 1 < ndim
45
+
46
+ if isinstance(freqs_cis, tuple):
47
+ # freqs_cis: (cos, sin) in real space
48
+ if head_first:
49
+ assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), \
50
+ f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
51
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
52
+ else:
53
+ assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), \
54
+ f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
55
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
56
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
57
+ else:
58
+ # freqs_cis: values in complex space
59
+ if head_first:
60
+ assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), \
61
+ f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
62
+ shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
63
+ else:
64
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1]), \
65
+ f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
66
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
67
+ return freqs_cis.view(*shape)
68
+
69
+
70
+ def rotate_half(x):
71
+ x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
72
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
73
+
74
+
75
+ def apply_rotary_emb(
76
+ xq: torch.Tensor,
77
+ xk: torch.Tensor,
78
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
79
+ head_first: bool = False,
80
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
81
+ """
82
+ Apply rotary embeddings to input tensors using the given frequency tensor.
83
+
84
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
85
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
86
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
87
+ returned as real tensors.
88
+
89
+ Args:
90
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
91
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
92
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
93
+ head_first (bool): head dimension first (except batch dim) or not.
94
+
95
+ Returns:
96
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
97
+
98
+ """
99
+ xk_out = None
100
+ if isinstance(freqs_cis, tuple):
101
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
102
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
103
+ # real * cos - imag * sin
104
+ # imag * cos + real * sin
105
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
106
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
107
+ else:
108
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
109
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
110
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
111
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
112
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
113
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
114
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
115
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
116
+
117
+ return xq_out, xk_out
118
+
119
+
120
+ class BasicAttentionLayer(nn.Module):
121
+ def __init__(self, attn_mode='flash', deterministic=False):
122
+ super().__init__()
123
+ self.attn_mode = attn_mode
124
+ self.deterministic = deterministic
125
+
126
+ def set_attn_mode(self, new_mode):
127
+ self.attn_mode = new_mode
128
+
129
+ def enable_deterministic(self):
130
+ self.deterministic = True
131
+
132
+ def disable_deterministic(self):
133
+ self.deterministic = False
134
+
135
+
136
+ MEMORY_LAYOUT = {
137
+ "self_flash": (
138
+ lambda x: x,
139
+ lambda x: x,
140
+ ),
141
+ "cross_flash": (
142
+ lambda x: x,
143
+ lambda x: x,
144
+ ),
145
+ "torch": (
146
+ lambda x: x.transpose(1, 2),
147
+ lambda x: x.transpose(1, 2),
148
+ ),
149
+ "vanilla": (
150
+ lambda x: x.transpose(1, 2),
151
+ lambda x: x.transpose(1, 2),
152
+ ),
153
+ }
154
+
155
+
156
+ # Copyed from https://github.com/huggingface/transformers/blob/
157
+ # b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/modeling_flash_attention_utils.py#L33C1-L57C6
158
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
159
+ """
160
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
161
+
162
+ Arguments:
163
+ attention_mask (`torch.Tensor`):
164
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means
165
+ valid and 0 means not valid.
166
+
167
+ Return:
168
+ indices (`torch.Tensor):
169
+ The indices of non-masked tokens from the flattened input sequence.
170
+ cu_seqlens (`torch.Tensor`):
171
+ The cumulative sequence lengths, used to index into ragged (unpadded)
172
+ tensors. `cu_seqlens` shape is (batch_size + 1,).
173
+ max_seqlen_in_batch (`int`):
174
+ Maximum sequence length in batch.
175
+ """
176
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
177
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
178
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
179
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
180
+ return (
181
+ indices,
182
+ cu_seqlens,
183
+ max_seqlen_in_batch,
184
+ )
185
+
186
+
187
+ # Copyed from https://github.com/huggingface/transformers/blob/
188
+ # b873234cb649a24865021f0d598627ce2b24d34a/src/transformers/utils/import_utils.py#L822
189
+ def is_flash_attn_greater_or_equal(library_version: str):
190
+ if not _is_package_available("flash_attn"):
191
+ return False
192
+
193
+ return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
194
+
195
+
196
+ def get_kv_seqlens_with_mask(attn_mask, k, v):
197
+ indices_k, cu_seqlens_k, max_seqlen_k = _get_unpad_data(attn_mask)
198
+ b, s1, a, d = k.shape
199
+ k = index_first_axis(k.reshape(b * s1, a, d), indices_k)
200
+ v = index_first_axis(v.reshape(b * s1, a, d), indices_k)
201
+ kv = torch.stack([k, v], dim=1)
202
+ return cu_seqlens_k, max_seqlen_k, kv
203
+
204
+
205
+ def get_q_seqlens(q):
206
+ bs, s, a, d = q.shape
207
+ cu_seqlens_q = torch.arange(0, (bs + 1) * s, step=s, dtype=torch.int32, device=q.device)
208
+ q = q.reshape(bs * s, a, d)
209
+ return cu_seqlens_q, s, q
210
+
211
+
212
+ def attention(q, k, v, mode, drop_rate=0, attn_mask=None, causal=False, deterministic=False,
213
+ cu_seqlens=None, max_seqlen=None, cu_seqlens_k=None, max_seqlen_k=None):
214
+ """
215
+ Perform QKV self attention.
216
+
217
+ Args:
218
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
219
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
220
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
221
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
222
+ drop_rate (float): Dropout rate in attention map. (default: 0)
223
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
224
+ (default: None)
225
+ causal (bool): Whether to use causal attention. (default: False)
226
+ deterministic (bool): Whether to use deterministic attention. (default: False)
227
+ cu_seqlens (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
228
+ used to index into q.
229
+ max_seqlen (int): The maximum sequence length in the batch of q.
230
+ cu_seqlens_k (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
231
+ used to index into kv.
232
+ max_seqlen_k (int): The maximum sequence length in the batch of k and v.
233
+
234
+ Returns:
235
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
236
+ """
237
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
238
+ q = pre_attn_layout(q)
239
+ k = pre_attn_layout(k)
240
+ v = pre_attn_layout(v)
241
+
242
+ if mode == 'torch':
243
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
244
+ attn_mask = attn_mask.to(q.dtype)
245
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
246
+
247
+ elif mode == 'vanilla':
248
+ scale_factor = 1 / math.sqrt(q.size(-1))
249
+
250
+ b, a, s, _ = q.shape
251
+ s1 = k.size(2)
252
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
253
+ if causal:
254
+ # Only applied to self attention
255
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
256
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
257
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
258
+ attn_bias.to(q.dtype)
259
+
260
+ if attn_mask is not None:
261
+ if attn_mask.dtype == torch.bool:
262
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
263
+ else:
264
+ attn_bias += attn_mask
265
+
266
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
267
+ attn += attn_bias
268
+ attn = attn.softmax(dim=-1)
269
+ attn = torch.dropout(attn, p=drop_rate, train=True)
270
+ x = attn @ v
271
+ else:
272
+ raise NotImplementedError(f'Unsupported attention mode: {mode}')
273
+
274
+ x = post_attn_layout(x)
275
+ b, s, a, d = x.shape
276
+ out = x.reshape(b, s, -1)
277
+ return out
278
+
279
+
280
+ class SelfAttentionLayer(BasicAttentionLayer):
281
+ def __init__(self,
282
+ dim,
283
+ num_heads,
284
+ qkv_bias=True,
285
+ qk_norm=True,
286
+ attn_drop=0,
287
+ proj_drop=0,
288
+ dtype=None,
289
+ device=None,
290
+ norm_type='layer',
291
+ attn_mode='self_flash',
292
+ deterministic=False,
293
+ ) -> None:
294
+ factory_kwargs = {'device': device, 'dtype': dtype}
295
+ super().__init__(attn_mode, deterministic)
296
+ self.dim = dim
297
+ self.num_heads = num_heads
298
+ assert self.dim % num_heads == 0, "dim must be divisible by num_heads"
299
+ self.head_dim = self.dim // num_heads
300
+ self.attn_drop = attn_drop
301
+
302
+ # This assertion is aligned with flash attention
303
+ assert (
304
+ self.head_dim % 8 == 0 and self.head_dim <= 128
305
+ ), "Only support head_dim <= 128 and divisible by 8"
306
+
307
+ self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias, **factory_kwargs)
308
+
309
+ norm_layer = get_norm_layer(norm_type)
310
+ self.q_norm = (
311
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
312
+ if qk_norm
313
+ else nn.Identity()
314
+ )
315
+ self.k_norm = (
316
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
317
+ if qk_norm
318
+ else nn.Identity()
319
+ )
320
+
321
+ self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs)
322
+ self.proj_drop = nn.Dropout(proj_drop)
323
+
324
+ def forward(self, x, freqs_cis=None, attn_mask=None):
325
+ """
326
+ Args:
327
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
328
+ freqs_cis (torch.Tensor, optional): (batch, hidden_dim // 2), RoPE for image
329
+ attn_mask (torch.Tensor, optional): (batch, seq_len, seq_len), mask for attention
330
+ """
331
+ b, s, d = x.shape
332
+
333
+ # Apply QKV projection
334
+ qkv = self.Wqkv(x)
335
+ qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, a, d]
336
+ q, k, v = qkv.unbind(dim=2) # [b, s, a, d]
337
+
338
+ # Apply QK-Norm if needed
339
+ q = self.q_norm(q)
340
+ k = self.k_norm(k)
341
+
342
+ # Apply RoPE if needed
343
+ if freqs_cis is not None:
344
+ qq, kk = apply_rotary_emb(q, k, freqs_cis)
345
+ assert qq.shape == q.shape and kk.shape == k.shape, \
346
+ f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
347
+ q, k = qq, kk
348
+
349
+ # Apply self attention
350
+ context = attention(q, k, v,
351
+ drop_rate=self.attn_drop if self.training else 0,
352
+ attn_mask=attn_mask,
353
+ mode=self.attn_mode,
354
+ deterministic=self.deterministic,
355
+ )
356
+ out = self.out_proj(context)
357
+ out = self.proj_drop(out)
358
+
359
+ return out
360
+
361
+
362
+ class CrossAttentionLayer(BasicAttentionLayer):
363
+ def __init__(self,
364
+ qdim,
365
+ kdim,
366
+ num_heads,
367
+ qkv_bias=True,
368
+ qk_norm=True,
369
+ attn_drop=0,
370
+ proj_drop=0,
371
+ dtype=None,
372
+ device=None,
373
+ norm_type='layer',
374
+ attn_mode='cross_flash',
375
+ deterministic=False,
376
+ ):
377
+ factory_kwargs = {'device': device, 'dtype': dtype}
378
+ super().__init__(attn_mode, deterministic)
379
+ self.qdim = qdim
380
+ self.kdim = kdim
381
+ self.num_heads = num_heads
382
+ assert self.qdim % num_heads == 0, "qdim must be divisible by num_heads"
383
+ self.head_dim = self.qdim // num_heads
384
+ self.attn_drop = attn_drop
385
+
386
+ # This assertion is aligned with flash attention
387
+ assert (
388
+ self.head_dim % 8 == 0 and self.head_dim <= 128
389
+ ), "Only support head_dim <= 128 and divisible by 8"
390
+
391
+ self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
392
+ self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
393
+
394
+ norm_layer = get_norm_layer(norm_type)
395
+ self.q_norm = (
396
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
397
+ if qk_norm
398
+ else nn.Identity()
399
+ )
400
+ self.k_norm = (
401
+ norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
402
+ if qk_norm
403
+ else nn.Identity()
404
+ )
405
+
406
+ self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
407
+ self.proj_drop = nn.Dropout(proj_drop)
408
+
409
+ def forward(self, x, y, attn_mask=None):
410
+ """
411
+ Args:
412
+ x (torch.Tensor): (batch, seq_len, hidden_dim) (where hidden_dim = num heads * head dim)
413
+ y (torch.Tensor): (batch, seq_len1, hidden_dim1)
414
+ attn_mask (torch.Tensor): (batch, seq_len1), mask for attention
415
+ """
416
+ b, s, d = x.shape
417
+ _, s1, d1 = y.shape
418
+
419
+ q = self.q_proj(x).view(b, s, self.num_heads, self.head_dim)
420
+ kv = self.kv_proj(y).view(b, s1, 2, self.num_heads, self.head_dim)
421
+ k, v = kv.unbind(dim=2)
422
+
423
+ # Apply QK-Norm if needed
424
+ q = self.q_norm(q)
425
+ k = self.k_norm(k)
426
+
427
+ # Apply cross attention
428
+ context = attention(q, k, v,
429
+ attn_mask=attn_mask,
430
+ drop_rate=self.attn_drop if self.training else 0,
431
+ mode=self.attn_mode,
432
+ deterministic=self.deterministic,
433
+ )
434
+ out = self.out_proj(context)
435
+ out = self.proj_drop(out)
436
+
437
+ return out
hymm_sp/modules/cameranet.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import einops
3
+ import torch.nn.functional as F
4
+ import collections.abc
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.init as init
9
+
10
+ from pathlib import Path
11
+ from einops import rearrange
12
+ from typing import Any, Dict, Optional, Tuple, Union
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from itertools import repeat
15
+ from .embed_layers import PatchEmbed
16
+
17
+
18
+ def _ntuple(n):
19
+ """
20
+ Creates a helper function to convert inputs to tuples of specified length.
21
+
22
+ Functionality:
23
+ - Converts iterable inputs (excluding strings) to tuples, ensuring length n
24
+ - Repeats single values n times to form a tuple
25
+ Useful for handling multi-dimensional parameters like kernel sizes and strides.
26
+
27
+ Args:
28
+ n (int): Target length of the tuple
29
+
30
+ Returns:
31
+ function: A parser function that converts inputs to n-length tuples
32
+ """
33
+ def parse(x):
34
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
35
+ x = tuple(x)
36
+ if len(x) == 1:
37
+ x = tuple(repeat(x[0], n))
38
+ return x
39
+ return tuple(repeat(x, n))
40
+ return parse
41
+
42
+
43
+ # Create common tuple conversion functions
44
+ to_1tuple = _ntuple(1)
45
+ to_2tuple = _ntuple(2)
46
+ to_3tuple = _ntuple(3)
47
+ to_4tuple = _ntuple(4)
48
+
49
+
50
+ class CameraNet(ModelMixin):
51
+ """
52
+ Camera state encoding network that processes camera parameters into feature embeddings.
53
+
54
+ This network converts camera state information into suitable feature representations
55
+ for video generation models through downsampling, convolutional encoding, and
56
+ temporal dimension compression. Supports loading from pretrained weights.
57
+ """
58
+ def __init__(
59
+ self,
60
+ in_channels,
61
+ downscale_coef,
62
+ out_channels,
63
+ patch_size,
64
+ hidden_size,
65
+ ):
66
+ super().__init__()
67
+ # Calculate initial channels: PixelUnshuffle moves spatial info to channel dimension
68
+ # resulting in channels = in_channels * (downscale_coef^2)
69
+ start_channels = in_channels * (downscale_coef ** 2)
70
+ input_channels = [start_channels, start_channels // 2, start_channels // 4]
71
+ self.input_channels = input_channels
72
+ self.unshuffle = nn.PixelUnshuffle(downscale_coef)
73
+
74
+ self.encode_first = nn.Sequential(
75
+ nn.Conv2d(input_channels[0], input_channels[1], kernel_size=1, stride=1, padding=0),
76
+ nn.GroupNorm(2, input_channels[1]),
77
+ nn.ReLU(),
78
+ )
79
+ self._initialize_weights(self.encode_first)
80
+ self.encode_second = nn.Sequential(
81
+ nn.Conv2d(input_channels[1], input_channels[2], kernel_size=1, stride=1, padding=0),
82
+ nn.GroupNorm(2, input_channels[2]),
83
+ nn.ReLU(),
84
+ )
85
+ self._initialize_weights(self.encode_second)
86
+
87
+ self.final_proj = nn.Conv2d(input_channels[2], out_channels, kernel_size=1)
88
+ self.zeros_init_linear(self.final_proj)
89
+
90
+ self.scale = nn.Parameter(torch.ones(1))
91
+
92
+ self.camera_in = PatchEmbed(patch_size=patch_size, in_chans=out_channels, embed_dim=hidden_size)
93
+
94
+
95
+ def zeros_init_linear(self, linear: nn.Module):
96
+ """
97
+ Zero-initializes weights and biases of linear or convolutional layers.
98
+
99
+ Args:
100
+ linear (nn.Module): Linear or convolutional layer to initialize
101
+ """
102
+ if isinstance(linear, (nn.Linear, nn.Conv2d)):
103
+ if hasattr(linear, "weight"):
104
+ nn.init.zeros_(linear.weight)
105
+ if hasattr(linear, "bias"):
106
+ nn.init.zeros_(linear.bias)
107
+
108
+ def _initialize_weights(self, block):
109
+ """
110
+ Initializes convolutional layer weights using He initialization,
111
+ with biases initialized to zero.
112
+
113
+ Args:
114
+ block (nn.Sequential): Sequential block containing convolutional layers
115
+ """
116
+ for m in block:
117
+ if isinstance(m, nn.Conv2d):
118
+ n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
119
+ init.normal_(m.weight, mean=0.0, std=np.sqrt(2.0 / n))
120
+ if m.bias is not None:
121
+ init.zeros_(m.bias)
122
+
123
+
124
+ def compress_time(self, x, num_frames):
125
+ """
126
+ Temporal dimension compression: reduces number of frames using average pooling
127
+ while preserving key temporal information.
128
+
129
+ Handling logic:
130
+ - Special frame counts (66 or 34): split into two segments, keep first frame of each
131
+ segment then pool remaining frames
132
+ - Odd frame counts: keep first frame, pool remaining frames
133
+ - Even frame counts: directly pool all frames
134
+
135
+ Args:
136
+ x (torch.Tensor): Input tensor with shape (b*f, c, h, w)
137
+ num_frames (int): Number of frames in temporal dimension
138
+
139
+ Returns:
140
+ torch.Tensor: Temporally compressed tensor with shape (b*f', c, h, w) where f' < f
141
+ """
142
+ # Reshape: (b*f, c, h, w) -> (b, f, c, h, w)
143
+ x = rearrange(x, '(b f) c h w -> b f c h w', f=num_frames)
144
+ batch_size, frames, channels, height, width = x.shape
145
+ x = rearrange(x, 'b f c h w -> (b h w) c f')
146
+
147
+ # print(x.shape)
148
+ # raise Exception
149
+ # Handle special frame counts (66 or 34)
150
+ if x.shape[-1] == 66 or x.shape[-1] == 34:
151
+ x_len = x.shape[-1]
152
+ # Process first segment: keep first frame, pool remaining
153
+ x_clip1 = x[...,:x_len//2]
154
+ x_clip1_first, x_clip1_rest = x_clip1[..., 0].unsqueeze(-1), x_clip1[..., 1:]
155
+ x_clip1_rest = F.avg_pool1d(x_clip1_rest, kernel_size=2, stride=2)
156
+
157
+ # Process second segment: keep first frame, pool remaining
158
+ x_clip2 = x[...,x_len//2:x_len]
159
+ x_clip2_first, x_clip2_rest = x_clip2[..., 0].unsqueeze(-1), x_clip2[..., 1:]
160
+ x_clip2_rest = F.avg_pool1d(x_clip2_rest, kernel_size=2, stride=2)
161
+
162
+ # Concatenate results from both segments
163
+ x = torch.cat([x_clip1_first, x_clip1_rest, x_clip2_first, x_clip2_rest], dim=-1)
164
+
165
+ elif x.shape[-1] % 2 == 1:
166
+ x_first, x_rest = x[..., 0], x[..., 1:]
167
+ if x_rest.shape[-1] > 0:
168
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
169
+
170
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
171
+ else:
172
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
173
+ x = rearrange(x, '(b h w) c f -> (b f) c h w', b=batch_size, h=height, w=width)
174
+ return x
175
+
176
+ def forward(
177
+ self,
178
+ camera_states: torch.Tensor,
179
+ ):
180
+ """
181
+ Forward pass: encodes camera states into feature embeddings.
182
+
183
+ Args:
184
+ camera_states (torch.Tensor): Camera state tensor with dimensions
185
+ (batch, frames, channels, height, width)
186
+
187
+ Returns:
188
+ torch.Tensor: Encoded feature embeddings after patch embedding and scaling
189
+ """
190
+ # import pdb;pdb.set_trace()
191
+ batch_size, num_frames, channels, height, width = camera_states.shape
192
+ camera_states = rearrange(camera_states, 'b f c h w -> (b f) c h w')
193
+ camera_states = self.unshuffle(camera_states)
194
+ camera_states = self.encode_first(camera_states)
195
+ camera_states = self.compress_time(camera_states, num_frames=num_frames)
196
+ num_frames = camera_states.shape[0] // batch_size
197
+ camera_states = self.encode_second(camera_states)
198
+ camera_states = self.compress_time(camera_states, num_frames=num_frames)
199
+ # camera_states = rearrange(camera_states, '(b f) c h w -> b f c h w', b=batch_size)
200
+ camera_states = self.final_proj(camera_states)
201
+ camera_states = rearrange(camera_states, "(b f) c h w -> b c f h w", b=batch_size)
202
+ camera_states = self.camera_in(camera_states)
203
+ return camera_states * self.scale
204
+
205
+ @classmethod
206
+ def from_pretrained(cls, pretrained_model_path):
207
+ """
208
+ Loads model from pretrained weight file.
209
+
210
+ Args:
211
+ pretrained_model_path (str): Path to pretrained weight file
212
+
213
+ Returns:
214
+ CameraNet: Model instance with loaded pretrained weights
215
+ """
216
+ if not Path(pretrained_model_path).exists():
217
+ print(f"There is no model file in {pretrained_model_path}")
218
+ print(f"loaded CameraNet's pretrained weights from {pretrained_model_path}.")
219
+
220
+ state_dict = torch.load(pretrained_model_path, map_location="cpu")
221
+ model = CameraNet(in_channels=6, downscale_coef=8, out_channels=16)
222
+ model.load_state_dict(state_dict, strict=True)
223
+ return model
224
+
225
+
226
+ if __name__ == "__main__":
227
+ # Test model initialization and forward pass
228
+ model = CameraNet(
229
+ in_channels=6,
230
+ downscale_coef=8,
231
+ out_channels=16,
232
+ patch_size=[1,2,2],
233
+ hidden_size=3072
234
+ )
235
+ print("Model structure:")
236
+ print(model)
237
+
238
+ # Generate test input (batch 1, 33 frames, 6 channels, 704x1280 resolution)
239
+ num_frames = 33
240
+ input_tensor = torch.randn(1, num_frames, 6, 704, 1280)
241
+
242
+ # Forward pass
243
+ output_tensor = model(input_tensor)
244
+
245
+ # Print results
246
+ print(f"Output shape: {output_tensor.shape}") # Expected: torch.Size([1, ...])
247
+ print("Output tensor example:")
248
+ print(output_tensor)
hymm_sp/modules/embed_layers.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from hymm_sp.helpers import to_2tuple
5
+
6
+
7
+ class PatchEmbed(nn.Module):
8
+ """ 2D Image to Patch Embedding
9
+
10
+ Image to Patch Embedding using Conv2d
11
+
12
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
13
+
14
+ Based on the impl in https://github.com/google-research/vision_transformer
15
+
16
+ Hacked together by / Copyright 2020 Ross Wightman
17
+
18
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
19
+ """
20
+ def __init__(
21
+ self,
22
+ patch_size=16,
23
+ in_chans=3,
24
+ embed_dim=768,
25
+ multitask_mask_training_type=None,
26
+ norm_layer=None,
27
+ flatten=True,
28
+ bias=True,
29
+ dtype=None,
30
+ device=None
31
+ ):
32
+ factory_kwargs = {'dtype': dtype, 'device': device}
33
+ super().__init__()
34
+ patch_size = to_2tuple(patch_size)
35
+ self.patch_size = patch_size
36
+ self.flatten = flatten
37
+
38
+ if multitask_mask_training_type == "concat":
39
+ orig_in_chans = in_chans
40
+ in_chans = in_chans * 2 + 1
41
+
42
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias,
43
+ **factory_kwargs)
44
+ if multitask_mask_training_type == "concat":
45
+ nn.init.xavier_uniform_(\
46
+ self.proj.weight[:, :orig_in_chans].view(self.proj.weight[:, :orig_in_chans].size(0), -1))
47
+ nn.init.zeros_(self.proj.weight[:, orig_in_chans:].view(self.proj.weight[:, orig_in_chans:].size(0), -1))
48
+ else:
49
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
50
+
51
+
52
+ if bias:
53
+ nn.init.zeros_(self.proj.bias)
54
+
55
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
56
+
57
+ def forward(self, x):
58
+ x = self.proj(x)
59
+ if self.flatten:
60
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
61
+ x = self.norm(x)
62
+ return x
63
+
64
+
65
+ class TextProjection(nn.Module):
66
+ """
67
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
68
+
69
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
70
+ """
71
+
72
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
73
+ factory_kwargs = {'dtype': dtype, 'device': device}
74
+ super().__init__()
75
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
76
+ self.act_1 = act_layer()
77
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
78
+
79
+ def forward(self, caption):
80
+ hidden_states = self.linear_1(caption)
81
+ hidden_states = self.act_1(hidden_states)
82
+ hidden_states = self.linear_2(hidden_states)
83
+ return hidden_states
84
+
85
+
86
+ def timestep_embedding(t, dim, max_period=10000):
87
+ """
88
+ Create sinusoidal timestep embeddings.
89
+
90
+ Args:
91
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
92
+ dim (int): the dimension of the output.
93
+ max_period (int): controls the minimum frequency of the embeddings.
94
+
95
+ Returns:
96
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
97
+
98
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
99
+ """
100
+ half = dim // 2
101
+ freqs = torch.exp(
102
+ -math.log(max_period)
103
+ * torch.arange(start=0, end=half, dtype=torch.float32)
104
+ / half
105
+ ).to(device=t.device)
106
+ args = t[:, None].float() * freqs[None]
107
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
108
+ if dim % 2:
109
+ embedding = torch.cat(
110
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
111
+ )
112
+ return embedding
113
+
114
+
115
+ class TimestepEmbedder(nn.Module):
116
+ """
117
+ Embeds scalar timesteps into vector representations.
118
+ """
119
+ def __init__(self,
120
+ hidden_size,
121
+ act_layer,
122
+ frequency_embedding_size=256,
123
+ max_period=10000,
124
+ out_size=None,
125
+ dtype=None,
126
+ device=None
127
+ ):
128
+ factory_kwargs = {'dtype': dtype, 'device': device}
129
+ super().__init__()
130
+ self.frequency_embedding_size = frequency_embedding_size
131
+ self.max_period = max_period
132
+ if out_size is None:
133
+ out_size = hidden_size
134
+
135
+ self.mlp = nn.Sequential(
136
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
137
+ act_layer(),
138
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
139
+ )
140
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
141
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
142
+
143
+ def forward(self, t):
144
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
145
+ t_emb = self.mlp(t_freq)
146
+ return t_emb
hymm_sp/modules/fp8_optimization.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/neuralmagic/AutoFP8/blob/main/auto_fp8/quantize.py
2
+ import gc
3
+ from typing import Tuple
4
+ import copy
5
+ import torch
6
+ import tqdm
7
+ import triton
8
+ import triton.language as tl
9
+
10
+
11
+
12
+ def cleanup_memory():
13
+ gc.collect()
14
+ torch.cuda.empty_cache()
15
+
16
+
17
+ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
18
+ """Quantize a tensor using per-tensor static scaling factor.
19
+ Args:
20
+ tensor: The input tensor.
21
+ """
22
+ finfo = torch.finfo(torch.float8_e4m3fn)
23
+ # Calculate the scale as dtype max divided by absmax.
24
+ # Since .abs() creates a new tensor, we use aminmax to get
25
+ # the min and max first and then calculate the absmax.
26
+ if tensor.numel() == 0:
27
+ # Deal with empty tensors (triggered by empty MoE experts)
28
+ min_val, max_val = (
29
+ torch.tensor(-16.0, dtype=tensor.dtype),
30
+ torch.tensor(16.0, dtype=tensor.dtype),
31
+ )
32
+ else:
33
+ min_val, max_val = tensor.aminmax()
34
+ amax = torch.maximum(min_val.abs(), max_val.abs())
35
+ scale = finfo.max / amax.clamp(min=1e-12)
36
+ # scale and clamp the tensor to bring it to
37
+ # the representative range of float8 data type
38
+ # (as default cast is unsaturated)
39
+ qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
40
+ # Return both float8 data and the inverse scale (as float),
41
+ # as both required as inputs to torch._scaled_mm
42
+ qweight = qweight.to(torch.float8_e4m3fn)
43
+ scale = scale.float().reciprocal()
44
+ return qweight, scale
45
+
46
+
47
+ fp8_gemm_configs = [
48
+ triton.Config({'BLOCK_SIZE_M': block_m,
49
+ 'BLOCK_SIZE_N': block_n,
50
+ 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
51
+ for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
52
+ ]
53
+ @triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
54
+ @triton.jit
55
+ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
56
+ a_scale, b_scale, # 改为单个scale值
57
+ M, N: tl.constexpr, K: tl.constexpr,
58
+ BLOCK_SIZE_M: tl.constexpr,
59
+ BLOCK_SIZE_N: tl.constexpr,
60
+ BLOCK_SIZE_K: tl.constexpr):
61
+ """
62
+ Performs a matrix multiplication operation on FP8 matrices with scaling factors.
63
+ """
64
+ pid_m = tl.program_id(axis=0)
65
+ pid_n = tl.program_id(axis=1)
66
+
67
+ offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
68
+ offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
69
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
70
+
71
+ a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
72
+ b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
73
+
74
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
75
+
76
+ for i in range(0, K, BLOCK_SIZE_K):
77
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i, other=0.0)
78
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i, other=0.0)
79
+
80
+ accumulator += tl.dot(a, b) * a_scale * b_scale
81
+
82
+ a_ptrs += BLOCK_SIZE_K
83
+ b_ptrs += BLOCK_SIZE_K
84
+
85
+ c = accumulator.to(c_ptr.dtype.element_ty)
86
+ offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
87
+ offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
88
+ c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
89
+ mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
90
+ tl.store(c_ptrs, c, mask=mask)
91
+
92
+
93
+ def triton_fp8_gemm(a: torch.Tensor,
94
+ b: torch.Tensor,
95
+ a_scale: float,
96
+ b_scale: float,
97
+ out_dtype=torch.bfloat16,
98
+ bias=None) -> torch.Tensor:
99
+ """
100
+ Perform a matrix multiplication using FP8 precision with per-tensor quantization.
101
+ """
102
+ assert a.is_contiguous() and b.is_contiguous()
103
+
104
+ K = a.size(-1)
105
+ M = a.numel() // K
106
+ N = b.size(0)
107
+ c = torch.empty((M, N), dtype=out_dtype, device=a.device)
108
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
109
+ if isinstance(a_scale, torch.Tensor):
110
+ a_scale = a_scale.item()
111
+ if isinstance(b_scale, torch.Tensor):
112
+ b_scale = b_scale.item()
113
+ # import pdb; pdb.set_trace()
114
+ fp8_gemm_kernel[grid](a, b, c, a_scale, b_scale, M, N, K)
115
+ if bias is not None:
116
+
117
+ c += bias
118
+
119
+ return c
120
+
121
+
122
+ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype, native_fp8_support=False):
123
+ """
124
+ Optimized FP8 GEMM implementation, supports both native FP8 and Triton paths,
125
+ and automatically handles 3D input and bias.
126
+ """
127
+ if A.numel() == 0:
128
+ # Handle empty tensor (e.g., when MoE expert is empty)
129
+ return torch.empty(size=(0, B.shape[0]), dtype=out_dtype, device=A.device)
130
+
131
+ # Check if reshape is needed (support for 3D input)
132
+ need_reshape = (A.dim() == 3)
133
+ batch_size = A.shape[0] if need_reshape else None
134
+ A_input = A.reshape(-1, A.shape[-1]).contiguous() if need_reshape else A
135
+
136
+ if native_fp8_support:
137
+ # Native FP8 support
138
+ output = torch._scaled_mm(
139
+ A_input,
140
+ B.t(),
141
+ out_dtype=out_dtype,
142
+ scale_a=torch.tensor(A_scale) if not isinstance(A_scale, torch.Tensor) else A_scale,
143
+ scale_b=torch.tensor(B_scale) if not isinstance(B_scale, torch.Tensor) else B_scale,
144
+ bias=bias.to(out_dtype),
145
+ )
146
+ else:
147
+ # Triton implementation
148
+ output = triton_fp8_gemm(
149
+ A_input,
150
+ B.contiguous(),
151
+ out_dtype=out_dtype,
152
+ a_scale=A_scale,
153
+ b_scale=B_scale,
154
+ bias=None,
155
+ )
156
+ if bias is not None:
157
+ output += bias
158
+
159
+ if need_reshape:
160
+ # Restore original batch dimension
161
+ output = output.reshape(batch_size, -1, output.shape[-1])
162
+
163
+ return output
164
+
165
+
166
+ # Class responsible for quantizing weights
167
+ class FP8DynamicLinear(torch.nn.Module):
168
+ def __init__(
169
+ self,
170
+ weight: torch.Tensor,
171
+ weight_scale: torch.Tensor,
172
+ bias: torch.nn.Parameter,
173
+ native_fp8_support: bool = False,
174
+ name: str = ""
175
+ ):
176
+ super().__init__()
177
+ self.weight = torch.nn.Parameter(weight, requires_grad=False)
178
+ self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
179
+ self.bias = bias
180
+ self.native_fp8_support = native_fp8_support
181
+ self.name = name
182
+ # @torch.compile
183
+ def forward(self, x):
184
+ if x.dtype != torch.float16 and x.dtype != torch.bfloat16:
185
+ # print(f"Warning: {self.name}'s input is not quantized to float16 or bfloat16")
186
+ # print(f"input dtype: {x.dtype}")
187
+ x = x.to(torch.bfloat16)
188
+ qinput, x_scale = per_tensor_quantize(x)
189
+ # print("--------------")
190
+ # print("layer_name:", self.name)
191
+ # print("A_input.shape:", qinput.shape)
192
+ # print("B.shape:", self.weight.shape)
193
+ # print("--------------")
194
+ output = fp8_gemm(
195
+ A=qinput,
196
+ A_scale=x_scale,
197
+ B=self.weight,
198
+ B_scale=self.weight_scale,
199
+ bias=self.bias,
200
+ out_dtype=x.dtype,
201
+ native_fp8_support=self.native_fp8_support,
202
+ )
203
+ return output
204
+
205
+
206
+ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
207
+ if "." in name:
208
+ parent_name = name.rsplit(".", 1)[0]
209
+ child_name = name[len(parent_name) + 1 :]
210
+ parent = model.get_submodule(parent_name)
211
+ else:
212
+ parent_name = ""
213
+ parent = model
214
+ child_name = name
215
+ setattr(parent, child_name, new_module)
216
+
217
+
218
+ def convert_fp8_linear(model: torch.nn.Module):
219
+ # native_fp8_support = (
220
+ # torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
221
+ # )
222
+ native_fp8_support = False
223
+ named_modules = list(model.named_modules())
224
+ for name, linear in tqdm.tqdm(named_modules, desc="Quantizing weights"):
225
+ if not isinstance(linear, torch.nn.Linear):
226
+ continue
227
+ if "mod" in name:
228
+ print(f"Warning: {name} is a mod module, skipping")
229
+ continue
230
+ if "block" not in name:
231
+ print(f"Warning: {name} is not in a block module, skipping")
232
+ continue
233
+ quant_weight, weight_scale = per_tensor_quantize(linear.weight)
234
+ bias = copy.deepcopy(linear.bias) if linear.bias is not None else None
235
+ quant_linear = FP8DynamicLinear(
236
+ weight=quant_weight,
237
+ weight_scale=weight_scale,
238
+ bias=bias,
239
+ native_fp8_support=native_fp8_support,
240
+ name = name
241
+ )
242
+ replace_module(model, name, quant_linear)
243
+ del linear.weight
244
+ del linear.bias
245
+ del linear
246
+ cleanup_memory()
hymm_sp/modules/mlp_layers.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from hymm_sp.helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks
15
+ """
16
+ def __init__(self,
17
+ in_channels,
18
+ hidden_channels=None,
19
+ out_features=None,
20
+ act_layer=nn.GELU,
21
+ norm_layer=None,
22
+ bias=True,
23
+ drop=0.,
24
+ use_conv=False,
25
+ device=None,
26
+ dtype=None
27
+ ):
28
+ factory_kwargs = {'device': device, 'dtype': dtype}
29
+ super().__init__()
30
+ out_features = out_features or in_channels
31
+ hidden_channels = hidden_channels or in_channels
32
+ bias = to_2tuple(bias)
33
+ drop_probs = to_2tuple(drop)
34
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
35
+
36
+ self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
37
+ self.act = act_layer()
38
+ self.drop1 = nn.Dropout(drop_probs[0])
39
+ self.norm = norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity()
40
+ self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
41
+ self.drop2 = nn.Dropout(drop_probs[1])
42
+
43
+ def forward(self, x):
44
+ x = self.fc1(x)
45
+ x = self.act(x)
46
+ x = self.drop1(x)
47
+ x = self.norm(x)
48
+ x = self.fc2(x)
49
+ x = self.drop2(x)
50
+ return x
51
+
52
+
53
+ class MLPEmbedder(nn.Module):
54
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
55
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
56
+ factory_kwargs = {'device': device, 'dtype': dtype}
57
+ super().__init__()
58
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
59
+ self.silu = nn.SiLU()
60
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
61
+
62
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
63
+ return self.out_layer(self.silu(self.in_layer(x)))
64
+
65
+
66
+ class FinalLayer(nn.Module):
67
+ """The final layer of DiT."""
68
+
69
+ def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
70
+ factory_kwargs = {'device': device, 'dtype': dtype}
71
+ super().__init__()
72
+
73
+ # Just use LayerNorm for the final layer
74
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
75
+ if isinstance(patch_size, int):
76
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, **factory_kwargs)
77
+ else:
78
+ self.linear = nn.Linear(hidden_size,
79
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
80
+ bias=True)
81
+ nn.init.zeros_(self.linear.weight)
82
+ nn.init.zeros_(self.linear.bias)
83
+
84
+ # Here we don't distinguish between the modulate types. Just use the simple one.
85
+ self.adaLN_modulation = nn.Sequential(
86
+ act_layer(),
87
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
88
+ )
89
+ # Zero-initialize the modulation
90
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
91
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
92
+
93
+ def forward(self, x, c):
94
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
95
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
96
+ x = self.linear(x)
97
+ return x
hymm_sp/modules/models.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional, Union, Dict
2
+ from einops import rearrange
3
+
4
+ import torch, os
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from diffusers.models import ModelMixin
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
10
+
11
+ from .activation_layers import get_activation_layer
12
+ from .norm_layers import get_norm_layer
13
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
14
+ from .attn_layers import apply_rotary_emb
15
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
16
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
17
+ from .token_refiner import SingleTokenRefiner
18
+ from .cameranet import CameraNet
19
+
20
+ from .parallel_states import (
21
+ nccl_info,
22
+ get_cu_seqlens,
23
+ get_sequence_parallel_state,
24
+ parallel_attention,
25
+ all_gather,
26
+ )
27
+
28
+ CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
29
+ DISABLE_SP = int(os.environ.get("DISABLE_SP", 0))
30
+ print(f'models: cpu_offload={CPU_OFFLOAD}, DISABLE_SP={DISABLE_SP}')
31
+
32
+ class DoubleStreamBlock(nn.Module):
33
+ def __init__(
34
+ self,
35
+ hidden_size: int,
36
+ num_heads: int,
37
+ mlp_width_ratio: float,
38
+ mlp_act_type: str = 'gelu_tanh',
39
+ qk_norm: bool = True,
40
+ qk_norm_type: str = 'rms',
41
+ qkv_bias: bool = False,
42
+ dtype: Optional[torch.dtype] = None,
43
+ device: Optional[torch.device] = None,
44
+ ):
45
+ factory_kwargs = {'device': device, 'dtype': dtype}
46
+ super().__init__()
47
+
48
+ self.deterministic = False
49
+ self.num_heads = num_heads
50
+ head_dim = hidden_size // num_heads
51
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
52
+
53
+ self.img_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs)
54
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
55
+
56
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
57
+ qk_norm_layer = get_norm_layer(qk_norm_type)
58
+ self.img_attn_q_norm = (
59
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
60
+ if qk_norm
61
+ else nn.Identity()
62
+ )
63
+ self.img_attn_k_norm = (
64
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
65
+ if qk_norm
66
+ else nn.Identity()
67
+ )
68
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
69
+
70
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
71
+ self.img_mlp = MLP(
72
+ hidden_size,
73
+ mlp_hidden_dim,
74
+ act_layer=get_activation_layer(mlp_act_type),
75
+ bias=True,
76
+ **factory_kwargs
77
+ )
78
+
79
+ self.txt_mod = ModulateDiT(hidden_size, factor=6, act_layer=get_activation_layer("silu"), **factory_kwargs)
80
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
81
+
82
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
83
+ qk_norm_layer = get_norm_layer(qk_norm_type)
84
+ self.txt_attn_q_norm = (
85
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
86
+ if qk_norm
87
+ else nn.Identity()
88
+ )
89
+ self.txt_attn_k_norm = (
90
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
91
+ if qk_norm
92
+ else nn.Identity()
93
+ )
94
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
95
+
96
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
97
+ self.txt_mlp = MLP(
98
+ hidden_size,
99
+ mlp_hidden_dim,
100
+ act_layer=get_activation_layer(mlp_act_type),
101
+ bias=True,
102
+ **factory_kwargs
103
+ )
104
+
105
+ def enable_deterministic(self):
106
+ self.deterministic = True
107
+
108
+ def disable_deterministic(self):
109
+ self.deterministic = False
110
+
111
+ def forward(
112
+ self,
113
+ img: torch.Tensor,
114
+ txt: torch.Tensor,
115
+ vec: torch.Tensor,
116
+ cu_seqlens_q: Optional[torch.Tensor] = None,
117
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
118
+ max_seqlen_q: Optional[int] = None,
119
+ max_seqlen_kv: Optional[int] = None,
120
+ freqs_cis: tuple = None,
121
+ use_sage: bool = True,
122
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
123
+ img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate = (
124
+ self.img_mod(vec).chunk(6, dim=-1)
125
+ )
126
+ txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate = (
127
+ self.txt_mod(vec).chunk(6, dim=-1)
128
+ )
129
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
130
+
131
+ # Prepare image for attention.
132
+ img_modulated = self.img_norm1(img)
133
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
134
+ img_qkv = self.img_attn_qkv(img_modulated)
135
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
136
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
137
+ # Apply QK-Norm if needed
138
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
139
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
140
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
141
+
142
+ # Apply RoPE if needed.
143
+ if freqs_cis is not None:
144
+ img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
145
+ assert img_qq.shape == img_q.shape and img_kk.shape == img_k.shape, \
146
+ f'img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}'
147
+ img_q, img_k = img_qq, img_kk
148
+
149
+ # Prepare txt for attention.
150
+ txt_modulated = self.txt_norm1(txt)
151
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
152
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
153
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
154
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
155
+ # Apply QK-Norm if needed.
156
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
157
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
158
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
159
+
160
+ # Run actual attention.
161
+ q = torch.cat((img_q, txt_q), dim=1)
162
+ k = torch.cat((img_k, txt_k), dim=1)
163
+ v = torch.cat((img_v, txt_v), dim=1)
164
+
165
+ # Compute attention.
166
+ if DISABLE_SP:
167
+ assert cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
168
+
169
+ q, k, v = [
170
+ x.view(x.shape[0] * x.shape[1], *x.shape[2:])
171
+ for x in [q, k, v]
172
+ ]
173
+ attn = flash_attn_varlen_func(
174
+ q,
175
+ k,
176
+ v,
177
+ cu_seqlens_q,
178
+ cu_seqlens_kv,
179
+ max_seqlen_q,
180
+ max_seqlen_kv,
181
+ )
182
+ attn = attn.view(img_k.shape[0], max_seqlen_q, -1).contiguous()
183
+ else:
184
+ attn, _ = parallel_attention(
185
+ (img_q, txt_q),
186
+ (img_k, txt_k),
187
+ (img_v, txt_v),
188
+ img_q_len=img_q.shape[1],
189
+ img_kv_len=img_k.shape[1],
190
+ cu_seqlens_q=cu_seqlens_q,
191
+ cu_seqlens_kv=cu_seqlens_kv,
192
+ max_seqlen_q=max_seqlen_q,
193
+ max_seqlen_kv=max_seqlen_kv,
194
+ use_sage=use_sage,
195
+ )
196
+ img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
197
+
198
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
199
+
200
+ # Calculate the img bloks.
201
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
202
+ img = img + apply_gate(self.img_mlp(modulate(
203
+ self.img_norm2(img),
204
+ shift=img_mod2_shift,
205
+ scale=img_mod2_scale)), gate=img_mod2_gate)
206
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
207
+ # Calculate the txt bloks.
208
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
209
+ txt = txt + apply_gate(self.txt_mlp(modulate(self.txt_norm2(txt),
210
+ shift=txt_mod2_shift,
211
+ scale=txt_mod2_scale)), gate=txt_mod2_gate)
212
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
213
+ return img, txt
214
+
215
+
216
+ class SingleStreamBlock(nn.Module):
217
+ """
218
+ A DiT block with parallel linear layers as described in
219
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ hidden_size: int,
225
+ num_heads: int,
226
+ mlp_width_ratio: float = 4.0,
227
+ mlp_act_type: str = 'gelu_tanh',
228
+ qk_norm: bool = True,
229
+ qk_norm_type: str = 'rms',
230
+ qk_scale: float = None,
231
+ dtype: Optional[torch.dtype] = None,
232
+ device: Optional[torch.device] = None,
233
+ ):
234
+ factory_kwargs = {'device': device, 'dtype': dtype}
235
+ super().__init__()
236
+
237
+ self.deterministic = False
238
+ self.hidden_size = hidden_size
239
+ self.num_heads = num_heads
240
+ head_dim = hidden_size // num_heads
241
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
242
+ self.mlp_hidden_dim = mlp_hidden_dim
243
+ self.scale = qk_scale or head_dim**-0.5
244
+
245
+ # qkv and mlp_in
246
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
247
+ # proj and mlp_out
248
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
249
+
250
+ qk_norm_layer = get_norm_layer(qk_norm_type)
251
+ self.q_norm = (
252
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
253
+ if qk_norm
254
+ else nn.Identity()
255
+ )
256
+ self.k_norm = (
257
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
258
+ if qk_norm
259
+ else nn.Identity()
260
+ )
261
+
262
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
263
+
264
+ self.mlp_act = get_activation_layer(mlp_act_type)()
265
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
266
+
267
+ def enable_deterministic(self):
268
+ self.deterministic = True
269
+
270
+ def disable_deterministic(self):
271
+ self.deterministic = False
272
+
273
+ def forward(
274
+ self,
275
+ x: torch.Tensor,
276
+ vec: torch.Tensor,
277
+ txt_len: int,
278
+ cu_seqlens_q: Optional[torch.Tensor] = None,
279
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
280
+ max_seqlen_q: Optional[int] = None,
281
+ max_seqlen_kv: Optional[int] = None,
282
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
283
+ use_sage: bool = True,
284
+ ) -> torch.Tensor:
285
+ mod_shift, mod_scale, mod_gate = (
286
+ self.modulation(vec).chunk(3, dim=-1)
287
+ )
288
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
289
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
290
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
291
+
292
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
293
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
294
+
295
+ # Apply QK-Norm if needed.
296
+ q = self.q_norm(q).to(v)
297
+ k = self.k_norm(k).to(v)
298
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
299
+
300
+ # Apply RoPE if needed.
301
+ if freqs_cis is not None:
302
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
303
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
304
+ img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
305
+ assert img_qq.shape == img_q.shape and img_kk.shape == img_k.shape, \
306
+ f'img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}'
307
+ img_q, img_k = img_qq, img_kk
308
+ q = torch.cat((img_q, txt_q), dim=1)
309
+ k = torch.cat((img_k, txt_k), dim=1)
310
+
311
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
312
+
313
+ # Compute attention.
314
+ if DISABLE_SP:
315
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, \
316
+ f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
317
+ # [b, s+l, a, d] -> [s+l, b, a, d]
318
+ q, k, v = [
319
+ x.view(x.shape[0] * x.shape[1], *x.shape[2:])
320
+ for x in [q, k, v]
321
+ ]
322
+
323
+ attn = flash_attn_varlen_func(
324
+ q,
325
+ k,
326
+ v,
327
+ cu_seqlens_q,
328
+ cu_seqlens_kv,
329
+ max_seqlen_q,
330
+ max_seqlen_kv,
331
+ )
332
+ attn = attn.view(x.shape[0], max_seqlen_q, -1).contiguous()
333
+ else:
334
+ img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
335
+ attn, _ = parallel_attention(
336
+ (img_q, txt_q),
337
+ (img_k, txt_k),
338
+ (img_v, txt_v),
339
+ img_q_len=img_q.shape[1],
340
+ img_kv_len=img_k.shape[1],
341
+ cu_seqlens_q=cu_seqlens_q,
342
+ cu_seqlens_kv=cu_seqlens_kv,
343
+ max_seqlen_q=max_seqlen_q,
344
+ max_seqlen_kv=max_seqlen_kv,
345
+ use_sage=use_sage,
346
+ )
347
+ if CPU_OFFLOAD:
348
+ torch.cuda.empty_cache()
349
+ tmp = torch.cat((attn, self.mlp_act(mlp)), 2)
350
+ torch.cuda.empty_cache()
351
+ output = self.linear2(tmp)
352
+ else:
353
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
354
+ return x + apply_gate(output, gate=mod_gate)
355
+
356
+
357
+ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
358
+ """
359
+ HunyuanVideo Transformer backbone
360
+
361
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
362
+
363
+ Reference:
364
+ [1] Flux.1: https://github.com/black-forest-labs/flux
365
+ [2] MMDiT: http://arxiv.org/abs/2403.03206,
366
+ https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
367
+
368
+ """
369
+ @register_to_config
370
+ def __init__(
371
+ self,
372
+ args,
373
+ patch_size: list = [1,2,2],
374
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
375
+ out_channels: int = None,
376
+ hidden_size: int = 3072,
377
+ mlp_width_ratio: float = 4.0,
378
+ mlp_act_type: str = 'gelu_tanh',
379
+ num_heads: int = 24,
380
+ depth_double_blocks: int = 19,
381
+ depth_single_blocks: int = 38,
382
+ rope_dim_list: List[int] = [16, 56, 56],
383
+ qkv_bias: bool = True,
384
+ qk_norm: bool = True,
385
+ qk_norm_type: str = 'rms',
386
+ guidance_embed: bool = False, # For modulation.
387
+ dtype: Optional[torch.dtype] = None,
388
+ device: Optional[torch.device] = None,
389
+ multitask_mask_training_type: Optional[str] = None,
390
+ camera_in_channels: int = 6,
391
+ camera_down_coef: int = 8,
392
+ ):
393
+ factory_kwargs = {'device': device, 'dtype': dtype}
394
+ super().__init__()
395
+
396
+ # Text projection. Default to linear projection.
397
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
398
+ self.text_projection = args.text_projection
399
+ self.text_states_dim = args.text_states_dim
400
+ self.use_attention_mask = args.use_attention_mask
401
+ self.text_states_dim_2 = args.text_states_dim_2
402
+
403
+ # Now we only use above configs from args.
404
+ self.patch_size = patch_size
405
+ self.in_channels = in_channels
406
+ self.out_channels = in_channels if out_channels is None else out_channels
407
+ self.unpatchify_channels = self.out_channels
408
+ self.guidance_embed = guidance_embed
409
+ self.rope_dim_list = rope_dim_list
410
+ self.multitask_mask_training_type = multitask_mask_training_type
411
+
412
+ if hidden_size % num_heads != 0:
413
+ raise ValueError(
414
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
415
+ )
416
+ pe_dim = hidden_size // num_heads
417
+ if sum(rope_dim_list) != pe_dim:
418
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
419
+ self.hidden_size = hidden_size
420
+ self.num_heads = num_heads
421
+
422
+ # image projection
423
+ self.img_in = PatchEmbed(
424
+ self.patch_size, self.in_channels, self.hidden_size, self.multitask_mask_training_type, **factory_kwargs
425
+ )
426
+
427
+ # text projection
428
+ if self.text_projection == "linear":
429
+ self.txt_in = TextProjection(
430
+ self.text_states_dim,
431
+ self.hidden_size,
432
+ get_activation_layer("silu"),
433
+ **factory_kwargs
434
+ )
435
+ elif self.text_projection == "single_refiner":
436
+ self.txt_in = SingleTokenRefiner(
437
+ self.text_states_dim, hidden_size, num_heads, depth=2, **factory_kwargs
438
+ )
439
+ else:
440
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
441
+
442
+ # time modulation
443
+ self.time_in = TimestepEmbedder(
444
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
445
+ )
446
+
447
+ # text modulation
448
+ self.vector_in = MLPEmbedder(
449
+ self.text_states_dim_2, self.hidden_size, **factory_kwargs
450
+ )
451
+
452
+ # guidance modulation
453
+ self.guidance_in = TimestepEmbedder(
454
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
455
+ ) if guidance_embed else None
456
+
457
+ # double blocks
458
+ self.double_blocks = nn.ModuleList(
459
+ [
460
+ DoubleStreamBlock(
461
+ self.hidden_size,
462
+ self.num_heads,
463
+ mlp_width_ratio=mlp_width_ratio,
464
+ mlp_act_type=mlp_act_type,
465
+ qk_norm=qk_norm,
466
+ qk_norm_type=qk_norm_type,
467
+ qkv_bias=qkv_bias,
468
+ **factory_kwargs
469
+ )
470
+ for _ in range(depth_double_blocks)
471
+ ]
472
+ )
473
+
474
+ # single blocks
475
+ self.single_blocks = nn.ModuleList(
476
+ [
477
+ SingleStreamBlock(
478
+ self.hidden_size,
479
+ self.num_heads,
480
+ mlp_width_ratio=mlp_width_ratio,
481
+ mlp_act_type=mlp_act_type,
482
+ qk_norm=qk_norm,
483
+ qk_norm_type=qk_norm_type,
484
+ **factory_kwargs
485
+ )
486
+ for _ in range(depth_single_blocks)
487
+ ]
488
+ )
489
+
490
+ self.final_layer = FinalLayer(
491
+ self.hidden_size,
492
+ self.patch_size,
493
+ self.out_channels,
494
+ get_activation_layer("silu"),
495
+ **factory_kwargs
496
+ )
497
+
498
+ self.camera_net = CameraNet(in_channels=camera_in_channels,
499
+ out_channels=in_channels,
500
+ downscale_coef=camera_down_coef,
501
+ patch_size=self.patch_size,
502
+ hidden_size=self.hidden_size,
503
+ )
504
+
505
+ def enable_deterministic(self):
506
+ for block in self.double_blocks:
507
+ block.enable_deterministic()
508
+ for block in self.single_blocks:
509
+ block.enable_deterministic()
510
+
511
+ def disable_deterministic(self):
512
+ for block in self.double_blocks:
513
+ block.disable_deterministic()
514
+ for block in self.single_blocks:
515
+ block.disable_deterministic()
516
+
517
+ def forward(
518
+ self,
519
+ x: torch.Tensor,
520
+ t: torch.Tensor, # Should be in range(0, 1000).
521
+ text_states: torch.Tensor = None,
522
+ text_mask: torch.Tensor = None, # Now we don't use it.
523
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
524
+ freqs_cos: Optional[torch.Tensor] = None,
525
+ freqs_sin: Optional[torch.Tensor] = None,
526
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
527
+ return_dict: bool = True,
528
+ is_cache: bool = False,
529
+ cam_latents = None,
530
+ use_sage: bool = False,
531
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
532
+ out = {}
533
+ img = x
534
+ txt = text_states
535
+ _, _, ot, oh, ow = x.shape
536
+ tt, th, tw = ot // self.patch_size[0], oh // self.patch_size[1], ow // self.patch_size[2]
537
+
538
+ # Prepare modulation vectors.
539
+ vec = self.time_in(t)
540
+
541
+ # text modulation
542
+ vec = vec + self.vector_in(text_states_2)
543
+
544
+ # guidance modulation
545
+ if self.guidance_embed:
546
+ if guidance is None:
547
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
548
+ else:
549
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
550
+ vec = vec + self.guidance_in(guidance)
551
+
552
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
553
+
554
+ camera_condition = cam_latents
555
+ assert camera_condition is not None, print("plucker_embedding is not provided")
556
+ latent_len = img.shape[2]
557
+
558
+
559
+ # Embed image and text.
560
+ img = self.img_in(img)
561
+ # ref_latents = self.img_in(ref_latents) # off in latent concat
562
+ if self.text_projection == "linear":
563
+ txt = self.txt_in(txt)
564
+ elif self.text_projection == "single_refiner":
565
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
566
+ else:
567
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
568
+
569
+ if camera_condition is not None:
570
+
571
+ if latent_len == 18:
572
+ camera_latents = torch.cat([self.camera_net(torch.zeros_like(camera_condition)), \
573
+ self.camera_net(camera_condition)], dim=1)
574
+ elif latent_len == 9:
575
+ camera_latents = self.camera_net(camera_condition)
576
+ elif latent_len == 10:
577
+ camera_latents = torch.cat([self.camera_net(torch.zeros_like(camera_condition)[:,0:4,:,:,:]), \
578
+ self.camera_net(camera_condition)], dim=1)
579
+ img = img + camera_latents
580
+
581
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
582
+ # ref_length = ref_latents.shape[-2]
583
+ # img = torch.cat([ref_latents, img], dim=-2) # t c
584
+ txt_seq_len = txt.shape[1]
585
+ img_seq_len = img.shape[1]
586
+ # Compute 'self-attention mask'.
587
+
588
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
589
+ cu_seqlens_kv = cu_seqlens_q
590
+ max_seqlen_q = img_seq_len + txt_seq_len
591
+ max_seqlen_kv = max_seqlen_q
592
+
593
+ if get_sequence_parallel_state():
594
+ sp_size = nccl_info.sp_size
595
+ sp_rank = nccl_info.rank_within_group
596
+ assert img.shape[1] % sp_size == 0, f"Cannot split video sequence into ulysses SP ({sp_size}) parts evenly"
597
+ img = torch.chunk(img, sp_size, dim=1)[sp_rank]
598
+ freqs_cos = torch.chunk(freqs_cos, sp_size, dim=0)[sp_rank]
599
+ freqs_sin = torch.chunk(freqs_sin, sp_size, dim=0)[sp_rank]
600
+
601
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
602
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
603
+ # --------------------- Pass through DiT blocks ------------------------
604
+ if not is_cache:
605
+ for layer_num, block in enumerate(self.double_blocks):
606
+ double_block_args = [img, txt, vec, cu_seqlens_q, cu_seqlens_kv, \
607
+ max_seqlen_q, max_seqlen_kv, freqs_cis, use_sage]
608
+ img, txt = block(*double_block_args)
609
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
610
+
611
+ # Merge txt and img to pass through single stream blocks.
612
+ x = torch.cat((img, txt), 1)
613
+ # Compatible with MMDiT.
614
+ if len(self.single_blocks) > 0:
615
+ for layer_num, block in enumerate(self.single_blocks):
616
+ if layer_num == (len(self.single_blocks) - 1):
617
+ self.cache_out = x
618
+ single_block_args = [x, vec, txt_seq_len, cu_seqlens_q, \
619
+ cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, (freqs_cos, freqs_sin), use_sage]
620
+ x = block(*single_block_args)
621
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
622
+ else:
623
+ x = self.cache_out
624
+ if len(self.single_blocks) > 0:
625
+ for layer_num, block in enumerate(self.single_blocks):
626
+ if layer_num < (len(self.single_blocks) - 1):
627
+ continue
628
+ single_block_args = [x, vec, txt_seq_len, cu_seqlens_q, \
629
+ cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, (freqs_cos, freqs_sin), use_sage]
630
+ x = block(*single_block_args)
631
+ if CPU_OFFLOAD: torch.cuda.empty_cache()
632
+
633
+ img = x[:, :-txt_seq_len, ...]
634
+
635
+ if get_sequence_parallel_state():
636
+ img = all_gather(img, dim=1)
637
+
638
+ # img = img[:, ref_length:]
639
+ # ---------------------------- Final layer ------------------------------
640
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
641
+ img = self.unpatchify(img, tt, th, tw)
642
+
643
+ if return_dict:
644
+ out['x'] = img
645
+ return out
646
+ return img
647
+
648
+ def unpatchify(self, x, t, h, w):
649
+ """
650
+ x: (N, T, patch_size**2 * C)
651
+ imgs: (N, H, W, C)
652
+ """
653
+ c = self.unpatchify_channels
654
+ pt, ph, pw = self.patch_size
655
+ assert t * h * w == x.shape[1]
656
+
657
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
658
+ x = torch.einsum('nthwcopq->nctohpwq', x)
659
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
660
+
661
+ return imgs
662
+
663
+ def params_count(self):
664
+ counts = {
665
+ "double": sum([
666
+ sum(p.numel() for p in block.img_attn_qkv.parameters()) +
667
+ sum(p.numel() for p in block.img_attn_proj.parameters()) +
668
+ sum(p.numel() for p in block.img_mlp.parameters()) +
669
+ sum(p.numel() for p in block.txt_attn_qkv.parameters()) +
670
+ sum(p.numel() for p in block.txt_attn_proj.parameters()) +
671
+ sum(p.numel() for p in block.txt_mlp.parameters())
672
+ for block in self.double_blocks
673
+ ]),
674
+ "single": sum([
675
+ sum(p.numel() for p in block.linear1.parameters()) +
676
+ sum(p.numel() for p in block.linear2.parameters())
677
+ for block in self.single_blocks
678
+ ]),
679
+ "total": sum(p.numel() for p in self.parameters()),
680
+ }
681
+ counts["attn+mlp"] = counts["double"] + counts["single"]
682
+ return counts
683
+
684
+ #################################################################################
685
+ # HunyuanVideo Configs #
686
+ #################################################################################
687
+
688
+ HUNYUAN_VIDEO_CONFIG = { # Attn+MLP / Total
689
+ 'HYVideo-T/2': { # 9.0B / 12.5B
690
+ 'depth_double_blocks': 20,
691
+ 'depth_single_blocks': 40,
692
+ 'rope_dim_list': [16, 56, 56],
693
+ 'hidden_size': 3072,
694
+ 'num_heads': 24,
695
+ 'mlp_width_ratio': 4,
696
+ },
697
+ }
hymm_sp/modules/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward
hymm_sp/modules/norm_layers.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The output tensor after applying RMSNorm.
54
+
55
+ """
56
+ output = self._norm(x.float()).type_as(x)
57
+ if hasattr(self, "weight"):
58
+ output = output * self.weight
59
+ return output
60
+
61
+
62
+ def get_norm_layer(norm_layer):
63
+ """
64
+ Get the normalization layer.
65
+
66
+ Args:
67
+ norm_layer (str): The type of normalization layer.
68
+
69
+ Returns:
70
+ norm_layer (nn.Module): The normalization layer.
71
+ """
72
+ if norm_layer == "layer":
73
+ return nn.LayerNorm
74
+ elif norm_layer == "rms":
75
+ return RMSNorm
76
+ else:
77
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hymm_sp/modules/parallel_states.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import datetime
4
+ import torch.distributed as dist
5
+ from typing import Any, Tuple
6
+ from torch import Tensor
7
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
8
+
9
+
10
+ class COMM_INFO:
11
+ def __init__(self):
12
+ self.group = None
13
+ self.sp_size = 1
14
+ self.global_rank = 0
15
+ self.rank_within_group = 0
16
+ self.group_id = 0
17
+
18
+
19
+ nccl_info = COMM_INFO()
20
+ _SEQUENCE_PARALLEL_STATE = False
21
+
22
+
23
+ def get_cu_seqlens(text_mask, img_len):
24
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
25
+
26
+ Args:
27
+ text_mask (torch.Tensor): the mask of text
28
+ img_len (int): the length of image
29
+
30
+ Returns:
31
+ torch.Tensor: the calculated cu_seqlens for flash attention
32
+ """
33
+ batch_size = text_mask.shape[0]
34
+ text_len = text_mask.sum(dim=1)
35
+ max_len = text_mask.shape[1] + img_len
36
+
37
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
38
+
39
+ for i in range(batch_size):
40
+ s = text_len[i] + img_len
41
+ s1 = i * max_len + s
42
+ s2 = (i + 1) * max_len
43
+ cu_seqlens[2 * i + 1] = s1
44
+ cu_seqlens[2 * i + 2] = s2
45
+
46
+ return cu_seqlens
47
+
48
+ def initialize_sequence_parallel_state(sequence_parallel_size):
49
+ global _SEQUENCE_PARALLEL_STATE
50
+ if sequence_parallel_size > 1:
51
+ _SEQUENCE_PARALLEL_STATE = True
52
+ initialize_sequence_parallel_group(sequence_parallel_size)
53
+ else:
54
+ nccl_info.sp_size = 1
55
+ nccl_info.global_rank = int(os.getenv("RANK", "0"))
56
+ nccl_info.rank_within_group = 0
57
+ nccl_info.group_id = int(os.getenv("RANK", "0"))
58
+
59
+ def get_sequence_parallel_state():
60
+ return _SEQUENCE_PARALLEL_STATE
61
+
62
+ def initialize_sequence_parallel_group(sequence_parallel_size):
63
+ """Initialize the sequence parallel group."""
64
+ rank = int(os.getenv("RANK", "0"))
65
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
66
+ assert (
67
+ world_size % sequence_parallel_size == 0
68
+ ), "world_size must be divisible by sequence_parallel_size, \
69
+ but got world_size: {}, sequence_parallel_size: {}".format(
70
+ world_size, sequence_parallel_size)
71
+ nccl_info.sp_size = sequence_parallel_size
72
+ nccl_info.global_rank = rank
73
+ num_sequence_parallel_groups: int = world_size // sequence_parallel_size
74
+ for i in range(num_sequence_parallel_groups):
75
+ ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
76
+ group = dist.new_group(ranks)
77
+ if rank in ranks:
78
+ nccl_info.group = group
79
+ nccl_info.rank_within_group = rank - i * sequence_parallel_size
80
+ nccl_info.group_id = i
81
+
82
+ def initialize_distributed(seed):
83
+ local_rank = int(os.getenv("RANK", 0))
84
+ world_size = int(os.getenv("WORLD_SIZE", 1))
85
+ torch.cuda.set_device(local_rank)
86
+ dist.init_process_group(backend="nccl",
87
+ init_method="env://",
88
+ timeout=datetime.timedelta(seconds=2**31-1),
89
+ world_size=world_size,
90
+ rank=local_rank)
91
+ torch.manual_seed(seed)
92
+ torch.cuda.manual_seed_all(seed)
93
+ initialize_sequence_parallel_state(world_size)
94
+
95
+ def _all_to_all_4D(input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.tensor:
96
+ """
97
+ all-to-all for QKV
98
+
99
+ Args:
100
+ input (torch.tensor): a tensor sharded along dim scatter dim
101
+ scatter_idx (int): default 1
102
+ gather_idx (int): default 2
103
+ group : torch process group
104
+
105
+ Returns:
106
+ torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
107
+ """
108
+ assert (input.dim() == 4), f"input must be 4D tensor, got {input.dim()} and shape {input.shape}"
109
+
110
+ seq_world_size = dist.get_world_size(group)
111
+ if scatter_idx == 2 and gather_idx == 1:
112
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen/P, hc, hs) output: (bs, seqlen, hc/P, hs)
113
+ bs, shard_seqlen, hc, hs = input.shape
114
+ seqlen = shard_seqlen * seq_world_size
115
+ shard_hc = hc // seq_world_size
116
+
117
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
118
+ # (bs, seqlen/P, hc, hs) -reshape->
119
+ # (bs, seq_len/P, P, hc/P, hs) -transpose(0,2)->
120
+ # (P, seq_len/P, bs, hc/P, hs)
121
+ input_t = (input.reshape(bs, shard_seqlen, seq_world_size, shard_hc, hs).transpose(0, 2).contiguous())
122
+
123
+ output = torch.empty_like(input_t)
124
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
125
+ # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head
126
+ if seq_world_size > 1:
127
+ dist.all_to_all_single(output, input_t, group=group)
128
+ torch.cuda.synchronize()
129
+ else:
130
+ output = input_t
131
+ # if scattering the seq-dim, transpose the heads back to the original dimension
132
+ output = output.reshape(seqlen, bs, shard_hc, hs)
133
+
134
+ # (seq_len, bs, hc/P, hs) -reshape-> (bs, seq_len, hc/P, hs)
135
+ output = output.transpose(0, 1).contiguous().reshape(bs, seqlen, shard_hc, hs)
136
+
137
+ return output
138
+
139
+ elif scatter_idx == 1 and gather_idx == 2:
140
+ # input (torch.tensor): a tensor sharded along dim 1 (bs, seqlen, hc/P, hs) output: (bs, seqlen/P, hc, hs)
141
+ bs, seqlen, shard_hc, hs = input.shape
142
+ hc = shard_hc * seq_world_size
143
+ shard_seqlen = seqlen // seq_world_size
144
+ seq_world_size = dist.get_world_size(group)
145
+
146
+ # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them!
147
+ # (bs, seqlen, hc/P, hs) -reshape-> (bs, P, seq_len/P, hc/P, hs) -transpose(0, 3)->
148
+ # (hc/P, P, seqlen/P, bs, hs) -transpose(0, 1) -> (P, hc/P, seqlen/P, bs, hs)
149
+ input_t = (input.reshape(bs, seq_world_size, shard_seqlen, shard_hc,
150
+ hs).transpose(0,
151
+ 3).transpose(0,
152
+ 1).contiguous().reshape(seq_world_size, shard_hc,
153
+ shard_seqlen, bs, hs))
154
+
155
+ output = torch.empty_like(input_t)
156
+ # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
157
+ # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
158
+ if seq_world_size > 1:
159
+ dist.all_to_all_single(output, input_t, group=group)
160
+ torch.cuda.synchronize()
161
+ else:
162
+ output = input_t
163
+
164
+ # if scattering the seq-dim, transpose the heads back to the original dimension
165
+ output = output.reshape(hc, shard_seqlen, bs, hs)
166
+
167
+ # (hc, seqlen/N, bs, hs) -tranpose(0,2)-> (bs, seqlen/N, hc, hs)
168
+ output = output.transpose(0, 2).contiguous().reshape(bs, shard_seqlen, hc, hs)
169
+
170
+ return output
171
+ else:
172
+ raise RuntimeError("scatter_idx must be 1 or 2 and gather_idx must be 1 or 2")
173
+
174
+
175
+ class SeqAllToAll4D(torch.autograd.Function):
176
+ @staticmethod
177
+ def forward(
178
+ ctx: Any,
179
+ group: dist.ProcessGroup,
180
+ input: Tensor,
181
+ scatter_idx: int,
182
+ gather_idx: int,
183
+ ) -> Tensor:
184
+ ctx.group = group
185
+ ctx.scatter_idx = scatter_idx
186
+ ctx.gather_idx = gather_idx
187
+
188
+ return _all_to_all_4D(input, scatter_idx, gather_idx, group=group)
189
+
190
+ @staticmethod
191
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
192
+ return (
193
+ None,
194
+ SeqAllToAll4D.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx),
195
+ None,
196
+ None,
197
+ )
198
+
199
+
200
+ def all_to_all_4D(
201
+ input_: torch.Tensor,
202
+ scatter_dim: int = 2,
203
+ gather_dim: int = 1,
204
+ ):
205
+ return SeqAllToAll4D.apply(nccl_info.group, input_, scatter_dim, gather_dim)
206
+
207
+
208
+ def _all_to_all(
209
+ input_: torch.Tensor,
210
+ world_size: int,
211
+ group: dist.ProcessGroup,
212
+ scatter_dim: int,
213
+ gather_dim: int,
214
+ ):
215
+ input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
216
+ output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
217
+ dist.all_to_all(output_list, input_list, group=group)
218
+ return torch.cat(output_list, dim=gather_dim).contiguous()
219
+
220
+
221
+ class _AllToAll(torch.autograd.Function):
222
+ """All-to-all communication.
223
+
224
+ Args:
225
+ input_: input matrix
226
+ process_group: communication group
227
+ scatter_dim: scatter dimension
228
+ gather_dim: gather dimension
229
+ """
230
+
231
+ @staticmethod
232
+ def forward(ctx, input_, process_group, scatter_dim, gather_dim):
233
+ ctx.process_group = process_group
234
+ ctx.scatter_dim = scatter_dim
235
+ ctx.gather_dim = gather_dim
236
+ ctx.world_size = dist.get_world_size(process_group)
237
+ output = _all_to_all(input_, ctx.world_size, process_group, scatter_dim, gather_dim)
238
+ return output
239
+
240
+ @staticmethod
241
+ def backward(ctx, grad_output):
242
+ grad_output = _all_to_all(
243
+ grad_output,
244
+ ctx.world_size,
245
+ ctx.process_group,
246
+ ctx.gather_dim,
247
+ ctx.scatter_dim,
248
+ )
249
+ return (
250
+ grad_output,
251
+ None,
252
+ None,
253
+ None,
254
+ )
255
+
256
+ def all_to_all(
257
+ input_: torch.Tensor,
258
+ scatter_dim: int = 2,
259
+ gather_dim: int = 1,
260
+ ):
261
+ return _AllToAll.apply(input_, nccl_info.group, scatter_dim, gather_dim)
262
+
263
+
264
+ class _AllGather(torch.autograd.Function):
265
+ """All-gather communication with autograd support.
266
+
267
+ Args:
268
+ input_: input tensor
269
+ dim: dimension along which to concatenate
270
+ """
271
+
272
+ @staticmethod
273
+ def forward(ctx, input_, dim):
274
+ ctx.dim = dim
275
+ world_size = nccl_info.sp_size
276
+ group = nccl_info.group
277
+ input_size = list(input_.size())
278
+
279
+ ctx.input_size = input_size[dim]
280
+
281
+ tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
282
+ input_ = input_.contiguous()
283
+ dist.all_gather(tensor_list, input_, group=group)
284
+
285
+ output = torch.cat(tensor_list, dim=dim)
286
+ return output
287
+
288
+ @staticmethod
289
+ def backward(ctx, grad_output):
290
+ world_size = nccl_info.sp_size
291
+ rank = nccl_info.rank_within_group
292
+ dim = ctx.dim
293
+ input_size = ctx.input_size
294
+
295
+ sizes = [input_size] * world_size
296
+
297
+ grad_input_list = torch.split(grad_output, sizes, dim=dim)
298
+ grad_input = grad_input_list[rank]
299
+
300
+ return grad_input, None
301
+
302
+
303
+ def all_gather(input_: torch.Tensor, dim: int = 1):
304
+ """Performs an all-gather operation on the input tensor along the specified dimension.
305
+
306
+ Args:
307
+ input_ (torch.Tensor): Input tensor of shape [B, H, S, D].
308
+ dim (int, optional): Dimension along which to concatenate. Defaults to 1.
309
+
310
+ Returns:
311
+ torch.Tensor: Output tensor after all-gather operation, concatenated along 'dim'.
312
+ """
313
+ return _AllGather.apply(input_, dim)
314
+
315
+ def parallel_attention(q, k, v,
316
+ img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv,
317
+ max_seqlen_q, max_seqlen_kv, use_sage):
318
+ """
319
+ img_q_len,img_kv_len: 32256
320
+ text_mask: 2x256
321
+ query: [2, 32256, 24, 128])
322
+ encoder_query: [2, 256, 24, 128]
323
+ """
324
+ query, encoder_query = q
325
+ key, encoder_key = k
326
+ value, encoder_value = v
327
+ rank = torch.distributed.get_rank()
328
+ if get_sequence_parallel_state():
329
+ query = all_to_all_4D(query, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
330
+ key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
331
+ value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
332
+ def shrink_head(encoder_state, dim):
333
+ local_heads = encoder_state.shape[dim] // nccl_info.sp_size
334
+ return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
335
+ encoder_query = shrink_head(encoder_query, dim=2)
336
+ encoder_key = shrink_head(encoder_key, dim=2)
337
+ encoder_value = shrink_head(encoder_value, dim=2)
338
+
339
+ sequence_length = query.size(1) # 32256
340
+ encoder_sequence_length = encoder_query.size(1) # 256
341
+
342
+ query = torch.cat([query, encoder_query], dim=1)
343
+ key = torch.cat([key, encoder_key], dim=1)
344
+ value = torch.cat([value, encoder_value], dim=1)
345
+ bsz = query.shape[0]
346
+ head = query.shape[-2]
347
+ head_dim = query.shape[-1]
348
+ if use_sage:
349
+ from sageattention import sageattn
350
+ hidden_states = sageattn(query, key, value, tensor_layout="NHD")
351
+ else:
352
+ query, key, value = [
353
+ x.view(x.shape[0] * x.shape[1], *x.shape[2:])
354
+ for x in [query, key, value]
355
+ ]
356
+ hidden_states = flash_attn_varlen_func(
357
+ query,
358
+ key,
359
+ value,
360
+ cu_seqlens_q,
361
+ cu_seqlens_kv,
362
+ max_seqlen_q,
363
+ max_seqlen_kv,
364
+ )
365
+
366
+ # B, S, 3, H, D
367
+ hidden_states = hidden_states.view(bsz, max_seqlen_q, head, head_dim).contiguous()
368
+
369
+ hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length),
370
+ dim=1)
371
+ if get_sequence_parallel_state():
372
+ hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2)
373
+ encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous()
374
+ hidden_states = hidden_states.to(query.dtype)
375
+ encoder_hidden_states = encoder_hidden_states.to(query.dtype)
376
+
377
+ attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
378
+
379
+ b, s, _, _= attn.shape
380
+ attn = attn.reshape(b, s, -1)
381
+ return attn, None
hymm_sp/modules/posemb_layers.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def get_1d_rotary_pos_embed(dim: int,
66
+ pos: Union[torch.FloatTensor, int],
67
+ theta: float = 10000.0,
68
+ use_real: bool = False,
69
+ theta_rescale_factor: float = 1.0,
70
+ interpolation_factor: float = 1.0,
71
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
72
+ """
73
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
74
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
75
+
76
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
77
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
78
+ The returned tensor contains complex values in complex64 data type.
79
+
80
+ Args:
81
+ dim (int): Dimension of the frequency tensor.
82
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
83
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
84
+ use_real (bool, optional): If True, return real part and imaginary part separately.
85
+ Otherwise, return complex numbers.
86
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
87
+
88
+ Returns:
89
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
90
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
91
+ """
92
+ if isinstance(pos, int):
93
+ pos = torch.arange(pos).float()
94
+
95
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
96
+ # has some connection to NTK literature
97
+ if theta_rescale_factor != 1.0:
98
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
99
+
100
+ freqs = 1.0 / (
101
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
102
+ ) # [D/2]
103
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
104
+ if use_real:
105
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
106
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
107
+ return freqs_cos, freqs_sin
108
+ else:
109
+ freqs_cis = torch.polar(
110
+ torch.ones_like(freqs), freqs
111
+ ) # complex64 # [S, D/2]
112
+ return freqs_cis
hymm_sp/modules/token_refiner.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .activation_layers import get_activation_layer
8
+ from .attn_layers import attention
9
+ from .norm_layers import get_norm_layer
10
+ from .embed_layers import TimestepEmbedder, TextProjection
11
+ from .attn_layers import attention
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ """
18
+ Transformer block for refining individual tokens with adaptive layer normalization.
19
+
20
+ Combines self-attention and feed-forward network (FFN) layers with modulation
21
+ based on conditional inputs (timestep and context embeddings). Supports query-key
22
+ normalization for improved attention stability.
23
+ """
24
+ def __init__(
25
+ self,
26
+ hidden_size,
27
+ num_heads,
28
+ mlp_ratio: str = 4.0,
29
+ mlp_drop_rate: float = 0.0,
30
+ act_type: str = "silu",
31
+ qk_norm: bool = False,
32
+ qk_norm_type: str = "layer",
33
+ qkv_bias: bool = True,
34
+ dtype: Optional[torch.dtype] = None,
35
+ device: Optional[torch.device] = None,
36
+ ):
37
+ factory_kwargs = {'device': device, 'dtype': dtype}
38
+ super().__init__()
39
+ self.num_heads = num_heads
40
+ head_dim = hidden_size // num_heads
41
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
42
+
43
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
44
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
45
+ qk_norm_layer = get_norm_layer(qk_norm_type)
46
+ self.self_attn_q_norm = (
47
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
48
+ if qk_norm
49
+ else nn.Identity()
50
+ )
51
+ self.self_attn_k_norm = (
52
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
53
+ if qk_norm
54
+ else nn.Identity()
55
+ )
56
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
57
+
58
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
59
+ act_layer = get_activation_layer(act_type)
60
+ self.mlp = MLP(
61
+ in_channels=hidden_size,
62
+ hidden_channels=mlp_hidden_dim,
63
+ act_layer=act_layer,
64
+ drop=mlp_drop_rate,
65
+ **factory_kwargs,
66
+ )
67
+
68
+ self.adaLN_modulation = nn.Sequential(
69
+ act_layer(),
70
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs)
71
+ )
72
+ # Zero-initialize the modulation
73
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
74
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
75
+
76
+ def forward(
77
+ self,
78
+ x: torch.Tensor,
79
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
80
+ attn_mask: torch.Tensor = None,
81
+ ):
82
+ """
83
+ Forward pass of the transformer block.
84
+
85
+ Args:
86
+ x: Input token embeddings (batch_size, seq_len, hidden_size)
87
+ c: Conditional embeddings (batch_size, hidden_size)
88
+ attn_mask: Attention mask (batch_size, 1, seq_len, seq_len)
89
+
90
+ Returns:
91
+ Updated token embeddings after self-attention and FFN
92
+ """
93
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
94
+
95
+ norm_x = self.norm1(x)
96
+ qkv = self.self_attn_qkv(norm_x)
97
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
98
+ # Apply QK-Norm if needed
99
+ q = self.self_attn_q_norm(q).to(v)
100
+ k = self.self_attn_k_norm(k).to(v)
101
+
102
+ # Self-Attention
103
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
104
+
105
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
106
+
107
+ # FFN Layer
108
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
109
+
110
+ return x
111
+
112
+
113
+ class IndividualTokenRefiner(nn.Module):
114
+ """
115
+ Stack of IndividualTokenRefinerBlocks for sequential token refinement.
116
+
117
+ Processes token sequences through multiple transformer blocks with
118
+ attention masking support for handling variable-length sequences.
119
+ """
120
+ def __init__(
121
+ self,
122
+ hidden_size,
123
+ num_heads,
124
+ depth,
125
+ mlp_ratio: float = 4.0,
126
+ mlp_drop_rate: float = 0.0,
127
+ act_type: str = "silu",
128
+ qk_norm: bool = False,
129
+ qk_norm_type: str = "layer",
130
+ qkv_bias: bool = True,
131
+ dtype: Optional[torch.dtype] = None,
132
+ device: Optional[torch.device] = None,
133
+ ):
134
+ factory_kwargs = {'device': device, 'dtype': dtype}
135
+ super().__init__()
136
+ self.blocks = nn.ModuleList([
137
+ IndividualTokenRefinerBlock(
138
+ hidden_size=hidden_size,
139
+ num_heads=num_heads,
140
+ mlp_ratio=mlp_ratio,
141
+ mlp_drop_rate=mlp_drop_rate,
142
+ act_type=act_type,
143
+ qk_norm=qk_norm,
144
+ qk_norm_type=qk_norm_type,
145
+ qkv_bias=qkv_bias,
146
+ **factory_kwargs,
147
+ ) for _ in range(depth)
148
+ ])
149
+
150
+ def forward(
151
+ self,
152
+ x: torch.Tensor,
153
+ c: torch.LongTensor,
154
+ mask: Optional[torch.Tensor] = None,
155
+ ):
156
+ """
157
+ Forward pass through the stack of transformer blocks.
158
+
159
+ Args:
160
+ x: Input token embeddings (batch_size, seq_len, hidden_size)
161
+ c: Conditional embeddings (batch_size, hidden_size)
162
+ mask: Sequence mask indicating valid tokens (batch_size, seq_len)
163
+
164
+ Returns:
165
+ Refined token embeddings after all blocks
166
+ """
167
+ self_attn_mask = None
168
+ if mask is not None:
169
+ batch_size = mask.shape[0]
170
+ seq_len = mask.shape[1]
171
+ mask = mask.to(x.device)
172
+ # batch_size x 1 x seq_len x seq_len
173
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
174
+ # batch_size x 1 x seq_len x seq_len
175
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
176
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of num_heads
177
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
178
+ # avoids self-attention weight being NaN for padding tokens
179
+ self_attn_mask[:, :, :, 0] = True
180
+
181
+ for block in self.blocks:
182
+ x = block(x, c, self_attn_mask)
183
+ return x
184
+
185
+
186
+ class SingleTokenRefiner(nn.Module):
187
+ """
188
+ Complete token refinement module with input embedding and conditional modulation.
189
+
190
+ Integrates timestep embedding, context projection, and a stack of transformer
191
+ blocks to refine token sequences based on both input data and conditional inputs.
192
+ """
193
+ def __init__(
194
+ self,
195
+ in_channels,
196
+ hidden_size,
197
+ num_heads,
198
+ depth,
199
+ mlp_ratio: float = 4.0,
200
+ mlp_drop_rate: float = 0.0,
201
+ act_type: str = "silu",
202
+ qk_norm: bool = False,
203
+ qk_norm_type: str = "layer",
204
+ qkv_bias: bool = True,
205
+ dtype: Optional[torch.dtype] = None,
206
+ device: Optional[torch.device] = None,
207
+ ):
208
+ factory_kwargs = {'device': device, 'dtype': dtype}
209
+ super().__init__()
210
+
211
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
212
+
213
+ act_layer = get_activation_layer(act_type)
214
+ # Build timestep embedding layer
215
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
216
+ # Build context embedding layer
217
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
218
+
219
+ self.individual_token_refiner = IndividualTokenRefiner(
220
+ hidden_size=hidden_size,
221
+ num_heads=num_heads,
222
+ depth=depth,
223
+ mlp_ratio=mlp_ratio,
224
+ mlp_drop_rate=mlp_drop_rate,
225
+ act_type=act_type,
226
+ qk_norm=qk_norm,
227
+ qk_norm_type=qk_norm_type,
228
+ qkv_bias=qkv_bias,
229
+ **factory_kwargs
230
+ )
231
+
232
+ def forward(
233
+ self,
234
+ x: torch.Tensor,
235
+ t: torch.LongTensor,
236
+ mask: Optional[torch.LongTensor] = None,
237
+ ):
238
+ """
239
+ Forward pass of the complete token refiner.
240
+
241
+ Args:
242
+ x: Input features (batch_size, seq_len, in_channels)
243
+ t: Timestep indices (batch_size,)
244
+ mask: Sequence mask for variable-length inputs (batch_size, seq_len)
245
+
246
+ Returns:
247
+ Refined token embeddings (batch_size, seq_len, hidden_size)
248
+ """
249
+ timestep_aware_representations = self.t_embedder(t)
250
+
251
+ if mask is None:
252
+ context_aware_representations = x.mean(dim=1)
253
+ else:
254
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
255
+ context_aware_representations = (
256
+ (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
257
+ )
258
+ context_aware_representations = self.c_embedder(context_aware_representations)
259
+ c = timestep_aware_representations + context_aware_representations
260
+
261
+ x = self.input_embedder(x)
262
+
263
+ x = self.individual_token_refiner(x, c, mask)
264
+
265
+ return x
hymm_sp/sample_batch.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from loguru import logger
4
+ import torch
5
+ import numpy as np
6
+ import torch.distributed
7
+ import random
8
+ import torchvision.transforms as transforms
9
+ from PIL import Image
10
+ import cv2
11
+
12
+ from torch.utils.data.distributed import DistributedSampler
13
+ from torch.utils.data import DataLoader
14
+ from hymm_sp.config import parse_args
15
+ from hymm_sp.sample_inference import HunyuanVideoSampler
16
+ from hymm_sp.data_kits.video_dataset import VideoCSVDataset
17
+ from hymm_sp.data_kits.data_tools import save_videos_grid
18
+ from hymm_sp.modules.parallel_states import (
19
+ initialize_distributed,
20
+ nccl_info,
21
+ )
22
+
23
+ class CropResize:
24
+ """
25
+ Custom transform to resize and crop images to a target size while preserving aspect ratio.
26
+
27
+ Resizes the image to ensure it covers the target dimensions, then center-crops to the exact size.
28
+ Useful for preparing consistent input dimensions for video generation models.
29
+ """
30
+ def __init__(self, size=(704, 1216)):
31
+ """
32
+ Args:
33
+ size (tuple): Target dimensions (height, width) for the output image
34
+ """
35
+ self.target_h, self.target_w = size
36
+
37
+ def __call__(self, img):
38
+ """
39
+ Apply the transform to an image.
40
+
41
+ Args:
42
+ img (PIL.Image): Input image to transform
43
+
44
+ Returns:
45
+ PIL.Image: Resized and cropped image with target dimensions
46
+ """
47
+ # Get original image dimensions
48
+ w, h = img.size
49
+
50
+ # Calculate scaling factor to ensure image covers target size
51
+ scale = max(
52
+ self.target_w / w, # Scale needed to cover target width
53
+ self.target_h / h # Scale needed to cover target height
54
+ )
55
+
56
+ # Resize image while preserving aspect ratio
57
+ new_size = (int(h * scale), int(w * scale))
58
+ resize_transform = transforms.Resize(
59
+ new_size,
60
+ interpolation=transforms.InterpolationMode.BILINEAR
61
+ )
62
+ resized_img = resize_transform(img)
63
+
64
+ # Center-crop to exact target dimensions
65
+ crop_transform = transforms.CenterCrop((self.target_h, self.target_w))
66
+ return crop_transform(resized_img)
67
+
68
+
69
+ def main():
70
+ """
71
+ Main function for video generation using the Hunyuan multimodal model.
72
+
73
+ Handles argument parsing, distributed setup, model loading, data preparation,
74
+ and video generation with action-controlled transitions. Supports both image-to-video
75
+ and video-to-video generation tasks.
76
+ """
77
+ # Parse command-line arguments and configuration
78
+ args = parse_args()
79
+ models_root_path = Path(args.ckpt)
80
+ action_list = args.action_list
81
+ action_speed_list = args.action_speed_list
82
+ negative_prompt = args.add_neg_prompt
83
+
84
+ # Initialize distributed training/evaluation environment
85
+ logger.info("*" * 20)
86
+ initialize_distributed(args.seed)
87
+
88
+ # Validate model checkpoint path exists
89
+ if not models_root_path.exists():
90
+ raise ValueError(f"Model checkpoint path does not exist: {models_root_path}")
91
+ logger.info("+" * 20)
92
+
93
+ # Set up output directory
94
+ save_path = args.save_path if args.save_path_suffix == "" else f'{args.save_path}_{args.save_path_suffix}'
95
+ os.makedirs(save_path, exist_ok=True)
96
+ logger.info(f"Generated videos will be saved to: {save_path}")
97
+
98
+ # Initialize device configuration for distributed processing
99
+ rank = 0
100
+ device = torch.device("cuda")
101
+ if nccl_info.sp_size > 1:
102
+ # Use specific GPU based on process rank in distributed setup
103
+ device = torch.device(f"cuda:{torch.distributed.get_rank()}")
104
+ rank = torch.distributed.get_rank()
105
+
106
+ # Load the Hunyuan video sampler model from checkpoint
107
+ logger.info(f"Loading model from checkpoint: {args.ckpt}")
108
+ hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(
109
+ args.ckpt,
110
+ args=args,
111
+ device=device if not args.cpu_offload else torch.device("cpu")
112
+ )
113
+ # Update args with model-specific configurations from the checkpoint
114
+ args = hunyuan_video_sampler.args
115
+
116
+ # Enable CPU offloading if specified to reduce GPU memory usage
117
+ if args.cpu_offload:
118
+ from diffusers.hooks import apply_group_offloading
119
+ onload_device = torch.device("cuda")
120
+ apply_group_offloading(
121
+ hunyuan_video_sampler.pipeline.transformer,
122
+ onload_device=onload_device,
123
+ offload_type="block_level",
124
+ num_blocks_per_group=1
125
+ )
126
+ logger.info("Enabled CPU offloading for transformer blocks")
127
+
128
+ # Process each batch in the dataset
129
+
130
+ prompt = args.prompt
131
+ image_paths = [args.image_path]
132
+ logger.info(f"Prompt: {prompt}, Image Path {args.image_path}")
133
+ # Generate random seed for reproducibility
134
+ seed = args.seed if args.seed else random.randint(0, 1_000_000)
135
+
136
+ # Define image transformation pipeline for input reference images
137
+ closest_size = (704, 1216)
138
+ ref_image_transform = transforms.Compose([
139
+ CropResize(closest_size),
140
+ transforms.CenterCrop(closest_size),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1] range
143
+ ])
144
+
145
+ # Handle image-based generation (start from a single image)
146
+ if args.image_start:
147
+ # Load and preprocess reference images
148
+ raw_ref_images = [Image.open(image_path).convert('RGB') for image_path in image_paths]
149
+
150
+ # Apply transformations and prepare tensor for model input
151
+ ref_images_pixel_values = [ref_image_transform(ref_image) for ref_image in raw_ref_images]
152
+ ref_images_pixel_values = torch.cat(ref_images_pixel_values).unsqueeze(0).unsqueeze(2).to(device)
153
+
154
+ # Encode reference images to latent space using VAE
155
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
156
+ if args.cpu_offload:
157
+ # Move VAE components to GPU temporarily for encoding
158
+ hunyuan_video_sampler.vae.quant_conv.to('cuda')
159
+ hunyuan_video_sampler.vae.encoder.to('cuda')
160
+
161
+ # Enable tiling for VAE to handle large images efficiently
162
+ hunyuan_video_sampler.pipeline.vae.enable_tiling()
163
+
164
+ # Encode image to latents and scale by VAE's scaling factor
165
+ raw_last_latents = hunyuan_video_sampler.vae.encode(
166
+ ref_images_pixel_values
167
+ ).latent_dist.sample().to(dtype=torch.float16) # Shape: (B, C, F, H, W)
168
+ raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
169
+ raw_ref_latents = raw_last_latents.clone()
170
+
171
+ # Clean up
172
+ hunyuan_video_sampler.pipeline.vae.disable_tiling()
173
+ if args.cpu_offload:
174
+ # Move VAE components back to CPU after encoding
175
+ hunyuan_video_sampler.vae.quant_conv.to('cpu')
176
+ hunyuan_video_sampler.vae.encoder.to('cpu')
177
+
178
+
179
+ # Handle video-based generation (start from an existing video)
180
+ else:
181
+ from decord import VideoReader # Lazy import for video handling
182
+
183
+ # Validate video file exists
184
+ video_path = args.video_path
185
+ if not os.path.exists(video_path):
186
+ raise RuntimeError(f"Video file not found: {video_path}")
187
+
188
+ # Load reference images from video metadata
189
+ raw_ref_images = [Image.open(image_path).convert('RGB') for image_path in image_paths]
190
+
191
+ # Load video and extract frames
192
+ ref_video = VideoReader(video_path)
193
+ ref_frames_length = len(ref_video)
194
+ logger.info(f"Loaded reference video with {ref_frames_length} frames")
195
+
196
+ # Preprocess video frames
197
+ transformed_images = []
198
+ for index in range(ref_frames_length):
199
+ # Convert video frame to PIL image and apply transformations
200
+ video_image = ref_video[index].numpy()
201
+ transformed_image = ref_image_transform(Image.fromarray(video_image))
202
+ transformed_images.append(transformed_image)
203
+
204
+ # Prepare tensor for model input
205
+ transformed_images = torch.stack(transformed_images, dim=1).unsqueeze(0).to(
206
+ device=hunyuan_video_sampler.device,
207
+ dtype=torch.float16
208
+ )
209
+
210
+ # Encode video frames to latent space using VAE
211
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
212
+ if args.cpu_offload:
213
+ hunyuan_video_sampler.vae.quant_conv.to('cuda')
214
+ hunyuan_video_sampler.vae.encoder.to('cuda')
215
+
216
+ hunyuan_video_sampler.pipeline.vae.enable_tiling()
217
+
218
+ # Encode last 33 frames of video (model-specific requirement)
219
+ raw_last_latents = hunyuan_video_sampler.vae.encode(
220
+ transformed_images[:, :, -33:, ...]
221
+ ).latent_dist.sample().to(dtype=torch.float16)
222
+ raw_last_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
223
+
224
+ # Encode a single reference frame from the video
225
+ raw_ref_latents = hunyuan_video_sampler.vae.encode(
226
+ transformed_images[:, :, -33:-32, ...]
227
+ ).latent_dist.sample().to(dtype=torch.float16)
228
+ raw_ref_latents.mul_(hunyuan_video_sampler.vae.config.scaling_factor)
229
+
230
+ # Clean up
231
+ hunyuan_video_sampler.pipeline.vae.disable_tiling()
232
+ if args.cpu_offload:
233
+ hunyuan_video_sampler.vae.quant_conv.to('cpu')
234
+ hunyuan_video_sampler.vae.encoder.to('cpu')
235
+
236
+ # Store references for generation loop
237
+ ref_images = raw_ref_images
238
+ last_latents = raw_last_latents
239
+ ref_latents = raw_ref_latents
240
+
241
+ # Generate video segments for each action in the action list
242
+ for idx, action_id in enumerate(action_list):
243
+ # Determine if this is the first action and using image start
244
+ is_image = (idx == 0 and args.image_start)
245
+
246
+ logger.info(f"Generating segment {idx+1}/{len(action_list)} with action ID: {action_id}")
247
+ # Generate video segment with the current action
248
+ outputs = hunyuan_video_sampler.predict(
249
+ prompt=prompt,
250
+ action_id=action_id,
251
+ action_speed=action_speed_list[idx],
252
+ is_image=is_image,
253
+ size=(704, 1216),
254
+ seed=seed,
255
+ last_latents=last_latents, # Previous frame latents for continuity
256
+ ref_latents=ref_latents, # Reference latents for style consistency
257
+ video_length=args.sample_n_frames,
258
+ guidance_scale=args.cfg_scale,
259
+ num_images_per_prompt=args.num_images,
260
+ negative_prompt=negative_prompt,
261
+ infer_steps=args.infer_steps,
262
+ flow_shift=args.flow_shift_eval_video,
263
+ use_linear_quadratic_schedule=args.use_linear_quadratic_schedule,
264
+ linear_schedule_end=args.linear_schedule_end,
265
+ use_deepcache=args.use_deepcache,
266
+ cpu_offload=args.cpu_offload,
267
+ ref_images=ref_images,
268
+ output_dir=save_path,
269
+ return_latents=True,
270
+ use_sage=args.use_sage,
271
+ )
272
+
273
+ # Update latents for next iteration (maintain temporal consistency)
274
+ ref_latents = outputs["ref_latents"]
275
+ last_latents = outputs["last_latents"]
276
+
277
+ # Save generated video segments if this is the main process (rank 0)
278
+ if rank == 0:
279
+ sub_samples = outputs['samples'][0]
280
+
281
+ # Initialize or concatenate video segments
282
+ if idx == 0:
283
+ if args.image_start:
284
+ out_cat = sub_samples
285
+ else:
286
+ # Combine original video frames with generated frames
287
+ out_cat = torch.cat([(transformed_images.detach().cpu() + 1) / 2.0, sub_samples], dim=2)
288
+ else:
289
+ # Append new segment to existing video
290
+ out_cat = torch.cat([out_cat, sub_samples], dim=2)
291
+
292
+ # Save final combined video
293
+ save_path_mp4 = f"{save_path}/{os.path.basename(args.image_path).split('.')[0]}.mp4"
294
+ save_videos_grid(out_cat, save_path_mp4, n_rows=1, fps=24)
295
+ logger.info(f"Saved generated video to: {save_path_mp4}")
296
+
297
+ if __name__ == "__main__":
298
+ main()
hymm_sp/sample_inference.py ADDED
@@ -0,0 +1,716 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ import torch
4
+ import random
5
+ from loguru import logger
6
+ import numpy as np
7
+ import matplotlib as mpl
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.patches import Patch
10
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
11
+
12
+ from hymm_sp.diffusion import load_diffusion_pipeline
13
+ from hymm_sp.helpers import get_nd_rotary_pos_embed_new
14
+ from hymm_sp.inference import Inference
15
+ from hymm_sp.diffusion.schedulers import FlowMatchDiscreteScheduler
16
+ from packaging import version as pver
17
+
18
+ ACTION_DICT = {"w": "forward", "a": "left", "d": "right", "s": "backward"}
19
+
20
+ def custom_meshgrid(*args):
21
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
22
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
23
+ return torch.meshgrid(*args)
24
+ else:
25
+ return torch.meshgrid(*args, indexing='ij')
26
+
27
+ def get_relative_pose(cam_params):
28
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
29
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
30
+ source_cam_c2w = abs_c2ws[0]
31
+ cam_to_origin = 0
32
+ target_cam_c2w = np.array([
33
+ [1, 0, 0, 0],
34
+ [0, 1, 0, -cam_to_origin],
35
+ [0, 0, 1, 0],
36
+ [0, 0, 0, 1]
37
+ ])
38
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
39
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
40
+ for pose in ret_poses:
41
+ pose[:3, -1:] *= 10
42
+ ret_poses = np.array(ret_poses, dtype=np.float32)
43
+ return ret_poses
44
+
45
+ def ray_condition(K, c2w, H, W, device, flip_flag=None):
46
+ # c2w: B, V, 4, 4
47
+ # K: B, V, 4
48
+
49
+ B, V = K.shape[:2]
50
+
51
+ j, i = custom_meshgrid(
52
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
53
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
54
+ )
55
+ i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
56
+ j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
57
+
58
+ n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
59
+ if n_flip > 0:
60
+ j_flip, i_flip = custom_meshgrid(
61
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
62
+ torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
63
+ )
64
+ i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
65
+ j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
66
+ i[:, flip_flag, ...] = i_flip
67
+ j[:, flip_flag, ...] = j_flip
68
+
69
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
70
+
71
+ zs = torch.ones_like(i) # [B, V, HxW]
72
+ xs = (i - cx) / fx * zs
73
+ ys = (j - cy) / fy * zs
74
+ zs = zs.expand_as(ys)
75
+
76
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
77
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
78
+
79
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
80
+ rays_o = c2w[..., :3, 3] # B, V, 3
81
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
82
+ # c2w @ dirctions
83
+ rays_dxo = torch.cross(rays_o, rays_d) # B, V, HW, 3
84
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
85
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
86
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
87
+ return plucker
88
+
89
+ def get_c2w(w2cs, transform_matrix, relative_c2w):
90
+ if relative_c2w:
91
+ target_cam_c2w = np.array([
92
+ [1, 0, 0, 0],
93
+ [0, 1, 0, 0],
94
+ [0, 0, 1, 0],
95
+ [0, 0, 0, 1]
96
+ ])
97
+ abs2rel = target_cam_c2w @ w2cs[0]
98
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ np.linalg.inv(w2c) for w2c in w2cs[1:]]
99
+ for pose in ret_poses:
100
+ pose[:3, -1:] *= 2
101
+ # ret_poses = [poses[:, :3]*2 for poses in ret_poses]
102
+ # ret_poses[:, :, :3] *= 2
103
+ else:
104
+ ret_poses = [np.linalg.inv(w2c) for w2c in w2cs]
105
+ ret_poses = [transform_matrix @ x for x in ret_poses]
106
+ return np.array(ret_poses, dtype=np.float32)
107
+
108
+ def generate_motion_segment(current_pose,
109
+ motion_type: str,
110
+ value: float,
111
+ duration: int = 30):
112
+ """
113
+ Parameters:
114
+ motion_type: ('forward', 'backward', 'left', 'right',
115
+ 'rotate_left', 'rotate_right', 'rotate_up', 'rotate_down')
116
+ value: Translation(m) or Rotation(degree)
117
+ duration: frames
118
+
119
+ Return:
120
+ positions: [np.array(x,y,z), ...]
121
+ rotations: [np.array(pitch,yaw,roll), ...]
122
+ """
123
+ positions = []
124
+ rotations = []
125
+
126
+ if motion_type in ['forward', 'backward']:
127
+ yaw_rad = np.radians(current_pose['rotation'][1])
128
+ pitch_rad = np.radians(current_pose['rotation'][0])
129
+
130
+ forward_vec = np.array([
131
+ -math.sin(yaw_rad) * math.cos(pitch_rad),
132
+ math.sin(pitch_rad),
133
+ -math.cos(yaw_rad) * math.cos(pitch_rad)
134
+ ])
135
+
136
+ direction = 1 if motion_type == 'forward' else -1
137
+ total_move = forward_vec * value * direction
138
+ step = total_move / duration
139
+
140
+ for i in range(1, duration+1):
141
+ new_pos = current_pose['position'] + step * i
142
+ positions.append(new_pos.copy())
143
+ rotations.append(current_pose['rotation'].copy())
144
+
145
+ current_pose['position'] = positions[-1]
146
+
147
+ elif motion_type in ['left', 'right']:
148
+ yaw_rad = np.radians(current_pose['rotation'][1])
149
+ right_vec = np.array([math.cos(yaw_rad), 0, -math.sin(yaw_rad)])
150
+
151
+ direction = -1 if motion_type == 'right' else 1
152
+ total_move = right_vec * value * direction
153
+ step = total_move / duration
154
+
155
+ for i in range(1, duration+1):
156
+ new_pos = current_pose['position'] + step * i
157
+ positions.append(new_pos.copy())
158
+ rotations.append(current_pose['rotation'].copy())
159
+
160
+ current_pose['position'] = positions[-1]
161
+
162
+ elif motion_type.endswith('rot'):
163
+ axis = motion_type.split('_')[0]
164
+ total_rotation = np.zeros(3)
165
+
166
+ if axis == 'left':
167
+ total_rotation[0] = value
168
+ elif axis == 'right':
169
+ total_rotation[0] = -value
170
+ elif axis == 'up':
171
+ total_rotation[2] = -value
172
+ elif axis == 'down':
173
+ total_rotation[2] = value
174
+
175
+ step = total_rotation / duration
176
+
177
+ for i in range(1, duration+1):
178
+ positions.append(current_pose['position'].copy())
179
+ new_rot = current_pose['rotation'] + step * i
180
+ rotations.append(new_rot.copy())
181
+
182
+ current_pose['rotation'] = rotations[-1]
183
+
184
+ return positions, rotations, current_pose
185
+
186
+ def euler_to_quaternion(angles):
187
+ pitch, yaw, roll = np.radians(angles)
188
+
189
+ cy = math.cos(yaw * 0.5)
190
+ sy = math.sin(yaw * 0.5)
191
+ cp = math.cos(pitch * 0.5)
192
+ sp = math.sin(pitch * 0.5)
193
+ cr = math.cos(roll * 0.5)
194
+ sr = math.sin(roll * 0.5)
195
+
196
+ qw = cy * cp * cr + sy * sp * sr
197
+ qx = cy * cp * sr - sy * sp * cr
198
+ qy = sy * cp * sr + cy * sp * cr
199
+ qz = sy * cp * cr - cy * sp * sr
200
+
201
+ return [qw, qx, qy, qz]
202
+
203
+ def quaternion_to_rotation_matrix(q):
204
+ qw, qx, qy, qz = q
205
+ return np.array([
206
+ [1 - 2*(qy**2 + qz**2), 2*(qx*qy - qw*qz), 2*(qx*qz + qw*qy)],
207
+ [2*(qx*qy + qw*qz), 1 - 2*(qx**2 + qz**2), 2*(qy*qz - qw*qx)],
208
+ [2*(qx*qz - qw*qy), 2*(qy*qz + qw*qx), 1 - 2*(qx**2 + qy**2)]
209
+ ])
210
+
211
+ def ActionToPoseFromID(action_id, value=0.2, duration=33):
212
+
213
+ all_positions = []
214
+ all_rotations = []
215
+ current_pose = {
216
+ 'position': np.array([0.0, 0.0, 0.0]), # XYZ
217
+ 'rotation': np.array([0.0, 0.0, 0.0]) # (pitch, yaw, roll)
218
+ }
219
+ intrinsic = [0.50505, 0.8979, 0.5, 0.5]
220
+ motion_type = ACTION_DICT[action_id]
221
+ positions, rotations, current_pose = generate_motion_segment(current_pose, motion_type, value, duration)
222
+ all_positions.extend(positions)
223
+ all_rotations.extend(rotations)
224
+
225
+ pose_list = []
226
+
227
+ row = [0] + intrinsic + [0, 0] + [1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]
228
+ first_row = " ".join(map(str, row))
229
+ pose_list.append(first_row)
230
+ for i, (pos, rot) in enumerate(zip(all_positions, all_rotations)):
231
+ quat = euler_to_quaternion(rot)
232
+ R = quaternion_to_rotation_matrix(quat)
233
+ extrinsic = np.hstack([R, pos.reshape(3, 1)])
234
+
235
+ row = [i] + intrinsic + [0, 0] + extrinsic.flatten().tolist()
236
+ pose_list.append(" ".join(map(str, row)))
237
+
238
+ return pose_list
239
+
240
+ class Camera(object):
241
+ def __init__(self, entry):
242
+ fx, fy, cx, cy = entry[1:5]
243
+ self.fx = fx
244
+ self.fy = fy
245
+ self.cx = cx
246
+ self.cy = cy
247
+ w2c_mat = np.array(entry[7:]).reshape(3, 4)
248
+ w2c_mat_4x4 = np.eye(4)
249
+ w2c_mat_4x4[:3, :] = w2c_mat
250
+ self.w2c_mat = w2c_mat_4x4
251
+ self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
252
+
253
+ class CameraPoseVisualizer:
254
+ def __init__(self, xlim, ylim, zlim):
255
+ self.fig = plt.figure(figsize=(7, 7))
256
+ self.ax = self.fig.add_subplot(projection='3d')
257
+ # self.ax.view_init(elev=25, azim=-90)
258
+ self.plotly_data = None # plotly data traces
259
+ self.ax.set_aspect("auto")
260
+ self.ax.set_xlim(xlim)
261
+ self.ax.set_ylim(ylim)
262
+ self.ax.set_zlim(zlim)
263
+ self.ax.set_xlabel('x')
264
+ self.ax.set_ylabel('y')
265
+ self.ax.set_zlabel('z')
266
+ print('initialize camera pose visualizer')
267
+
268
+ def extrinsic2pyramid(self, extrinsic, color_map='red', hw_ratio=9/16, base_xval=1, zval=3):
269
+ vertex_std = np.array([[0, 0, 0, 1],
270
+ [base_xval, -base_xval * hw_ratio, zval, 1],
271
+ [base_xval, base_xval * hw_ratio, zval, 1],
272
+ [-base_xval, base_xval * hw_ratio, zval, 1],
273
+ [-base_xval, -base_xval * hw_ratio, zval, 1]])
274
+ vertex_transformed = vertex_std @ extrinsic.T
275
+ meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]],
276
+ [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]],
277
+ [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]],
278
+ [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]],
279
+ [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1],
280
+ vertex_transformed[4, :-1]]]
281
+
282
+ color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map)
283
+
284
+ self.ax.add_collection3d(
285
+ Poly3DCollection(meshes, facecolors=color, linewidths=0.3, edgecolors=color, alpha=0.35))
286
+
287
+ def customize_legend(self, list_label):
288
+ list_handle = []
289
+ for idx, label in enumerate(list_label):
290
+ color = plt.cm.rainbow(idx / len(list_label))
291
+ patch = Patch(color=color, label=label)
292
+ list_handle.append(patch)
293
+ plt.legend(loc='right', bbox_to_anchor=(1.8, 0.5), handles=list_handle)
294
+
295
+ def colorbar(self, max_frame_length):
296
+ cmap = mpl.cm.rainbow
297
+ norm = mpl.colors.Normalize(vmin=0, vmax=max_frame_length)
298
+ self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
299
+ ax=self.ax, orientation='vertical', label='Frame Number')
300
+
301
+ def show(self, file_name):
302
+ plt.title('Extrinsic Parameters')
303
+ # plt.show()
304
+ plt.savefig(file_name)
305
+
306
+
307
+ def align_to(value, alignment):
308
+ return int(math.ceil(value / alignment) * alignment)
309
+
310
+
311
+ def GetPoseEmbedsFromPoses(poses, h, w, target_length, flip=False, start_index=None):
312
+
313
+ poses = [pose.split(' ') for pose in poses]
314
+
315
+ start_idx = start_index
316
+ sample_id = [start_idx + i for i in range(target_length)]
317
+
318
+ poses = [poses[i] for i in sample_id]
319
+
320
+ frame = len(poses)
321
+ w2cs = [np.asarray([float(p) for p in pose[7:]]).reshape(3, 4) for pose in poses]
322
+ transform_matrix = np.asarray([[1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]).reshape(4, 4)
323
+ last_row = np.zeros((1, 4))
324
+ last_row[0, -1] = 1.0
325
+ w2cs = [np.concatenate((w2c, last_row), axis=0) for w2c in w2cs]
326
+ c2ws = get_c2w(w2cs, transform_matrix, relative_c2w=True)
327
+
328
+ cam_params = [[float(x) for x in pose] for pose in poses]
329
+ assert len(cam_params) == target_length
330
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
331
+
332
+ monst3r_w = cam_params[0].cx * 2
333
+ monst3r_h = cam_params[0].cy * 2
334
+ ratio_w, ratio_h = w/monst3r_w, h/monst3r_h
335
+ intrinsics = np.asarray([[cam_param.fx * ratio_w,
336
+ cam_param.fy * ratio_h,
337
+ cam_param.cx * ratio_w,
338
+ cam_param.cy * ratio_h]
339
+ for cam_param in cam_params], dtype=np.float32)
340
+ intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
341
+ relative_pose = True
342
+ if relative_pose:
343
+ c2w_poses = get_relative_pose(cam_params)
344
+ else:
345
+ c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
346
+ c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
347
+ uncond_c2w = torch.zeros_like(c2w)
348
+ uncond_c2w[:, :] = torch.eye(4, device=c2w.device)
349
+ flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device)
350
+ plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu',
351
+ flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
352
+ uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu',
353
+ flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
354
+
355
+ return plucker_embedding, uncond_plucker_embedding, poses
356
+
357
+ def GetPoseEmbedsFromTxt(pose_dir, h, w, target_length, flip=False, start_index=None, step=1):
358
+ # get camera pose
359
+ with open(pose_dir, 'r') as f:
360
+ poses = f.readlines()
361
+ poses = [pose.strip().split(' ') for pose in poses[1:]]
362
+ start_idx = start_index
363
+ sample_id = [start_idx + i*step for i in range(target_length)]
364
+ poses = [poses[i] for i in sample_id]
365
+
366
+ cam_params = [[float(x) for x in pose] for pose in poses]
367
+ assert len(cam_params) == target_length
368
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
369
+
370
+ monst3r_w = cam_params[0].cx * 2
371
+ monst3r_h = cam_params[0].cy * 2
372
+ ratio_w, ratio_h = w/monst3r_w, h/monst3r_h
373
+ intrinsics = np.asarray([[cam_param.fx * ratio_w,
374
+ cam_param.fy * ratio_h,
375
+ cam_param.cx * ratio_w,
376
+ cam_param.cy * ratio_h]
377
+ for cam_param in cam_params], dtype=np.float32)
378
+ intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
379
+ relative_pose = True
380
+ if relative_pose:
381
+ c2w_poses = get_relative_pose(cam_params)
382
+ else:
383
+ c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
384
+ c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
385
+ uncond_c2w = torch.zeros_like(c2w)
386
+ uncond_c2w[:, :] = torch.eye(4, device=c2w.device)
387
+ if flip:
388
+ flip_flag = torch.ones(target_length, dtype=torch.bool, device=c2w.device)
389
+ else:
390
+ flip_flag = torch.zeros(target_length, dtype=torch.bool, device=c2w.device)
391
+ plucker_embedding = ray_condition(intrinsics, c2w, h, w, device='cpu',
392
+ flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
393
+ uncond_plucker_embedding = ray_condition(intrinsics, uncond_c2w, h, w, device='cpu',
394
+ flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
395
+
396
+ return plucker_embedding, uncond_plucker_embedding, poses
397
+
398
+
399
+ class HunyuanVideoSampler(Inference):
400
+ def __init__(self, args, vae, vae_kwargs, text_encoder, model, text_encoder_2=None, pipeline=None,
401
+ device=0, logger=None):
402
+ super().__init__(args, vae, vae_kwargs, text_encoder, model, text_encoder_2=text_encoder_2,
403
+ pipeline=pipeline, device=device, logger=logger)
404
+
405
+ self.args = args
406
+ self.pipeline = load_diffusion_pipeline(
407
+ args, 0, self.vae, self.text_encoder, self.text_encoder_2, self.model,
408
+ device=self.device)
409
+ print('load hunyuan model successful... ')
410
+
411
+ def get_rotary_pos_embed(self, video_length, height, width, concat_dict={}):
412
+ target_ndim = 3
413
+ ndim = 5 - 2
414
+ if '884' in self.args.vae:
415
+ latents_size = [(video_length-1)//4+1 , height//8, width//8]
416
+ else:
417
+ latents_size = [video_length , height//8, width//8]
418
+
419
+ if isinstance(self.model.patch_size, int):
420
+ assert all(s % self.model.patch_size == 0 for s in latents_size), \
421
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
422
+ f"but got {latents_size}."
423
+ rope_sizes = [s // self.model.patch_size for s in latents_size]
424
+ elif isinstance(self.model.patch_size, list):
425
+ assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
426
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
427
+ f"but got {latents_size}."
428
+ rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
429
+
430
+ if len(rope_sizes) != target_ndim:
431
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
432
+ head_dim = self.model.hidden_size // self.model.num_heads
433
+ rope_dim_list = self.model.rope_dim_list
434
+ if rope_dim_list is None:
435
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
436
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
437
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
438
+ rope_sizes,
439
+ theta=self.args.rope_theta,
440
+ use_real=True,
441
+ theta_rescale_factor=1,
442
+ concat_dict=concat_dict)
443
+ return freqs_cos, freqs_sin
444
+
445
+ @torch.no_grad()
446
+ def predict(self,
447
+ prompt,
448
+ is_image=True,
449
+ size=(720, 1280),
450
+ video_length=129,
451
+ seed=None,
452
+ negative_prompt=None,
453
+ infer_steps=50,
454
+ guidance_scale=6.0,
455
+ flow_shift=5.0,
456
+ batch_size=1,
457
+ num_videos_per_prompt=1,
458
+ verbose=1,
459
+ output_type="pil",
460
+ **kwargs):
461
+ """
462
+ Predict the image from the given text.
463
+
464
+ Args:
465
+ prompt (str or List[str]): The input text.
466
+ kwargs:
467
+ size (int): The (height, width) of the output image/video. Default is (256, 256).
468
+ video_length (int): The frame number of the output video. Default is 1.
469
+ seed (int or List[str]): The random seed for the generation. Default is a random integer.
470
+ negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
471
+ infer_steps (int): The number of inference steps. Default is 100.
472
+ guidance_scale (float): The guidance scale for the generation. Default is 6.0.
473
+ num_videos_per_prompt (int): The number of videos per prompt. Default is 1.
474
+ verbose (int): 0 for no log, 1 for all log, 2 for fewer log. Default is 1.
475
+ output_type (str): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
476
+ Default is 'pil'.
477
+ """
478
+
479
+ out_dict = dict()
480
+
481
+ # ---------------------------------
482
+ # Prompt
483
+ # ---------------------------------
484
+ prompt_embeds = kwargs.get("prompt_embeds", None)
485
+ attention_mask = kwargs.get("attention_mask", None)
486
+ negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
487
+ negative_attention_mask = kwargs.get("negative_attention_mask", None)
488
+ ref_latents = kwargs.get("ref_latents", None)
489
+ uncond_ref_latents = kwargs.get("uncond_ref_latents", None)
490
+ return_latents = kwargs.get("return_latents", False)
491
+ negative_prompt = kwargs.get("negative_prompt", None)
492
+
493
+ action_id = kwargs.get("action_id", None)
494
+ action_speed = kwargs.get("action_speed", None)
495
+ start_index = kwargs.get("start_index", None)
496
+ last_latents = kwargs.get("last_latents", None)
497
+ ref_latents = kwargs.get("ref_latents", None)
498
+ input_pose = kwargs.get("input_pose", None)
499
+ step = kwargs.get("step", 1)
500
+ use_sage = kwargs.get("use_sage", False)
501
+
502
+ size = self.parse_size(size)
503
+ target_height = align_to(size[0], 16)
504
+ target_width = align_to(size[1], 16)
505
+ # target_video_length = video_length
506
+
507
+ if input_pose is not None:
508
+ pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromTxt(input_pose,
509
+ target_height,
510
+ target_width,
511
+ 33,
512
+ kwargs.get("flip", False),
513
+ start_index,
514
+ step)
515
+ else:
516
+ pose = ActionToPoseFromID(action_id, value=action_speed)
517
+ pose_embeds, uncond_pose_embeds, poses = GetPoseEmbedsFromPoses(pose,
518
+ target_height,
519
+ target_width,
520
+ 33,
521
+ kwargs.get("flip", False),
522
+ 0)
523
+
524
+ if is_image:
525
+ target_length = 34
526
+ else:
527
+ target_length = 66
528
+
529
+ out_dict['frame'] = target_length
530
+ # print("pose embeds: ", pose_embeds.shape, uncond_pose_embeds.shape)
531
+
532
+ pose_embeds = pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda')
533
+ uncond_pose_embeds = uncond_pose_embeds.unsqueeze(0).to(torch.bfloat16).to('cuda')
534
+
535
+
536
+
537
+ cpu_offload = kwargs.get("cpu_offload", 0)
538
+ use_deepcache = kwargs.get("use_deepcache", 1)
539
+ denoise_strength = kwargs.get("denoise_strength", 1.0)
540
+ init_latents = kwargs.get("init_latents", None)
541
+ mask = kwargs.get("mask", None)
542
+ if prompt is None:
543
+ # prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should not be None
544
+ # pipeline will help to check this
545
+ prompt = None
546
+ negative_prompt = None
547
+ batch_size = prompt_embeds.shape[0]
548
+ assert prompt_embeds is not None
549
+ else:
550
+ # prompt_embeds, attention_mask, negative_prompt_embeds and negative_attention_mask should be None
551
+ # pipeline will help to check this
552
+ if isinstance(prompt, str):
553
+ batch_size = 1
554
+ prompt = [prompt]
555
+ elif isinstance(prompt, (list, tuple)):
556
+ batch_size = len(prompt)
557
+ else:
558
+ raise ValueError(f"Prompt must be a string or a list of strings, got {prompt}.")
559
+
560
+ if negative_prompt is None:
561
+ negative_prompt = [""] * batch_size
562
+ if isinstance(negative_prompt, str):
563
+ negative_prompt = [negative_prompt] * batch_size
564
+
565
+ # ---------------------------------
566
+ # Other arguments
567
+ # ---------------------------------
568
+ scheduler = FlowMatchDiscreteScheduler(shift=flow_shift,
569
+ reverse=self.args.flow_reverse,
570
+ solver=self.args.flow_solver,
571
+ )
572
+ self.pipeline.scheduler = scheduler
573
+
574
+ # ---------------------------------
575
+ # Random seed
576
+ # ---------------------------------
577
+
578
+ if isinstance(seed, torch.Tensor):
579
+ seed = seed.tolist()
580
+ if seed is None:
581
+ seeds = [random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt)]
582
+ elif isinstance(seed, int):
583
+ seeds = [seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt)]
584
+ elif isinstance(seed, (list, tuple)):
585
+ if len(seed) == batch_size:
586
+ seeds = [int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt)]
587
+ elif len(seed) == batch_size * num_videos_per_prompt:
588
+ seeds = [int(s) for s in seed]
589
+ else:
590
+ raise ValueError(
591
+ f"Length of seed must be equal to number of prompt(batch_size) or "
592
+ f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
593
+ )
594
+ else:
595
+ raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.")
596
+ generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
597
+
598
+ # ---------------------------------
599
+ # Image/Video size and frame
600
+ # ---------------------------------
601
+
602
+
603
+ out_dict['size'] = (target_height, target_width)
604
+ out_dict['video_length'] = target_length
605
+ out_dict['seeds'] = seeds
606
+ out_dict['negative_prompt'] = negative_prompt
607
+ # ---------------------------------
608
+ # Build RoPE
609
+ # ---------------------------------
610
+
611
+ concat_dict = {'mode': 'timecat', 'bias': -1}
612
+ if is_image:
613
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed(37, target_height, target_width)
614
+ else:
615
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed(69, target_height, target_width)
616
+
617
+ n_tokens = freqs_cos.shape[0]
618
+
619
+ # ---------------------------------
620
+ # Inference
621
+ # ---------------------------------
622
+ output_dir = kwargs.get("output_dir", None)
623
+
624
+ if verbose == 1:
625
+ debug_str = f"""
626
+ size: {out_dict['size']}
627
+ video_length: {target_length}
628
+ prompt: {prompt}
629
+ neg_prompt: {negative_prompt}
630
+ seed: {seed}
631
+ infer_steps: {infer_steps}
632
+ denoise_strength: {denoise_strength}
633
+ use_deepcache: {use_deepcache}
634
+ use_sage: {use_sage}
635
+ cpu_offload: {cpu_offload}
636
+ num_images_per_prompt: {num_videos_per_prompt}
637
+ guidance_scale: {guidance_scale}
638
+ n_tokens: {n_tokens}
639
+ flow_shift: {flow_shift}
640
+ output: {output_dir}"""
641
+ self.logger.info(debug_str)
642
+
643
+ start_time = time.time()
644
+ samples = self.pipeline(prompt=prompt,
645
+ last_latents=last_latents,
646
+ cam_latents=pose_embeds,
647
+ uncond_cam_latents=uncond_pose_embeds,
648
+ height=target_height,
649
+ width=target_width,
650
+ video_length=target_length,
651
+ gt_latents = ref_latents,
652
+ num_inference_steps=infer_steps,
653
+ guidance_scale=guidance_scale,
654
+ negative_prompt=negative_prompt,
655
+ num_videos_per_prompt=num_videos_per_prompt,
656
+ generator=generator,
657
+ prompt_embeds=prompt_embeds,
658
+ ref_latents=ref_latents,
659
+ latents=init_latents,
660
+ denoise_strength=denoise_strength,
661
+ mask=mask,
662
+ uncond_ref_latents=uncond_ref_latents,
663
+ ip_cfg_scale=self.args.ip_cfg_scale,
664
+ use_deepcache=use_deepcache,
665
+ attention_mask=attention_mask,
666
+ negative_prompt_embeds=negative_prompt_embeds,
667
+ negative_attention_mask=negative_attention_mask,
668
+ output_type=output_type,
669
+ freqs_cis=(freqs_cos, freqs_sin),
670
+ n_tokens=n_tokens,
671
+ data_type='video' if target_length > 1 else 'image',
672
+ is_progress_bar=True,
673
+ vae_ver=self.args.vae,
674
+ enable_tiling=self.args.vae_tiling,
675
+ cpu_offload=cpu_offload,
676
+ return_latents=return_latents,
677
+ use_sage=use_sage,
678
+ )
679
+ if samples is None:
680
+ return None
681
+ out_dict['samples'] = []
682
+ out_dict["prompts"] = prompt
683
+ out_dict['pose'] = poses
684
+
685
+ if return_latents:
686
+ print("return_latents | TRUE")
687
+ latents, timesteps, last_latents, ref_latents = samples[1], samples[2], samples[3], samples[4]
688
+ # samples = samples[0][0]
689
+ if samples[0] is not None and len(samples[0]) > 0:
690
+ samples = samples[0][0]
691
+ else:
692
+ samples = None
693
+ out_dict["denoised_lantents"] = latents
694
+ out_dict["timesteps"] = timesteps
695
+ out_dict["ref_latents"] = ref_latents
696
+ out_dict["last_latents"] = last_latents
697
+
698
+ else:
699
+ samples = samples[0]
700
+
701
+ if samples is not None:
702
+ for i, sample in enumerate(samples):
703
+ sample = samples[i].unsqueeze(0)
704
+ sub_samples = []
705
+ sub_samples.append(sample)
706
+ sample_num = len(sub_samples)
707
+ sub_samples = torch.concat(sub_samples)
708
+ # only save in tp rank 0
709
+ out_dict['samples'].append(sub_samples)
710
+
711
+ # visualize pose
712
+
713
+ gen_time = time.time() - start_time
714
+ logger.info(f"Success, time: {gen_time}")
715
+ return out_dict
716
+
hymm_sp/text_encoder/__init__.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ from copy import deepcopy
4
+
5
+ import torch, os
6
+ import torch.nn as nn
7
+ from transformers import (
8
+ CLIPTextModel, CLIPTokenizer, LlavaForConditionalGeneration,LlamaModel,
9
+ LlamaTokenizerFast
10
+ )
11
+ from transformers.utils import ModelOutput
12
+ from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH, PRECISION_TO_TYPE
13
+
14
+ CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
15
+ print(f'text_encoder: cpu_offload={CPU_OFFLOAD}')
16
+
17
+ def use_default(value, default):
18
+ return value if value is not None else default
19
+
20
+ def load_text_encoder(text_encoder_type,
21
+ text_encoder_precision=None,
22
+ text_encoder_path=None,
23
+ logger=None,
24
+ device=None
25
+ ):
26
+ if text_encoder_path is None:
27
+ text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
28
+ if logger is not None:
29
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
30
+
31
+ if text_encoder_type == "clipL":
32
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
33
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
34
+ elif text_encoder_type == "llava-llama-3-8b":
35
+ text_encoder = LlavaForConditionalGeneration.from_pretrained(text_encoder_path, low_cpu_mem_usage=True)
36
+ import transformers
37
+ transformers_version = transformers.__version__
38
+ if transformers_version >= "4.53.0":
39
+ text_encoder.final_layer_norm = text_encoder.language_model.norm
40
+ else:
41
+ text_encoder.final_layer_norm = text_encoder.language_model.model.norm
42
+
43
+ else:
44
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
45
+
46
+ if text_encoder_precision is not None:
47
+ text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
48
+
49
+ text_encoder.requires_grad_(False)
50
+
51
+ if logger is not None:
52
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
53
+
54
+ if device is not None:
55
+ text_encoder = text_encoder.to(device)
56
+
57
+ return text_encoder, text_encoder_path
58
+
59
+ def load_tokenizer(tokenizer_type,
60
+ tokenizer_path=None,
61
+ padding_side="right",
62
+ logger=None
63
+ ):
64
+ if tokenizer_path is None:
65
+ tokenizer_path = TOKENIZER_PATH[tokenizer_type]
66
+ if logger is not None:
67
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
68
+
69
+ if tokenizer_type == "clipL":
70
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
71
+ elif tokenizer_type == "llava-llama-3-8b":
72
+ tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path, padding_side=padding_side)
73
+ else:
74
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
75
+
76
+ return tokenizer, tokenizer_path
77
+
78
+
79
+ @dataclass
80
+ class TextEncoderModelOutput(ModelOutput):
81
+ """
82
+ Base class for model's outputs that also contains a pooling of the last hidden states.
83
+
84
+ Args:
85
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
86
+ Sequence of hidden-states at the output of the last layer of the model.
87
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
88
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
89
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*,
90
+ returned when `output_hidden_states=True` is passed):
91
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
92
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
93
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
94
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
95
+ List of decoded texts.
96
+ """
97
+
98
+ hidden_state: torch.FloatTensor = None
99
+ attention_mask: Optional[torch.LongTensor] = None
100
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
101
+ text_outputs: Optional[list] = None
102
+
103
+
104
+ class TextEncoder(nn.Module):
105
+ def __init__(self,
106
+ text_encoder_type: str,
107
+ max_length: int,
108
+ text_encoder_precision: Optional[str] = None,
109
+ text_encoder_path: Optional[str] = None,
110
+ tokenizer_type: Optional[str] = None,
111
+ tokenizer_path: Optional[str] = None,
112
+ output_key: Optional[str] = None,
113
+ use_attention_mask: bool = True,
114
+ input_max_length: Optional[int] = None,
115
+ prompt_template_video: Optional[dict] = None,
116
+ hidden_state_skip_layer: Optional[int] = None,
117
+ apply_final_norm: bool = False,
118
+ reproduce: bool = False,
119
+ logger=None,
120
+ device=None,
121
+ ):
122
+ super().__init__()
123
+ self.text_encoder_type = text_encoder_type
124
+ self.max_length = max_length
125
+ self.precision = text_encoder_precision
126
+ self.model_path = text_encoder_path
127
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
128
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
129
+ self.use_attention_mask = use_attention_mask
130
+ if prompt_template_video is not None:
131
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
132
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
133
+ self.prompt_template_video = prompt_template_video
134
+ self.hidden_state_skip_layer = hidden_state_skip_layer
135
+ self.apply_final_norm = apply_final_norm
136
+ self.reproduce = reproduce
137
+ self.logger = logger
138
+
139
+ self.use_video_template = self.prompt_template_video is not None
140
+ if self.use_video_template:
141
+ if self.prompt_template_video is not None:
142
+ assert isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video, (
143
+ f"`prompt_template_video` must be a dictionary with a key 'template', \
144
+ got {self.prompt_template_video}"
145
+ )
146
+ assert '{}' in str(self.prompt_template_video["template"]), (
147
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
148
+ f"got {self.prompt_template_video['template']}"
149
+ )
150
+
151
+ if "clip" in text_encoder_type:
152
+ self.output_key = output_key or "pooler_output"
153
+ elif "llama" in text_encoder_type:
154
+ self.output_key = output_key or "last_hidden_state"
155
+ else:
156
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
157
+
158
+ self.model, self.model_path = load_text_encoder(
159
+ text_encoder_type=self.text_encoder_type,
160
+ text_encoder_precision=self.precision,
161
+ text_encoder_path=self.model_path,
162
+ logger=self.logger,
163
+ device=device
164
+ )
165
+ self.dtype = self.model.dtype
166
+ self.device = self.model.device
167
+
168
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
169
+ tokenizer_type=self.tokenizer_type,
170
+ tokenizer_path=self.tokenizer_path,
171
+ padding_side="right",
172
+ logger=self.logger
173
+ )
174
+
175
+ def __repr__(self):
176
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
177
+
178
+ @staticmethod
179
+ def apply_text_to_template(text, template):
180
+ """
181
+ Apply text to template.
182
+
183
+ Args:
184
+ text (str): Input text.
185
+ template (str or list): Template string or list of chat conversation.
186
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
187
+ by adding a space. Defaults to True.
188
+ """
189
+ if isinstance(template, str):
190
+ # Will send string to tokenizer. Used for llm
191
+ return template.format(text)
192
+ else:
193
+ raise TypeError(f"Unsupported template type: {type(template)}")
194
+
195
+ def text2tokens(self, text, data_type='video', name='person'):
196
+ """
197
+ Tokenize the input text.
198
+
199
+ Args:
200
+ text (str or list): Input text.
201
+ """
202
+ tokenize_input_type = 'str'
203
+ if self.use_video_template:
204
+ if data_type == 'video':
205
+ prompt_template = self.prompt_template_video["template"]
206
+ else:
207
+ raise ValueError(f"Unsupported data type: {data_type}")
208
+ if isinstance(text, (list, tuple)):
209
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
210
+ if isinstance(text[0], list):
211
+ tokenize_input_type = 'list'
212
+ elif isinstance(text, str):
213
+ text = self.apply_text_to_template(text, prompt_template)
214
+ if isinstance(text, list):
215
+ tokenize_input_type = 'list'
216
+ else:
217
+ raise TypeError(f"Unsupported text type: {type(text)}")
218
+
219
+ kwargs = dict(truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
220
+ if self.text_encoder_type == "llava-llama-3-8b":
221
+ if isinstance(text, list):
222
+ for i in range(len(text)):
223
+ text[i] = text[i] + '\nThe %s looks like<image>' % name
224
+ elif isinstance(text, str):
225
+ text = text + '\nThe %s looks like<image>' % name
226
+ else:
227
+ raise NotImplementedError
228
+
229
+ if tokenize_input_type == 'str':
230
+ return self.tokenizer(text,
231
+ return_length=False,
232
+ return_overflowing_tokens=False,
233
+ return_attention_mask=True,
234
+ **kwargs, )
235
+ elif tokenize_input_type == 'list':
236
+ return self.tokenizer.apply_chat_template(text,
237
+ add_generation_prompt=True,
238
+ tokenize=True,
239
+ return_dict=True,
240
+ **kwargs, )
241
+ else:
242
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
243
+
244
+ def encode(self, batch_encoding, use_attention_mask=None, output_hidden_states=False, do_sample=None,
245
+ hidden_state_skip_layer=None, return_texts=False, data_type='image'):
246
+ """
247
+ Args:
248
+ batch_encoding (dict): Batch encoding from tokenizer.
249
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
250
+ Defaults to None.
251
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
252
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
253
+ output_hidden_states will be set True. Defaults to False.
254
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
255
+ When self.produce is False, do_sample is set to True by default.
256
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
257
+ If None, self.output_key will be used. Defaults to None.
258
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
259
+ """
260
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
261
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
262
+ do_sample = use_default(do_sample, not self.reproduce)
263
+ if CPU_OFFLOAD:
264
+ self.model.to('cuda')
265
+ print(f'encode prompt: move text_encoder to cuda')
266
+
267
+ attention_mask = batch_encoding["attention_mask"].to(self.model.device) if use_attention_mask else None
268
+ if 'pixel_value_llava' in batch_encoding:
269
+ outputs = self.model(
270
+ input_ids=batch_encoding["input_ids"].to(self.model.device),
271
+ attention_mask=attention_mask,
272
+ pixel_values=batch_encoding["pixel_value_llava"].to(self.model.device),
273
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None)
274
+ else:
275
+ outputs = self.model(
276
+ input_ids=batch_encoding["input_ids"].to(self.model.device),
277
+ attention_mask=attention_mask,
278
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,)
279
+ if hidden_state_skip_layer is not None:
280
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
281
+ # Real last hidden state already has layer norm applied. So here we only apply it
282
+ # for intermediate layers.
283
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
284
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
285
+ else:
286
+ last_hidden_state = outputs[self.output_key]
287
+
288
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
289
+ if self.use_video_template:
290
+ if data_type == 'video':
291
+ crop_start = self.prompt_template_video.get("crop_start", -1)
292
+ else:
293
+ raise ValueError(f"Unsupported data type: {data_type}")
294
+ if crop_start > 0:
295
+ last_hidden_state = last_hidden_state[:, crop_start:]
296
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
297
+ if CPU_OFFLOAD:
298
+ self.model.to('cpu')
299
+ torch.cuda.empty_cache()
300
+ print(f'encode prompt successful: move text_encoder to cpu')
301
+ if output_hidden_states:
302
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
303
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
304
+
305
+ def forward(self, text, use_attention_mask=None, output_hidden_states=False, do_sample=False,
306
+ hidden_state_skip_layer=None, return_texts=False):
307
+ batch_encoding = self.text2tokens(text)
308
+ return self.encode(batch_encoding, use_attention_mask=use_attention_mask,
309
+ output_hidden_states=output_hidden_states, do_sample=do_sample,
310
+ hidden_state_skip_layer=hidden_state_skip_layer, return_texts=return_texts)
hymm_sp/vae/__init__.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
4
+ from ..constants import VAE_PATH, PRECISION_TO_TYPE
5
+
6
+ def load_vae(vae_type,
7
+ vae_precision=None,
8
+ sample_size=None,
9
+ vae_path=None,
10
+ logger=None,
11
+ device=None
12
+ ):
13
+ """
14
+ Load and configure a Variational Autoencoder (VAE) model.
15
+
16
+ This function handles loading 3D causal VAE models, including configuration,
17
+ weight loading, precision setting, and device placement. It ensures the model
18
+ is properly initialized for inference.
19
+
20
+ Parameters:
21
+ vae_type (str): Type identifier for the VAE, must follow '???-*' format for 3D VAEs
22
+ vae_precision (str, optional): Desired precision type (e.g., 'fp16', 'fp32').
23
+ Uses model's default if not specified.
24
+ sample_size (tuple, optional): Input sample dimensions to override config defaults
25
+ vae_path (str, optional): Path to VAE model files. Uses predefined path from
26
+ VAE_PATH constant if not specified.
27
+ logger (logging.Logger, optional): Logger instance for progress/debug messages
28
+ device (torch.device, optional): Target device to place the model (e.g., 'cuda' or 'cpu')
29
+
30
+ Returns:
31
+ tuple: Contains:
32
+ - vae (AutoencoderKLCausal3D): Loaded and configured VAE model
33
+ - vae_path (str): Actual path used to load the VAE
34
+ - spatial_compression_ratio (int): Spatial dimension compression factor
35
+ - time_compression_ratio (int): Temporal dimension compression factor
36
+
37
+ Raises:
38
+ ValueError: If vae_type does not follow the required 3D VAE format '???-*'
39
+ """
40
+ if vae_path is None:
41
+ vae_path = VAE_PATH[vae_type]
42
+ vae_compress_spec, _, _ = vae_type.split("-")
43
+ length = len(vae_compress_spec)
44
+ # Process 3D VAE (valid format with 3-character compression spec)
45
+ if length == 3:
46
+ if logger is not None:
47
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
48
+ config = AutoencoderKLCausal3D.load_config(vae_path)
49
+ if sample_size:
50
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
51
+ else:
52
+ vae = AutoencoderKLCausal3D.from_config(config)
53
+ ckpt = torch.load(Path(vae_path) / "pytorch_model.pt", map_location=vae.device)
54
+ if "state_dict" in ckpt:
55
+ ckpt = ckpt["state_dict"]
56
+ vae_ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
57
+ vae.load_state_dict(vae_ckpt)
58
+
59
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
60
+ time_compression_ratio = vae.config.time_compression_ratio
61
+ else:
62
+ raise ValueError(f"Invalid VAE model: {vae_type}. Must be 3D VAE in the format of '???-*'.")
63
+
64
+ if vae_precision is not None:
65
+ vae = vae.to(dtype=PRECISION_TO_TYPE[vae_precision])
66
+
67
+ vae.requires_grad_(False)
68
+
69
+ if logger is not None:
70
+ logger.info(f"VAE to dtype: {vae.dtype}")
71
+
72
+ if device is not None:
73
+ vae = vae.to(device)
74
+
75
+ # Ensure model is in evaluation mode (disables dropout/batch norm training behavior)
76
+ # Note: Even with dropout rate 0, eval mode is recommended for consistent inference
77
+ vae.eval()
78
+
79
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
hymm_sp/vae/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from typing import Dict, Optional, Tuple, Union
4
+ from dataclasses import dataclass
5
+ from torch import distributed as dist
6
+ import loguru
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.distributed
10
+
11
+ from torch import distributed as dist
12
+
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ try:
15
+ # This diffusers is modified and packed in the mirror.
16
+ from diffusers.loaders import FromOriginalVAEMixin
17
+ except ImportError:
18
+ # Use this to be compatible with the original diffusers.
19
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
20
+ from diffusers.utils.accelerate_utils import apply_forward_hook
21
+ from diffusers.models.attention_processor import (
22
+ ADDED_KV_ATTENTION_PROCESSORS,
23
+ CROSS_ATTENTION_PROCESSORS,
24
+ Attention,
25
+ AttentionProcessor,
26
+ AttnAddedKVProcessor,
27
+ AttnProcessor,
28
+ )
29
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
32
+
33
+ import threading
34
+ from hymm_sp.modules.parallel_states import (
35
+ initialize_sequence_parallel_state,
36
+ get_sequence_parallel_state,
37
+ nccl_info,
38
+ )
39
+
40
+ def cur_rank():
41
+ return nccl_info.rank_within_group
42
+
43
+ def cur_world_size():
44
+ return nccl_info.sp_size
45
+
46
+ """
47
+ use trt need install polygraphy and onnx-graphsurgeon
48
+ python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com
49
+ """
50
+ try:
51
+ from polygraphy.backend.trt import ( TrtRunner, EngineFromBytes)
52
+ from polygraphy.backend.common import BytesFromPath
53
+ except:
54
+ print("TrtRunner or EngineFromBytes is not available, you can not use trt engine")
55
+
56
+ @dataclass
57
+ class DecoderOutput2(BaseOutput):
58
+ sample: torch.FloatTensor
59
+ posterior: Optional[DiagonalGaussianDistribution] = None
60
+
61
+
62
+ MODEL_OUTPUT_PATH = os.environ.get('MODEL_OUTPUT_PATH')
63
+ MODEL_BASE = os.environ.get('MODEL_BASE')
64
+
65
+ CPU_OFFLOAD = int(os.environ.get("CPU_OFFLOAD", 0))
66
+ DISABLE_SP = int(os.environ.get("DISABLE_SP", 0))
67
+
68
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
69
+ r"""
70
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
71
+
72
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
73
+ for all models (such as downloading or saving).
74
+
75
+ Parameters:
76
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
77
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
78
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
79
+ Tuple of downsample block types.
80
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
81
+ Tuple of upsample block types.
82
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
83
+ Tuple of block output channels.
84
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
85
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
86
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
87
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
88
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
89
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
90
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
91
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
92
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
93
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
94
+ force_upcast (`bool`, *optional*, default to `True`):
95
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
96
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
97
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
98
+ """
99
+
100
+ _supports_gradient_checkpointing = True
101
+
102
+ @register_to_config
103
+ def __init__(
104
+ self,
105
+ in_channels: int = 3,
106
+ out_channels: int = 3,
107
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
108
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
109
+ block_out_channels: Tuple[int] = (64,),
110
+ layers_per_block: int = 1,
111
+ act_fn: str = "silu",
112
+ latent_channels: int = 4,
113
+ norm_num_groups: int = 32,
114
+ sample_size: int = 32,
115
+ sample_tsize: int = 64,
116
+ scaling_factor: float = 0.18215,
117
+ force_upcast: float = True,
118
+ spatial_compression_ratio: int = 8,
119
+ time_compression_ratio: int = 4,
120
+ disable_causal_conv: bool = False,
121
+ mid_block_add_attention: bool = True,
122
+ mid_block_causal_attn: bool = False,
123
+ use_trt_engine: bool = False,
124
+ nccl_gather: bool = True,
125
+ engine_path: str = f"{MODEL_BASE}/HYVAE_decoder+conv_256x256xT_fp16_H20.engine",
126
+ ):
127
+ super().__init__()
128
+
129
+ self.disable_causal_conv = disable_causal_conv
130
+ self.time_compression_ratio = time_compression_ratio
131
+
132
+ self.encoder = EncoderCausal3D(
133
+ in_channels=in_channels,
134
+ out_channels=latent_channels,
135
+ down_block_types=down_block_types,
136
+ block_out_channels=block_out_channels,
137
+ layers_per_block=layers_per_block,
138
+ act_fn=act_fn,
139
+ norm_num_groups=norm_num_groups,
140
+ double_z=True,
141
+ time_compression_ratio=time_compression_ratio,
142
+ spatial_compression_ratio=spatial_compression_ratio,
143
+ disable_causal=disable_causal_conv,
144
+ mid_block_add_attention=mid_block_add_attention,
145
+ mid_block_causal_attn=mid_block_causal_attn,
146
+ )
147
+
148
+ self.decoder = DecoderCausal3D(
149
+ in_channels=latent_channels,
150
+ out_channels=out_channels,
151
+ up_block_types=up_block_types,
152
+ block_out_channels=block_out_channels,
153
+ layers_per_block=layers_per_block,
154
+ norm_num_groups=norm_num_groups,
155
+ act_fn=act_fn,
156
+ time_compression_ratio=time_compression_ratio,
157
+ spatial_compression_ratio=spatial_compression_ratio,
158
+ disable_causal=disable_causal_conv,
159
+ mid_block_add_attention=mid_block_add_attention,
160
+ mid_block_causal_attn=mid_block_causal_attn,
161
+ )
162
+
163
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
164
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
165
+
166
+ self.use_slicing = False
167
+ self.use_spatial_tiling = False
168
+ self.use_temporal_tiling = False
169
+
170
+
171
+ # only relevant if vae tiling is enabled
172
+ self.tile_sample_min_tsize = sample_tsize
173
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
174
+
175
+ self.tile_sample_min_size = self.config.sample_size
176
+ sample_size = (
177
+ self.config.sample_size[0]
178
+ if isinstance(self.config.sample_size, (list, tuple))
179
+ else self.config.sample_size
180
+ )
181
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
182
+ self.tile_overlap_factor = 0.25
183
+
184
+ # ============= parallism related code ===================
185
+ world_size = cur_world_size()
186
+ self.parallel_decode = False if CPU_OFFLOAD else get_sequence_parallel_state()
187
+ print("WORLD SIZE: ", world_size)
188
+
189
+
190
+ def _set_gradient_checkpointing(self, module, value=False):
191
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
192
+ module.gradient_checkpointing = value
193
+
194
+ def enable_temporal_tiling(self, use_tiling: bool = True):
195
+ self.use_temporal_tiling = use_tiling
196
+
197
+ def disable_temporal_tiling(self):
198
+ self.enable_temporal_tiling(False)
199
+
200
+ def enable_spatial_tiling(self, use_tiling: bool = True):
201
+ self.use_spatial_tiling = use_tiling
202
+
203
+ def disable_spatial_tiling(self):
204
+ self.enable_spatial_tiling(False)
205
+
206
+ def enable_tiling(self, use_tiling: bool = True):
207
+ r"""
208
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
209
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
210
+ processing larger images.
211
+ """
212
+ self.enable_spatial_tiling(use_tiling)
213
+ self.enable_temporal_tiling(use_tiling)
214
+
215
+ def disable_tiling(self):
216
+ r"""
217
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
218
+ decoding in one step.
219
+ """
220
+ self.disable_spatial_tiling()
221
+ self.disable_temporal_tiling()
222
+
223
+ def enable_slicing(self):
224
+ r"""
225
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
226
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
227
+ """
228
+ self.use_slicing = True
229
+
230
+ def disable_slicing(self):
231
+ r"""
232
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
233
+ decoding in one step.
234
+ """
235
+ self.use_slicing = False
236
+
237
+
238
+ def load_trt_decoder(self):
239
+ self.use_trt_decoder = True
240
+ self.engine = EngineFromBytes(BytesFromPath(self.engine_path))
241
+
242
+ self.trt_decoder_runner = TrtRunner(self.engine)
243
+ self.activate_trt_decoder()
244
+
245
+ def disable_trt_decoder(self):
246
+ self.use_trt_decoder = False
247
+ del self.engine
248
+
249
+ def activate_trt_decoder(self):
250
+ self.trt_decoder_runner.activate()
251
+
252
+ def deactivate_trt_decoder(self):
253
+ self.trt_decoder_runner.deactivate()
254
+
255
+ @property
256
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
257
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
258
+ r"""
259
+ Returns:
260
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
261
+ indexed by its weight name.
262
+ """
263
+ # set recursively
264
+ processors = {}
265
+
266
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
267
+ if hasattr(module, "get_processor"):
268
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
269
+
270
+ for sub_name, child in module.named_children():
271
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
272
+
273
+ return processors
274
+
275
+ for name, module in self.named_children():
276
+ fn_recursive_add_processors(name, module, processors)
277
+
278
+ return processors
279
+
280
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
281
+ def set_attn_processor(
282
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
283
+ ):
284
+ r"""
285
+ Sets the attention processor to use to compute attention.
286
+
287
+ Parameters:
288
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
289
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
290
+ for **all** `Attention` layers.
291
+
292
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
293
+ processor. This is strongly recommended when setting trainable attention processors.
294
+
295
+ """
296
+ count = len(self.attn_processors.keys())
297
+
298
+ if isinstance(processor, dict) and len(processor) != count:
299
+ raise ValueError(
300
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
301
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
302
+ )
303
+
304
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
305
+ if hasattr(module, "set_processor"):
306
+ if not isinstance(processor, dict):
307
+ module.set_processor(processor, _remove_lora=_remove_lora)
308
+ else:
309
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
310
+
311
+ for sub_name, child in module.named_children():
312
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
313
+
314
+ for name, module in self.named_children():
315
+ fn_recursive_attn_processor(name, module, processor)
316
+
317
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
318
+ def set_default_attn_processor(self):
319
+ """
320
+ Disables custom attention processors and sets the default attention implementation.
321
+ """
322
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
323
+ processor = AttnAddedKVProcessor()
324
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
325
+ processor = AttnProcessor()
326
+ else:
327
+ raise ValueError(
328
+ f"Cannot call `set_default_attn_processor` \
329
+ when attention processors are of type {next(iter(self.attn_processors.values()))}"
330
+ )
331
+
332
+ self.set_attn_processor(processor, _remove_lora=True)
333
+
334
+ @apply_forward_hook
335
+ def encode(
336
+ self, x: torch.FloatTensor, return_dict: bool = True
337
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
338
+ """
339
+ Encode a batch of images into latents.
340
+
341
+ Args:
342
+ x (`torch.FloatTensor`): Input batch of images.
343
+ return_dict (`bool`, *optional*, defaults to `True`):
344
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
345
+
346
+ Returns:
347
+ The latent representations of the encoded images. If `return_dict` is True, a
348
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
349
+ """
350
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions"
351
+
352
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
353
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
354
+
355
+ if self.use_spatial_tiling and \
356
+ (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
357
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
358
+
359
+ if self.use_slicing and x.shape[0] > 1:
360
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
361
+ h = torch.cat(encoded_slices)
362
+ else:
363
+ h = self.encoder(x)
364
+
365
+ moments = self.quant_conv(h)
366
+ posterior = DiagonalGaussianDistribution(moments)
367
+
368
+ if not return_dict:
369
+ return (posterior,)
370
+
371
+ return AutoencoderKLOutput(latent_dist=posterior)
372
+
373
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
374
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions"
375
+
376
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
377
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
378
+
379
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or \
380
+ z.shape[-2] > self.tile_latent_min_size):
381
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
382
+
383
+ if self.use_trt_decoder:
384
+ # For unknown reason, `copy_outputs_to_host` must be set to True
385
+ dec = self.trt_decoder_runner.infer({"input": z.to(RECOMMENDED_DTYPE).contiguous()}, \
386
+ copy_outputs_to_host=True)["output"].to(device=z.device, dtype=z.dtype)
387
+ else:
388
+ z = self.post_quant_conv(z)
389
+ dec = self.decoder(z)
390
+
391
+ if not return_dict:
392
+ return (dec,)
393
+
394
+ return DecoderOutput(sample=dec)
395
+
396
+ @apply_forward_hook
397
+ def decode(
398
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
399
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
400
+ """
401
+ Decode a batch of images.
402
+
403
+ Args:
404
+ z (`torch.FloatTensor`): Input batch of latent vectors.
405
+ return_dict (`bool`, *optional*, defaults to `True`):
406
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
407
+
408
+ Returns:
409
+ [`~models.vae.DecoderOutput`] or `tuple`:
410
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
411
+ returned.
412
+
413
+ """
414
+
415
+ if self.use_slicing and z.shape[0] > 1:
416
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
417
+ decoded = torch.cat(decoded_slices)
418
+ else:
419
+ decoded = self._decode(z).sample
420
+
421
+ if not return_dict:
422
+ return (decoded,)
423
+
424
+ return DecoderOutput(sample=decoded)
425
+
426
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
427
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
428
+ if blend_extent == 0:
429
+ return b
430
+
431
+ a_region = a[..., -blend_extent:, :]
432
+ b_region = b[..., :blend_extent, :]
433
+
434
+ weights = torch.arange(blend_extent, device=a.device, dtype=a.dtype) / blend_extent
435
+ weights = weights.view(1, 1, 1, blend_extent, 1)
436
+
437
+ blended = a_region * (1 - weights) + b_region * weights
438
+
439
+ b[..., :blend_extent, :] = blended
440
+ return b
441
+
442
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
443
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
444
+ if blend_extent == 0:
445
+ return b
446
+
447
+ a_region = a[..., -blend_extent:]
448
+ b_region = b[..., :blend_extent]
449
+
450
+ weights = torch.arange(blend_extent, device=a.device, dtype=a.dtype) / blend_extent
451
+ weights = weights.view(1, 1, 1, 1, blend_extent)
452
+
453
+ blended = a_region * (1 - weights) + b_region * weights
454
+
455
+ b[..., :blend_extent] = blended
456
+ return b
457
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
458
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
459
+ if blend_extent == 0:
460
+ return b
461
+
462
+ a_region = a[..., -blend_extent:, :, :]
463
+ b_region = b[..., :blend_extent, :, :]
464
+
465
+ weights = torch.arange(blend_extent, device=a.device, dtype=a.dtype) / blend_extent
466
+ weights = weights.view(1, 1, blend_extent, 1, 1)
467
+
468
+ blended = a_region * (1 - weights) + b_region * weights
469
+
470
+ b[..., :blend_extent, :, :] = blended
471
+ return b
472
+
473
+ def spatial_tiled_encode(self,
474
+ x: torch.FloatTensor,
475
+ return_dict: bool = True,
476
+ return_moments: bool = False) -> AutoencoderKLOutput:
477
+ r"""Encode a batch of images using a tiled encoder.
478
+
479
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
480
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
481
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
482
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
483
+ output, but they should be much less noticeable.
484
+
485
+ Args:
486
+ x (`torch.FloatTensor`): Input batch of images.
487
+ return_dict (`bool`, *optional*, defaults to `True`):
488
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
489
+
490
+ Returns:
491
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
492
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
493
+ `tuple` is returned.
494
+ """
495
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
496
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
497
+ row_limit = self.tile_latent_min_size - blend_extent
498
+
499
+ # Split video into tiles and encode them separately.
500
+ rows = []
501
+ for i in range(0, x.shape[-2], overlap_size):
502
+ row = []
503
+ for j in range(0, x.shape[-1], overlap_size):
504
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
505
+ tile = self.encoder(tile)
506
+ tile = self.quant_conv(tile)
507
+ row.append(tile)
508
+ rows.append(row)
509
+ result_rows = []
510
+ for i, row in enumerate(rows):
511
+ result_row = []
512
+ for j, tile in enumerate(row):
513
+ # blend the above tile and the left tile
514
+ # to the current tile and add the current tile to the result row
515
+ if i > 0:
516
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
517
+ if j > 0:
518
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
519
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
520
+ result_rows.append(torch.cat(result_row, dim=-1))
521
+
522
+ moments = torch.cat(result_rows, dim=-2)
523
+ if return_moments:
524
+ return moments
525
+
526
+ posterior = DiagonalGaussianDistribution(moments)
527
+ if not return_dict:
528
+ return (posterior,)
529
+
530
+ return AutoencoderKLOutput(latent_dist=posterior)
531
+
532
+
533
+ def spatial_tiled_decode(self,
534
+ z: torch.FloatTensor,
535
+ return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
536
+ r"""
537
+ Decode a batch of images using a tiled decoder.
538
+
539
+ Args:
540
+ z (`torch.FloatTensor`): Input batch of latent vectors.
541
+ return_dict (`bool`, *optional*, defaults to `True`):
542
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
543
+
544
+ Returns:
545
+ [`~models.vae.DecoderOutput`] or `tuple`:
546
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
547
+ returned.
548
+ """
549
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
550
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
551
+ row_limit = self.tile_sample_min_size - blend_extent
552
+
553
+ # Split z into overlapping tiles and decode them separately.
554
+ # The tiles have an overlap to avoid seams between tiles.
555
+ rank = cur_rank()
556
+ rows = []
557
+ if self.parallel_decode and rank == 0:
558
+ rank = cur_rank()
559
+ #torch.cuda.set_device(rank) # set device for trt_runner
560
+ world_size = cur_world_size()
561
+
562
+
563
+ cur_device_id = 0
564
+ device_tasks = []
565
+ for i in range(world_size):
566
+ device_tasks.append([])
567
+ for i in range(0, z.shape[-2], overlap_size):
568
+ row = []
569
+ for j in range(0, z.shape[-1], overlap_size):
570
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
571
+ row.append(None)
572
+ device_tasks[cur_device_id].append((i // overlap_size, \
573
+ j // overlap_size, \
574
+ tile.to("cuda:" + str(cur_device_id))))
575
+ #device_tasks[cur_device_id].append((i // overlap_size, j // overlap_size, tile))
576
+ cur_device_id = (cur_device_id + 1) % world_size
577
+ rows.append(row)
578
+
579
+ def thread_run(decoder, device_id, inputs, outputs):
580
+ for input in inputs:
581
+ cur_vae = self.device_vaes[device_id]
582
+ ret = cur_vae.decoder(cur_vae.post_quant_conv(input[2]))
583
+ outputs[input[0]][input[1]] = ret
584
+ return
585
+
586
+ threads = []
587
+ for i in range(world_size):
588
+ cur_thread = threading.Thread(target=thread_run,
589
+ args=(self, i, device_tasks[i], rows),
590
+ name="DecoderThread-" + str(i))
591
+ threads.append(cur_thread)
592
+ cur_thread.start()
593
+
594
+ for cur_thread in threads:
595
+ cur_thread.join()
596
+
597
+ for i in range(len(rows)):
598
+ for j in range(len(rows[i])):
599
+ rows[i][j] = rows[i][j].to("cuda:0")
600
+
601
+ else:
602
+ for i in range(0, z.shape[-2], overlap_size):
603
+ row = []
604
+ for j in range(0, z.shape[-1], overlap_size):
605
+ tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
606
+ tile = self.post_quant_conv(tile)
607
+ decoded = self.decoder(tile)
608
+ row.append(decoded)
609
+ rows.append(row)
610
+
611
+ result_rows = []
612
+ for i, row in enumerate(rows):
613
+ result_row = []
614
+ for j, tile in enumerate(row):
615
+ # blend the above tile and the left tile
616
+ # to the current tile and add the current tile to the result row
617
+ if i > 0:
618
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
619
+ if j > 0:
620
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
621
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
622
+ result_rows.append(torch.cat(result_row, dim=-1))
623
+
624
+ if self.parallel_decode and rank != 0:
625
+ if not return_dict:
626
+ return (None,)
627
+ return DecoderOutput(sample=None)
628
+
629
+ dec = torch.cat(result_rows, dim=-2)
630
+ if not return_dict:
631
+ return (dec,)
632
+
633
+ return DecoderOutput(sample=dec)
634
+
635
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
636
+ assert not self.disable_causal_conv, "Temporal tiling is only compatible with causal convolutions."
637
+
638
+ B, C, T, H, W = x.shape
639
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
640
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
641
+ t_limit = self.tile_latent_min_tsize - blend_extent
642
+
643
+ # Split the video into tiles and encode them separately.
644
+ row = []
645
+ for i in range(0, T, overlap_size):
646
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
647
+ if self.use_spatial_tiling and \
648
+ (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
649
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
650
+ else:
651
+ tile = self.encoder(tile)
652
+ tile = self.quant_conv(tile)
653
+ if i > 0:
654
+ tile = tile[:, :, 1:, :, :]
655
+ row.append(tile)
656
+ result_row = []
657
+ for i, tile in enumerate(row):
658
+ if i > 0:
659
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
660
+ result_row.append(tile[:, :, :t_limit, :, :])
661
+ else:
662
+ result_row.append(tile[:, :, :t_limit+1, :, :])
663
+
664
+ moments = torch.cat(result_row, dim=2)
665
+ posterior = DiagonalGaussianDistribution(moments)
666
+
667
+ if not return_dict:
668
+ return (posterior,)
669
+
670
+ return AutoencoderKLOutput(latent_dist=posterior)
671
+
672
+ def temporal_tiled_decode(self,
673
+ z: torch.FloatTensor,
674
+ return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
675
+ # Split z into overlapping tiles and decode them separately.
676
+ assert not self.disable_causal_conv, "Temporal tiling is only supported with causal convolutions."
677
+
678
+ B, C, T, H, W = z.shape
679
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
680
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
681
+ t_limit = self.tile_sample_min_tsize - blend_extent
682
+ rank = 0 if CPU_OFFLOAD or DISABLE_SP else cur_rank()
683
+ row = []
684
+ for i in range(0, T, overlap_size):
685
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
686
+ if self.use_spatial_tiling and \
687
+ (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
688
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
689
+ else:
690
+ tile = self.post_quant_conv(tile)
691
+ decoded = self.decoder(tile)
692
+ if i > 0 and (not self.parallel_decode or rank == 0):
693
+ decoded = decoded[:, :, 1:, :, :]
694
+ row.append(decoded)
695
+ if not CPU_OFFLOAD and not DISABLE_SP and self.parallel_decode and rank != 0:
696
+ return DecoderOutput(sample=None)
697
+ result_row = []
698
+ for i, tile in enumerate(row):
699
+ if i > 0:
700
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
701
+ result_row.append(tile[:, :, :t_limit, :, :])
702
+ else:
703
+ result_row.append(tile[:, :, :t_limit+1, :, :])
704
+
705
+ dec = torch.cat(result_row, dim=2)
706
+ if not return_dict:
707
+ return (dec,)
708
+
709
+ return DecoderOutput(sample=dec)
710
+
711
+ def forward(
712
+ self,
713
+ sample: torch.FloatTensor,
714
+ sample_posterior: bool = False,
715
+ return_dict: bool = True,
716
+ return_posterior: bool = False,
717
+ generator: Optional[torch.Generator] = None,
718
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
719
+ r"""
720
+ Args:
721
+ sample (`torch.FloatTensor`): Input sample.
722
+ sample_posterior (`bool`, *optional*, defaults to `False`):
723
+ Whether to sample from the posterior.
724
+ return_dict (`bool`, *optional*, defaults to `True`):
725
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
726
+ """
727
+ x = sample
728
+ posterior = self.encode(x).latent_dist
729
+ if sample_posterior:
730
+ z = posterior.sample(generator=generator)
731
+ else:
732
+ z = posterior.mode()
733
+ dec = self.decode(z).sample
734
+
735
+ if not return_dict:
736
+ if return_posterior:
737
+ return (dec, posterior)
738
+ else:
739
+ return (dec,)
740
+ if return_posterior:
741
+ return DecoderOutput2(sample=dec, posterior=posterior)
742
+ else:
743
+ return DecoderOutput2(sample=dec)
744
+
745
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
746
+ def fuse_qkv_projections(self):
747
+ """
748
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
749
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
750
+
751
+ <Tip warning={true}>
752
+
753
+ This API is 🧪 experimental.
754
+
755
+ </Tip>
756
+ """
757
+ self.original_attn_processors = None
758
+
759
+ for _, attn_processor in self.attn_processors.items():
760
+ if "Added" in str(attn_processor.__class__.__name__):
761
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
762
+
763
+ self.original_attn_processors = self.attn_processors
764
+
765
+ for module in self.modules():
766
+ if isinstance(module, Attention):
767
+ module.fuse_projections(fuse=True)
768
+
769
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
770
+ def unfuse_qkv_projections(self):
771
+ """Disables the fused QKV projection if enabled.
772
+
773
+ <Tip warning={true}>
774
+
775
+ This API is 🧪 experimental.
776
+
777
+ </Tip>
778
+
779
+ """
780
+ if self.original_attn_processors is not None:
781
+ self.set_attn_processor(self.original_attn_processors)
hymm_sp/vae/unet_causal_3d_blocks.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, Dict, Optional, Tuple, Union
15
+
16
+ import numpy as np
17
+ import torch
18
+ import torch.nn.functional as F
19
+ from torch import nn
20
+ from einops import rearrange
21
+
22
+ from diffusers.utils import is_torch_version, logging
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import SpatialNorm
25
+ from diffusers.models.attention_processor import Attention
26
+ from diffusers.models.normalization import AdaGroupNorm
27
+ from diffusers.models.normalization import RMSNorm
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
34
+ seq_len = n_frame * n_hw
35
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
36
+ for i in range(seq_len):
37
+ i_frame = i // n_hw
38
+ mask[i, : (i_frame + 1) * n_hw] = 0
39
+ if batch_size is not None:
40
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
41
+ return mask
42
+
43
+
44
+ class CausalConv3d(nn.Module):
45
+ def __init__(
46
+ self,
47
+ chan_in,
48
+ chan_out,
49
+ kernel_size: Union[int, Tuple[int, int, int]],
50
+ stride: Union[int, Tuple[int, int, int]] = 1,
51
+ dilation: Union[int, Tuple[int, int, int]] = 1,
52
+ pad_mode = 'replicate',
53
+ disable_causal=False,
54
+ **kwargs
55
+ ):
56
+ super().__init__()
57
+
58
+ self.pad_mode = pad_mode
59
+ if disable_causal:
60
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2,
61
+ kernel_size // 2, kernel_size // 2, kernel_size // 2)
62
+ else:
63
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2,
64
+ kernel_size // 2, kernel_size - 1, 0) # W, H, T
65
+ self.time_causal_padding = padding
66
+
67
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride = stride, dilation = dilation, **kwargs)
68
+
69
+ def forward(self, x):
70
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
71
+ return self.conv(x)
72
+
73
+ class CausalAvgPool3d(nn.Module):
74
+ def __init__(
75
+ self,
76
+ kernel_size: Union[int, Tuple[int, int, int]],
77
+ stride: Union[int, Tuple[int, int, int]],
78
+ pad_mode = 'replicate',
79
+ disable_causal=False,
80
+ **kwargs
81
+ ):
82
+ super().__init__()
83
+
84
+ self.pad_mode = pad_mode
85
+ if disable_causal:
86
+ padding = (0, 0, 0, 0, 0, 0)
87
+ else:
88
+ padding = (0, 0, 0, 0, stride - 1, 0) # W, H, T
89
+ self.time_causal_padding = padding
90
+
91
+ self.conv = nn.AvgPool3d(kernel_size, stride=stride, ceil_mode=True, **kwargs)
92
+ self.pad_mode = pad_mode
93
+
94
+ def forward(self, x):
95
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
96
+ return self.conv(x)
97
+
98
+ class UpsampleCausal3D(nn.Module):
99
+ """A 3D upsampling layer with an optional convolution.
100
+
101
+ Parameters:
102
+ channels (`int`):
103
+ number of channels in the inputs and outputs.
104
+ use_conv (`bool`, default `False`):
105
+ option to use a convolution.
106
+ use_conv_transpose (`bool`, default `False`):
107
+ option to use a convolution transpose.
108
+ out_channels (`int`, optional):
109
+ number of output channels. Defaults to `channels`.
110
+ name (`str`, default `conv`):
111
+ name of the upsampling 3D layer.
112
+ """
113
+
114
+ def __init__(
115
+ self,
116
+ channels: int,
117
+ use_conv: bool = False,
118
+ use_conv_transpose: bool = False,
119
+ out_channels: Optional[int] = None,
120
+ name: str = "conv",
121
+ kernel_size: Optional[int] = None,
122
+ padding=1,
123
+ norm_type=None,
124
+ eps=None,
125
+ elementwise_affine=None,
126
+ bias=True,
127
+ interpolate=True,
128
+ upsample_factor=(2, 2, 2),
129
+ disable_causal=False,
130
+ ):
131
+ super().__init__()
132
+ self.channels = channels
133
+ self.out_channels = out_channels or channels
134
+ self.use_conv = use_conv
135
+ self.use_conv_transpose = use_conv_transpose
136
+ self.name = name
137
+ self.interpolate = interpolate
138
+ self.upsample_factor = upsample_factor
139
+ self.disable_causal = disable_causal
140
+
141
+ if norm_type == "ln_norm":
142
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
143
+ elif norm_type == "rms_norm":
144
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
145
+ elif norm_type is None:
146
+ self.norm = None
147
+ else:
148
+ raise ValueError(f"unknown norm_type: {norm_type}")
149
+
150
+ conv = None
151
+ if use_conv_transpose:
152
+ assert False, "Not Implement yet"
153
+ if kernel_size is None:
154
+ kernel_size = 4
155
+ conv = nn.ConvTranspose2d(
156
+ channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias
157
+ )
158
+ elif use_conv:
159
+ if kernel_size is None:
160
+ kernel_size = 3
161
+ conv = CausalConv3d(self.channels, self.out_channels,
162
+ kernel_size=kernel_size, bias=bias, disable_causal=disable_causal)
163
+
164
+ if name == "conv":
165
+ self.conv = conv
166
+ else:
167
+ self.Conv2d_0 = conv
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states: torch.FloatTensor,
172
+ output_size: Optional[int] = None,
173
+ scale: float = 1.0,
174
+ ) -> torch.FloatTensor:
175
+ assert hidden_states.shape[1] == self.channels
176
+
177
+ if self.norm is not None:
178
+ assert False, "Not Implement yet"
179
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
180
+
181
+ if self.use_conv_transpose:
182
+ return self.conv(hidden_states)
183
+
184
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
185
+ # https://github.com/pytorch/pytorch/issues/86679
186
+ dtype = hidden_states.dtype
187
+ if dtype == torch.bfloat16:
188
+ hidden_states = hidden_states.to(torch.float32)
189
+
190
+ # upsample_nearest_nhwc fails with large batch sizes.
191
+ # see https://github.com/huggingface/diffusers/issues/984
192
+ if hidden_states.shape[0] >= 64:
193
+ hidden_states = hidden_states.contiguous()
194
+
195
+ # if `output_size` is passed we force the interpolation output
196
+ # size and do not make use of `scale_factor=2`
197
+ if self.interpolate:
198
+ B, C, T, H, W = hidden_states.shape
199
+ if not self.disable_causal:
200
+ first_h, other_h = hidden_states.split((1, T-1), dim=2)
201
+ if output_size is None:
202
+ if T > 1:
203
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
204
+
205
+ first_h = first_h.squeeze(2)
206
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
207
+ first_h = first_h.unsqueeze(2)
208
+ else:
209
+ assert False, "Not Implement yet"
210
+ other_h = F.interpolate(other_h, size=output_size, mode="nearest")
211
+
212
+ if T > 1:
213
+ hidden_states = torch.cat((first_h, other_h), dim=2)
214
+ else:
215
+ hidden_states = first_h
216
+ else:
217
+ hidden_states = F.interpolate(hidden_states, scale_factor=self.upsample_factor, mode="nearest")
218
+
219
+ if dtype == torch.bfloat16:
220
+ hidden_states = hidden_states.to(dtype)
221
+
222
+ if self.use_conv:
223
+ if self.name == "conv":
224
+ hidden_states = self.conv(hidden_states)
225
+ else:
226
+ hidden_states = self.Conv2d_0(hidden_states)
227
+
228
+ return hidden_states
229
+
230
+ class DownsampleCausal3D(nn.Module):
231
+ """A 3D downsampling layer with an optional convolution.
232
+
233
+ Parameters:
234
+ channels (`int`):
235
+ number of channels in the inputs and outputs.
236
+ use_conv (`bool`, default `False`):
237
+ option to use a convolution.
238
+ out_channels (`int`, optional):
239
+ number of output channels. Defaults to `channels`.
240
+ padding (`int`, default `1`):
241
+ padding for the convolution.
242
+ name (`str`, default `conv`):
243
+ name of the downsampling 3D layer.
244
+ """
245
+
246
+ def __init__(
247
+ self,
248
+ channels: int,
249
+ use_conv: bool = False,
250
+ out_channels: Optional[int] = None,
251
+ padding: int = 1,
252
+ name: str = "conv",
253
+ kernel_size=3,
254
+ norm_type=None,
255
+ eps=None,
256
+ elementwise_affine=None,
257
+ bias=True,
258
+ stride=2,
259
+ disable_causal=False,
260
+ ):
261
+ super().__init__()
262
+ self.channels = channels
263
+ self.out_channels = out_channels or channels
264
+ self.use_conv = use_conv
265
+ self.padding = padding
266
+ stride = stride
267
+ self.name = name
268
+
269
+ if norm_type == "ln_norm":
270
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
271
+ elif norm_type == "rms_norm":
272
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
273
+ elif norm_type is None:
274
+ self.norm = None
275
+ else:
276
+ raise ValueError(f"unknown norm_type: {norm_type}")
277
+
278
+ if use_conv:
279
+ conv = CausalConv3d(
280
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride,
281
+ disable_causal=disable_causal, bias=bias
282
+ )
283
+ else:
284
+ raise NotImplementedError
285
+ if name == "conv":
286
+ self.Conv2d_0 = conv
287
+ self.conv = conv
288
+ elif name == "Conv2d_0":
289
+ self.conv = conv
290
+ else:
291
+ self.conv = conv
292
+
293
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
294
+ assert hidden_states.shape[1] == self.channels
295
+
296
+ if self.norm is not None:
297
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
298
+
299
+ assert hidden_states.shape[1] == self.channels
300
+
301
+ hidden_states = self.conv(hidden_states)
302
+
303
+ return hidden_states
304
+
305
+ class ResnetBlockCausal3D(nn.Module):
306
+ r"""
307
+ A Resnet block.
308
+
309
+ Parameters:
310
+ in_channels (`int`): The number of channels in the input.
311
+ out_channels (`int`, *optional*, default to be `None`):
312
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
313
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
314
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
315
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
316
+ groups_out (`int`, *optional*, default to None):
317
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
318
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
319
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
320
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
321
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
322
+ "ada_group" for a stronger conditioning with scale and shift.
323
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
324
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
325
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
326
+ use_in_shortcut (`bool`, *optional*, default to `True`):
327
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
328
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
329
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
330
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
331
+ `conv_shortcut` output.
332
+ conv_3d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
333
+ If None, same as `out_channels`.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ *,
339
+ in_channels: int,
340
+ out_channels: Optional[int] = None,
341
+ conv_shortcut: bool = False,
342
+ dropout: float = 0.0,
343
+ temb_channels: int = 512,
344
+ groups: int = 32,
345
+ groups_out: Optional[int] = None,
346
+ pre_norm: bool = True,
347
+ eps: float = 1e-6,
348
+ non_linearity: str = "swish",
349
+ skip_time_act: bool = False,
350
+ time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
351
+ kernel: Optional[torch.FloatTensor] = None,
352
+ output_scale_factor: float = 1.0,
353
+ use_in_shortcut: Optional[bool] = None,
354
+ up: bool = False,
355
+ down: bool = False,
356
+ conv_shortcut_bias: bool = True,
357
+ conv_3d_out_channels: Optional[int] = None,
358
+ disable_causal: bool = False,
359
+ ):
360
+ super().__init__()
361
+ self.pre_norm = pre_norm
362
+ self.pre_norm = True
363
+ self.in_channels = in_channels
364
+ out_channels = in_channels if out_channels is None else out_channels
365
+ self.out_channels = out_channels
366
+ self.use_conv_shortcut = conv_shortcut
367
+ self.up = up
368
+ self.down = down
369
+ self.output_scale_factor = output_scale_factor
370
+ self.time_embedding_norm = time_embedding_norm
371
+ self.skip_time_act = skip_time_act
372
+
373
+ linear_cls = nn.Linear
374
+
375
+ if groups_out is None:
376
+ groups_out = groups
377
+
378
+ if self.time_embedding_norm == "ada_group":
379
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
380
+ elif self.time_embedding_norm == "spatial":
381
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
382
+ else:
383
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
384
+
385
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1, disable_causal=disable_causal)
386
+
387
+ if temb_channels is not None:
388
+ if self.time_embedding_norm == "default":
389
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
390
+ elif self.time_embedding_norm == "scale_shift":
391
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
392
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
393
+ self.time_emb_proj = None
394
+ else:
395
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
396
+ else:
397
+ self.time_emb_proj = None
398
+
399
+ if self.time_embedding_norm == "ada_group":
400
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
401
+ elif self.time_embedding_norm == "spatial":
402
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
403
+ else:
404
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
405
+
406
+ self.dropout = torch.nn.Dropout(dropout)
407
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
408
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels,
409
+ kernel_size=3, stride=1, disable_causal=disable_causal)
410
+
411
+ self.nonlinearity = get_activation(non_linearity)
412
+
413
+ self.upsample = self.downsample = None
414
+ if self.up:
415
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False, disable_causal=disable_causal)
416
+ elif self.down:
417
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False,
418
+ disable_causal=disable_causal, name="op")
419
+
420
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels \
421
+ if use_in_shortcut is None else use_in_shortcut
422
+
423
+ self.conv_shortcut = None
424
+ if self.use_in_shortcut:
425
+ self.conv_shortcut = CausalConv3d(
426
+ in_channels,
427
+ conv_3d_out_channels,
428
+ kernel_size=1,
429
+ stride=1,
430
+ disable_causal=disable_causal,
431
+ bias=conv_shortcut_bias,
432
+ )
433
+
434
+ def forward(
435
+ self,
436
+ input_tensor: torch.FloatTensor,
437
+ temb: torch.FloatTensor,
438
+ scale: float = 1.0,
439
+ ) -> torch.FloatTensor:
440
+ hidden_states = input_tensor
441
+
442
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
443
+ hidden_states = self.norm1(hidden_states, temb)
444
+ else:
445
+ hidden_states = self.norm1(hidden_states)
446
+
447
+ hidden_states = self.nonlinearity(hidden_states)
448
+
449
+ if self.upsample is not None:
450
+ # upsample_nearest_nhwc fails with large batch sizes.
451
+ # see https://github.com/huggingface/diffusers/issues/984
452
+ if hidden_states.shape[0] >= 64:
453
+ input_tensor = input_tensor.contiguous()
454
+ hidden_states = hidden_states.contiguous()
455
+ input_tensor = (
456
+ self.upsample(input_tensor, scale=scale)
457
+ )
458
+ hidden_states = (
459
+ self.upsample(hidden_states, scale=scale)
460
+ )
461
+ elif self.downsample is not None:
462
+ input_tensor = (
463
+ self.downsample(input_tensor, scale=scale)
464
+ )
465
+ hidden_states = (
466
+ self.downsample(hidden_states, scale=scale)
467
+ )
468
+
469
+ hidden_states = self.conv1(hidden_states)
470
+
471
+ if self.time_emb_proj is not None:
472
+ if not self.skip_time_act:
473
+ temb = self.nonlinearity(temb)
474
+ temb = (
475
+ self.time_emb_proj(temb, scale)[:, :, None, None]
476
+ )
477
+
478
+ if temb is not None and self.time_embedding_norm == "default":
479
+ hidden_states = hidden_states + temb
480
+
481
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
482
+ hidden_states = self.norm2(hidden_states, temb)
483
+ else:
484
+ hidden_states = self.norm2(hidden_states)
485
+
486
+ if temb is not None and self.time_embedding_norm == "scale_shift":
487
+ scale, shift = torch.chunk(temb, 2, dim=1)
488
+ hidden_states = hidden_states * (1 + scale) + shift
489
+
490
+ hidden_states = self.nonlinearity(hidden_states)
491
+
492
+ hidden_states = self.dropout(hidden_states)
493
+ hidden_states = self.conv2(hidden_states)
494
+
495
+ if self.conv_shortcut is not None:
496
+ input_tensor = (
497
+ self.conv_shortcut(input_tensor)
498
+ )
499
+
500
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
501
+
502
+ return output_tensor
503
+
504
+ def get_down_block3d(
505
+ down_block_type: str,
506
+ num_layers: int,
507
+ in_channels: int,
508
+ out_channels: int,
509
+ temb_channels: int,
510
+ add_downsample: bool,
511
+ downsample_stride: int,
512
+ resnet_eps: float,
513
+ resnet_act_fn: str,
514
+ transformer_layers_per_block: int = 1,
515
+ num_attention_heads: Optional[int] = None,
516
+ resnet_groups: Optional[int] = None,
517
+ cross_attention_dim: Optional[int] = None,
518
+ downsample_padding: Optional[int] = None,
519
+ dual_cross_attention: bool = False,
520
+ use_linear_projection: bool = False,
521
+ only_cross_attention: bool = False,
522
+ upcast_attention: bool = False,
523
+ resnet_time_scale_shift: str = "default",
524
+ attention_type: str = "default",
525
+ resnet_skip_time_act: bool = False,
526
+ resnet_out_scale_factor: float = 1.0,
527
+ cross_attention_norm: Optional[str] = None,
528
+ attention_head_dim: Optional[int] = None,
529
+ downsample_type: Optional[str] = None,
530
+ dropout: float = 0.0,
531
+ disable_causal: bool = False,
532
+ ):
533
+ # If attn head dim is not defined, we default it to the number of heads
534
+ if attention_head_dim is None:
535
+ logger.warn(
536
+ f"It is recommended to provide `attention_head_dim` when calling \
537
+ `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
538
+ )
539
+ attention_head_dim = num_attention_heads
540
+
541
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
542
+ if down_block_type == "DownEncoderBlockCausal3D":
543
+ return DownEncoderBlockCausal3D(
544
+ num_layers=num_layers,
545
+ in_channels=in_channels,
546
+ out_channels=out_channels,
547
+ dropout=dropout,
548
+ add_downsample=add_downsample,
549
+ downsample_stride=downsample_stride,
550
+ resnet_eps=resnet_eps,
551
+ resnet_act_fn=resnet_act_fn,
552
+ resnet_groups=resnet_groups,
553
+ downsample_padding=downsample_padding,
554
+ resnet_time_scale_shift=resnet_time_scale_shift,
555
+ disable_causal=disable_causal,
556
+ )
557
+ raise ValueError(f"{down_block_type} does not exist.")
558
+
559
+ def get_up_block3d(
560
+ up_block_type: str,
561
+ num_layers: int,
562
+ in_channels: int,
563
+ out_channels: int,
564
+ prev_output_channel: int,
565
+ temb_channels: int,
566
+ add_upsample: bool,
567
+ upsample_scale_factor: Tuple,
568
+ resnet_eps: float,
569
+ resnet_act_fn: str,
570
+ resolution_idx: Optional[int] = None,
571
+ transformer_layers_per_block: int = 1,
572
+ num_attention_heads: Optional[int] = None,
573
+ resnet_groups: Optional[int] = None,
574
+ cross_attention_dim: Optional[int] = None,
575
+ dual_cross_attention: bool = False,
576
+ use_linear_projection: bool = False,
577
+ only_cross_attention: bool = False,
578
+ upcast_attention: bool = False,
579
+ resnet_time_scale_shift: str = "default",
580
+ attention_type: str = "default",
581
+ resnet_skip_time_act: bool = False,
582
+ resnet_out_scale_factor: float = 1.0,
583
+ cross_attention_norm: Optional[str] = None,
584
+ attention_head_dim: Optional[int] = None,
585
+ upsample_type: Optional[str] = None,
586
+ dropout: float = 0.0,
587
+ disable_causal: bool = False,
588
+ ) -> nn.Module:
589
+ # If attn head dim is not defined, we default it to the number of heads
590
+ if attention_head_dim is None:
591
+ logger.warn(
592
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. \
593
+ Defaulting `attention_head_dim` to {num_attention_heads}."
594
+ )
595
+ attention_head_dim = num_attention_heads
596
+
597
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
598
+ if up_block_type == "UpDecoderBlockCausal3D":
599
+ return UpDecoderBlockCausal3D(
600
+ num_layers=num_layers,
601
+ in_channels=in_channels,
602
+ out_channels=out_channels,
603
+ resolution_idx=resolution_idx,
604
+ dropout=dropout,
605
+ add_upsample=add_upsample,
606
+ upsample_scale_factor=upsample_scale_factor,
607
+ resnet_eps=resnet_eps,
608
+ resnet_act_fn=resnet_act_fn,
609
+ resnet_groups=resnet_groups,
610
+ resnet_time_scale_shift=resnet_time_scale_shift,
611
+ temb_channels=temb_channels,
612
+ disable_causal=disable_causal,
613
+ )
614
+ raise ValueError(f"{up_block_type} does not exist.")
615
+
616
+
617
+ class UNetMidBlockCausal3D(nn.Module):
618
+ """
619
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
620
+
621
+ Args:
622
+ in_channels (`int`): The number of input channels.
623
+ temb_channels (`int`): The number of temporal embedding channels.
624
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
625
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
626
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
627
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
628
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
629
+ model on tasks with long-range temporal dependencies.
630
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
631
+ resnet_groups (`int`, *optional*, defaults to 32):
632
+ The number of groups to use in the group normalization layers of the resnet blocks.
633
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
634
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
635
+ Whether to use pre-normalization for the resnet blocks.
636
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
637
+ attention_head_dim (`int`, *optional*, defaults to 1):
638
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
639
+ the number of input channels.
640
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
641
+
642
+ Returns:
643
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
644
+ in_channels, height, width)`.
645
+
646
+ """
647
+
648
+ def __init__(
649
+ self,
650
+ in_channels: int,
651
+ temb_channels: int,
652
+ dropout: float = 0.0,
653
+ num_layers: int = 1,
654
+ resnet_eps: float = 1e-6,
655
+ resnet_time_scale_shift: str = "default", # default, spatial
656
+ resnet_act_fn: str = "swish",
657
+ resnet_groups: int = 32,
658
+ attn_groups: Optional[int] = None,
659
+ resnet_pre_norm: bool = True,
660
+ add_attention: bool = True,
661
+ attention_head_dim: int = 1,
662
+ output_scale_factor: float = 1.0,
663
+ disable_causal: bool = False,
664
+ causal_attention: bool = False,
665
+ ):
666
+ super().__init__()
667
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
668
+ self.add_attention = add_attention
669
+ self.causal_attention = causal_attention
670
+
671
+ if attn_groups is None:
672
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
673
+
674
+ # there is always at least one resnet
675
+ resnets = [
676
+ ResnetBlockCausal3D(
677
+ in_channels=in_channels,
678
+ out_channels=in_channels,
679
+ temb_channels=temb_channels,
680
+ eps=resnet_eps,
681
+ groups=resnet_groups,
682
+ dropout=dropout,
683
+ time_embedding_norm=resnet_time_scale_shift,
684
+ non_linearity=resnet_act_fn,
685
+ output_scale_factor=output_scale_factor,
686
+ pre_norm=resnet_pre_norm,
687
+ disable_causal=disable_causal,
688
+ )
689
+ ]
690
+ attentions = []
691
+
692
+ if attention_head_dim is None:
693
+ logger.warn(
694
+ f"It is not recommend to pass `attention_head_dim=None`. \
695
+ Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
696
+ )
697
+ attention_head_dim = in_channels
698
+
699
+ for _ in range(num_layers):
700
+ if self.add_attention:
701
+ #assert False, "Not implemented yet"
702
+ attentions.append(
703
+ Attention(
704
+ in_channels,
705
+ heads=in_channels // attention_head_dim,
706
+ dim_head=attention_head_dim,
707
+ rescale_output_factor=output_scale_factor,
708
+ eps=resnet_eps,
709
+ norm_num_groups=attn_groups,
710
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
711
+ residual_connection=True,
712
+ bias=True,
713
+ upcast_softmax=True,
714
+ _from_deprecated_attn_block=True,
715
+ )
716
+ )
717
+ else:
718
+ attentions.append(None)
719
+
720
+ resnets.append(
721
+ ResnetBlockCausal3D(
722
+ in_channels=in_channels,
723
+ out_channels=in_channels,
724
+ temb_channels=temb_channels,
725
+ eps=resnet_eps,
726
+ groups=resnet_groups,
727
+ dropout=dropout,
728
+ time_embedding_norm=resnet_time_scale_shift,
729
+ non_linearity=resnet_act_fn,
730
+ output_scale_factor=output_scale_factor,
731
+ pre_norm=resnet_pre_norm,
732
+ disable_causal=disable_causal,
733
+ )
734
+ )
735
+
736
+ self.attentions = nn.ModuleList(attentions)
737
+ self.resnets = nn.ModuleList(resnets)
738
+
739
+ def forward(self,
740
+ hidden_states: torch.FloatTensor,
741
+ temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
742
+ hidden_states = self.resnets[0](hidden_states, temb)
743
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
744
+ if attn is not None:
745
+ B, C, T, H, W = hidden_states.shape
746
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
747
+ if self.causal_attention:
748
+ attention_mask = prepare_causal_attention_mask(T, H * W,
749
+ hidden_states.dtype,
750
+ hidden_states.device,
751
+ batch_size=B)
752
+ else:
753
+ attention_mask = None
754
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
755
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
756
+ hidden_states = resnet(hidden_states, temb)
757
+
758
+ return hidden_states
759
+
760
+
761
+ class DownEncoderBlockCausal3D(nn.Module):
762
+ def __init__(
763
+ self,
764
+ in_channels: int,
765
+ out_channels: int,
766
+ dropout: float = 0.0,
767
+ num_layers: int = 1,
768
+ resnet_eps: float = 1e-6,
769
+ resnet_time_scale_shift: str = "default",
770
+ resnet_act_fn: str = "swish",
771
+ resnet_groups: int = 32,
772
+ resnet_pre_norm: bool = True,
773
+ output_scale_factor: float = 1.0,
774
+ add_downsample: bool = True,
775
+ downsample_stride: int = 2,
776
+ downsample_padding: int = 1,
777
+ disable_causal: bool = False,
778
+ ):
779
+ super().__init__()
780
+ resnets = []
781
+
782
+ for i in range(num_layers):
783
+ in_channels = in_channels if i == 0 else out_channels
784
+ resnets.append(
785
+ ResnetBlockCausal3D(
786
+ in_channels=in_channels,
787
+ out_channels=out_channels,
788
+ temb_channels=None,
789
+ eps=resnet_eps,
790
+ groups=resnet_groups,
791
+ dropout=dropout,
792
+ time_embedding_norm=resnet_time_scale_shift,
793
+ non_linearity=resnet_act_fn,
794
+ output_scale_factor=output_scale_factor,
795
+ pre_norm=resnet_pre_norm,
796
+ disable_causal=disable_causal,
797
+ )
798
+ )
799
+
800
+ self.resnets = nn.ModuleList(resnets)
801
+
802
+ if add_downsample:
803
+ self.downsamplers = nn.ModuleList(
804
+ [
805
+ DownsampleCausal3D(
806
+ out_channels,
807
+ use_conv=True,
808
+ out_channels=out_channels,
809
+ padding=downsample_padding,
810
+ name="op",
811
+ stride=downsample_stride,
812
+ disable_causal=disable_causal,
813
+ )
814
+ ]
815
+ )
816
+ else:
817
+ self.downsamplers = None
818
+
819
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
820
+ for resnet in self.resnets:
821
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
822
+
823
+ if self.downsamplers is not None:
824
+ for downsampler in self.downsamplers:
825
+ hidden_states = downsampler(hidden_states, scale)
826
+
827
+ return hidden_states
828
+
829
+
830
+ class UpDecoderBlockCausal3D(nn.Module):
831
+ def __init__(
832
+ self,
833
+ in_channels: int,
834
+ out_channels: int,
835
+ resolution_idx: Optional[int] = None,
836
+ dropout: float = 0.0,
837
+ num_layers: int = 1,
838
+ resnet_eps: float = 1e-6,
839
+ resnet_time_scale_shift: str = "default", # default, spatial
840
+ resnet_act_fn: str = "swish",
841
+ resnet_groups: int = 32,
842
+ resnet_pre_norm: bool = True,
843
+ output_scale_factor: float = 1.0,
844
+ add_upsample: bool = True,
845
+ upsample_scale_factor = (2, 2, 2),
846
+ temb_channels: Optional[int] = None,
847
+ disable_causal: bool = False,
848
+ ):
849
+ super().__init__()
850
+ resnets = []
851
+
852
+ for i in range(num_layers):
853
+ input_channels = in_channels if i == 0 else out_channels
854
+
855
+ resnets.append(
856
+ ResnetBlockCausal3D(
857
+ in_channels=input_channels,
858
+ out_channels=out_channels,
859
+ temb_channels=temb_channels,
860
+ eps=resnet_eps,
861
+ groups=resnet_groups,
862
+ dropout=dropout,
863
+ time_embedding_norm=resnet_time_scale_shift,
864
+ non_linearity=resnet_act_fn,
865
+ output_scale_factor=output_scale_factor,
866
+ pre_norm=resnet_pre_norm,
867
+ disable_causal=disable_causal,
868
+ )
869
+ )
870
+
871
+ self.resnets = nn.ModuleList(resnets)
872
+
873
+ if add_upsample:
874
+ self.upsamplers = nn.ModuleList(
875
+ [
876
+ UpsampleCausal3D(
877
+ out_channels,
878
+ use_conv=True,
879
+ out_channels=out_channels,
880
+ upsample_factor=upsample_scale_factor,
881
+ disable_causal=disable_causal
882
+ )
883
+ ]
884
+ )
885
+ else:
886
+ self.upsamplers = None
887
+
888
+ self.resolution_idx = resolution_idx
889
+
890
+ def forward(
891
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
892
+ ) -> torch.FloatTensor:
893
+ for resnet in self.resnets:
894
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
895
+
896
+ if self.upsamplers is not None:
897
+ for upsampler in self.upsamplers:
898
+ hidden_states = upsampler(hidden_states)
899
+
900
+ return hidden_states
hymm_sp/vae/vae.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from diffusers.utils import BaseOutput, is_torch_version
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.models.attention_processor import SpatialNorm
11
+ from .unet_causal_3d_blocks import (
12
+ CausalConv3d,
13
+ UNetMidBlockCausal3D,
14
+ get_down_block3d,
15
+ get_up_block3d,
16
+ )
17
+
18
+ @dataclass
19
+ class DecoderOutput(BaseOutput):
20
+ r"""
21
+ Output of decoding method.
22
+
23
+ Args:
24
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
25
+ The decoded output sample from the last layer of the model.
26
+ """
27
+
28
+ sample: torch.FloatTensor
29
+
30
+
31
+ class EncoderCausal3D(nn.Module):
32
+ r"""
33
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
34
+
35
+ Args:
36
+ in_channels (`int`, *optional*, defaults to 3):
37
+ The number of input channels.
38
+ out_channels (`int`, *optional*, defaults to 3):
39
+ The number of output channels.
40
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
41
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
42
+ options.
43
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
44
+ The number of output channels for each block.
45
+ layers_per_block (`int`, *optional*, defaults to 2):
46
+ The number of layers per block.
47
+ norm_num_groups (`int`, *optional*, defaults to 32):
48
+ The number of groups for normalization.
49
+ act_fn (`str`, *optional*, defaults to `"silu"`):
50
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
51
+ double_z (`bool`, *optional*, defaults to `True`):
52
+ Whether to double the number of output channels for the last block.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ in_channels: int = 3,
58
+ out_channels: int = 3,
59
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
60
+ block_out_channels: Tuple[int, ...] = (64,),
61
+ layers_per_block: int = 2,
62
+ norm_num_groups: int = 32,
63
+ act_fn: str = "silu",
64
+ double_z: bool = True,
65
+ mid_block_add_attention=True,
66
+ time_compression_ratio: int = 4,
67
+ spatial_compression_ratio: int = 8,
68
+ disable_causal: bool = False,
69
+ mid_block_causal_attn: bool = False,
70
+ ):
71
+ super().__init__()
72
+ self.layers_per_block = layers_per_block
73
+
74
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0],
75
+ kernel_size=3, stride=1, disable_causal=disable_causal)
76
+ self.mid_block = None
77
+ self.down_blocks = nn.ModuleList([])
78
+
79
+ # down
80
+ output_channel = block_out_channels[0]
81
+ for i, down_block_type in enumerate(down_block_types):
82
+ input_channel = output_channel
83
+ output_channel = block_out_channels[i]
84
+ is_final_block = i == len(block_out_channels) - 1
85
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
86
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
87
+
88
+ if time_compression_ratio == 4:
89
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
90
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - \
91
+ num_time_downsample_layers) and not is_final_block)
92
+ elif time_compression_ratio == 8:
93
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
94
+ add_time_downsample = bool(i < num_time_downsample_layers)
95
+ else:
96
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
97
+
98
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
99
+ downsample_stride_T = (2, ) if add_time_downsample else (1, )
100
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
101
+ down_block = get_down_block3d(
102
+ down_block_type,
103
+ num_layers=self.layers_per_block,
104
+ in_channels=input_channel,
105
+ out_channels=output_channel,
106
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
107
+ downsample_stride=downsample_stride,
108
+ resnet_eps=1e-6,
109
+ downsample_padding=0,
110
+ resnet_act_fn=act_fn,
111
+ resnet_groups=norm_num_groups,
112
+ attention_head_dim=output_channel,
113
+ temb_channels=None,
114
+ disable_causal=disable_causal,
115
+ )
116
+ self.down_blocks.append(down_block)
117
+
118
+ # mid
119
+ self.mid_block = UNetMidBlockCausal3D(
120
+ in_channels=block_out_channels[-1],
121
+ resnet_eps=1e-6,
122
+ resnet_act_fn=act_fn,
123
+ output_scale_factor=1,
124
+ resnet_time_scale_shift="default",
125
+ attention_head_dim=block_out_channels[-1],
126
+ resnet_groups=norm_num_groups,
127
+ temb_channels=None,
128
+ add_attention=mid_block_add_attention,
129
+ disable_causal=disable_causal,
130
+ causal_attention=mid_block_causal_attn,
131
+ )
132
+
133
+ # out
134
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
135
+ self.conv_act = nn.SiLU()
136
+
137
+ conv_out_channels = 2 * out_channels if double_z else out_channels
138
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels,
139
+ kernel_size=3, disable_causal=disable_causal)
140
+
141
+ self.gradient_checkpointing = False
142
+
143
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
144
+ r"""The forward method of the `EncoderCausal3D` class."""
145
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
146
+
147
+ sample = self.conv_in(sample)
148
+
149
+ if self.training and self.gradient_checkpointing:
150
+
151
+ def create_custom_forward(module):
152
+ def custom_forward(*inputs):
153
+ return module(*inputs)
154
+
155
+ return custom_forward
156
+
157
+ # down
158
+ if is_torch_version(">=", "1.11.0"):
159
+ for down_block in self.down_blocks:
160
+ sample = torch.utils.checkpoint.checkpoint(
161
+ create_custom_forward(down_block), sample, use_reentrant=False
162
+ )
163
+ # middle
164
+ sample = torch.utils.checkpoint.checkpoint(
165
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
166
+ )
167
+ else:
168
+ for down_block in self.down_blocks:
169
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
170
+ # middle
171
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
172
+
173
+ else:
174
+ # down
175
+ for down_block in self.down_blocks:
176
+ sample = down_block(sample)
177
+
178
+ # middle
179
+ sample = self.mid_block(sample)
180
+
181
+ # post-process
182
+ sample = self.conv_norm_out(sample)
183
+ sample = self.conv_act(sample)
184
+ sample = self.conv_out(sample)
185
+
186
+ return sample
187
+
188
+
189
+ class DecoderCausal3D(nn.Module):
190
+ r"""
191
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its
192
+ latent representation into an output sample.
193
+
194
+ Args:
195
+ in_channels (`int`, *optional*, defaults to 3):
196
+ The number of input channels.
197
+ out_channels (`int`, *optional*, defaults to 3):
198
+ The number of output channels.
199
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
200
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
201
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
202
+ The number of output channels for each block.
203
+ layers_per_block (`int`, *optional*, defaults to 2):
204
+ The number of layers per block.
205
+ norm_num_groups (`int`, *optional*, defaults to 32):
206
+ The number of groups for normalization.
207
+ act_fn (`str`, *optional*, defaults to `"silu"`):
208
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
209
+ norm_type (`str`, *optional*, defaults to `"group"`):
210
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ in_channels: int = 3,
216
+ out_channels: int = 3,
217
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
218
+ block_out_channels: Tuple[int, ...] = (64,),
219
+ layers_per_block: int = 2,
220
+ norm_num_groups: int = 32,
221
+ act_fn: str = "silu",
222
+ norm_type: str = "group", # group, spatial
223
+ mid_block_add_attention=True,
224
+ time_compression_ratio: int = 4,
225
+ spatial_compression_ratio: int = 8,
226
+ disable_causal: bool = False,
227
+ mid_block_causal_attn: bool = False,
228
+ ):
229
+ super().__init__()
230
+ self.layers_per_block = layers_per_block
231
+
232
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3,
233
+ stride=1, disable_causal=disable_causal)
234
+ self.mid_block = None
235
+ self.up_blocks = nn.ModuleList([])
236
+
237
+ temb_channels = in_channels if norm_type == "spatial" else None
238
+
239
+ # mid
240
+ self.mid_block = UNetMidBlockCausal3D(
241
+ in_channels=block_out_channels[-1],
242
+ resnet_eps=1e-6,
243
+ resnet_act_fn=act_fn,
244
+ output_scale_factor=1,
245
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
246
+ attention_head_dim=block_out_channels[-1],
247
+ resnet_groups=norm_num_groups,
248
+ temb_channels=temb_channels,
249
+ add_attention=mid_block_add_attention,
250
+ disable_causal=disable_causal,
251
+ causal_attention=mid_block_causal_attn,
252
+ )
253
+
254
+ # up
255
+ reversed_block_out_channels = list(reversed(block_out_channels))
256
+ output_channel = reversed_block_out_channels[0]
257
+ for i, up_block_type in enumerate(up_block_types):
258
+ prev_output_channel = output_channel
259
+ output_channel = reversed_block_out_channels[i]
260
+ is_final_block = i == len(block_out_channels) - 1
261
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
262
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
263
+
264
+ if time_compression_ratio == 4:
265
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
266
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - \
267
+ num_time_upsample_layers and not is_final_block)
268
+ elif time_compression_ratio == 8:
269
+ add_spatial_upsample = bool(i >= len(block_out_channels) - num_spatial_upsample_layers)
270
+ add_time_upsample = bool(i >= len(block_out_channels) - num_time_upsample_layers)
271
+ else:
272
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
273
+
274
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
275
+ upsample_scale_factor_T = (2, ) if add_time_upsample else (1, )
276
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
277
+ up_block = get_up_block3d(
278
+ up_block_type,
279
+ num_layers=self.layers_per_block + 1,
280
+ in_channels=prev_output_channel,
281
+ out_channels=output_channel,
282
+ prev_output_channel=None,
283
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
284
+ upsample_scale_factor=upsample_scale_factor,
285
+ resnet_eps=1e-6,
286
+ resnet_act_fn=act_fn,
287
+ resnet_groups=norm_num_groups,
288
+ attention_head_dim=output_channel,
289
+ temb_channels=temb_channels,
290
+ resnet_time_scale_shift=norm_type,
291
+ disable_causal=disable_causal,
292
+ )
293
+ self.up_blocks.append(up_block)
294
+ prev_output_channel = output_channel
295
+
296
+ # out
297
+ if norm_type == "spatial":
298
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
299
+ else:
300
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
301
+ self.conv_act = nn.SiLU()
302
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3, disable_causal=disable_causal)
303
+
304
+ self.gradient_checkpointing = False
305
+
306
+ def forward(
307
+ self,
308
+ sample: torch.FloatTensor,
309
+ latent_embeds: Optional[torch.FloatTensor] = None,
310
+ ) -> torch.FloatTensor:
311
+ r"""The forward method of the `DecoderCausal3D` class."""
312
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
313
+
314
+ sample = self.conv_in(sample)
315
+
316
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
317
+ if self.training and self.gradient_checkpointing:
318
+
319
+ def create_custom_forward(module):
320
+ def custom_forward(*inputs):
321
+ return module(*inputs)
322
+
323
+ return custom_forward
324
+
325
+ if is_torch_version(">=", "1.11.0"):
326
+ # middle
327
+ sample = torch.utils.checkpoint.checkpoint(
328
+ create_custom_forward(self.mid_block),
329
+ sample,
330
+ latent_embeds,
331
+ use_reentrant=False,
332
+ )
333
+ sample = sample.to(upscale_dtype)
334
+
335
+ # up
336
+ for up_block in self.up_blocks:
337
+ sample = torch.utils.checkpoint.checkpoint(
338
+ create_custom_forward(up_block),
339
+ sample,
340
+ latent_embeds,
341
+ use_reentrant=False,
342
+ )
343
+ else:
344
+ # middle
345
+ sample = torch.utils.checkpoint.checkpoint(
346
+ create_custom_forward(self.mid_block), sample, latent_embeds
347
+ )
348
+ sample = sample.to(upscale_dtype)
349
+
350
+ # up
351
+ for up_block in self.up_blocks:
352
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
353
+ else:
354
+ # middle
355
+ sample = self.mid_block(sample, latent_embeds)
356
+ sample = sample.to(upscale_dtype)
357
+
358
+ # up
359
+ for up_block in self.up_blocks:
360
+ sample = up_block(sample, latent_embeds)
361
+
362
+ # post-process
363
+ if latent_embeds is None:
364
+ sample = self.conv_norm_out(sample)
365
+ else:
366
+ sample = self.conv_norm_out(sample, latent_embeds)
367
+ sample = self.conv_act(sample)
368
+ sample = self.conv_out(sample)
369
+
370
+ return sample
371
+
372
+
373
+ class DiagonalGaussianDistribution(object):
374
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
375
+ if parameters.ndim == 3:
376
+ dim = 2 # (B, L, C)
377
+ elif parameters.ndim == 5 or parameters.ndim == 4:
378
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
379
+ else:
380
+ raise NotImplementedError
381
+ self.parameters = parameters
382
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
383
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
384
+ self.deterministic = deterministic
385
+ self.std = torch.exp(0.5 * self.logvar)
386
+ self.var = torch.exp(self.logvar)
387
+ if self.deterministic:
388
+ self.var = self.std = torch.zeros_like(
389
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
390
+ )
391
+
392
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
393
+ # make sure sample is on the same device as the parameters and has same dtype
394
+ sample = randn_tensor(
395
+ self.mean.shape,
396
+ generator=generator,
397
+ device=self.parameters.device,
398
+ dtype=self.parameters.dtype,
399
+ )
400
+ x = self.mean + self.std * sample
401
+ return x
402
+
403
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
404
+ if self.deterministic:
405
+ return torch.Tensor([0.0])
406
+ else:
407
+ reduce_dim = list(range(1, self.mean.ndim))
408
+ if other is None:
409
+ return 0.5 * torch.sum(
410
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
411
+ dim=reduce_dim,
412
+ )
413
+ else:
414
+ return 0.5 * torch.sum(
415
+ torch.pow(self.mean - other.mean, 2) / other.var
416
+ + self.var / other.var
417
+ - 1.0
418
+ - self.logvar
419
+ + other.logvar,
420
+ dim=reduce_dim,
421
+ )
422
+
423
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
424
+ if self.deterministic:
425
+ return torch.Tensor([0.0])
426
+ logtwopi = np.log(2.0 * np.pi)
427
+ return 0.5 * torch.sum(
428
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
429
+ dim=dims,
430
+ )
431
+
432
+ def mode(self) -> torch.Tensor:
433
+ return self.mean
requirements.txt ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.9.0
2
+ av==15.0.0
3
+ certifi==2025.8.3
4
+ charset-normalizer==3.4.2
5
+ contourpy==1.3.3
6
+ cycler==0.12.1
7
+ decord==0.6.0
8
+ diffusers==0.34.0
9
+ einops==0.8.1
10
+ filelock==3.13.1
11
+ fonttools==4.59.0
12
+ fsspec==2024.6.1
13
+ hf-xet==1.1.5
14
+ huggingface-hub==0.34.3
15
+ idna==3.10
16
+ imageio==2.37.0
17
+ imageio-ffmpeg==0.6.0
18
+ importlib_metadata==8.7.0
19
+ Jinja2==3.1.4
20
+ kiwisolver==1.4.8
21
+ loguru==0.7.3
22
+ MarkupSafe==2.1.5
23
+ matplotlib==3.10.5
24
+ mpmath==1.3.0
25
+ networkx==3.3
26
+ ninja==1.11.1.4
27
+ numpy==2.1.2
28
+ nvidia-ml-py==12.575.51
29
+ nvidia-nccl-cu12==2.21.5
30
+ nvidia-nvjitlink-cu12==12.4.127
31
+ nvidia-nvtx-cu12==12.4.127
32
+ nvitop==1.5.2
33
+ opencv-python-headless==4.12.0.88
34
+ packaging==25.0
35
+ pandas==2.3.1
36
+ pillow==11.0.0
37
+ protobuf==6.31.1
38
+ psutil==7.0.0
39
+ pyparsing==3.2.3
40
+ python-dateutil==2.9.0.post0
41
+ pytz==2025.2
42
+ PyYAML==6.0.2
43
+ regex==2025.7.34
44
+ requests==2.32.4
45
+ safetensors==0.5.3
46
+ sentencepiece==0.2.0
47
+ setuptools==78.1.1
48
+ six==1.17.0
49
+ sympy==1.13.1
50
+ tokenizers==0.21.4
51
+ tqdm==4.67.1
52
+ transformers==4.54.1
53
+ triton==3.1.0
54
+ typing_extensions==4.12.2
55
+ tzdata==2025.2
56
+ urllib3==2.5.0
57
+ wheel==0.45.1
58
+ zipp==3.23.0
59
+ gradio==5.42.0
60
+ sageattention==1.0.6
scripts/run_sample_batch_4090.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ JOBS_DIR=$(dirname $(dirname "$0"))
3
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
4
+ export MODEL_BASE="/path/to/models"
5
+ checkpoint_path="/path/to/ckpts"
6
+
7
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
8
+ modelname='Tencent_hunyuanGameCraft_720P'
9
+
10
+ # disable sp and enable cpu offload
11
+ export DISABLE_SP=1
12
+ export CPU_OFFLOAD=1
13
+ export NUM_GPU=1
14
+
15
+ # # enable both sp and cpu offload
16
+ # export DISABLE_SP=0
17
+ # export CPU_OFFLOAD=1
18
+ # export NUM_GPU=8
19
+
20
+ torchrun --nnodes=1 --nproc_per_node=${NUM_GPU} --master_port 29605 hymm_sp/sample_batch.py \
21
+ --image-path "asset/village.png" \
22
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
23
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
24
+ --ckpt ${checkpoint_path} \
25
+ --video-size 704 1216 \
26
+ --cfg-scale 2.0 \
27
+ --image-start \
28
+ --action-list w s d a \
29
+ --action-speed-list 0.2 0.2 0.2 0.2 \
30
+ --seed 250160 \
31
+ --infer-steps 50 \
32
+ --flow-shift-eval-video 5.0 \
33
+ --cpu-offload \
34
+ --use-fp8 \
35
+ --save-path './results/'
scripts/run_sample_batch_distill.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ JOBS_DIR=$(dirname $(dirname "$0"))
3
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
4
+ export MODEL_BASE="weights/stdmodels"
5
+ checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states_distill.pt"
6
+
7
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
8
+ modelname='Tencent_hunyuanGameCraft_720P'
9
+
10
+ torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
11
+ --image-path "asset/village.png" \
12
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
13
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
14
+ --ckpt ${checkpoint_path} \
15
+ --video-size 704 1216 \
16
+ --cfg-scale 1.0 \
17
+ --image-start \
18
+ --action-list w s d a \
19
+ --action-speed-list 0.2 0.2 0.2 0.2 \
20
+ --seed 250160 \
21
+ --infer-steps 8 \
22
+ --use-fp8 \
23
+ --flow-shift-eval-video 5.0 \
24
+ --save-path './results_distill/'
scripts/run_sample_batch_sp.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ JOBS_DIR=$(dirname $(dirname "$0"))
3
+ export PYTHONPATH=${JOBS_DIR}:$PYTHONPATH
4
+ export MODEL_BASE="weights/stdmodels"
5
+ checkpoint_path="weights/gamecraft_models/mp_rank_00_model_states.pt"
6
+
7
+ current_time=$(date "+%Y.%m.%d-%H.%M.%S")
8
+ modelname='Tencent_hunyuanGameCraft_720P'
9
+
10
+ torchrun --nnodes=1 --nproc_per_node=8 --master_port 29605 hymm_sp/sample_batch.py \
11
+ --image-path "asset/village.png" \
12
+ --prompt "A charming medieval village with cobblestone streets, thatched-roof houses, and vibrant flower gardens under a bright blue sky." \
13
+ --add-pos-prompt "Realistic, High-quality." \
14
+ --add-neg-prompt "overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, text, subtitles, static, picture, black border." \
15
+ --ckpt ${checkpoint_path} \
16
+ --video-size 704 1216 \
17
+ --cfg-scale 2.0 \
18
+ --image-start \
19
+ --action-list w s d a \
20
+ --action-speed-list 0.2 0.2 0.2 0.2 \
21
+ --seed 250160 \
22
+ --infer-steps 50 \
23
+ --flow-shift-eval-video 5.0 \
24
+ --save-path './results/'