Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- .gitattributes +14 -0
- LICENSE.txt +201 -0
- ORIGINAL_README.md +322 -0
- assets/efficiency.png +3 -0
- assets/logo.png +3 -0
- assets/logo2.jpeg +3 -0
- assets/pipe.png +3 -0
- examples/multi/1/1.WAV +3 -0
- examples/multi/1/2.WAV +3 -0
- examples/multi/1/multi1.png +3 -0
- examples/multi/2/1.wav +3 -0
- examples/multi/2/multi2.png +3 -0
- examples/multi/3/1-man.WAV +3 -0
- examples/multi/3/1-woman.WAV +3 -0
- examples/multi/3/multi3.png +3 -0
- examples/multitalk_example_1.json +13 -0
- examples/multitalk_example_2.json +9 -0
- examples/multitalk_example_3.json +9 -0
- examples/single/1.wav +3 -0
- examples/single/single1.png +3 -0
- examples/single_example_1.json +7 -0
- generate_multitalk.py +500 -0
- requirements.txt +15 -0
- src/audio_analysis/torch_utils.py +20 -0
- src/audio_analysis/wav2vec2.py +125 -0
- src/utils.py +60 -0
- src/vram_management/__init__.py +1 -0
- src/vram_management/layers.py +179 -0
- wan/__init__.py +6 -0
- wan/configs/__init__.py +58 -0
- wan/configs/shared_config.py +19 -0
- wan/configs/wan_i2v_14B.py +24 -0
- wan/configs/wan_multitalk_14B.py +36 -0
- wan/configs/wan_t2v_14B.py +29 -0
- wan/configs/wan_t2v_1_3B.py +29 -0
- wan/distributed/__init__.py +0 -0
- wan/distributed/fsdp.py +43 -0
- wan/distributed/xdit_context_parallel.py +550 -0
- wan/first_last_frame2video.py +377 -0
- wan/image2video.py +350 -0
- wan/modules/__init__.py +18 -0
- wan/modules/attention.py +393 -0
- wan/modules/clip.py +542 -0
- wan/modules/model.py +631 -0
- wan/modules/multitalk_model.py +799 -0
- wan/modules/t5.py +513 -0
- wan/modules/tokenizers.py +82 -0
- wan/modules/vace_model.py +250 -0
- wan/modules/vae.py +663 -0
- wan/modules/xlm_roberta.py +170 -0
    	
        .gitattributes
    CHANGED
    
    | @@ -33,3 +33,17 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 33 | 
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 36 | 
            +
            assets/efficiency.png filter=lfs diff=lfs merge=lfs -text
         | 
| 37 | 
            +
            assets/logo.png filter=lfs diff=lfs merge=lfs -text
         | 
| 38 | 
            +
            assets/logo2.jpeg filter=lfs diff=lfs merge=lfs -text
         | 
| 39 | 
            +
            assets/pipe.png filter=lfs diff=lfs merge=lfs -text
         | 
| 40 | 
            +
            examples/multi/1/1.WAV filter=lfs diff=lfs merge=lfs -text
         | 
| 41 | 
            +
            examples/multi/1/2.WAV filter=lfs diff=lfs merge=lfs -text
         | 
| 42 | 
            +
            examples/multi/1/multi1.png filter=lfs diff=lfs merge=lfs -text
         | 
| 43 | 
            +
            examples/multi/2/1.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 44 | 
            +
            examples/multi/2/multi2.png filter=lfs diff=lfs merge=lfs -text
         | 
| 45 | 
            +
            examples/multi/3/1-man.WAV filter=lfs diff=lfs merge=lfs -text
         | 
| 46 | 
            +
            examples/multi/3/1-woman.WAV filter=lfs diff=lfs merge=lfs -text
         | 
| 47 | 
            +
            examples/multi/3/multi3.png filter=lfs diff=lfs merge=lfs -text
         | 
| 48 | 
            +
            examples/single/1.wav filter=lfs diff=lfs merge=lfs -text
         | 
| 49 | 
            +
            examples/single/single1.png filter=lfs diff=lfs merge=lfs -text
         | 
    	
        LICENSE.txt
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
                                             Apache License
         | 
| 2 | 
            +
                                       Version 2.0, January 2004
         | 
| 3 | 
            +
                                    http://www.apache.org/licenses/
         | 
| 4 | 
            +
             | 
| 5 | 
            +
               TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
         | 
| 6 | 
            +
             | 
| 7 | 
            +
               1. Definitions.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                  "License" shall mean the terms and conditions for use, reproduction,
         | 
| 10 | 
            +
                  and distribution as defined by Sections 1 through 9 of this document.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                  "Licensor" shall mean the copyright owner or entity authorized by
         | 
| 13 | 
            +
                  the copyright owner that is granting the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                  "Legal Entity" shall mean the union of the acting entity and all
         | 
| 16 | 
            +
                  other entities that control, are controlled by, or are under common
         | 
| 17 | 
            +
                  control with that entity. For the purposes of this definition,
         | 
| 18 | 
            +
                  "control" means (i) the power, direct or indirect, to cause the
         | 
| 19 | 
            +
                  direction or management of such entity, whether by contract or
         | 
| 20 | 
            +
                  otherwise, or (ii) ownership of fifty percent (50%) or more of the
         | 
| 21 | 
            +
                  outstanding shares, or (iii) beneficial ownership of such entity.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                  "You" (or "Your") shall mean an individual or Legal Entity
         | 
| 24 | 
            +
                  exercising permissions granted by this License.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                  "Source" form shall mean the preferred form for making modifications,
         | 
| 27 | 
            +
                  including but not limited to software source code, documentation
         | 
| 28 | 
            +
                  source, and configuration files.
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                  "Object" form shall mean any form resulting from mechanical
         | 
| 31 | 
            +
                  transformation or translation of a Source form, including but
         | 
| 32 | 
            +
                  not limited to compiled object code, generated documentation,
         | 
| 33 | 
            +
                  and conversions to other media types.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                  "Work" shall mean the work of authorship, whether in Source or
         | 
| 36 | 
            +
                  Object form, made available under the License, as indicated by a
         | 
| 37 | 
            +
                  copyright notice that is included in or attached to the work
         | 
| 38 | 
            +
                  (an example is provided in the Appendix below).
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                  "Derivative Works" shall mean any work, whether in Source or Object
         | 
| 41 | 
            +
                  form, that is based on (or derived from) the Work and for which the
         | 
| 42 | 
            +
                  editorial revisions, annotations, elaborations, or other modifications
         | 
| 43 | 
            +
                  represent, as a whole, an original work of authorship. For the purposes
         | 
| 44 | 
            +
                  of this License, Derivative Works shall not include works that remain
         | 
| 45 | 
            +
                  separable from, or merely link (or bind by name) to the interfaces of,
         | 
| 46 | 
            +
                  the Work and Derivative Works thereof.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                  "Contribution" shall mean any work of authorship, including
         | 
| 49 | 
            +
                  the original version of the Work and any modifications or additions
         | 
| 50 | 
            +
                  to that Work or Derivative Works thereof, that is intentionally
         | 
| 51 | 
            +
                  submitted to Licensor for inclusion in the Work by the copyright owner
         | 
| 52 | 
            +
                  or by an individual or Legal Entity authorized to submit on behalf of
         | 
| 53 | 
            +
                  the copyright owner. For the purposes of this definition, "submitted"
         | 
| 54 | 
            +
                  means any form of electronic, verbal, or written communication sent
         | 
| 55 | 
            +
                  to the Licensor or its representatives, including but not limited to
         | 
| 56 | 
            +
                  communication on electronic mailing lists, source code control systems,
         | 
| 57 | 
            +
                  and issue tracking systems that are managed by, or on behalf of, the
         | 
| 58 | 
            +
                  Licensor for the purpose of discussing and improving the Work, but
         | 
| 59 | 
            +
                  excluding communication that is conspicuously marked or otherwise
         | 
| 60 | 
            +
                  designated in writing by the copyright owner as "Not a Contribution."
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                  "Contributor" shall mean Licensor and any individual or Legal Entity
         | 
| 63 | 
            +
                  on behalf of whom a Contribution has been received by Licensor and
         | 
| 64 | 
            +
                  subsequently incorporated within the Work.
         | 
| 65 | 
            +
             | 
| 66 | 
            +
               2. Grant of Copyright License. Subject to the terms and conditions of
         | 
| 67 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 68 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 69 | 
            +
                  copyright license to reproduce, prepare Derivative Works of,
         | 
| 70 | 
            +
                  publicly display, publicly perform, sublicense, and distribute the
         | 
| 71 | 
            +
                  Work and such Derivative Works in Source or Object form.
         | 
| 72 | 
            +
             | 
| 73 | 
            +
               3. Grant of Patent License. Subject to the terms and conditions of
         | 
| 74 | 
            +
                  this License, each Contributor hereby grants to You a perpetual,
         | 
| 75 | 
            +
                  worldwide, non-exclusive, no-charge, royalty-free, irrevocable
         | 
| 76 | 
            +
                  (except as stated in this section) patent license to make, have made,
         | 
| 77 | 
            +
                  use, offer to sell, sell, import, and otherwise transfer the Work,
         | 
| 78 | 
            +
                  where such license applies only to those patent claims licensable
         | 
| 79 | 
            +
                  by such Contributor that are necessarily infringed by their
         | 
| 80 | 
            +
                  Contribution(s) alone or by combination of their Contribution(s)
         | 
| 81 | 
            +
                  with the Work to which such Contribution(s) was submitted. If You
         | 
| 82 | 
            +
                  institute patent litigation against any entity (including a
         | 
| 83 | 
            +
                  cross-claim or counterclaim in a lawsuit) alleging that the Work
         | 
| 84 | 
            +
                  or a Contribution incorporated within the Work constitutes direct
         | 
| 85 | 
            +
                  or contributory patent infringement, then any patent licenses
         | 
| 86 | 
            +
                  granted to You under this License for that Work shall terminate
         | 
| 87 | 
            +
                  as of the date such litigation is filed.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
               4. Redistribution. You may reproduce and distribute copies of the
         | 
| 90 | 
            +
                  Work or Derivative Works thereof in any medium, with or without
         | 
| 91 | 
            +
                  modifications, and in Source or Object form, provided that You
         | 
| 92 | 
            +
                  meet the following conditions:
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                  (a) You must give any other recipients of the Work or
         | 
| 95 | 
            +
                      Derivative Works a copy of this License; and
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                  (b) You must cause any modified files to carry prominent notices
         | 
| 98 | 
            +
                      stating that You changed the files; and
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                  (c) You must retain, in the Source form of any Derivative Works
         | 
| 101 | 
            +
                      that You distribute, all copyright, patent, trademark, and
         | 
| 102 | 
            +
                      attribution notices from the Source form of the Work,
         | 
| 103 | 
            +
                      excluding those notices that do not pertain to any part of
         | 
| 104 | 
            +
                      the Derivative Works; and
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                  (d) If the Work includes a "NOTICE" text file as part of its
         | 
| 107 | 
            +
                      distribution, then any Derivative Works that You distribute must
         | 
| 108 | 
            +
                      include a readable copy of the attribution notices contained
         | 
| 109 | 
            +
                      within such NOTICE file, excluding those notices that do not
         | 
| 110 | 
            +
                      pertain to any part of the Derivative Works, in at least one
         | 
| 111 | 
            +
                      of the following places: within a NOTICE text file distributed
         | 
| 112 | 
            +
                      as part of the Derivative Works; within the Source form or
         | 
| 113 | 
            +
                      documentation, if provided along with the Derivative Works; or,
         | 
| 114 | 
            +
                      within a display generated by the Derivative Works, if and
         | 
| 115 | 
            +
                      wherever such third-party notices normally appear. The contents
         | 
| 116 | 
            +
                      of the NOTICE file are for informational purposes only and
         | 
| 117 | 
            +
                      do not modify the License. You may add Your own attribution
         | 
| 118 | 
            +
                      notices within Derivative Works that You distribute, alongside
         | 
| 119 | 
            +
                      or as an addendum to the NOTICE text from the Work, provided
         | 
| 120 | 
            +
                      that such additional attribution notices cannot be construed
         | 
| 121 | 
            +
                      as modifying the License.
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                  You may add Your own copyright statement to Your modifications and
         | 
| 124 | 
            +
                  may provide additional or different license terms and conditions
         | 
| 125 | 
            +
                  for use, reproduction, or distribution of Your modifications, or
         | 
| 126 | 
            +
                  for any such Derivative Works as a whole, provided Your use,
         | 
| 127 | 
            +
                  reproduction, and distribution of the Work otherwise complies with
         | 
| 128 | 
            +
                  the conditions stated in this License.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
               5. Submission of Contributions. Unless You explicitly state otherwise,
         | 
| 131 | 
            +
                  any Contribution intentionally submitted for inclusion in the Work
         | 
| 132 | 
            +
                  by You to the Licensor shall be under the terms and conditions of
         | 
| 133 | 
            +
                  this License, without any additional terms or conditions.
         | 
| 134 | 
            +
                  Notwithstanding the above, nothing herein shall supersede or modify
         | 
| 135 | 
            +
                  the terms of any separate license agreement you may have executed
         | 
| 136 | 
            +
                  with Licensor regarding such Contributions.
         | 
| 137 | 
            +
             | 
| 138 | 
            +
               6. Trademarks. This License does not grant permission to use the trade
         | 
| 139 | 
            +
                  names, trademarks, service marks, or product names of the Licensor,
         | 
| 140 | 
            +
                  except as required for reasonable and customary use in describing the
         | 
| 141 | 
            +
                  origin of the Work and reproducing the content of the NOTICE file.
         | 
| 142 | 
            +
             | 
| 143 | 
            +
               7. Disclaimer of Warranty. Unless required by applicable law or
         | 
| 144 | 
            +
                  agreed to in writing, Licensor provides the Work (and each
         | 
| 145 | 
            +
                  Contributor provides its Contributions) on an "AS IS" BASIS,
         | 
| 146 | 
            +
                  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
         | 
| 147 | 
            +
                  implied, including, without limitation, any warranties or conditions
         | 
| 148 | 
            +
                  of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
         | 
| 149 | 
            +
                  PARTICULAR PURPOSE. You are solely responsible for determining the
         | 
| 150 | 
            +
                  appropriateness of using or redistributing the Work and assume any
         | 
| 151 | 
            +
                  risks associated with Your exercise of permissions under this License.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
               8. Limitation of Liability. In no event and under no legal theory,
         | 
| 154 | 
            +
                  whether in tort (including negligence), contract, or otherwise,
         | 
| 155 | 
            +
                  unless required by applicable law (such as deliberate and grossly
         | 
| 156 | 
            +
                  negligent acts) or agreed to in writing, shall any Contributor be
         | 
| 157 | 
            +
                  liable to You for damages, including any direct, indirect, special,
         | 
| 158 | 
            +
                  incidental, or consequential damages of any character arising as a
         | 
| 159 | 
            +
                  result of this License or out of the use or inability to use the
         | 
| 160 | 
            +
                  Work (including but not limited to damages for loss of goodwill,
         | 
| 161 | 
            +
                  work stoppage, computer failure or malfunction, or any and all
         | 
| 162 | 
            +
                  other commercial damages or losses), even if such Contributor
         | 
| 163 | 
            +
                  has been advised of the possibility of such damages.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
               9. Accepting Warranty or Additional Liability. While redistributing
         | 
| 166 | 
            +
                  the Work or Derivative Works thereof, You may choose to offer,
         | 
| 167 | 
            +
                  and charge a fee for, acceptance of support, warranty, indemnity,
         | 
| 168 | 
            +
                  or other liability obligations and/or rights consistent with this
         | 
| 169 | 
            +
                  License. However, in accepting such obligations, You may act only
         | 
| 170 | 
            +
                  on Your own behalf and on Your sole responsibility, not on behalf
         | 
| 171 | 
            +
                  of any other Contributor, and only if You agree to indemnify,
         | 
| 172 | 
            +
                  defend, and hold each Contributor harmless for any liability
         | 
| 173 | 
            +
                  incurred by, or claims asserted against, such Contributor by reason
         | 
| 174 | 
            +
                  of your accepting any such warranty or additional liability.
         | 
| 175 | 
            +
             | 
| 176 | 
            +
               END OF TERMS AND CONDITIONS
         | 
| 177 | 
            +
             | 
| 178 | 
            +
               APPENDIX: How to apply the Apache License to your work.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                  To apply the Apache License to your work, attach the following
         | 
| 181 | 
            +
                  boilerplate notice, with the fields enclosed by brackets "[]"
         | 
| 182 | 
            +
                  replaced with your own identifying information. (Don't include
         | 
| 183 | 
            +
                  the brackets!)  The text should be enclosed in the appropriate
         | 
| 184 | 
            +
                  comment syntax for the file format. We also recommend that a
         | 
| 185 | 
            +
                  file or class name and description of purpose be included on the
         | 
| 186 | 
            +
                  same "printed page" as the copyright notice for easier
         | 
| 187 | 
            +
                  identification within third-party archives.
         | 
| 188 | 
            +
             | 
| 189 | 
            +
               Copyright [yyyy] [name of copyright owner]
         | 
| 190 | 
            +
             | 
| 191 | 
            +
               Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 192 | 
            +
               you may not use this file except in compliance with the License.
         | 
| 193 | 
            +
               You may obtain a copy of the License at
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                   http://www.apache.org/licenses/LICENSE-2.0
         | 
| 196 | 
            +
             | 
| 197 | 
            +
               Unless required by applicable law or agreed to in writing, software
         | 
| 198 | 
            +
               distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 199 | 
            +
               WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 200 | 
            +
               See the License for the specific language governing permissions and
         | 
| 201 | 
            +
               limitations under the License.
         | 
    	
        ORIGINAL_README.md
    ADDED
    
    | @@ -0,0 +1,322 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            <div align="center">
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            <p align="center">
         | 
| 4 | 
            +
              <img src="assets/logo2.jpeg" alt="MultiTalk" width="240"/>
         | 
| 5 | 
            +
            </p>
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            <h1>Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation</h1>
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            [Zhe Kong*](https://scholar.google.com/citations?user=4X3yLwsAAAAJ&hl=zh-CN) · [Feng Gao*](https://scholar.google.com/citations?user=lFkCeoYAAAAJ) ·[Yong Zhang](https://yzhang2016.github.io/)<sup>✉</sup> · [Zhuoliang Kang](https://scholar.google.com/citations?user=W1ZXjMkAAAAJ&hl=en) · [Xiaoming Wei](https://scholar.google.com/citations?user=JXV5yrZxj5MC&hl=zh-CN) · [Xunliang Cai](https://openreview.net/profile?id=~Xunliang_Cai1)  
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            [Guanying Chen](https://guanyingc.github.io/) · [Wenhan Luo](https://whluo.github.io/)<sup>✉</sup>
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            <sup>*</sup>Equal Contribution
         | 
| 15 | 
            +
            <sup>✉</sup>Corresponding Authors
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            <a href='https://meigen-ai.github.io/multi-talk/'><img src='https://img.shields.io/badge/Project-Page-green'></a>
         | 
| 19 | 
            +
            <a href='https://arxiv.org/abs/2505.22647'><img src='https://img.shields.io/badge/Technique-Report-red'></a>
         | 
| 20 | 
            +
            <a href='https://huggingface.co/MeiGen-AI/MeiGen-MultiTalk'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a>
         | 
| 21 | 
            +
            </div>
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            > **TL; DR:**  MultiTalk is an audio-driven multi-person conversational video generation. It enables the video creation of multi-person conversation 💬, singing  🎤,  interaction control 👬, and cartoon 🙊.
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            <p align="center">
         | 
| 26 | 
            +
              <img src="assets/pipe.png">
         | 
| 27 | 
            +
            </p>
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            ## Video Demos
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            <table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
         | 
| 32 | 
            +
              <tr>
         | 
| 33 | 
            +
                  <td>
         | 
| 34 | 
            +
                      <video src="https://github.com/user-attachments/assets/e55952e6-e1b2-44a5-9887-a89307a378da" width="320" controls loop></video>
         | 
| 35 | 
            +
                  </td>
         | 
| 36 | 
            +
                  <td>
         | 
| 37 | 
            +
                      <video src="https://github.com/user-attachments/assets/f0396c19-d459-42aa-9d78-34fdea10de18" width="320" controls loop></video>
         | 
| 38 | 
            +
                  </td>
         | 
| 39 | 
            +
                   <td>
         | 
| 40 | 
            +
                      <video src="https://github.com/user-attachments/assets/3576fd04-3e5f-4933-ac7b-1c4e6a601379" width="320" controls loop></video>
         | 
| 41 | 
            +
                 </td>
         | 
| 42 | 
            +
              </tr>
         | 
| 43 | 
            +
              <tr>
         | 
| 44 | 
            +
                  <td>
         | 
| 45 | 
            +
                      <video src="https://github.com/user-attachments/assets/5589056e-3202-442d-a62a-2cad7a7ecb19" width="320" controls loop></video>
         | 
| 46 | 
            +
                  </td>
         | 
| 47 | 
            +
                  <td>
         | 
| 48 | 
            +
                      <video src="https://github.com/user-attachments/assets/554bfbe7-0090-492c-94be-329f5e39e175" width="320" controls loop></video>
         | 
| 49 | 
            +
                  </td>
         | 
| 50 | 
            +
                   <td>
         | 
| 51 | 
            +
                      <video src="https://github.com/user-attachments/assets/9e961f35-9413-4846-a806-8186d54061da" width="320" controls loop></video>
         | 
| 52 | 
            +
                 </td>
         | 
| 53 | 
            +
              </tr>
         | 
| 54 | 
            +
              <tr>
         | 
| 55 | 
            +
                  <td>
         | 
| 56 | 
            +
                      <video src="https://github.com/user-attachments/assets/342595ab-cf75-4872-8182-f20fe8c95611" width="320" controls loop></video>
         | 
| 57 | 
            +
                  </td>
         | 
| 58 | 
            +
                  <td>
         | 
| 59 | 
            +
                      <video src="https://github.com/user-attachments/assets/6476f9f0-35e0-4484-91a4-8aa646aa994a" width="320" controls loop></video>
         | 
| 60 | 
            +
                  </td>
         | 
| 61 | 
            +
                   <td>
         | 
| 62 | 
            +
                      <video src="https://github.com/user-attachments/assets/d8fc8e94-0cba-4c25-9f3a-a8d7e0a785e1" width="320" controls loop></video>
         | 
| 63 | 
            +
                 </td>
         | 
| 64 | 
            +
              </tr>
         | 
| 65 | 
            +
            </table>
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            ## ✨ Key Features
         | 
| 72 | 
            +
             | 
| 73 | 
            +
            We propose **MultiTalk** , a novel framework for audio-driven multi-person conversational video generation. Given a multi-stream audio input, a reference image and a prompt, MultiTalk generates a video containing interactions following the prompt, with consistent lip motions aligned with the audio.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            > - 💬 **Realistic Conversations** - Support single & multi-person generation
         | 
| 76 | 
            +
            > - 👥 **Interactive Character Control** - Direct virtual humans via prompts
         | 
| 77 | 
            +
            > - 🎤 **Generalization Performances** - Support the generation of cartoon character and singing 
         | 
| 78 | 
            +
            > - 📺 **Resolution Flexibility**: 480p & 720p output at arbitrary aspect ratios
         | 
| 79 | 
            +
            > - ⏱️ **Long Video Generation**: Support video generation up to 15 seconds
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            ## 🔥 Latest News
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            * June 14, 2025: 🔥🔥 We release `MultiTalk` with support for `multi-GPU inference`, `teacache acceleration`, `APG` and `low-VRAM inference` (enabling 480P video generation on a single RTX 4090). [APG](https://arxiv.org/abs/2410.02416) is used to alleviate the color error accumulation in long video generation. TeaCache is capable of increasing speed by approximately 2~3x.
         | 
| 84 | 
            +
            * June 9, 2025: 🔥🔥 We release the [weights](https://huggingface.co/MeiGen-AI/MeiGen-MultiTalk) and inference code of **MultiTalk** 
         | 
| 85 | 
            +
            * May 29, 2025: We release the [Technique-Report](https://arxiv.org/abs/2505.22647) of **MultiTalk** 
         | 
| 86 | 
            +
            * May 29, 2025: We release the [project page](https://meigen-ai.github.io/multi-talk/) of **MultiTalk** 
         | 
| 87 | 
            +
             | 
| 88 | 
            +
            ## 🌐 Community  Works
         | 
| 89 | 
            +
            - [ComfyUI](https://github.com/kijai/ComfyUI-WanVideoWrapper/tree/multitalk): thanks [kijai](https://github.com/kijai) for integrating MultiTalk into ComfyUI-WanVideoWrapper. [Rudra](https://github.com/Rudra-ai-coder) found something interesting that MultiTalk can be combined with Wanx T2V and VACE in the [issue](https://github.com/kijai/ComfyUI-WanVideoWrapper/issues/635). 
         | 
| 90 | 
            +
            - [Google Colab example](https://colab.research.google.com/drive/185OyRIpJDlpnRjhBRb7FnaRlq11BLZTa?usp=sharing), an exmaple for inference on A100 provided by [Braffolk](https://github.com/Braffolk).
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            ## 📑 Todo List
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            - [x] Release the technical report
         | 
| 95 | 
            +
            - [x] Inference
         | 
| 96 | 
            +
            - [x] Checkpoints
         | 
| 97 | 
            +
            - [x] Multi-GPU Inference
         | 
| 98 | 
            +
            - [ ] Inference acceleration
         | 
| 99 | 
            +
              - [x] TeaCache
         | 
| 100 | 
            +
              - [ ] int8 quantization
         | 
| 101 | 
            +
              - [ ] LCM distillation
         | 
| 102 | 
            +
              - [ ] Sparse Attention
         | 
| 103 | 
            +
            - [x] Run with very low VRAM
         | 
| 104 | 
            +
            - [ ] TTS integration
         | 
| 105 | 
            +
            - [ ] Gradio demo
         | 
| 106 | 
            +
            - [ ] ComfyUI
         | 
| 107 | 
            +
            - [ ] 1.3B model
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            ## Quick Start
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            ### 🛠️Installation
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            #### 1. Create a conda environment and install pytorch, xformers
         | 
| 114 | 
            +
            ```
         | 
| 115 | 
            +
            conda create -n multitalk python=3.10
         | 
| 116 | 
            +
            conda activate multitalk
         | 
| 117 | 
            +
            pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
         | 
| 118 | 
            +
            pip install -U xformers==0.0.28 --index-url https://download.pytorch.org/whl/cu121
         | 
| 119 | 
            +
            ```
         | 
| 120 | 
            +
            #### 2. Flash-attn installation:
         | 
| 121 | 
            +
            ```
         | 
| 122 | 
            +
            pip install ninja 
         | 
| 123 | 
            +
            pip install psutil 
         | 
| 124 | 
            +
            pip install packaging 
         | 
| 125 | 
            +
            pip install flash_attn
         | 
| 126 | 
            +
            ```
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            #### 3. Other dependencies
         | 
| 129 | 
            +
            ```
         | 
| 130 | 
            +
            pip install -r requirements.txt
         | 
| 131 | 
            +
            conda install -c conda-forge librosa
         | 
| 132 | 
            +
            ```
         | 
| 133 | 
            +
             | 
| 134 | 
            +
            #### 4. FFmeg installation
         | 
| 135 | 
            +
            ```
         | 
| 136 | 
            +
            conda install -c conda-forge ffmpeg
         | 
| 137 | 
            +
            ```
         | 
| 138 | 
            +
            or
         | 
| 139 | 
            +
            ```
         | 
| 140 | 
            +
            sudo yum install ffmpeg ffmpeg-devel
         | 
| 141 | 
            +
            ```
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            ### 🧱Model Preparation
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            #### 1. Model Download
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            | Models        |                       Download Link                                           |    Notes                      |
         | 
| 148 | 
            +
            | --------------|-------------------------------------------------------------------------------|-------------------------------|
         | 
| 149 | 
            +
            | Wan2.1-I2V-14B-480P  |      🤗 [Huggingface](https://huggingface.co/Wan-AI/Wan2.1-I2V-14B-480P)       | Base model
         | 
| 150 | 
            +
            | chinese-wav2vec2-base |      🤗 [Huggingface](https://huggingface.co/TencentGameMate/chinese-wav2vec2-base)          | Audio encoder
         | 
| 151 | 
            +
            | MeiGen-MultiTalk      |      🤗 [Huggingface](https://huggingface.co/MeiGen-AI/MeiGen-MultiTalk)              | Our audio condition weights
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            Download models using huggingface-cli:
         | 
| 154 | 
            +
            ``` sh
         | 
| 155 | 
            +
            huggingface-cli download Wan-AI/Wan2.1-I2V-14B-480P --local-dir ./weights/Wan2.1-I2V-14B-480P
         | 
| 156 | 
            +
            huggingface-cli download TencentGameMate/chinese-wav2vec2-base --local-dir ./weights/chinese-wav2vec2-base
         | 
| 157 | 
            +
            huggingface-cli download TencentGameMate/chinese-wav2vec2-base model.safetensors --revision refs/pr/1 --local-dir ./weights/chinese-wav2vec2-base
         | 
| 158 | 
            +
            huggingface-cli download MeiGen-AI/MeiGen-MultiTalk --local-dir ./weights/MeiGen-MultiTalk
         | 
| 159 | 
            +
            ```
         | 
| 160 | 
            +
             | 
| 161 | 
            +
            #### 2. Link or Copy MultiTalk Model to Wan2.1-I2V-14B-480P Directory
         | 
| 162 | 
            +
             | 
| 163 | 
            +
            Link through:
         | 
| 164 | 
            +
            ```
         | 
| 165 | 
            +
            mv weights/Wan2.1-I2V-14B-480P/diffusion_pytorch_model.safetensors.index.json weights/Wan2.1-I2V-14B-480P/diffusion_pytorch_model.safetensors.index.json_old
         | 
| 166 | 
            +
            sudo ln -s {Absolute path}/weights/MeiGen-MultiTalk/diffusion_pytorch_model.safetensors.index.json weights/Wan2.1-I2V-14B-480P/
         | 
| 167 | 
            +
            sudo ln -s {Absolute path}/weights/MeiGen-MultiTalk/multitalk.safetensors weights/Wan2.1-I2V-14B-480P/
         | 
| 168 | 
            +
            ```
         | 
| 169 | 
            +
             | 
| 170 | 
            +
            Or, copy through:
         | 
| 171 | 
            +
            ```
         | 
| 172 | 
            +
            mv weights/Wan2.1-I2V-14B-480P/diffusion_pytorch_model.safetensors.index.json weights/Wan2.1-I2V-14B-480P/diffusion_pytorch_model.safetensors.index.json_old
         | 
| 173 | 
            +
            cp weights/MeiGen-MultiTalk/diffusion_pytorch_model.safetensors.index.json weights/Wan2.1-I2V-14B-480P/
         | 
| 174 | 
            +
            cp weights/MeiGen-MultiTalk/multitalk.safetensors weights/Wan2.1-I2V-14B-480P/
         | 
| 175 | 
            +
            ```
         | 
| 176 | 
            +
            ### 🔑 Quick Inference
         | 
| 177 | 
            +
             | 
| 178 | 
            +
            Our model is compatible with both 480P and 720P resolutions. The current code only supports 480P inference. 720P inference requires multiple GPUs, and we will provide an update soon.
         | 
| 179 | 
            +
            > Some tips
         | 
| 180 | 
            +
            > - Lip synchronization accuracy: Audio CFG works optimally between 3–5. Increase the audio CFG value for better synchronization.
         | 
| 181 | 
            +
            > - Video clip length: The model was trained on 81-frame videos at 25 FPS. For optimal prompt following performance, generate clips at 81 frames. Generating up to 201 frames is possible, though longer clips might reduce prompt-following performance.
         | 
| 182 | 
            +
            > - Long video generation: Audio CFG influences color tone consistency across segments. Set this value to 3 to alleviate tonal variations.
         | 
| 183 | 
            +
            > - Sampling steps: If you want to generate a video fast, you can decrease the sampling steps to even 10 that will not hurt the lip synchronization accuracy, but affects the motion and visual quality. More sampling steps, better video quality.
         | 
| 184 | 
            +
            > - TeaCache accelerate: The optimal range for `--teacache_thresh` is between 0.2 and 0.5. Increasing this value can further improve acceleration, but may also lead to a decline in the quality of the generated video.
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            #### Usage of MultiTalk
         | 
| 187 | 
            +
            ```
         | 
| 188 | 
            +
            --mode streaming: long video generation.
         | 
| 189 | 
            +
            --mode clip: generate short video with one chunk. 
         | 
| 190 | 
            +
            --use_teacache: run with TeaCache.
         | 
| 191 | 
            +
            --size multitalk-480: generate 480P video.
         | 
| 192 | 
            +
            --size multitalk-720: generate 720P video.
         | 
| 193 | 
            +
            --use_apg: run with APG.
         | 
| 194 | 
            +
            --teacache_thresh: A coefficient used for TeaCache acceleration
         | 
| 195 | 
            +
            ```
         | 
| 196 | 
            +
             | 
| 197 | 
            +
            #### 1. Single-Person
         | 
| 198 | 
            +
             | 
| 199 | 
            +
            ##### 1) Run with single GPU
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
            ```
         | 
| 203 | 
            +
            python generate_multitalk.py \
         | 
| 204 | 
            +
                --ckpt_dir weights/Wan2.1-I2V-14B-480P \
         | 
| 205 | 
            +
                --wav2vec_dir 'weights/chinese-wav2vec2-base' \
         | 
| 206 | 
            +
                --input_json examples/single_example_1.json \
         | 
| 207 | 
            +
                --sample_steps 40 \
         | 
| 208 | 
            +
                --mode streaming \
         | 
| 209 | 
            +
                --use_teacache \
         | 
| 210 | 
            +
                --save_file single_long_exp
         | 
| 211 | 
            +
             | 
| 212 | 
            +
            ```
         | 
| 213 | 
            +
             | 
| 214 | 
            +
            ##### 2) Run with very low VRAM
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            If you want run with very low VRAM, set `--num_persistent_param_in_dit 0`:
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
            ```
         | 
| 220 | 
            +
            python generate_multitalk.py \
         | 
| 221 | 
            +
                --ckpt_dir weights/Wan2.1-I2V-14B-480P \
         | 
| 222 | 
            +
                --wav2vec_dir 'weights/chinese-wav2vec2-base' \
         | 
| 223 | 
            +
                --input_json examples/single_example_1.json \
         | 
| 224 | 
            +
                --sample_steps 40 \
         | 
| 225 | 
            +
                --mode streaming \
         | 
| 226 | 
            +
                --num_persistent_param_in_dit 0 \
         | 
| 227 | 
            +
                --use_teacache \
         | 
| 228 | 
            +
                --save_file single_long_lowvram_exp
         | 
| 229 | 
            +
             | 
| 230 | 
            +
            ```
         | 
| 231 | 
            +
             | 
| 232 | 
            +
            ##### 3) Multi-GPU inference
         | 
| 233 | 
            +
             | 
| 234 | 
            +
            ```
         | 
| 235 | 
            +
            GPU_NUM=8
         | 
| 236 | 
            +
            torchrun --nproc_per_node=$GPU_NUM --standalone generate_multitalk.py \
         | 
| 237 | 
            +
                --ckpt_dir weights/Wan2.1-I2V-14B-480P \
         | 
| 238 | 
            +
                --wav2vec_dir 'weights/chinese-wav2vec2-base' \
         | 
| 239 | 
            +
                --dit_fsdp --t5_fsdp \
         | 
| 240 | 
            +
                --ulysses_size=$GPU_NUM \
         | 
| 241 | 
            +
                --input_json examples/single_example_1.json \
         | 
| 242 | 
            +
                --sample_steps 40 \
         | 
| 243 | 
            +
                --mode streaming \
         | 
| 244 | 
            +
                --use_teacache \
         | 
| 245 | 
            +
                --save_file single_long_multigpu_exp
         | 
| 246 | 
            +
             | 
| 247 | 
            +
            ```
         | 
| 248 | 
            +
             | 
| 249 | 
            +
             | 
| 250 | 
            +
             | 
| 251 | 
            +
            #### 2. Multi-Person
         | 
| 252 | 
            +
             | 
| 253 | 
            +
            ##### 1) Run with single GPU
         | 
| 254 | 
            +
             | 
| 255 | 
            +
            ```
         | 
| 256 | 
            +
            python generate_multitalk.py \
         | 
| 257 | 
            +
                --ckpt_dir weights/Wan2.1-I2V-14B-480P \
         | 
| 258 | 
            +
                --wav2vec_dir 'weights/chinese-wav2vec2-base' \
         | 
| 259 | 
            +
                --input_json examples/multitalk_example_2.json \
         | 
| 260 | 
            +
                --sample_steps 40 \
         | 
| 261 | 
            +
                --mode streaming \
         | 
| 262 | 
            +
                --use_teacache \
         | 
| 263 | 
            +
                --save_file multi_long_exp
         | 
| 264 | 
            +
            ```
         | 
| 265 | 
            +
            ##### 2) Run with very low VRAM
         | 
| 266 | 
            +
             | 
| 267 | 
            +
             | 
| 268 | 
            +
            ```
         | 
| 269 | 
            +
            python generate_multitalk.py \
         | 
| 270 | 
            +
                --ckpt_dir weights/Wan2.1-I2V-14B-480P \
         | 
| 271 | 
            +
                --wav2vec_dir 'weights/chinese-wav2vec2-base' \
         | 
| 272 | 
            +
                --input_json examples/multitalk_example_2.json \
         | 
| 273 | 
            +
                --sample_steps 40 \
         | 
| 274 | 
            +
                --mode streaming \
         | 
| 275 | 
            +
                --num_persistent_param_in_dit 0 \
         | 
| 276 | 
            +
                --use_teacache \
         | 
| 277 | 
            +
                --save_file multi_long_lowvram_exp
         | 
| 278 | 
            +
             | 
| 279 | 
            +
            ```
         | 
| 280 | 
            +
             | 
| 281 | 
            +
            ##### 3) Multi-GPU inference
         | 
| 282 | 
            +
             | 
| 283 | 
            +
            ```
         | 
| 284 | 
            +
            GPU_NUM=8
         | 
| 285 | 
            +
            torchrun --nproc_per_node=$GPU_NUM --standalone generate_multitalk.py \
         | 
| 286 | 
            +
                --ckpt_dir weights/Wan2.1-I2V-14B-480P \
         | 
| 287 | 
            +
                --wav2vec_dir 'weights/chinese-wav2vec2-base' \
         | 
| 288 | 
            +
                --dit_fsdp --t5_fsdp --ulysses_size=$GPU_NUM \
         | 
| 289 | 
            +
                --input_json examples/multitalk_example_2.json \
         | 
| 290 | 
            +
                --sample_steps 40 \
         | 
| 291 | 
            +
                --mode streaming --use_teacache \
         | 
| 292 | 
            +
                --save_file multi_long_multigpu_exp
         | 
| 293 | 
            +
             | 
| 294 | 
            +
            ```
         | 
| 295 | 
            +
             | 
| 296 | 
            +
            ## 🚀Computational Efficiency
         | 
| 297 | 
            +
            The results are evaluated on A100 GPUs for multi-person generation. Single-person generation uses less memory and provides faster inference.
         | 
| 298 | 
            +
            <p align="center">
         | 
| 299 | 
            +
              <img src="assets/efficiency.png">
         | 
| 300 | 
            +
            </p>
         | 
| 301 | 
            +
            TeaCache is capable of increasing speed by approximately 2~3x.
         | 
| 302 | 
            +
             | 
| 303 | 
            +
             | 
| 304 | 
            +
            ## 📚 Citation
         | 
| 305 | 
            +
             | 
| 306 | 
            +
            If you find our work useful in your research, please consider citing:
         | 
| 307 | 
            +
             | 
| 308 | 
            +
            ```
         | 
| 309 | 
            +
            @article{kong2025let,
         | 
| 310 | 
            +
              title={Let Them Talk: Audio-Driven Multi-Person Conversational Video Generation},
         | 
| 311 | 
            +
              author={Kong, Zhe and Gao, Feng and Zhang, Yong and Kang, Zhuoliang and Wei, Xiaoming and Cai, Xunliang and Chen, Guanying and Luo, Wenhan},
         | 
| 312 | 
            +
              journal={arXiv preprint arXiv:2505.22647},
         | 
| 313 | 
            +
              year={2025}
         | 
| 314 | 
            +
            }
         | 
| 315 | 
            +
            ```
         | 
| 316 | 
            +
             | 
| 317 | 
            +
            ## 📜 License
         | 
| 318 | 
            +
            The models in this repository are licensed under the Apache 2.0 License. We claim no rights over the your generated contents, 
         | 
| 319 | 
            +
            granting you the freedom to use them while ensuring that your usage complies with the provisions of this license. 
         | 
| 320 | 
            +
            You are fully accountable for your use of the models, which must not involve sharing any content that violates applicable laws, 
         | 
| 321 | 
            +
            causes harm to individuals or groups, disseminates personal information intended for harm, spreads misinformation, or targets vulnerable populations. 
         | 
| 322 | 
            +
             | 
    	
        assets/efficiency.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/logo.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/logo2.jpeg
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        assets/pipe.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/multi/1/1.WAV
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8397a9b3c0add26384afe7e544e36cbc4806d8f2d7c705e11bb2897dc1bc993b
         | 
| 3 | 
            +
            size 315436
         | 
    	
        examples/multi/1/2.WAV
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:753120ceadbdab3ce206423a1419f73018695682787414ca2f4613306be50bfc
         | 
| 3 | 
            +
            size 544812
         | 
    	
        examples/multi/1/multi1.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/multi/2/1.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:51eb6a408a8b5b33a732378e2a38e7412ba273186b85c324ec6a099d23fe38af
         | 
| 3 | 
            +
            size 1273592
         | 
    	
        examples/multi/2/multi2.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/multi/3/1-man.WAV
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:d304fd88850d6673649d1844db2894e03bf5a775123048eebcb01ab3b79bff5e
         | 
| 3 | 
            +
            size 1503276
         | 
    	
        examples/multi/3/1-woman.WAV
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:3e1ebd7ae1587ebc7f0986f8b61e7fcc99c6fb57fbb15ab9373968e701afc8bf
         | 
| 3 | 
            +
            size 1503276
         | 
    	
        examples/multi/3/multi3.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/multitalk_example_1.json
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "prompt": "In a casual, intimate setting, a man and a woman are engaged in a heartfelt conversation inside a car. The man, sporting a denim jacket over a blue shirt, sits attentively with a seatbelt fastened, his gaze fixed on the woman beside him. The woman, wearing a black tank top and a denim jacket draped over her shoulders, smiles warmly, her eyes reflecting genuine interest and connection. The car's interior, with its beige seats and simple design, provides a backdrop that emphasizes their interaction. The scene captures a moment of shared understanding and connection, set against the soft, diffused light of an overcast day. A medium shot from a slightly angled perspective, focusing on their expressions and body language.",
         | 
| 3 | 
            +
                "cond_image": "examples/multi/1/multi1.png",
         | 
| 4 | 
            +
                "audio_type": "add",
         | 
| 5 | 
            +
                "cond_audio": {
         | 
| 6 | 
            +
                    "person1": "examples/multi/1/1.WAV",
         | 
| 7 | 
            +
                    "person2": "examples/multi/1/2.WAV"
         | 
| 8 | 
            +
                },
         | 
| 9 | 
            +
                "bbox": {
         | 
| 10 | 
            +
                    "person1": [160, 120, 1280, 1080], 
         | 
| 11 | 
            +
                    "person2": [160, 1320, 1280, 2280]
         | 
| 12 | 
            +
                }
         | 
| 13 | 
            +
            }
         | 
    	
        examples/multitalk_example_2.json
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "prompt": "In a cozy recording studio, a man and a woman are singing together with passion and emotion. The man, with short brown hair, wears a light gray button-up shirt, his expression filled with concentration and warmth. The woman, with long wavy brown hair, dons a sleeveless dress adorned with small polka dots, her eyes closed as she belts out a heartfelt melody. The studio is equipped with professional microphones, and the background features soundproofing panels, creating an intimate and focused atmosphere. A close-up shot captures their expressions and the intensity of their performance.",
         | 
| 3 | 
            +
                "cond_image": "examples/multi/2/multi2.png",
         | 
| 4 | 
            +
                "audio_type": "para",
         | 
| 5 | 
            +
                "cond_audio": {
         | 
| 6 | 
            +
                    "person1": "examples/multi/2/1.wav",
         | 
| 7 | 
            +
                    "person2": "examples/multi/2/1.wav"
         | 
| 8 | 
            +
                }
         | 
| 9 | 
            +
            }
         | 
    	
        examples/multitalk_example_3.json
    ADDED
    
    | @@ -0,0 +1,9 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "prompt": "In a cozy recording studio, a man and a woman are singing together. The man, with tousled brown hair, stands to the left, wearing a light green button-down shirt. His gaze is directed towards the woman, who is smiling warmly. She, with wavy dark hair, is dressed in a black floral dress and stands to the right, her eyes closed in enjoyment. Between them is a professional microphone, capturing their harmonious voices. The background features wooden panels and various audio equipment, creating an intimate and focused atmosphere. The lighting is soft and warm, highlighting their expressions and the intimate setting. A medium shot captures their interaction closely.",
         | 
| 3 | 
            +
                "cond_image": "examples/multi/3/multi3.png",
         | 
| 4 | 
            +
                "audio_type": "para",
         | 
| 5 | 
            +
                "cond_audio": {
         | 
| 6 | 
            +
                    "person1": "examples/multi/3/1-man.WAV",
         | 
| 7 | 
            +
                    "person2": "examples/multi/3/1-woman.WAV"
         | 
| 8 | 
            +
                }
         | 
| 9 | 
            +
            }
         | 
    	
        examples/single/1.wav
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:ba2733897f561f747e6508734bff4eeee29d0a73638e5c39c0c0b806701d4e8b
         | 
| 3 | 
            +
            size 1888320
         | 
    	
        examples/single/single1.png
    ADDED
    
    |   | 
| Git LFS Details
 | 
    	
        examples/single_example_1.json
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
                "prompt": "A woman is passionately singing into a professional microphone in a recording studio. She wears large black headphones and a dark cardigan over a gray top. Her long, wavy brown hair frames her face as she looks slightly upwards, her mouth open mid-song. The studio is equipped with various audio equipment, including a mixing console and a keyboard, with soundproofing panels on the walls. The lighting is warm and focused on her, creating a professional and intimate atmosphere. A close-up shot captures her expressive performance.",
         | 
| 3 | 
            +
                "cond_image": "examples/single/single1.png",
         | 
| 4 | 
            +
                "cond_audio": {
         | 
| 5 | 
            +
                    "person1": "examples/single/1.wav"
         | 
| 6 | 
            +
                }
         | 
| 7 | 
            +
            }
         | 
    	
        generate_multitalk.py
    ADDED
    
    | @@ -0,0 +1,500 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import argparse
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import sys
         | 
| 6 | 
            +
            import json
         | 
| 7 | 
            +
            import warnings
         | 
| 8 | 
            +
            from datetime import datetime
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            warnings.filterwarnings('ignore')
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import random
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            import torch
         | 
| 15 | 
            +
            import torch.distributed as dist
         | 
| 16 | 
            +
            from PIL import Image
         | 
| 17 | 
            +
            import subprocess
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            import wan
         | 
| 20 | 
            +
            from wan.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
         | 
| 21 | 
            +
            from wan.utils.utils import cache_image, cache_video, str2bool
         | 
| 22 | 
            +
            from wan.utils.multitalk_utils import save_video_ffmpeg
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            from transformers import Wav2Vec2FeatureExtractor
         | 
| 25 | 
            +
            from src.audio_analysis.wav2vec2 import Wav2Vec2Model
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            import librosa
         | 
| 28 | 
            +
            import pyloudnorm as pyln
         | 
| 29 | 
            +
            import numpy as np
         | 
| 30 | 
            +
            from einops import rearrange
         | 
| 31 | 
            +
            import soundfile as sf
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def _validate_args(args):
         | 
| 34 | 
            +
                # Basic check
         | 
| 35 | 
            +
                assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
         | 
| 36 | 
            +
                assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
         | 
| 39 | 
            +
                if args.sample_steps is None:
         | 
| 40 | 
            +
                    args.sample_steps = 40
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                if args.sample_shift is None:
         | 
| 43 | 
            +
                    if args.size == 'multitalk-480':
         | 
| 44 | 
            +
                        args.sample_shift = 7
         | 
| 45 | 
            +
                    elif args.size == 'multitalk-720':
         | 
| 46 | 
            +
                        args.sample_shift = 11
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        raise NotImplementedError(f'Not supported size')
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
         | 
| 51 | 
            +
                    0, 99999999)
         | 
| 52 | 
            +
                # Size check
         | 
| 53 | 
            +
                assert args.size in SUPPORTED_SIZES[
         | 
| 54 | 
            +
                    args.
         | 
| 55 | 
            +
                    task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
             | 
| 58 | 
            +
            def _parse_args():
         | 
| 59 | 
            +
                parser = argparse.ArgumentParser(
         | 
| 60 | 
            +
                    description="Generate a image or video from a text prompt or image using Wan"
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
                parser.add_argument(
         | 
| 63 | 
            +
                    "--task",
         | 
| 64 | 
            +
                    type=str,
         | 
| 65 | 
            +
                    default="multitalk-14B",
         | 
| 66 | 
            +
                    choices=list(WAN_CONFIGS.keys()),
         | 
| 67 | 
            +
                    help="The task to run.")
         | 
| 68 | 
            +
                parser.add_argument(
         | 
| 69 | 
            +
                    "--size",
         | 
| 70 | 
            +
                    type=str,
         | 
| 71 | 
            +
                    default="multitalk-480",
         | 
| 72 | 
            +
                    choices=list(SIZE_CONFIGS.keys()),
         | 
| 73 | 
            +
                    help="The buckget size of the generated video. The aspect ratio of the output video will follow that of the input image."
         | 
| 74 | 
            +
                )
         | 
| 75 | 
            +
                parser.add_argument(
         | 
| 76 | 
            +
                    "--frame_num",
         | 
| 77 | 
            +
                    type=int,
         | 
| 78 | 
            +
                    default=81,
         | 
| 79 | 
            +
                    help="How many frames to be generated in one clip. The number should be 4n+1"
         | 
| 80 | 
            +
                )
         | 
| 81 | 
            +
                parser.add_argument(
         | 
| 82 | 
            +
                    "--ckpt_dir",
         | 
| 83 | 
            +
                    type=str,
         | 
| 84 | 
            +
                    default=None,
         | 
| 85 | 
            +
                    help="The path to the Wan checkpoint directory.")
         | 
| 86 | 
            +
                parser.add_argument(
         | 
| 87 | 
            +
                    "--wav2vec_dir",
         | 
| 88 | 
            +
                    type=str,
         | 
| 89 | 
            +
                    default=None,
         | 
| 90 | 
            +
                    help="The path to the wav2vec checkpoint directory.")
         | 
| 91 | 
            +
                parser.add_argument(
         | 
| 92 | 
            +
                    "--offload_model",
         | 
| 93 | 
            +
                    type=str2bool,
         | 
| 94 | 
            +
                    default=None,
         | 
| 95 | 
            +
                    help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
         | 
| 96 | 
            +
                )
         | 
| 97 | 
            +
                parser.add_argument(
         | 
| 98 | 
            +
                    "--ulysses_size",
         | 
| 99 | 
            +
                    type=int,
         | 
| 100 | 
            +
                    default=1,
         | 
| 101 | 
            +
                    help="The size of the ulysses parallelism in DiT.")
         | 
| 102 | 
            +
                parser.add_argument(
         | 
| 103 | 
            +
                    "--ring_size",
         | 
| 104 | 
            +
                    type=int,
         | 
| 105 | 
            +
                    default=1,
         | 
| 106 | 
            +
                    help="The size of the ring attention parallelism in DiT.")
         | 
| 107 | 
            +
                parser.add_argument(
         | 
| 108 | 
            +
                    "--t5_fsdp",
         | 
| 109 | 
            +
                    action="store_true",
         | 
| 110 | 
            +
                    default=False,
         | 
| 111 | 
            +
                    help="Whether to use FSDP for T5.")
         | 
| 112 | 
            +
                parser.add_argument(
         | 
| 113 | 
            +
                    "--t5_cpu",
         | 
| 114 | 
            +
                    action="store_true",
         | 
| 115 | 
            +
                    default=False,
         | 
| 116 | 
            +
                    help="Whether to place T5 model on CPU.")
         | 
| 117 | 
            +
                parser.add_argument(
         | 
| 118 | 
            +
                    "--dit_fsdp",
         | 
| 119 | 
            +
                    action="store_true",
         | 
| 120 | 
            +
                    default=False,
         | 
| 121 | 
            +
                    help="Whether to use FSDP for DiT.")
         | 
| 122 | 
            +
                parser.add_argument(
         | 
| 123 | 
            +
                    "--save_file",
         | 
| 124 | 
            +
                    type=str,
         | 
| 125 | 
            +
                    default=None,
         | 
| 126 | 
            +
                    help="The file to save the generated image or video to.")
         | 
| 127 | 
            +
                parser.add_argument(
         | 
| 128 | 
            +
                    "--audio_save_dir",
         | 
| 129 | 
            +
                    type=str,
         | 
| 130 | 
            +
                    default='save_audio',
         | 
| 131 | 
            +
                    help="The path to save the audio embedding.")
         | 
| 132 | 
            +
                parser.add_argument(
         | 
| 133 | 
            +
                    "--base_seed",
         | 
| 134 | 
            +
                    type=int,
         | 
| 135 | 
            +
                    default=42,
         | 
| 136 | 
            +
                    help="The seed to use for generating the image or video.")
         | 
| 137 | 
            +
                parser.add_argument(
         | 
| 138 | 
            +
                    "--input_json",
         | 
| 139 | 
            +
                    type=str,
         | 
| 140 | 
            +
                    default='examples.json',
         | 
| 141 | 
            +
                    help="[meta file] The condition path to generate the video.")
         | 
| 142 | 
            +
                parser.add_argument(
         | 
| 143 | 
            +
                    "--motion_frame",
         | 
| 144 | 
            +
                    type=int,
         | 
| 145 | 
            +
                    default=25,
         | 
| 146 | 
            +
                    help="Driven frame length used in the mode of long video genration.")
         | 
| 147 | 
            +
                parser.add_argument(
         | 
| 148 | 
            +
                    "--mode",
         | 
| 149 | 
            +
                    type=str,
         | 
| 150 | 
            +
                    default="clip",
         | 
| 151 | 
            +
                    choices=['clip', 'streaming'],
         | 
| 152 | 
            +
                    help="clip: generate one video chunk, streaming: long video generation")
         | 
| 153 | 
            +
                parser.add_argument(
         | 
| 154 | 
            +
                    "--sample_steps", type=int, default=None, help="The sampling steps.")
         | 
| 155 | 
            +
                parser.add_argument(
         | 
| 156 | 
            +
                    "--sample_shift",
         | 
| 157 | 
            +
                    type=float,
         | 
| 158 | 
            +
                    default=None,
         | 
| 159 | 
            +
                    help="Sampling shift factor for flow matching schedulers.")
         | 
| 160 | 
            +
                parser.add_argument(
         | 
| 161 | 
            +
                    "--sample_text_guide_scale",
         | 
| 162 | 
            +
                    type=float,
         | 
| 163 | 
            +
                    default=5.0,
         | 
| 164 | 
            +
                    help="Classifier free guidance scale for text control.")
         | 
| 165 | 
            +
                parser.add_argument(
         | 
| 166 | 
            +
                    "--sample_audio_guide_scale",
         | 
| 167 | 
            +
                    type=float,
         | 
| 168 | 
            +
                    default=4.0,
         | 
| 169 | 
            +
                    help="Classifier free guidance scale for audio control.")
         | 
| 170 | 
            +
                parser.add_argument(
         | 
| 171 | 
            +
                    "--num_persistent_param_in_dit",
         | 
| 172 | 
            +
                    type=int,
         | 
| 173 | 
            +
                    default=None,
         | 
| 174 | 
            +
                    required=False,
         | 
| 175 | 
            +
                    help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required",
         | 
| 176 | 
            +
                )
         | 
| 177 | 
            +
                parser.add_argument(
         | 
| 178 | 
            +
                    "--use_teacache",
         | 
| 179 | 
            +
                    action="store_true",
         | 
| 180 | 
            +
                    default=False,
         | 
| 181 | 
            +
                    help="Enable teacache for video generation."
         | 
| 182 | 
            +
                )
         | 
| 183 | 
            +
                parser.add_argument(
         | 
| 184 | 
            +
                    "--teacache_thresh",
         | 
| 185 | 
            +
                    type=float,
         | 
| 186 | 
            +
                    default=0.2,
         | 
| 187 | 
            +
                    help="Threshold for teacache."
         | 
| 188 | 
            +
                )
         | 
| 189 | 
            +
                parser.add_argument(
         | 
| 190 | 
            +
                    "--use_apg",
         | 
| 191 | 
            +
                    action="store_true",
         | 
| 192 | 
            +
                    default=False,
         | 
| 193 | 
            +
                    help="Enable adaptive projected guidance for video generation (APG)."
         | 
| 194 | 
            +
                )
         | 
| 195 | 
            +
                parser.add_argument(
         | 
| 196 | 
            +
                    "--apg_momentum",
         | 
| 197 | 
            +
                    type=float,
         | 
| 198 | 
            +
                    default=-0.75,
         | 
| 199 | 
            +
                    help="Momentum used in adaptive projected guidance (APG)."
         | 
| 200 | 
            +
                )
         | 
| 201 | 
            +
                parser.add_argument(
         | 
| 202 | 
            +
                    "--apg_norm_threshold",
         | 
| 203 | 
            +
                    type=float,
         | 
| 204 | 
            +
                    default=55,
         | 
| 205 | 
            +
                    help="Norm threshold used in adaptive projected guidance (APG)."
         | 
| 206 | 
            +
                )
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                
         | 
| 209 | 
            +
                args = parser.parse_args()
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                _validate_args(args)
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                return args
         | 
| 214 | 
            +
             | 
| 215 | 
            +
            def custom_init(device, wav2vec):    
         | 
| 216 | 
            +
                audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device)
         | 
| 217 | 
            +
                audio_encoder.feature_extractor._freeze_parameters()
         | 
| 218 | 
            +
                wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True)
         | 
| 219 | 
            +
                return wav2vec_feature_extractor, audio_encoder
         | 
| 220 | 
            +
             | 
| 221 | 
            +
            def loudness_norm(audio_array, sr=16000, lufs=-23):
         | 
| 222 | 
            +
                meter = pyln.Meter(sr)
         | 
| 223 | 
            +
                loudness = meter.integrated_loudness(audio_array)
         | 
| 224 | 
            +
                if abs(loudness) > 100:
         | 
| 225 | 
            +
                    return audio_array
         | 
| 226 | 
            +
                normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs)
         | 
| 227 | 
            +
                return normalized_audio
         | 
| 228 | 
            +
             | 
| 229 | 
            +
            def audio_prepare_multi(left_path, right_path, audio_type, sample_rate=16000):
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                if not (left_path=='None' or right_path=='None'):
         | 
| 232 | 
            +
                    human_speech_array1 = audio_prepare_single(left_path)
         | 
| 233 | 
            +
                    human_speech_array2 = audio_prepare_single(right_path)
         | 
| 234 | 
            +
                elif left_path=='None':
         | 
| 235 | 
            +
                    human_speech_array2 = audio_prepare_single(right_path)
         | 
| 236 | 
            +
                    human_speech_array1 = np.zeros(human_speech_array2.shape[0])
         | 
| 237 | 
            +
                elif right_path=='None':
         | 
| 238 | 
            +
                    human_speech_array1 = audio_prepare_single(left_path)
         | 
| 239 | 
            +
                    human_speech_array2 = np.zeros(human_speech_array1.shape[0])
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                if audio_type=='para':
         | 
| 242 | 
            +
                    new_human_speech1 = human_speech_array1
         | 
| 243 | 
            +
                    new_human_speech2 = human_speech_array2
         | 
| 244 | 
            +
                elif audio_type=='add':
         | 
| 245 | 
            +
                    new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) 
         | 
| 246 | 
            +
                    new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]])
         | 
| 247 | 
            +
                sum_human_speechs = new_human_speech1 + new_human_speech2
         | 
| 248 | 
            +
                return new_human_speech1, new_human_speech2, sum_human_speechs
         | 
| 249 | 
            +
             | 
| 250 | 
            +
            def _init_logging(rank):
         | 
| 251 | 
            +
                # logging
         | 
| 252 | 
            +
                if rank == 0:
         | 
| 253 | 
            +
                    # set format
         | 
| 254 | 
            +
                    logging.basicConfig(
         | 
| 255 | 
            +
                        level=logging.INFO,
         | 
| 256 | 
            +
                        format="[%(asctime)s] %(levelname)s: %(message)s",
         | 
| 257 | 
            +
                        handlers=[logging.StreamHandler(stream=sys.stdout)])
         | 
| 258 | 
            +
                else:
         | 
| 259 | 
            +
                    logging.basicConfig(level=logging.ERROR)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
            def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu'):
         | 
| 262 | 
            +
                audio_duration = len(speech_array) / sr
         | 
| 263 | 
            +
                video_length = audio_duration * 25 # Assume the video fps is 25
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                # wav2vec_feature_extractor
         | 
| 266 | 
            +
                audio_feature = np.squeeze(
         | 
| 267 | 
            +
                    wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
         | 
| 268 | 
            +
                )
         | 
| 269 | 
            +
                audio_feature = torch.from_numpy(audio_feature).float().to(device=device)
         | 
| 270 | 
            +
                audio_feature = audio_feature.unsqueeze(0)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                # audio encoder
         | 
| 273 | 
            +
                with torch.no_grad():
         | 
| 274 | 
            +
                    embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                if len(embeddings) == 0:
         | 
| 277 | 
            +
                    print("Fail to extract audio embedding")
         | 
| 278 | 
            +
                    return None
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
         | 
| 281 | 
            +
                audio_emb = rearrange(audio_emb, "b s d -> s b d")
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                audio_emb = audio_emb.cpu().detach()
         | 
| 284 | 
            +
                return audio_emb
         | 
| 285 | 
            +
             | 
| 286 | 
            +
            def extract_audio_from_video(filename, sample_rate):
         | 
| 287 | 
            +
                raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav'
         | 
| 288 | 
            +
                ffmpeg_command = [
         | 
| 289 | 
            +
                    "ffmpeg",
         | 
| 290 | 
            +
                    "-y",
         | 
| 291 | 
            +
                    "-i",
         | 
| 292 | 
            +
                    str(filename),
         | 
| 293 | 
            +
                    "-vn",
         | 
| 294 | 
            +
                    "-acodec",
         | 
| 295 | 
            +
                    "pcm_s16le",
         | 
| 296 | 
            +
                    "-ar",
         | 
| 297 | 
            +
                    "16000",
         | 
| 298 | 
            +
                    "-ac",
         | 
| 299 | 
            +
                    "2",
         | 
| 300 | 
            +
                    str(raw_audio_path),
         | 
| 301 | 
            +
                ]
         | 
| 302 | 
            +
                subprocess.run(ffmpeg_command, check=True)
         | 
| 303 | 
            +
                human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate)
         | 
| 304 | 
            +
                human_speech_array = loudness_norm(human_speech_array, sr)
         | 
| 305 | 
            +
                os.remove(raw_audio_path)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                return human_speech_array
         | 
| 308 | 
            +
             | 
| 309 | 
            +
            def audio_prepare_single(audio_path, sample_rate=16000):
         | 
| 310 | 
            +
                ext = os.path.splitext(audio_path)[1].lower()
         | 
| 311 | 
            +
                if ext in ['.mp4', '.mov', '.avi', '.mkv']:
         | 
| 312 | 
            +
                    human_speech_array = extract_audio_from_video(audio_path, sample_rate)
         | 
| 313 | 
            +
                    return human_speech_array
         | 
| 314 | 
            +
                else:
         | 
| 315 | 
            +
                    human_speech_array, sr = librosa.load(audio_path, sr=sample_rate)
         | 
| 316 | 
            +
                    human_speech_array = loudness_norm(human_speech_array, sr)
         | 
| 317 | 
            +
                    return human_speech_array
         | 
| 318 | 
            +
             | 
| 319 | 
            +
            def generate(args):
         | 
| 320 | 
            +
                rank = int(os.getenv("RANK", 0))
         | 
| 321 | 
            +
                world_size = int(os.getenv("WORLD_SIZE", 1))
         | 
| 322 | 
            +
                local_rank = int(os.getenv("LOCAL_RANK", 0))
         | 
| 323 | 
            +
                device = local_rank
         | 
| 324 | 
            +
                _init_logging(rank)
         | 
| 325 | 
            +
             | 
| 326 | 
            +
                if args.offload_model is None:
         | 
| 327 | 
            +
                    args.offload_model = False if world_size > 1 else True
         | 
| 328 | 
            +
                    logging.info(
         | 
| 329 | 
            +
                        f"offload_model is not specified, set to {args.offload_model}.")
         | 
| 330 | 
            +
                if world_size > 1:
         | 
| 331 | 
            +
                    torch.cuda.set_device(local_rank)
         | 
| 332 | 
            +
                    dist.init_process_group(
         | 
| 333 | 
            +
                        backend="nccl",
         | 
| 334 | 
            +
                        init_method="env://",
         | 
| 335 | 
            +
                        rank=rank,
         | 
| 336 | 
            +
                        world_size=world_size)
         | 
| 337 | 
            +
                else:
         | 
| 338 | 
            +
                    assert not (
         | 
| 339 | 
            +
                        args.t5_fsdp or args.dit_fsdp
         | 
| 340 | 
            +
                    ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
         | 
| 341 | 
            +
                    assert not (
         | 
| 342 | 
            +
                        args.ulysses_size > 1 or args.ring_size > 1
         | 
| 343 | 
            +
                    ), f"context parallel are not supported in non-distributed environments."
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                if args.ulysses_size > 1 or args.ring_size > 1:
         | 
| 346 | 
            +
                    assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size."
         | 
| 347 | 
            +
                    from xfuser.core.distributed import (
         | 
| 348 | 
            +
                        init_distributed_environment,
         | 
| 349 | 
            +
                        initialize_model_parallel,
         | 
| 350 | 
            +
                    )
         | 
| 351 | 
            +
                    init_distributed_environment(
         | 
| 352 | 
            +
                        rank=dist.get_rank(), world_size=dist.get_world_size())
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    initialize_model_parallel(
         | 
| 355 | 
            +
                        sequence_parallel_degree=dist.get_world_size(),
         | 
| 356 | 
            +
                        ring_degree=args.ring_size,
         | 
| 357 | 
            +
                        ulysses_degree=args.ulysses_size,
         | 
| 358 | 
            +
                    )
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                # TODO: use prompt refine
         | 
| 361 | 
            +
                # if args.use_prompt_extend:
         | 
| 362 | 
            +
                #     if args.prompt_extend_method == "dashscope":
         | 
| 363 | 
            +
                #         prompt_expander = DashScopePromptExpander(
         | 
| 364 | 
            +
                #             model_name=args.prompt_extend_model,
         | 
| 365 | 
            +
                #             is_vl="i2v" in args.task or "flf2v" in args.task)
         | 
| 366 | 
            +
                #     elif args.prompt_extend_method == "local_qwen":
         | 
| 367 | 
            +
                #         prompt_expander = QwenPromptExpander(
         | 
| 368 | 
            +
                #             model_name=args.prompt_extend_model,
         | 
| 369 | 
            +
                #             is_vl="i2v" in args.task,
         | 
| 370 | 
            +
                #             device=rank)
         | 
| 371 | 
            +
                #     else:
         | 
| 372 | 
            +
                #         raise NotImplementedError(
         | 
| 373 | 
            +
                #             f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                cfg = WAN_CONFIGS[args.task]
         | 
| 376 | 
            +
                if args.ulysses_size > 1:
         | 
| 377 | 
            +
                    assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                logging.info(f"Generation job args: {args}")
         | 
| 380 | 
            +
                logging.info(f"Generation model config: {cfg}")
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                if dist.is_initialized():
         | 
| 383 | 
            +
                    base_seed = [args.base_seed] if rank == 0 else [None]
         | 
| 384 | 
            +
                    dist.broadcast_object_list(base_seed, src=0)
         | 
| 385 | 
            +
                    args.base_seed = base_seed[0]
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                assert args.task == "multitalk-14B", 'You should choose multitalk in args.task.'
         | 
| 388 | 
            +
                
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                # TODO: add prompt refine
         | 
| 391 | 
            +
                # img = Image.open(args.image).convert("RGB")
         | 
| 392 | 
            +
                # if args.use_prompt_extend:
         | 
| 393 | 
            +
                #     logging.info("Extending prompt ...")
         | 
| 394 | 
            +
                #     if rank == 0:
         | 
| 395 | 
            +
                #         prompt_output = prompt_expander(
         | 
| 396 | 
            +
                #             args.prompt,
         | 
| 397 | 
            +
                #             tar_lang=args.prompt_extend_target_lang,
         | 
| 398 | 
            +
                #             image=img,
         | 
| 399 | 
            +
                #             seed=args.base_seed)
         | 
| 400 | 
            +
                #         if prompt_output.status == False:
         | 
| 401 | 
            +
                #             logging.info(
         | 
| 402 | 
            +
                #                 f"Extending prompt failed: {prompt_output.message}")
         | 
| 403 | 
            +
                #             logging.info("Falling back to original prompt.")
         | 
| 404 | 
            +
                #             input_prompt = args.prompt
         | 
| 405 | 
            +
                #         else:
         | 
| 406 | 
            +
                #             input_prompt = prompt_output.prompt
         | 
| 407 | 
            +
                #         input_prompt = [input_prompt]
         | 
| 408 | 
            +
                #     else:
         | 
| 409 | 
            +
                #         input_prompt = [None]
         | 
| 410 | 
            +
                #     if dist.is_initialized():
         | 
| 411 | 
            +
                #         dist.broadcast_object_list(input_prompt, src=0)
         | 
| 412 | 
            +
                #     args.prompt = input_prompt[0]
         | 
| 413 | 
            +
                #     logging.info(f"Extended prompt: {args.prompt}")
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                # read input files
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                with open(args.input_json, 'r', encoding='utf-8') as f:
         | 
| 420 | 
            +
                    input_data = json.load(f)
         | 
| 421 | 
            +
                    
         | 
| 422 | 
            +
                    wav2vec_feature_extractor, audio_encoder= custom_init('cpu', args.wav2vec_dir)
         | 
| 423 | 
            +
                    args.audio_save_dir = os.path.join(args.audio_save_dir, input_data['cond_image'].split('/')[-1].split('.')[0])
         | 
| 424 | 
            +
                    os.makedirs(args.audio_save_dir,exist_ok=True)
         | 
| 425 | 
            +
                    
         | 
| 426 | 
            +
                    if len(input_data['cond_audio'])==2:
         | 
| 427 | 
            +
                        new_human_speech1, new_human_speech2, sum_human_speechs = audio_prepare_multi(input_data['cond_audio']['person1'], input_data['cond_audio']['person2'], input_data['audio_type'])
         | 
| 428 | 
            +
                        audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder)
         | 
| 429 | 
            +
                        audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder)
         | 
| 430 | 
            +
                        emb1_path = os.path.join(args.audio_save_dir, '1.pt')
         | 
| 431 | 
            +
                        emb2_path = os.path.join(args.audio_save_dir, '2.pt')
         | 
| 432 | 
            +
                        sum_audio = os.path.join(args.audio_save_dir, 'sum.wav')
         | 
| 433 | 
            +
                        sf.write(sum_audio, sum_human_speechs, 16000)
         | 
| 434 | 
            +
                        torch.save(audio_embedding_1, emb1_path)
         | 
| 435 | 
            +
                        torch.save(audio_embedding_2, emb2_path)
         | 
| 436 | 
            +
                        input_data['cond_audio']['person1'] = emb1_path
         | 
| 437 | 
            +
                        input_data['cond_audio']['person2'] = emb2_path
         | 
| 438 | 
            +
                        input_data['video_audio'] = sum_audio
         | 
| 439 | 
            +
                    elif len(input_data['cond_audio'])==1:
         | 
| 440 | 
            +
                        human_speech = audio_prepare_single(input_data['cond_audio']['person1'])
         | 
| 441 | 
            +
                        audio_embedding = get_embedding(human_speech, wav2vec_feature_extractor, audio_encoder)
         | 
| 442 | 
            +
                        emb_path = os.path.join(args.audio_save_dir, '1.pt')
         | 
| 443 | 
            +
                        sum_audio = os.path.join(args.audio_save_dir, 'sum.wav')
         | 
| 444 | 
            +
                        sf.write(sum_audio, human_speech, 16000)
         | 
| 445 | 
            +
                        torch.save(audio_embedding, emb_path)
         | 
| 446 | 
            +
                        input_data['cond_audio']['person1'] = emb_path
         | 
| 447 | 
            +
                        input_data['video_audio'] = sum_audio
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                logging.info("Creating MultiTalk pipeline.")
         | 
| 450 | 
            +
                wan_i2v = wan.MultiTalkPipeline(
         | 
| 451 | 
            +
                    config=cfg,
         | 
| 452 | 
            +
                    checkpoint_dir=args.ckpt_dir,
         | 
| 453 | 
            +
                    device_id=device,
         | 
| 454 | 
            +
                    rank=rank,
         | 
| 455 | 
            +
                    t5_fsdp=args.t5_fsdp,
         | 
| 456 | 
            +
                    dit_fsdp=args.dit_fsdp, 
         | 
| 457 | 
            +
                    use_usp=(args.ulysses_size > 1 or args.ring_size > 1),  
         | 
| 458 | 
            +
                    t5_cpu=args.t5_cpu
         | 
| 459 | 
            +
                )
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                if args.num_persistent_param_in_dit is not None:
         | 
| 462 | 
            +
                    wan_i2v.vram_management = True
         | 
| 463 | 
            +
                    wan_i2v.enable_vram_management(
         | 
| 464 | 
            +
                        num_persistent_param_in_dit=args.num_persistent_param_in_dit
         | 
| 465 | 
            +
                    )
         | 
| 466 | 
            +
                
         | 
| 467 | 
            +
                logging.info("Generating video ...")
         | 
| 468 | 
            +
                video = wan_i2v.generate(
         | 
| 469 | 
            +
                    input_data,
         | 
| 470 | 
            +
                    size_buckget=args.size,
         | 
| 471 | 
            +
                    motion_frame=args.motion_frame,
         | 
| 472 | 
            +
                    frame_num=args.frame_num,
         | 
| 473 | 
            +
                    shift=args.sample_shift,
         | 
| 474 | 
            +
                    sampling_steps=args.sample_steps,
         | 
| 475 | 
            +
                    text_guide_scale=args.sample_text_guide_scale,
         | 
| 476 | 
            +
                    audio_guide_scale=args.sample_audio_guide_scale,
         | 
| 477 | 
            +
                    seed=args.base_seed,
         | 
| 478 | 
            +
                    offload_model=args.offload_model,
         | 
| 479 | 
            +
                    max_frames_num=args.frame_num if args.mode == 'clip' else 1000,
         | 
| 480 | 
            +
                    extra_args=args,
         | 
| 481 | 
            +
                    )
         | 
| 482 | 
            +
                
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                if rank == 0:
         | 
| 485 | 
            +
                    
         | 
| 486 | 
            +
                    if args.save_file is None:
         | 
| 487 | 
            +
                        formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
         | 
| 488 | 
            +
                        formatted_prompt = input_data['prompt'].replace(" ", "_").replace("/",
         | 
| 489 | 
            +
                                                                                    "_")[:50]
         | 
| 490 | 
            +
                        args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{args.ring_size}_{formatted_prompt}_{formatted_time}"
         | 
| 491 | 
            +
                    
         | 
| 492 | 
            +
                    logging.info(f"Saving generated video to {args.save_file}.mp4")
         | 
| 493 | 
            +
                    save_video_ffmpeg(video, args.save_file, [input_data['video_audio']])
         | 
| 494 | 
            +
                    
         | 
| 495 | 
            +
                logging.info("Finished.")
         | 
| 496 | 
            +
             | 
| 497 | 
            +
             | 
| 498 | 
            +
            if __name__ == "__main__":
         | 
| 499 | 
            +
                args = _parse_args()
         | 
| 500 | 
            +
                generate(args)
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            opencv-python>=4.9.0.80
         | 
| 2 | 
            +
            diffusers>=0.31.0
         | 
| 3 | 
            +
            transformers>=4.49.0
         | 
| 4 | 
            +
            tokenizers>=0.20.3
         | 
| 5 | 
            +
            accelerate>=1.1.1
         | 
| 6 | 
            +
            tqdm
         | 
| 7 | 
            +
            imageio
         | 
| 8 | 
            +
            easydict
         | 
| 9 | 
            +
            ftfy
         | 
| 10 | 
            +
            dashscope
         | 
| 11 | 
            +
            imageio-ffmpeg
         | 
| 12 | 
            +
            gradio>=5.0.0
         | 
| 13 | 
            +
            numpy>=1.23.5,<2
         | 
| 14 | 
            +
            xfuser>=0.4.1
         | 
| 15 | 
            +
            pyloudnorm
         | 
    	
        src/audio_analysis/torch_utils.py
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def get_mask_from_lengths(lengths, max_len=None):
         | 
| 6 | 
            +
                lengths = lengths.to(torch.long)
         | 
| 7 | 
            +
                if max_len is None:
         | 
| 8 | 
            +
                    max_len = torch.max(lengths).item()
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
         | 
| 11 | 
            +
                mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                return mask
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def linear_interpolation(features, seq_len):
         | 
| 17 | 
            +
                features = features.transpose(1, 2)
         | 
| 18 | 
            +
                output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
         | 
| 19 | 
            +
                return output_features.transpose(1, 2)
         | 
| 20 | 
            +
             | 
    	
        src/audio_analysis/wav2vec2.py
    ADDED
    
    | @@ -0,0 +1,125 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from transformers import Wav2Vec2Config, Wav2Vec2Model
         | 
| 2 | 
            +
            from transformers.modeling_outputs import BaseModelOutput
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from src.audio_analysis.torch_utils import linear_interpolation
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            # the implementation of Wav2Vec2Model is borrowed from
         | 
| 7 | 
            +
            # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
         | 
| 8 | 
            +
            # initialize our encoder with the pre-trained wav2vec 2.0 weights.
         | 
| 9 | 
            +
            class Wav2Vec2Model(Wav2Vec2Model):
         | 
| 10 | 
            +
                def __init__(self, config: Wav2Vec2Config):
         | 
| 11 | 
            +
                    super().__init__(config)
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                def forward(
         | 
| 14 | 
            +
                    self,
         | 
| 15 | 
            +
                    input_values,
         | 
| 16 | 
            +
                    seq_len,
         | 
| 17 | 
            +
                    attention_mask=None,
         | 
| 18 | 
            +
                    mask_time_indices=None,
         | 
| 19 | 
            +
                    output_attentions=None,
         | 
| 20 | 
            +
                    output_hidden_states=None,
         | 
| 21 | 
            +
                    return_dict=None,
         | 
| 22 | 
            +
                ):
         | 
| 23 | 
            +
                    self.config.output_attentions = True
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                    output_hidden_states = (
         | 
| 26 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 27 | 
            +
                    )
         | 
| 28 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    extract_features = self.feature_extractor(input_values)
         | 
| 31 | 
            +
                    extract_features = extract_features.transpose(1, 2)
         | 
| 32 | 
            +
                    extract_features = linear_interpolation(extract_features, seq_len=seq_len)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    if attention_mask is not None:
         | 
| 35 | 
            +
                        # compute reduced attention_mask corresponding to feature vectors
         | 
| 36 | 
            +
                        attention_mask = self._get_feature_vector_attention_mask(
         | 
| 37 | 
            +
                            extract_features.shape[1], attention_mask, add_adapter=False
         | 
| 38 | 
            +
                        )
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    hidden_states, extract_features = self.feature_projection(extract_features)
         | 
| 41 | 
            +
                    hidden_states = self._mask_hidden_states(
         | 
| 42 | 
            +
                        hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
         | 
| 43 | 
            +
                    )
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    encoder_outputs = self.encoder(
         | 
| 46 | 
            +
                        hidden_states,
         | 
| 47 | 
            +
                        attention_mask=attention_mask,
         | 
| 48 | 
            +
                        output_attentions=output_attentions,
         | 
| 49 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 50 | 
            +
                        return_dict=return_dict,
         | 
| 51 | 
            +
                    )
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    hidden_states = encoder_outputs[0]
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    if self.adapter is not None:
         | 
| 56 | 
            +
                        hidden_states = self.adapter(hidden_states)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    if not return_dict:
         | 
| 59 | 
            +
                        return (hidden_states, ) + encoder_outputs[1:]
         | 
| 60 | 
            +
                    return BaseModelOutput(
         | 
| 61 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 62 | 
            +
                        hidden_states=encoder_outputs.hidden_states,
         | 
| 63 | 
            +
                        attentions=encoder_outputs.attentions,
         | 
| 64 | 
            +
                    )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
                def feature_extract(
         | 
| 68 | 
            +
                    self,
         | 
| 69 | 
            +
                    input_values,
         | 
| 70 | 
            +
                    seq_len,
         | 
| 71 | 
            +
                ):
         | 
| 72 | 
            +
                    extract_features = self.feature_extractor(input_values)
         | 
| 73 | 
            +
                    extract_features = extract_features.transpose(1, 2)
         | 
| 74 | 
            +
                    extract_features = linear_interpolation(extract_features, seq_len=seq_len)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return extract_features
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def encode(
         | 
| 79 | 
            +
                    self,
         | 
| 80 | 
            +
                    extract_features,
         | 
| 81 | 
            +
                    attention_mask=None,
         | 
| 82 | 
            +
                    mask_time_indices=None,
         | 
| 83 | 
            +
                    output_attentions=None,
         | 
| 84 | 
            +
                    output_hidden_states=None,
         | 
| 85 | 
            +
                    return_dict=None,
         | 
| 86 | 
            +
                ):
         | 
| 87 | 
            +
                    self.config.output_attentions = True
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    output_hidden_states = (
         | 
| 90 | 
            +
                        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    if attention_mask is not None:
         | 
| 95 | 
            +
                        # compute reduced attention_mask corresponding to feature vectors
         | 
| 96 | 
            +
                        attention_mask = self._get_feature_vector_attention_mask(
         | 
| 97 | 
            +
                            extract_features.shape[1], attention_mask, add_adapter=False
         | 
| 98 | 
            +
                        )
         | 
| 99 | 
            +
                        
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    hidden_states, extract_features = self.feature_projection(extract_features)
         | 
| 102 | 
            +
                    hidden_states = self._mask_hidden_states(
         | 
| 103 | 
            +
                        hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    encoder_outputs = self.encoder(
         | 
| 107 | 
            +
                        hidden_states,
         | 
| 108 | 
            +
                        attention_mask=attention_mask,
         | 
| 109 | 
            +
                        output_attentions=output_attentions,
         | 
| 110 | 
            +
                        output_hidden_states=output_hidden_states,
         | 
| 111 | 
            +
                        return_dict=return_dict,
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    hidden_states = encoder_outputs[0]
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    if self.adapter is not None:
         | 
| 117 | 
            +
                        hidden_states = self.adapter(hidden_states)
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                    if not return_dict:
         | 
| 120 | 
            +
                        return (hidden_states, ) + encoder_outputs[1:]
         | 
| 121 | 
            +
                    return BaseModelOutput(
         | 
| 122 | 
            +
                        last_hidden_state=hidden_states,
         | 
| 123 | 
            +
                        hidden_states=encoder_outputs.hidden_states,
         | 
| 124 | 
            +
                        attentions=encoder_outputs.attentions,
         | 
| 125 | 
            +
                    )
         | 
    	
        src/utils.py
    ADDED
    
    | @@ -0,0 +1,60 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from contextlib import contextmanager
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            @contextmanager
         | 
| 6 | 
            +
            def init_weights_on_device(device=torch.device("meta"), include_buffers: bool = False):
         | 
| 7 | 
            +
                old_register_parameter = torch.nn.Module.register_parameter
         | 
| 8 | 
            +
                if include_buffers:
         | 
| 9 | 
            +
                    old_register_buffer = torch.nn.Module.register_buffer
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def register_empty_parameter(module, name, param):
         | 
| 12 | 
            +
                    old_register_parameter(module, name, param)
         | 
| 13 | 
            +
                    if param is not None:
         | 
| 14 | 
            +
                        param_cls = type(module._parameters[name])
         | 
| 15 | 
            +
                        kwargs = module._parameters[name].__dict__
         | 
| 16 | 
            +
                        kwargs["requires_grad"] = param.requires_grad
         | 
| 17 | 
            +
                        module._parameters[name] = param_cls(
         | 
| 18 | 
            +
                            module._parameters[name].to(device), **kwargs
         | 
| 19 | 
            +
                        )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def register_empty_buffer(module, name, buffer, persistent=True):
         | 
| 22 | 
            +
                    old_register_buffer(module, name, buffer, persistent=persistent)
         | 
| 23 | 
            +
                    if buffer is not None:
         | 
| 24 | 
            +
                        module._buffers[name] = module._buffers[name].to(device)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def patch_tensor_constructor(fn):
         | 
| 27 | 
            +
                    def wrapper(*args, **kwargs):
         | 
| 28 | 
            +
                        kwargs["device"] = device
         | 
| 29 | 
            +
                        return fn(*args, **kwargs)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    return wrapper
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                if include_buffers:
         | 
| 34 | 
            +
                    tensor_constructors_to_patch = {
         | 
| 35 | 
            +
                        torch_function_name: getattr(torch, torch_function_name)
         | 
| 36 | 
            +
                        for torch_function_name in ["empty", "zeros", "ones", "full"]
         | 
| 37 | 
            +
                    }
         | 
| 38 | 
            +
                else:
         | 
| 39 | 
            +
                    tensor_constructors_to_patch = {}
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                try:
         | 
| 42 | 
            +
                    torch.nn.Module.register_parameter = register_empty_parameter
         | 
| 43 | 
            +
                    if include_buffers:
         | 
| 44 | 
            +
                        torch.nn.Module.register_buffer = register_empty_buffer
         | 
| 45 | 
            +
                    for torch_function_name in tensor_constructors_to_patch.keys():
         | 
| 46 | 
            +
                        setattr(
         | 
| 47 | 
            +
                            torch,
         | 
| 48 | 
            +
                            torch_function_name,
         | 
| 49 | 
            +
                            patch_tensor_constructor(getattr(torch, torch_function_name)),
         | 
| 50 | 
            +
                        )
         | 
| 51 | 
            +
                    yield
         | 
| 52 | 
            +
                finally:
         | 
| 53 | 
            +
                    torch.nn.Module.register_parameter = old_register_parameter
         | 
| 54 | 
            +
                    if include_buffers:
         | 
| 55 | 
            +
                        torch.nn.Module.register_buffer = old_register_buffer
         | 
| 56 | 
            +
                    for (
         | 
| 57 | 
            +
                        torch_function_name,
         | 
| 58 | 
            +
                        old_torch_function,
         | 
| 59 | 
            +
                    ) in tensor_constructors_to_patch.items():
         | 
| 60 | 
            +
                        setattr(torch, torch_function_name, old_torch_function)
         | 
    	
        src/vram_management/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            from .layers import *
         | 
    	
        src/vram_management/layers.py
    ADDED
    
    | @@ -0,0 +1,179 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import copy
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from src.utils import init_weights_on_device
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def cast_to(weight, dtype, device):
         | 
| 9 | 
            +
                r = torch.empty_like(weight, dtype=dtype, device=device)
         | 
| 10 | 
            +
                r.copy_(weight)
         | 
| 11 | 
            +
                return r
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class AutoWrappedModule(torch.nn.Module):
         | 
| 15 | 
            +
                def __init__(
         | 
| 16 | 
            +
                    self,
         | 
| 17 | 
            +
                    module: torch.nn.Module,
         | 
| 18 | 
            +
                    offload_dtype,
         | 
| 19 | 
            +
                    offload_device,
         | 
| 20 | 
            +
                    onload_dtype,
         | 
| 21 | 
            +
                    onload_device,
         | 
| 22 | 
            +
                    computation_dtype,
         | 
| 23 | 
            +
                    computation_device,
         | 
| 24 | 
            +
                ):
         | 
| 25 | 
            +
                    super().__init__()
         | 
| 26 | 
            +
                    self.module = module.to(dtype=offload_dtype, device=offload_device)
         | 
| 27 | 
            +
                    self.offload_dtype = offload_dtype
         | 
| 28 | 
            +
                    self.offload_device = offload_device
         | 
| 29 | 
            +
                    self.onload_dtype = onload_dtype
         | 
| 30 | 
            +
                    self.onload_device = onload_device
         | 
| 31 | 
            +
                    self.computation_dtype = computation_dtype
         | 
| 32 | 
            +
                    self.computation_device = computation_device
         | 
| 33 | 
            +
                    self.state = 0
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def offload(self):
         | 
| 36 | 
            +
                    if self.state == 1 and (
         | 
| 37 | 
            +
                        self.offload_dtype != self.onload_dtype
         | 
| 38 | 
            +
                        or self.offload_device != self.onload_device
         | 
| 39 | 
            +
                    ):
         | 
| 40 | 
            +
                        self.module.to(dtype=self.offload_dtype, device=self.offload_device)
         | 
| 41 | 
            +
                        self.state = 0
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def onload(self):
         | 
| 44 | 
            +
                    if self.state == 0 and (
         | 
| 45 | 
            +
                        self.offload_dtype != self.onload_dtype
         | 
| 46 | 
            +
                        or self.offload_device != self.onload_device
         | 
| 47 | 
            +
                    ):
         | 
| 48 | 
            +
                        self.module.to(dtype=self.onload_dtype, device=self.onload_device)
         | 
| 49 | 
            +
                        self.state = 1
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def forward(self, *args, **kwargs):
         | 
| 52 | 
            +
                    if (
         | 
| 53 | 
            +
                        self.onload_dtype == self.computation_dtype
         | 
| 54 | 
            +
                        and self.onload_device == self.computation_device
         | 
| 55 | 
            +
                    ):
         | 
| 56 | 
            +
                        module = self.module
         | 
| 57 | 
            +
                    else:
         | 
| 58 | 
            +
                        module = copy.deepcopy(self.module).to(
         | 
| 59 | 
            +
                            dtype=self.computation_dtype, device=self.computation_device
         | 
| 60 | 
            +
                        )
         | 
| 61 | 
            +
                    return module(*args, **kwargs)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
             | 
| 64 | 
            +
            class AutoWrappedLinear(torch.nn.Linear):
         | 
| 65 | 
            +
                def __init__(
         | 
| 66 | 
            +
                    self,
         | 
| 67 | 
            +
                    module: torch.nn.Linear,
         | 
| 68 | 
            +
                    offload_dtype,
         | 
| 69 | 
            +
                    offload_device,
         | 
| 70 | 
            +
                    onload_dtype,
         | 
| 71 | 
            +
                    onload_device,
         | 
| 72 | 
            +
                    computation_dtype,
         | 
| 73 | 
            +
                    computation_device,
         | 
| 74 | 
            +
                ):
         | 
| 75 | 
            +
                    with init_weights_on_device(device=torch.device("meta")):
         | 
| 76 | 
            +
                        super().__init__(
         | 
| 77 | 
            +
                            in_features=module.in_features,
         | 
| 78 | 
            +
                            out_features=module.out_features,
         | 
| 79 | 
            +
                            bias=module.bias is not None,
         | 
| 80 | 
            +
                            dtype=offload_dtype,
         | 
| 81 | 
            +
                            device=offload_device,
         | 
| 82 | 
            +
                        )
         | 
| 83 | 
            +
                    self.weight = module.weight
         | 
| 84 | 
            +
                    self.bias = module.bias
         | 
| 85 | 
            +
                    self.offload_dtype = offload_dtype
         | 
| 86 | 
            +
                    self.offload_device = offload_device
         | 
| 87 | 
            +
                    self.onload_dtype = onload_dtype
         | 
| 88 | 
            +
                    self.onload_device = onload_device
         | 
| 89 | 
            +
                    self.computation_dtype = computation_dtype
         | 
| 90 | 
            +
                    self.computation_device = computation_device
         | 
| 91 | 
            +
                    self.state = 0
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def offload(self):
         | 
| 94 | 
            +
                    if self.state == 1 and (
         | 
| 95 | 
            +
                        self.offload_dtype != self.onload_dtype
         | 
| 96 | 
            +
                        or self.offload_device != self.onload_device
         | 
| 97 | 
            +
                    ):
         | 
| 98 | 
            +
                        self.to(dtype=self.offload_dtype, device=self.offload_device)
         | 
| 99 | 
            +
                        self.state = 0
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def onload(self):
         | 
| 102 | 
            +
                    if self.state == 0 and (
         | 
| 103 | 
            +
                        self.offload_dtype != self.onload_dtype
         | 
| 104 | 
            +
                        or self.offload_device != self.onload_device
         | 
| 105 | 
            +
                    ):
         | 
| 106 | 
            +
                        self.to(dtype=self.onload_dtype, device=self.onload_device)
         | 
| 107 | 
            +
                        self.state = 1
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                def forward(self, x, *args, **kwargs):
         | 
| 110 | 
            +
                    if (
         | 
| 111 | 
            +
                        self.onload_dtype == self.computation_dtype
         | 
| 112 | 
            +
                        and self.onload_device == self.computation_device
         | 
| 113 | 
            +
                    ):
         | 
| 114 | 
            +
                        weight, bias = self.weight, self.bias
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        weight = cast_to(
         | 
| 117 | 
            +
                            self.weight, self.computation_dtype, self.computation_device
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
                        bias = (
         | 
| 120 | 
            +
                            None
         | 
| 121 | 
            +
                            if self.bias is None
         | 
| 122 | 
            +
                            else cast_to(self.bias, self.computation_dtype, self.computation_device)
         | 
| 123 | 
            +
                        )
         | 
| 124 | 
            +
                    return torch.nn.functional.linear(x, weight, bias)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def enable_vram_management_recursively(
         | 
| 128 | 
            +
                model: torch.nn.Module,
         | 
| 129 | 
            +
                module_map: dict,
         | 
| 130 | 
            +
                module_config: dict,
         | 
| 131 | 
            +
                max_num_param=None,
         | 
| 132 | 
            +
                overflow_module_config: dict = None,
         | 
| 133 | 
            +
                total_num_param=0,
         | 
| 134 | 
            +
            ):
         | 
| 135 | 
            +
                for name, module in model.named_children():
         | 
| 136 | 
            +
                    for source_module, target_module in module_map.items():
         | 
| 137 | 
            +
                        if isinstance(module, source_module):
         | 
| 138 | 
            +
                            num_param = sum(p.numel() for p in module.parameters())
         | 
| 139 | 
            +
                            # print(str(module) + ':' + str(num_param))
         | 
| 140 | 
            +
                            if (
         | 
| 141 | 
            +
                                max_num_param is not None
         | 
| 142 | 
            +
                                and total_num_param + num_param > max_num_param
         | 
| 143 | 
            +
                            ):
         | 
| 144 | 
            +
                                # print(str(module) + '-->\t\t num:' + str(num_param) + "\t total:" + str(total_num_param))
         | 
| 145 | 
            +
                                module_config_ = overflow_module_config
         | 
| 146 | 
            +
                            else:
         | 
| 147 | 
            +
                                module_config_ = module_config
         | 
| 148 | 
            +
                            module_ = target_module(module, **module_config_)
         | 
| 149 | 
            +
                            setattr(model, name, module_)
         | 
| 150 | 
            +
                            total_num_param += num_param
         | 
| 151 | 
            +
                            break
         | 
| 152 | 
            +
                    else:
         | 
| 153 | 
            +
                        total_num_param = enable_vram_management_recursively(
         | 
| 154 | 
            +
                            module,
         | 
| 155 | 
            +
                            module_map,
         | 
| 156 | 
            +
                            module_config,
         | 
| 157 | 
            +
                            max_num_param,
         | 
| 158 | 
            +
                            overflow_module_config,
         | 
| 159 | 
            +
                            total_num_param,
         | 
| 160 | 
            +
                        )
         | 
| 161 | 
            +
                return total_num_param
         | 
| 162 | 
            +
             | 
| 163 | 
            +
             | 
| 164 | 
            +
            def enable_vram_management(
         | 
| 165 | 
            +
                model: torch.nn.Module,
         | 
| 166 | 
            +
                module_map: dict,
         | 
| 167 | 
            +
                module_config: dict,
         | 
| 168 | 
            +
                max_num_param=None,
         | 
| 169 | 
            +
                overflow_module_config: dict = None,
         | 
| 170 | 
            +
            ):
         | 
| 171 | 
            +
                enable_vram_management_recursively(
         | 
| 172 | 
            +
                    model,
         | 
| 173 | 
            +
                    module_map,
         | 
| 174 | 
            +
                    module_config,
         | 
| 175 | 
            +
                    max_num_param,
         | 
| 176 | 
            +
                    overflow_module_config,
         | 
| 177 | 
            +
                    total_num_param=0,
         | 
| 178 | 
            +
                )
         | 
| 179 | 
            +
                model.vram_management_enabled = True
         | 
    	
        wan/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from . import configs, distributed, modules
         | 
| 2 | 
            +
            from .first_last_frame2video import WanFLF2V
         | 
| 3 | 
            +
            from .image2video import WanI2V
         | 
| 4 | 
            +
            from .text2video import WanT2V
         | 
| 5 | 
            +
            from .vace import WanVace, WanVaceMP
         | 
| 6 | 
            +
            from .multitalk import MultiTalkPipeline
         | 
    	
        wan/configs/__init__.py
    ADDED
    
    | @@ -0,0 +1,58 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import copy
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            os.environ['TOKENIZERS_PARALLELISM'] = 'false'
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .wan_i2v_14B import i2v_14B
         | 
| 8 | 
            +
            from .wan_t2v_1_3B import t2v_1_3B
         | 
| 9 | 
            +
            from .wan_t2v_14B import t2v_14B
         | 
| 10 | 
            +
            from .wan_multitalk_14B import multitalk_14B
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # the config of t2i_14B is the same as t2v_14B
         | 
| 13 | 
            +
            t2i_14B = copy.deepcopy(t2v_14B)
         | 
| 14 | 
            +
            t2i_14B.__name__ = 'Config: Wan T2I 14B'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # the config of flf2v_14B is the same as i2v_14B
         | 
| 17 | 
            +
            flf2v_14B = copy.deepcopy(i2v_14B)
         | 
| 18 | 
            +
            flf2v_14B.__name__ = 'Config: Wan FLF2V 14B'
         | 
| 19 | 
            +
            flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            WAN_CONFIGS = {
         | 
| 22 | 
            +
                't2v-14B': t2v_14B,
         | 
| 23 | 
            +
                't2v-1.3B': t2v_1_3B,
         | 
| 24 | 
            +
                'i2v-14B': i2v_14B,
         | 
| 25 | 
            +
                't2i-14B': t2i_14B,
         | 
| 26 | 
            +
                'flf2v-14B': flf2v_14B,
         | 
| 27 | 
            +
                'vace-1.3B': t2v_1_3B,
         | 
| 28 | 
            +
                'vace-14B': t2v_14B,
         | 
| 29 | 
            +
                'multitalk-14B': multitalk_14B,
         | 
| 30 | 
            +
            }
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            SIZE_CONFIGS = {
         | 
| 33 | 
            +
                '720*1280': (720, 1280),
         | 
| 34 | 
            +
                '1280*720': (1280, 720),
         | 
| 35 | 
            +
                '480*832': (480, 832),
         | 
| 36 | 
            +
                '832*480': (832, 480),
         | 
| 37 | 
            +
                '1024*1024': (1024, 1024),
         | 
| 38 | 
            +
                'multitalk-480': (640, 640),
         | 
| 39 | 
            +
                'multitalk-720': (960, 960),
         | 
| 40 | 
            +
            }
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            MAX_AREA_CONFIGS = {
         | 
| 43 | 
            +
                '720*1280': 720 * 1280,
         | 
| 44 | 
            +
                '1280*720': 1280 * 720,
         | 
| 45 | 
            +
                '480*832': 480 * 832,
         | 
| 46 | 
            +
                '832*480': 832 * 480,
         | 
| 47 | 
            +
            }
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            SUPPORTED_SIZES = {
         | 
| 50 | 
            +
                't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
         | 
| 51 | 
            +
                't2v-1.3B': ('480*832', '832*480'),
         | 
| 52 | 
            +
                'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
         | 
| 53 | 
            +
                'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'),
         | 
| 54 | 
            +
                't2i-14B': tuple(SIZE_CONFIGS.keys()),
         | 
| 55 | 
            +
                'vace-1.3B': ('480*832', '832*480'),
         | 
| 56 | 
            +
                'vace-14B': ('720*1280', '1280*720', '480*832', '832*480'),
         | 
| 57 | 
            +
                'multitalk-14B': ('multitalk-480', 'multitalk-720'),
         | 
| 58 | 
            +
            }
         | 
    	
        wan/configs/shared_config.py
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from easydict import EasyDict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            #------------------------ Wan shared config ------------------------#
         | 
| 6 | 
            +
            wan_shared_cfg = EasyDict()
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # t5
         | 
| 9 | 
            +
            wan_shared_cfg.t5_model = 'umt5_xxl'
         | 
| 10 | 
            +
            wan_shared_cfg.t5_dtype = torch.bfloat16
         | 
| 11 | 
            +
            wan_shared_cfg.text_len = 512
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # transformer
         | 
| 14 | 
            +
            wan_shared_cfg.param_dtype = torch.bfloat16
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # inference
         | 
| 17 | 
            +
            wan_shared_cfg.num_train_timesteps = 1000
         | 
| 18 | 
            +
            wan_shared_cfg.sample_fps = 16
         | 
| 19 | 
            +
            wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
         | 
    	
        wan/configs/wan_i2v_14B.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from easydict import EasyDict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from .shared_config import wan_shared_cfg
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #------------------------ Wan I2V 14B ------------------------#
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            i2v_14B = EasyDict(__name__='Config: Wan I2V 14B')
         | 
| 10 | 
            +
            i2v_14B.update(wan_shared_cfg)
         | 
| 11 | 
            +
            i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
         | 
| 14 | 
            +
            i2v_14B.t5_tokenizer = 'google/umt5-xxl'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # clip
         | 
| 17 | 
            +
            i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
         | 
| 18 | 
            +
            i2v_14B.clip_dtype = torch.float16
         | 
| 19 | 
            +
            i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
         | 
| 20 | 
            +
            i2v_14B.clip_tokenizer = 'xlm-roberta-large'
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # vae
         | 
| 23 | 
            +
            i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
         | 
| 24 | 
            +
            i2v_14B.vae_stride = (4, 8, 8)
         | 
    	
        wan/configs/wan_multitalk_14B.py
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from easydict import EasyDict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from .shared_config import wan_shared_cfg
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            #------------------------ Wan I2V 14B ------------------------#
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            multitalk_14B = EasyDict(__name__='Config: Wan MultiTalk AI2V 14B')
         | 
| 10 | 
            +
            multitalk_14B.update(wan_shared_cfg)
         | 
| 11 | 
            +
            multitalk_14B.sample_neg_prompt = 'bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards'
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            multitalk_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
         | 
| 14 | 
            +
            multitalk_14B.t5_tokenizer = 'google/umt5-xxl'
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            # clip
         | 
| 17 | 
            +
            multitalk_14B.clip_model = 'clip_xlm_roberta_vit_h_14'
         | 
| 18 | 
            +
            multitalk_14B.clip_dtype = torch.float16
         | 
| 19 | 
            +
            multitalk_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'
         | 
| 20 | 
            +
            multitalk_14B.clip_tokenizer = 'xlm-roberta-large'
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # vae
         | 
| 23 | 
            +
            multitalk_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
         | 
| 24 | 
            +
            multitalk_14B.vae_stride = (4, 8, 8)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            # transformer
         | 
| 27 | 
            +
            multitalk_14B.patch_size = (1, 2, 2)
         | 
| 28 | 
            +
            multitalk_14B.dim = 5120
         | 
| 29 | 
            +
            multitalk_14B.ffn_dim = 13824
         | 
| 30 | 
            +
            multitalk_14B.freq_dim = 256
         | 
| 31 | 
            +
            multitalk_14B.num_heads = 40
         | 
| 32 | 
            +
            multitalk_14B.num_layers = 40
         | 
| 33 | 
            +
            multitalk_14B.window_size = (-1, -1)
         | 
| 34 | 
            +
            multitalk_14B.qk_norm = True
         | 
| 35 | 
            +
            multitalk_14B.cross_attn_norm = True
         | 
| 36 | 
            +
            multitalk_14B.eps = 1e-6
         | 
    	
        wan/configs/wan_t2v_14B.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            from easydict import EasyDict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .shared_config import wan_shared_cfg
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            #------------------------ Wan T2V 14B ------------------------#
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            t2v_14B = EasyDict(__name__='Config: Wan T2V 14B')
         | 
| 9 | 
            +
            t2v_14B.update(wan_shared_cfg)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # t5
         | 
| 12 | 
            +
            t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
         | 
| 13 | 
            +
            t2v_14B.t5_tokenizer = 'google/umt5-xxl'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # vae
         | 
| 16 | 
            +
            t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth'
         | 
| 17 | 
            +
            t2v_14B.vae_stride = (4, 8, 8)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # transformer
         | 
| 20 | 
            +
            t2v_14B.patch_size = (1, 2, 2)
         | 
| 21 | 
            +
            t2v_14B.dim = 5120
         | 
| 22 | 
            +
            t2v_14B.ffn_dim = 13824
         | 
| 23 | 
            +
            t2v_14B.freq_dim = 256
         | 
| 24 | 
            +
            t2v_14B.num_heads = 40
         | 
| 25 | 
            +
            t2v_14B.num_layers = 40
         | 
| 26 | 
            +
            t2v_14B.window_size = (-1, -1)
         | 
| 27 | 
            +
            t2v_14B.qk_norm = True
         | 
| 28 | 
            +
            t2v_14B.cross_attn_norm = True
         | 
| 29 | 
            +
            t2v_14B.eps = 1e-6
         | 
    	
        wan/configs/wan_t2v_1_3B.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            from easydict import EasyDict
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from .shared_config import wan_shared_cfg
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            #------------------------ Wan T2V 1.3B ------------------------#
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B')
         | 
| 9 | 
            +
            t2v_1_3B.update(wan_shared_cfg)
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            # t5
         | 
| 12 | 
            +
            t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
         | 
| 13 | 
            +
            t2v_1_3B.t5_tokenizer = 'google/umt5-xxl'
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # vae
         | 
| 16 | 
            +
            t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth'
         | 
| 17 | 
            +
            t2v_1_3B.vae_stride = (4, 8, 8)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # transformer
         | 
| 20 | 
            +
            t2v_1_3B.patch_size = (1, 2, 2)
         | 
| 21 | 
            +
            t2v_1_3B.dim = 1536
         | 
| 22 | 
            +
            t2v_1_3B.ffn_dim = 8960
         | 
| 23 | 
            +
            t2v_1_3B.freq_dim = 256
         | 
| 24 | 
            +
            t2v_1_3B.num_heads = 12
         | 
| 25 | 
            +
            t2v_1_3B.num_layers = 30
         | 
| 26 | 
            +
            t2v_1_3B.window_size = (-1, -1)
         | 
| 27 | 
            +
            t2v_1_3B.qk_norm = True
         | 
| 28 | 
            +
            t2v_1_3B.cross_attn_norm = True
         | 
| 29 | 
            +
            t2v_1_3B.eps = 1e-6
         | 
    	
        wan/distributed/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        wan/distributed/fsdp.py
    ADDED
    
    | @@ -0,0 +1,43 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
            from functools import partial
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
         | 
| 7 | 
            +
            from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
         | 
| 8 | 
            +
            from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
         | 
| 9 | 
            +
            from torch.distributed.utils import _free_storage
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def shard_model(
         | 
| 13 | 
            +
                model,
         | 
| 14 | 
            +
                device_id,
         | 
| 15 | 
            +
                param_dtype=torch.bfloat16,
         | 
| 16 | 
            +
                reduce_dtype=torch.float32,
         | 
| 17 | 
            +
                buffer_dtype=torch.float32,
         | 
| 18 | 
            +
                process_group=None,
         | 
| 19 | 
            +
                sharding_strategy=ShardingStrategy.FULL_SHARD,
         | 
| 20 | 
            +
                sync_module_states=True,
         | 
| 21 | 
            +
            ):
         | 
| 22 | 
            +
                model = FSDP(
         | 
| 23 | 
            +
                    module=model,
         | 
| 24 | 
            +
                    process_group=process_group,
         | 
| 25 | 
            +
                    sharding_strategy=sharding_strategy,
         | 
| 26 | 
            +
                    auto_wrap_policy=partial(
         | 
| 27 | 
            +
                        lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
         | 
| 28 | 
            +
                    # mixed_precision=MixedPrecision(
         | 
| 29 | 
            +
                    #     param_dtype=param_dtype,
         | 
| 30 | 
            +
                    #     reduce_dtype=reduce_dtype,
         | 
| 31 | 
            +
                    #     buffer_dtype=buffer_dtype),
         | 
| 32 | 
            +
                    device_id=device_id,
         | 
| 33 | 
            +
                    sync_module_states=sync_module_states)
         | 
| 34 | 
            +
                return model
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            def free_model(model):
         | 
| 38 | 
            +
                for m in model.modules():
         | 
| 39 | 
            +
                    if isinstance(m, FSDP):
         | 
| 40 | 
            +
                        _free_storage(m._handle.flat_param.data)
         | 
| 41 | 
            +
                del model
         | 
| 42 | 
            +
                gc.collect()
         | 
| 43 | 
            +
                torch.cuda.empty_cache()
         | 
    	
        wan/distributed/xdit_context_parallel.py
    ADDED
    
    | @@ -0,0 +1,550 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.cuda.amp as amp
         | 
| 6 | 
            +
            from xfuser.core.distributed import (
         | 
| 7 | 
            +
                get_sequence_parallel_rank,
         | 
| 8 | 
            +
                get_sequence_parallel_world_size,
         | 
| 9 | 
            +
                get_sp_group,
         | 
| 10 | 
            +
            )
         | 
| 11 | 
            +
            from einops import rearrange
         | 
| 12 | 
            +
            from xfuser.core.long_ctx_attention import xFuserLongContextAttention
         | 
| 13 | 
            +
            import xformers.ops
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from ..modules.model import sinusoidal_embedding_1d
         | 
| 16 | 
            +
            from ..utils.multitalk_utils import get_attn_map_with_target, split_token_counts_and_frame_ids, normalize_and_scale
         | 
| 17 | 
            +
            from ..modules.attention import SingleStreamAttention, SingleStreamMutiAttention
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def pad_freqs(original_tensor, target_len):
         | 
| 21 | 
            +
                seq_len, s1, s2 = original_tensor.shape
         | 
| 22 | 
            +
                pad_size = target_len - seq_len
         | 
| 23 | 
            +
                padding_tensor = torch.ones(
         | 
| 24 | 
            +
                    pad_size,
         | 
| 25 | 
            +
                    s1,
         | 
| 26 | 
            +
                    s2,
         | 
| 27 | 
            +
                    dtype=original_tensor.dtype,
         | 
| 28 | 
            +
                    device=original_tensor.device)
         | 
| 29 | 
            +
                padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
         | 
| 30 | 
            +
                return padded_tensor
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            @amp.autocast(enabled=False)
         | 
| 34 | 
            +
            def rope_apply(x, grid_sizes, freqs):
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                x:          [B, L, N, C].
         | 
| 37 | 
            +
                grid_sizes: [B, 3].
         | 
| 38 | 
            +
                freqs:      [M, C // 2].
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                s, n, c = x.size(1), x.size(2), x.size(3) // 2
         | 
| 41 | 
            +
                # split freqs
         | 
| 42 | 
            +
                freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # [[N, head_dim/2], [N, head_dim/2], [N, head_dim/2]] # T H W 极坐标
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                # loop over samples
         | 
| 45 | 
            +
                output = []
         | 
| 46 | 
            +
                for i, (f, h, w) in enumerate(grid_sizes.tolist()):
         | 
| 47 | 
            +
                    seq_len = f * h * w
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    # precompute multipliers
         | 
| 50 | 
            +
                    x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
         | 
| 51 | 
            +
                        s, n, -1, 2)) # [L, N, C/2] # 极坐标
         | 
| 52 | 
            +
                    freqs_i = torch.cat([
         | 
| 53 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 54 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 55 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 56 | 
            +
                    ],
         | 
| 57 | 
            +
                                        dim=-1).reshape(seq_len, 1, -1) # seq_lens, 1,  3 * dim / 2 (T H W)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    # apply rotary embedding
         | 
| 60 | 
            +
                    sp_size = get_sequence_parallel_world_size()
         | 
| 61 | 
            +
                    sp_rank = get_sequence_parallel_rank()
         | 
| 62 | 
            +
                    freqs_i = pad_freqs(freqs_i, s * sp_size)
         | 
| 63 | 
            +
                    s_per_rank = s
         | 
| 64 | 
            +
                    freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
         | 
| 65 | 
            +
                                                                   s_per_rank), :, :]
         | 
| 66 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
         | 
| 67 | 
            +
                    x_i = torch.cat([x_i, x[i, s:]])
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # append to collection
         | 
| 70 | 
            +
                    output.append(x_i)
         | 
| 71 | 
            +
                return torch.stack(output).float()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
             | 
| 74 | 
            +
            def usp_dit_forward_vace(self, x, vace_context, seq_len, kwargs):
         | 
| 75 | 
            +
                # embeddings
         | 
| 76 | 
            +
                c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
         | 
| 77 | 
            +
                c = [u.flatten(2).transpose(1, 2) for u in c]
         | 
| 78 | 
            +
                c = torch.cat([
         | 
| 79 | 
            +
                    torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
         | 
| 80 | 
            +
                    for u in c
         | 
| 81 | 
            +
                ])
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                # arguments
         | 
| 84 | 
            +
                new_kwargs = dict(x=x)
         | 
| 85 | 
            +
                new_kwargs.update(kwargs)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                # Context Parallel
         | 
| 88 | 
            +
                c = torch.chunk(
         | 
| 89 | 
            +
                    c, get_sequence_parallel_world_size(),
         | 
| 90 | 
            +
                    dim=1)[get_sequence_parallel_rank()]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                hints = []
         | 
| 93 | 
            +
                for block in self.vace_blocks:
         | 
| 94 | 
            +
                    c, c_skip = block(c, **new_kwargs)
         | 
| 95 | 
            +
                    hints.append(c_skip)
         | 
| 96 | 
            +
                return hints
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            def usp_dit_forward(
         | 
| 100 | 
            +
                self,
         | 
| 101 | 
            +
                x,
         | 
| 102 | 
            +
                t,
         | 
| 103 | 
            +
                context,
         | 
| 104 | 
            +
                seq_len,
         | 
| 105 | 
            +
                vace_context=None,
         | 
| 106 | 
            +
                vace_context_scale=1.0,
         | 
| 107 | 
            +
                clip_fea=None,
         | 
| 108 | 
            +
                y=None,
         | 
| 109 | 
            +
            ):
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                x:              A list of videos each with shape [C, T, H, W].
         | 
| 112 | 
            +
                t:              [B].
         | 
| 113 | 
            +
                context:        A list of text embeddings each with shape [L, C].
         | 
| 114 | 
            +
                """
         | 
| 115 | 
            +
                if self.model_type == 'i2v':
         | 
| 116 | 
            +
                    assert clip_fea is not None and y is not None
         | 
| 117 | 
            +
                # params
         | 
| 118 | 
            +
                device = self.patch_embedding.weight.device
         | 
| 119 | 
            +
                if self.freqs.device != device:
         | 
| 120 | 
            +
                    self.freqs = self.freqs.to(device)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                if self.model_type != 'vace' and y is not None:
         | 
| 123 | 
            +
                    x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                # embeddings
         | 
| 126 | 
            +
                x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 127 | 
            +
                grid_sizes = torch.stack(
         | 
| 128 | 
            +
                    [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 129 | 
            +
                x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 130 | 
            +
                seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 131 | 
            +
                assert seq_lens.max() <= seq_len
         | 
| 132 | 
            +
                x = torch.cat([
         | 
| 133 | 
            +
                    torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
         | 
| 134 | 
            +
                    for u in x
         | 
| 135 | 
            +
                ])
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                # time embeddings
         | 
| 138 | 
            +
                with amp.autocast(dtype=torch.float32):
         | 
| 139 | 
            +
                    e = self.time_embedding(
         | 
| 140 | 
            +
                        sinusoidal_embedding_1d(self.freq_dim, t).float())
         | 
| 141 | 
            +
                    e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 142 | 
            +
                    assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                # context
         | 
| 145 | 
            +
                context_lens = None
         | 
| 146 | 
            +
                context = self.text_embedding(
         | 
| 147 | 
            +
                    torch.stack([
         | 
| 148 | 
            +
                        torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 149 | 
            +
                        for u in context
         | 
| 150 | 
            +
                    ]))
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                if self.model_type != 'vace' and clip_fea is not None:
         | 
| 153 | 
            +
                    context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
         | 
| 154 | 
            +
                    context = torch.concat([context_clip, context], dim=1)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                # arguments
         | 
| 157 | 
            +
                kwargs = dict(
         | 
| 158 | 
            +
                    e=e0,
         | 
| 159 | 
            +
                    seq_lens=seq_lens,
         | 
| 160 | 
            +
                    grid_sizes=grid_sizes,
         | 
| 161 | 
            +
                    freqs=self.freqs,
         | 
| 162 | 
            +
                    context=context,
         | 
| 163 | 
            +
                    context_lens=context_lens)
         | 
| 164 | 
            +
                
         | 
| 165 | 
            +
                # Context Parallel
         | 
| 166 | 
            +
                x = torch.chunk(
         | 
| 167 | 
            +
                    x, get_sequence_parallel_world_size(),
         | 
| 168 | 
            +
                    dim=1)[get_sequence_parallel_rank()]
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                for block in self.blocks:
         | 
| 171 | 
            +
                    x = block(x, **kwargs)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                # head
         | 
| 174 | 
            +
                x = self.head(x, e)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                # Context Parallel
         | 
| 177 | 
            +
                x = get_sp_group().all_gather(x, dim=1)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                # unpatchify
         | 
| 180 | 
            +
                x = self.unpatchify(x, grid_sizes)
         | 
| 181 | 
            +
                return [u.float() for u in x]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
            def usp_attn_forward(self,
         | 
| 185 | 
            +
                                 x,
         | 
| 186 | 
            +
                                 seq_lens,
         | 
| 187 | 
            +
                                 grid_sizes,
         | 
| 188 | 
            +
                                 freqs,
         | 
| 189 | 
            +
                                 dtype=torch.bfloat16):
         | 
| 190 | 
            +
                b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 191 | 
            +
                half_dtypes = (torch.float16, torch.bfloat16)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def half(x):
         | 
| 194 | 
            +
                    return x if x.dtype in half_dtypes else x.to(dtype)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                # query, key, value function
         | 
| 197 | 
            +
                def qkv_fn(x):
         | 
| 198 | 
            +
                    q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 199 | 
            +
                    k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 200 | 
            +
                    v = self.v(x).view(b, s, n, d)
         | 
| 201 | 
            +
                    return q, k, v
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                q, k, v = qkv_fn(x)
         | 
| 204 | 
            +
                q = rope_apply(q, grid_sizes, freqs)
         | 
| 205 | 
            +
                k = rope_apply(k, grid_sizes, freqs)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                # TODO: We should use unpaded q,k,v for attention.
         | 
| 208 | 
            +
                # k_lens = seq_lens // get_sequence_parallel_world_size()
         | 
| 209 | 
            +
                # if k_lens is not None:
         | 
| 210 | 
            +
                #     q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
         | 
| 211 | 
            +
                #     k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
         | 
| 212 | 
            +
                #     v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                x = xFuserLongContextAttention()(
         | 
| 215 | 
            +
                    None,
         | 
| 216 | 
            +
                    query=half(q),
         | 
| 217 | 
            +
                    key=half(k),
         | 
| 218 | 
            +
                    value=half(v),
         | 
| 219 | 
            +
                    window_size=self.window_size)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                # TODO: padding after attention.
         | 
| 222 | 
            +
                # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                # output
         | 
| 225 | 
            +
                x = x.flatten(2)
         | 
| 226 | 
            +
                x = self.o(x)
         | 
| 227 | 
            +
                return x
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            def usp_dit_forward_multitalk(
         | 
| 233 | 
            +
                self,
         | 
| 234 | 
            +
                x,
         | 
| 235 | 
            +
                t,
         | 
| 236 | 
            +
                context,
         | 
| 237 | 
            +
                seq_len,
         | 
| 238 | 
            +
                clip_fea=None,
         | 
| 239 | 
            +
                y=None,
         | 
| 240 | 
            +
                audio=None,
         | 
| 241 | 
            +
                ref_target_masks=None,
         | 
| 242 | 
            +
            ):
         | 
| 243 | 
            +
                """
         | 
| 244 | 
            +
                x:              A list of videos each with shape [C, T, H, W].
         | 
| 245 | 
            +
                t:              [B].
         | 
| 246 | 
            +
                context:        A list of text embeddings each with shape [L, C].
         | 
| 247 | 
            +
                """
         | 
| 248 | 
            +
                
         | 
| 249 | 
            +
                assert clip_fea is not None and y is not None
         | 
| 250 | 
            +
                # params
         | 
| 251 | 
            +
                device = self.patch_embedding.weight.device
         | 
| 252 | 
            +
                if self.freqs.device != device:
         | 
| 253 | 
            +
                    self.freqs = self.freqs.to(device)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                _, T, H, W = x[0].shape
         | 
| 256 | 
            +
                N_t = T // self.patch_size[0]
         | 
| 257 | 
            +
                N_h = H // self.patch_size[1]
         | 
| 258 | 
            +
                N_w = W // self.patch_size[2]
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                if y is not None:
         | 
| 261 | 
            +
                    x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 262 | 
            +
                x[0] = x[0].to(context[0].dtype)
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                # embeddings
         | 
| 265 | 
            +
                x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 266 | 
            +
                grid_sizes = torch.stack(
         | 
| 267 | 
            +
                    [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 268 | 
            +
                x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 269 | 
            +
                seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 270 | 
            +
                assert seq_lens.max() <= seq_len
         | 
| 271 | 
            +
                x = torch.cat([
         | 
| 272 | 
            +
                    torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
         | 
| 273 | 
            +
                    for u in x
         | 
| 274 | 
            +
                ])
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                # time embeddings
         | 
| 277 | 
            +
                with amp.autocast(dtype=torch.float32):
         | 
| 278 | 
            +
                    e = self.time_embedding(
         | 
| 279 | 
            +
                        sinusoidal_embedding_1d(self.freq_dim, t).float())
         | 
| 280 | 
            +
                    e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 281 | 
            +
                    assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                # context
         | 
| 284 | 
            +
                context_lens = None
         | 
| 285 | 
            +
                context = self.text_embedding(
         | 
| 286 | 
            +
                    torch.stack([
         | 
| 287 | 
            +
                        torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 288 | 
            +
                        for u in context
         | 
| 289 | 
            +
                    ]))
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                if clip_fea is not None:
         | 
| 292 | 
            +
                    context_clip = self.img_emb(clip_fea)  
         | 
| 293 | 
            +
                    context = torch.concat([context_clip, context], dim=1)
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                # get audio token
         | 
| 296 | 
            +
                audio_cond = audio.to(device=x.device, dtype=x.dtype)
         | 
| 297 | 
            +
                first_frame_audio_emb_s = audio_cond[:, :1, ...] 
         | 
| 298 | 
            +
                latter_frame_audio_emb = audio_cond[:, 1:, ...] 
         | 
| 299 | 
            +
                latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale) 
         | 
| 300 | 
            +
                middle_index = self.audio_window // 2
         | 
| 301 | 
            +
                latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] 
         | 
| 302 | 
            +
                latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
         | 
| 303 | 
            +
                latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] 
         | 
| 304 | 
            +
                latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
         | 
| 305 | 
            +
                latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] 
         | 
| 306 | 
            +
                latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
         | 
| 307 | 
            +
                latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) 
         | 
| 308 | 
            +
                audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) 
         | 
| 309 | 
            +
                human_num = len(audio_embedding)
         | 
| 310 | 
            +
                audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
             | 
| 313 | 
            +
                # convert ref_target_masks to token_ref_target_masks
         | 
| 314 | 
            +
                if ref_target_masks is not None:
         | 
| 315 | 
            +
                    ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32) 
         | 
| 316 | 
            +
                    token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest') 
         | 
| 317 | 
            +
                    token_ref_target_masks = token_ref_target_masks.squeeze(0) 
         | 
| 318 | 
            +
                    token_ref_target_masks = (token_ref_target_masks > 0)
         | 
| 319 | 
            +
                    token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) 
         | 
| 320 | 
            +
                    token_ref_target_masks = token_ref_target_masks.to(x.dtype)
         | 
| 321 | 
            +
                
         | 
| 322 | 
            +
                if self.enable_teacache:
         | 
| 323 | 
            +
                    modulated_inp = e0 if self.use_ret_steps else e
         | 
| 324 | 
            +
                    if self.cnt%3==0: # cond
         | 
| 325 | 
            +
                        if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
         | 
| 326 | 
            +
                            should_calc_cond = True
         | 
| 327 | 
            +
                            self.accumulated_rel_l1_distance_cond = 0
         | 
| 328 | 
            +
                        else:
         | 
| 329 | 
            +
                            rescale_func = np.poly1d(self.coefficients)
         | 
| 330 | 
            +
                            self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
         | 
| 331 | 
            +
                            # print("accumulated_rel_l1_distance_even", self.accumulated_rel_l1_distance_even)
         | 
| 332 | 
            +
                            if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
         | 
| 333 | 
            +
                                should_calc_cond = False
         | 
| 334 | 
            +
                            else:
         | 
| 335 | 
            +
                                should_calc_cond = True
         | 
| 336 | 
            +
                                self.accumulated_rel_l1_distance_cond = 0
         | 
| 337 | 
            +
                        self.previous_e0_cond = modulated_inp.clone()
         | 
| 338 | 
            +
                    elif self.cnt%3==1: # drop_text
         | 
| 339 | 
            +
                        if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
         | 
| 340 | 
            +
                            should_calc_drop_text = True
         | 
| 341 | 
            +
                            self.accumulated_rel_l1_distance_drop_text = 0
         | 
| 342 | 
            +
                        else:
         | 
| 343 | 
            +
                            rescale_func = np.poly1d(self.coefficients)
         | 
| 344 | 
            +
                            self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
         | 
| 345 | 
            +
                            if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
         | 
| 346 | 
            +
                                should_calc_drop_text = False
         | 
| 347 | 
            +
                            else:
         | 
| 348 | 
            +
                                should_calc_drop_text = True
         | 
| 349 | 
            +
                                self.accumulated_rel_l1_distance_drop_text = 0
         | 
| 350 | 
            +
                        self.previous_e0_drop_text = modulated_inp.clone()
         | 
| 351 | 
            +
                    else: # uncond
         | 
| 352 | 
            +
                        if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
         | 
| 353 | 
            +
                            should_calc_uncond = True
         | 
| 354 | 
            +
                            self.accumulated_rel_l1_distance_uncond = 0
         | 
| 355 | 
            +
                        else:
         | 
| 356 | 
            +
                            rescale_func = np.poly1d(self.coefficients)
         | 
| 357 | 
            +
                            self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
         | 
| 358 | 
            +
                            if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
         | 
| 359 | 
            +
                                should_calc_uncond = False
         | 
| 360 | 
            +
                            else:
         | 
| 361 | 
            +
                                should_calc_uncond = True
         | 
| 362 | 
            +
                                self.accumulated_rel_l1_distance_uncond = 0
         | 
| 363 | 
            +
                        self.previous_e0_uncond = modulated_inp.clone()
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                # Context Parallel
         | 
| 366 | 
            +
                x = torch.chunk(
         | 
| 367 | 
            +
                    x, get_sequence_parallel_world_size(),
         | 
| 368 | 
            +
                    dim=1)[get_sequence_parallel_rank()]
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                # arguments
         | 
| 371 | 
            +
                kwargs = dict(
         | 
| 372 | 
            +
                    e=e0,
         | 
| 373 | 
            +
                    seq_lens=seq_lens,
         | 
| 374 | 
            +
                    grid_sizes=grid_sizes,
         | 
| 375 | 
            +
                    freqs=self.freqs,
         | 
| 376 | 
            +
                    context=context,
         | 
| 377 | 
            +
                    context_lens=context_lens,
         | 
| 378 | 
            +
                    audio_embedding=audio_embedding,
         | 
| 379 | 
            +
                    ref_target_masks=token_ref_target_masks,
         | 
| 380 | 
            +
                    human_num=human_num,
         | 
| 381 | 
            +
                    )
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                if self.enable_teacache:
         | 
| 384 | 
            +
                    if self.cnt%3==0:
         | 
| 385 | 
            +
                        if not should_calc_cond:
         | 
| 386 | 
            +
                            x +=  self.previous_residual_cond
         | 
| 387 | 
            +
                        else:
         | 
| 388 | 
            +
                            ori_x = x.clone()
         | 
| 389 | 
            +
                            for block in self.blocks:
         | 
| 390 | 
            +
                                x = block(x, **kwargs)
         | 
| 391 | 
            +
                            self.previous_residual_cond = x - ori_x
         | 
| 392 | 
            +
                    elif self.cnt%3==1:
         | 
| 393 | 
            +
                        if not should_calc_drop_text:
         | 
| 394 | 
            +
                            x +=  self.previous_residual_drop_text
         | 
| 395 | 
            +
                        else:
         | 
| 396 | 
            +
                            ori_x = x.clone()
         | 
| 397 | 
            +
                            for block in self.blocks:
         | 
| 398 | 
            +
                                x = block(x, **kwargs)
         | 
| 399 | 
            +
                            self.previous_residual_drop_text = x - ori_x
         | 
| 400 | 
            +
                    else:
         | 
| 401 | 
            +
                        if not should_calc_uncond:
         | 
| 402 | 
            +
                            x +=  self.previous_residual_uncond
         | 
| 403 | 
            +
                        else:
         | 
| 404 | 
            +
                            ori_x = x.clone()
         | 
| 405 | 
            +
                            for block in self.blocks:
         | 
| 406 | 
            +
                                x = block(x, **kwargs)
         | 
| 407 | 
            +
                            self.previous_residual_uncond = x - ori_x
         | 
| 408 | 
            +
                else:
         | 
| 409 | 
            +
                    for block in self.blocks:
         | 
| 410 | 
            +
                        x = block(x, **kwargs)
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                # head
         | 
| 413 | 
            +
                x = self.head(x, e)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                # Context Parallel
         | 
| 416 | 
            +
                x = get_sp_group().all_gather(x, dim=1)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                # unpatchify
         | 
| 419 | 
            +
                x = self.unpatchify(x, grid_sizes)
         | 
| 420 | 
            +
                if self.enable_teacache:
         | 
| 421 | 
            +
                    self.cnt += 1
         | 
| 422 | 
            +
                    if self.cnt >= self.num_steps:
         | 
| 423 | 
            +
                        self.cnt = 0
         | 
| 424 | 
            +
                    
         | 
| 425 | 
            +
                return torch.stack(x).float()
         | 
| 426 | 
            +
             | 
| 427 | 
            +
             | 
| 428 | 
            +
            def usp_attn_forward_multitalk(self,
         | 
| 429 | 
            +
                                 x,
         | 
| 430 | 
            +
                                 seq_lens,
         | 
| 431 | 
            +
                                 grid_sizes,
         | 
| 432 | 
            +
                                 freqs,
         | 
| 433 | 
            +
                                 dtype=torch.bfloat16,
         | 
| 434 | 
            +
                                 ref_target_masks=None):
         | 
| 435 | 
            +
                b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 436 | 
            +
                half_dtypes = (torch.float16, torch.bfloat16)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                def half(x):
         | 
| 439 | 
            +
                    return x if x.dtype in half_dtypes else x.to(dtype)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                # query, key, value function
         | 
| 442 | 
            +
                def qkv_fn(x):
         | 
| 443 | 
            +
                    q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 444 | 
            +
                    k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 445 | 
            +
                    v = self.v(x).view(b, s, n, d)
         | 
| 446 | 
            +
                    return q, k, v
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                q, k, v = qkv_fn(x)
         | 
| 449 | 
            +
                q = rope_apply(q, grid_sizes, freqs)
         | 
| 450 | 
            +
                k = rope_apply(k, grid_sizes, freqs)
         | 
| 451 | 
            +
             | 
| 452 | 
            +
             | 
| 453 | 
            +
                x = xFuserLongContextAttention()(
         | 
| 454 | 
            +
                    None,
         | 
| 455 | 
            +
                    query=half(q),
         | 
| 456 | 
            +
                    key=half(k),
         | 
| 457 | 
            +
                    value=half(v),
         | 
| 458 | 
            +
                    window_size=self.window_size)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
             | 
| 461 | 
            +
                # output
         | 
| 462 | 
            +
                x = x.flatten(2)
         | 
| 463 | 
            +
                x = self.o(x)
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                with torch.no_grad():
         | 
| 466 | 
            +
                    x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0], 
         | 
| 467 | 
            +
                                                        ref_target_masks=ref_target_masks, enable_sp=True) 
         | 
| 468 | 
            +
             | 
| 469 | 
            +
                return x, x_ref_attn_map
         | 
| 470 | 
            +
             | 
| 471 | 
            +
             | 
| 472 | 
            +
             | 
| 473 | 
            +
             | 
| 474 | 
            +
            def usp_crossattn_multi_forward_multitalk(self, 
         | 
| 475 | 
            +
                                                    x: torch.Tensor, 
         | 
| 476 | 
            +
                                                    encoder_hidden_states: torch.Tensor,  # 1, 21, 64, C
         | 
| 477 | 
            +
                                                    shape=None, 
         | 
| 478 | 
            +
                                                    x_ref_attn_map=None,
         | 
| 479 | 
            +
                                                    human_num=None) -> torch.Tensor:
         | 
| 480 | 
            +
                    
         | 
| 481 | 
            +
                    N_t, N_h, N_w = shape 
         | 
| 482 | 
            +
                    sp_size = get_sequence_parallel_world_size()
         | 
| 483 | 
            +
                    sp_rank = get_sequence_parallel_rank()
         | 
| 484 | 
            +
                    audio_tokens_per_frame = 32
         | 
| 485 | 
            +
                    visual_seqlen, frame_ids = split_token_counts_and_frame_ids(N_t, N_h * N_w, sp_size, sp_rank)
         | 
| 486 | 
            +
                    encoder_hidden_states = encoder_hidden_states[:, min(frame_ids):max(frame_ids)+1, ...]
         | 
| 487 | 
            +
                    encoder_hidden_states = rearrange(encoder_hidden_states, "B T N C -> B (T N) C")
         | 
| 488 | 
            +
                    N_a = len(frame_ids)
         | 
| 489 | 
            +
                    kv_seq = [audio_tokens_per_frame * human_num] * N_a
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    if human_num == 1:
         | 
| 492 | 
            +
                        return super(SingleStreamMutiAttention, self).forward(x, encoder_hidden_states, shape, enable_sp=True, kv_seq=kv_seq)
         | 
| 493 | 
            +
             | 
| 494 | 
            +
             | 
| 495 | 
            +
                    # get q for hidden_state
         | 
| 496 | 
            +
                    B, N, C = x.shape
         | 
| 497 | 
            +
                    q = self.q_linear(x) 
         | 
| 498 | 
            +
                    q_shape = (B, N, self.num_heads, self.head_dim) 
         | 
| 499 | 
            +
                    q = q.view(q_shape).permute((0, 2, 1, 3))
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    if self.qk_norm:
         | 
| 502 | 
            +
                        q = self.q_norm(q)
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    max_values = x_ref_attn_map.max(1).values[:, None, None] 
         | 
| 505 | 
            +
                    min_values = x_ref_attn_map.min(1).values[:, None, None] 
         | 
| 506 | 
            +
                    max_min_values = torch.cat([max_values, min_values], dim=2)
         | 
| 507 | 
            +
                    max_min_values = get_sp_group().all_gather(max_min_values, dim=1)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
         | 
| 510 | 
            +
                    human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
         | 
| 511 | 
            +
             | 
| 512 | 
            +
                    human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
         | 
| 513 | 
            +
                    human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
         | 
| 514 | 
            +
                    back   = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
         | 
| 515 | 
            +
                    max_indices = x_ref_attn_map.argmax(dim=0)
         | 
| 516 | 
            +
                    normalized_map = torch.stack([human1, human2, back], dim=1)
         | 
| 517 | 
            +
                    normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N 
         | 
| 518 | 
            +
                    q = self.rope_1d(q, normalized_pos)
         | 
| 519 | 
            +
             
         | 
| 520 | 
            +
                    encoder_kv = self.kv_linear(encoder_hidden_states) 
         | 
| 521 | 
            +
                    encoder_kv_shape = (B, encoder_hidden_states.size(1), 2, self.num_heads, self.head_dim)
         | 
| 522 | 
            +
                    encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
         | 
| 523 | 
            +
                    encoder_k, encoder_v = encoder_kv.unbind(0) # B H N C
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    if self.qk_norm:
         | 
| 526 | 
            +
                        encoder_k = self.add_k_norm(encoder_k)
         | 
| 527 | 
            +
             | 
| 528 | 
            +
                    # position embedding for condition audio embeddings
         | 
| 529 | 
            +
                    per_frame = torch.zeros(audio_tokens_per_frame * human_num, dtype=encoder_k.dtype).to(encoder_k.device)
         | 
| 530 | 
            +
                    per_frame[:audio_tokens_per_frame] = (self.rope_h1[0] + self.rope_h1[1]) / 2
         | 
| 531 | 
            +
                    per_frame[audio_tokens_per_frame:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
         | 
| 532 | 
            +
                    encoder_pos = torch.concat([per_frame]*N_a, dim=0)
         | 
| 533 | 
            +
                    encoder_k = self.rope_1d(encoder_k, encoder_pos)
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    # get attn
         | 
| 536 | 
            +
                    q = rearrange(q, "B H M K -> B M H K")
         | 
| 537 | 
            +
                    encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
         | 
| 538 | 
            +
                    encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
         | 
| 539 | 
            +
                    attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(visual_seqlen, kv_seq)
         | 
| 540 | 
            +
                    x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
         | 
| 541 | 
            +
                    x = rearrange(x, "B M H K -> B H M K")
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    # linear transform
         | 
| 544 | 
            +
                    x_output_shape = (B, N, C)
         | 
| 545 | 
            +
                    x = x.transpose(1, 2) 
         | 
| 546 | 
            +
                    x = x.reshape(x_output_shape) 
         | 
| 547 | 
            +
                    x = self.proj(x) 
         | 
| 548 | 
            +
                    x = self.proj_drop(x)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    return x
         | 
    	
        wan/first_last_frame2video.py
    ADDED
    
    | @@ -0,0 +1,377 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
            import types
         | 
| 9 | 
            +
            from contextlib import contextmanager
         | 
| 10 | 
            +
            from functools import partial
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import torch.cuda.amp as amp
         | 
| 15 | 
            +
            import torch.distributed as dist
         | 
| 16 | 
            +
            import torchvision.transforms.functional as TF
         | 
| 17 | 
            +
            from tqdm import tqdm
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from .distributed.fsdp import shard_model
         | 
| 20 | 
            +
            from .modules.clip import CLIPModel
         | 
| 21 | 
            +
            from .modules.model import WanModel
         | 
| 22 | 
            +
            from .modules.t5 import T5EncoderModel
         | 
| 23 | 
            +
            from .modules.vae import WanVAE
         | 
| 24 | 
            +
            from .utils.fm_solvers import (
         | 
| 25 | 
            +
                FlowDPMSolverMultistepScheduler,
         | 
| 26 | 
            +
                get_sampling_sigmas,
         | 
| 27 | 
            +
                retrieve_timesteps,
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
            from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class WanFLF2V:
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(
         | 
| 35 | 
            +
                    self,
         | 
| 36 | 
            +
                    config,
         | 
| 37 | 
            +
                    checkpoint_dir,
         | 
| 38 | 
            +
                    device_id=0,
         | 
| 39 | 
            +
                    rank=0,
         | 
| 40 | 
            +
                    t5_fsdp=False,
         | 
| 41 | 
            +
                    dit_fsdp=False,
         | 
| 42 | 
            +
                    use_usp=False,
         | 
| 43 | 
            +
                    t5_cpu=False,
         | 
| 44 | 
            +
                    init_on_cpu=True,
         | 
| 45 | 
            +
                ):
         | 
| 46 | 
            +
                    r"""
         | 
| 47 | 
            +
                    Initializes the image-to-video generation model components.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    Args:
         | 
| 50 | 
            +
                        config (EasyDict):
         | 
| 51 | 
            +
                            Object containing model parameters initialized from config.py
         | 
| 52 | 
            +
                        checkpoint_dir (`str`):
         | 
| 53 | 
            +
                            Path to directory containing model checkpoints
         | 
| 54 | 
            +
                        device_id (`int`,  *optional*, defaults to 0):
         | 
| 55 | 
            +
                            Id of target GPU device
         | 
| 56 | 
            +
                        rank (`int`,  *optional*, defaults to 0):
         | 
| 57 | 
            +
                            Process rank for distributed training
         | 
| 58 | 
            +
                        t5_fsdp (`bool`, *optional*, defaults to False):
         | 
| 59 | 
            +
                            Enable FSDP sharding for T5 model
         | 
| 60 | 
            +
                        dit_fsdp (`bool`, *optional*, defaults to False):
         | 
| 61 | 
            +
                            Enable FSDP sharding for DiT model
         | 
| 62 | 
            +
                        use_usp (`bool`, *optional*, defaults to False):
         | 
| 63 | 
            +
                            Enable distribution strategy of USP.
         | 
| 64 | 
            +
                        t5_cpu (`bool`, *optional*, defaults to False):
         | 
| 65 | 
            +
                            Whether to place T5 model on CPU. Only works without t5_fsdp.
         | 
| 66 | 
            +
                        init_on_cpu (`bool`, *optional*, defaults to True):
         | 
| 67 | 
            +
                            Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
         | 
| 68 | 
            +
                    """
         | 
| 69 | 
            +
                    self.device = torch.device(f"cuda:{device_id}")
         | 
| 70 | 
            +
                    self.config = config
         | 
| 71 | 
            +
                    self.rank = rank
         | 
| 72 | 
            +
                    self.use_usp = use_usp
         | 
| 73 | 
            +
                    self.t5_cpu = t5_cpu
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self.num_train_timesteps = config.num_train_timesteps
         | 
| 76 | 
            +
                    self.param_dtype = config.param_dtype
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    shard_fn = partial(shard_model, device_id=device_id)
         | 
| 79 | 
            +
                    self.text_encoder = T5EncoderModel(
         | 
| 80 | 
            +
                        text_len=config.text_len,
         | 
| 81 | 
            +
                        dtype=config.t5_dtype,
         | 
| 82 | 
            +
                        device=torch.device('cpu'),
         | 
| 83 | 
            +
                        checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
         | 
| 84 | 
            +
                        tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
         | 
| 85 | 
            +
                        shard_fn=shard_fn if t5_fsdp else None,
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.vae_stride = config.vae_stride
         | 
| 89 | 
            +
                    self.patch_size = config.patch_size
         | 
| 90 | 
            +
                    self.vae = WanVAE(
         | 
| 91 | 
            +
                        vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
         | 
| 92 | 
            +
                        device=self.device)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.clip = CLIPModel(
         | 
| 95 | 
            +
                        dtype=config.clip_dtype,
         | 
| 96 | 
            +
                        device=self.device,
         | 
| 97 | 
            +
                        checkpoint_path=os.path.join(checkpoint_dir,
         | 
| 98 | 
            +
                                                     config.clip_checkpoint),
         | 
| 99 | 
            +
                        tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    logging.info(f"Creating WanModel from {checkpoint_dir}")
         | 
| 102 | 
            +
                    self.model = WanModel.from_pretrained(checkpoint_dir)
         | 
| 103 | 
            +
                    self.model.eval().requires_grad_(False)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if t5_fsdp or dit_fsdp or use_usp:
         | 
| 106 | 
            +
                        init_on_cpu = False
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    if use_usp:
         | 
| 109 | 
            +
                        from xfuser.core.distributed import get_sequence_parallel_world_size
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        from .distributed.xdit_context_parallel import (
         | 
| 112 | 
            +
                            usp_attn_forward,
         | 
| 113 | 
            +
                            usp_dit_forward,
         | 
| 114 | 
            +
                        )
         | 
| 115 | 
            +
                        for block in self.model.blocks:
         | 
| 116 | 
            +
                            block.self_attn.forward = types.MethodType(
         | 
| 117 | 
            +
                                usp_attn_forward, block.self_attn)
         | 
| 118 | 
            +
                        self.model.forward = types.MethodType(usp_dit_forward, self.model)
         | 
| 119 | 
            +
                        self.sp_size = get_sequence_parallel_world_size()
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        self.sp_size = 1
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if dist.is_initialized():
         | 
| 124 | 
            +
                        dist.barrier()
         | 
| 125 | 
            +
                    if dit_fsdp:
         | 
| 126 | 
            +
                        self.model = shard_fn(self.model)
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        if not init_on_cpu:
         | 
| 129 | 
            +
                            self.model.to(self.device)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    self.sample_neg_prompt = config.sample_neg_prompt
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def generate(self,
         | 
| 134 | 
            +
                             input_prompt,
         | 
| 135 | 
            +
                             first_frame,
         | 
| 136 | 
            +
                             last_frame,
         | 
| 137 | 
            +
                             max_area=720 * 1280,
         | 
| 138 | 
            +
                             frame_num=81,
         | 
| 139 | 
            +
                             shift=16,
         | 
| 140 | 
            +
                             sample_solver='unipc',
         | 
| 141 | 
            +
                             sampling_steps=50,
         | 
| 142 | 
            +
                             guide_scale=5.5,
         | 
| 143 | 
            +
                             n_prompt="",
         | 
| 144 | 
            +
                             seed=-1,
         | 
| 145 | 
            +
                             offload_model=True):
         | 
| 146 | 
            +
                    r"""
         | 
| 147 | 
            +
                    Generates video frames from input first-last frame and text prompt using diffusion process.
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    Args:
         | 
| 150 | 
            +
                        input_prompt (`str`):
         | 
| 151 | 
            +
                            Text prompt for content generation.
         | 
| 152 | 
            +
                        first_frame (PIL.Image.Image):
         | 
| 153 | 
            +
                            Input image tensor. Shape: [3, H, W]
         | 
| 154 | 
            +
                        last_frame (PIL.Image.Image):
         | 
| 155 | 
            +
                            Input image tensor. Shape: [3, H, W]
         | 
| 156 | 
            +
                            [NOTE] If the sizes of first_frame and last_frame are mismatched, last_frame will be cropped & resized
         | 
| 157 | 
            +
                            to match first_frame.
         | 
| 158 | 
            +
                        max_area (`int`, *optional*, defaults to 720*1280):
         | 
| 159 | 
            +
                            Maximum pixel area for latent space calculation. Controls video resolution scaling
         | 
| 160 | 
            +
                        frame_num (`int`, *optional*, defaults to 81):
         | 
| 161 | 
            +
                            How many frames to sample from a video. The number should be 4n+1
         | 
| 162 | 
            +
                        shift (`float`, *optional*, defaults to 5.0):
         | 
| 163 | 
            +
                            Noise schedule shift parameter. Affects temporal dynamics
         | 
| 164 | 
            +
                            [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
         | 
| 165 | 
            +
                        sample_solver (`str`, *optional*, defaults to 'unipc'):
         | 
| 166 | 
            +
                            Solver used to sample the video.
         | 
| 167 | 
            +
                        sampling_steps (`int`, *optional*, defaults to 40):
         | 
| 168 | 
            +
                            Number of diffusion sampling steps. Higher values improve quality but slow generation
         | 
| 169 | 
            +
                        guide_scale (`float`, *optional*, defaults 5.0):
         | 
| 170 | 
            +
                            Classifier-free guidance scale. Controls prompt adherence vs. creativity
         | 
| 171 | 
            +
                        n_prompt (`str`, *optional*, defaults to ""):
         | 
| 172 | 
            +
                            Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
         | 
| 173 | 
            +
                        seed (`int`, *optional*, defaults to -1):
         | 
| 174 | 
            +
                            Random seed for noise generation. If -1, use random seed
         | 
| 175 | 
            +
                        offload_model (`bool`, *optional*, defaults to True):
         | 
| 176 | 
            +
                            If True, offloads models to CPU during generation to save VRAM
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    Returns:
         | 
| 179 | 
            +
                        torch.Tensor:
         | 
| 180 | 
            +
                            Generated video frames tensor. Dimensions: (C, N H, W) where:
         | 
| 181 | 
            +
                            - C: Color channels (3 for RGB)
         | 
| 182 | 
            +
                            - N: Number of frames (81)
         | 
| 183 | 
            +
                            - H: Frame height (from max_area)
         | 
| 184 | 
            +
                            - W: Frame width from max_area)
         | 
| 185 | 
            +
                    """
         | 
| 186 | 
            +
                    first_frame_size = first_frame.size
         | 
| 187 | 
            +
                    last_frame_size = last_frame.size
         | 
| 188 | 
            +
                    first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).to(
         | 
| 189 | 
            +
                        self.device)
         | 
| 190 | 
            +
                    last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).to(
         | 
| 191 | 
            +
                        self.device)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    F = frame_num
         | 
| 194 | 
            +
                    first_frame_h, first_frame_w = first_frame.shape[1:]
         | 
| 195 | 
            +
                    aspect_ratio = first_frame_h / first_frame_w
         | 
| 196 | 
            +
                    lat_h = round(
         | 
| 197 | 
            +
                        np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
         | 
| 198 | 
            +
                        self.patch_size[1] * self.patch_size[1])
         | 
| 199 | 
            +
                    lat_w = round(
         | 
| 200 | 
            +
                        np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
         | 
| 201 | 
            +
                        self.patch_size[2] * self.patch_size[2])
         | 
| 202 | 
            +
                    first_frame_h = lat_h * self.vae_stride[1]
         | 
| 203 | 
            +
                    first_frame_w = lat_w * self.vae_stride[2]
         | 
| 204 | 
            +
                    if first_frame_size != last_frame_size:
         | 
| 205 | 
            +
                        # 1. resize
         | 
| 206 | 
            +
                        last_frame_resize_ratio = max(
         | 
| 207 | 
            +
                            first_frame_size[0] / last_frame_size[0],
         | 
| 208 | 
            +
                            first_frame_size[1] / last_frame_size[1])
         | 
| 209 | 
            +
                        last_frame_size = [
         | 
| 210 | 
            +
                            round(last_frame_size[0] * last_frame_resize_ratio),
         | 
| 211 | 
            +
                            round(last_frame_size[1] * last_frame_resize_ratio),
         | 
| 212 | 
            +
                        ]
         | 
| 213 | 
            +
                        # 2. center crop
         | 
| 214 | 
            +
                        last_frame = TF.center_crop(last_frame, last_frame_size)
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
         | 
| 217 | 
            +
                        self.patch_size[1] * self.patch_size[2])
         | 
| 218 | 
            +
                    max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
         | 
| 221 | 
            +
                    seed_g = torch.Generator(device=self.device)
         | 
| 222 | 
            +
                    seed_g.manual_seed(seed)
         | 
| 223 | 
            +
                    noise = torch.randn(
         | 
| 224 | 
            +
                        16, (F - 1) // 4 + 1,
         | 
| 225 | 
            +
                        lat_h,
         | 
| 226 | 
            +
                        lat_w,
         | 
| 227 | 
            +
                        dtype=torch.float32,
         | 
| 228 | 
            +
                        generator=seed_g,
         | 
| 229 | 
            +
                        device=self.device)
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
         | 
| 232 | 
            +
                    msk[:, 1:-1] = 0
         | 
| 233 | 
            +
                    msk = torch.concat([
         | 
| 234 | 
            +
                        torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
         | 
| 235 | 
            +
                    ],
         | 
| 236 | 
            +
                                       dim=1)
         | 
| 237 | 
            +
                    msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
         | 
| 238 | 
            +
                    msk = msk.transpose(1, 2)[0]
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    if n_prompt == "":
         | 
| 241 | 
            +
                        n_prompt = self.sample_neg_prompt
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # preprocess
         | 
| 244 | 
            +
                    if not self.t5_cpu:
         | 
| 245 | 
            +
                        self.text_encoder.model.to(self.device)
         | 
| 246 | 
            +
                        context = self.text_encoder([input_prompt], self.device)
         | 
| 247 | 
            +
                        context_null = self.text_encoder([n_prompt], self.device)
         | 
| 248 | 
            +
                        if offload_model:
         | 
| 249 | 
            +
                            self.text_encoder.model.cpu()
         | 
| 250 | 
            +
                    else:
         | 
| 251 | 
            +
                        context = self.text_encoder([input_prompt], torch.device('cpu'))
         | 
| 252 | 
            +
                        context_null = self.text_encoder([n_prompt], torch.device('cpu'))
         | 
| 253 | 
            +
                        context = [t.to(self.device) for t in context]
         | 
| 254 | 
            +
                        context_null = [t.to(self.device) for t in context_null]
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    self.clip.model.to(self.device)
         | 
| 257 | 
            +
                    clip_context = self.clip.visual(
         | 
| 258 | 
            +
                        [first_frame[:, None, :, :], last_frame[:, None, :, :]])
         | 
| 259 | 
            +
                    if offload_model:
         | 
| 260 | 
            +
                        self.clip.model.cpu()
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    y = self.vae.encode([
         | 
| 263 | 
            +
                        torch.concat([
         | 
| 264 | 
            +
                            torch.nn.functional.interpolate(
         | 
| 265 | 
            +
                                first_frame[None].cpu(),
         | 
| 266 | 
            +
                                size=(first_frame_h, first_frame_w),
         | 
| 267 | 
            +
                                mode='bicubic').transpose(0, 1),
         | 
| 268 | 
            +
                            torch.zeros(3, F - 2, first_frame_h, first_frame_w),
         | 
| 269 | 
            +
                            torch.nn.functional.interpolate(
         | 
| 270 | 
            +
                                last_frame[None].cpu(),
         | 
| 271 | 
            +
                                size=(first_frame_h, first_frame_w),
         | 
| 272 | 
            +
                                mode='bicubic').transpose(0, 1),
         | 
| 273 | 
            +
                        ],
         | 
| 274 | 
            +
                                     dim=1).to(self.device)
         | 
| 275 | 
            +
                    ])[0]
         | 
| 276 | 
            +
                    y = torch.concat([msk, y])
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    @contextmanager
         | 
| 279 | 
            +
                    def noop_no_sync():
         | 
| 280 | 
            +
                        yield
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    no_sync = getattr(self.model, 'no_sync', noop_no_sync)
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    # evaluation mode
         | 
| 285 | 
            +
                    with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                        if sample_solver == 'unipc':
         | 
| 288 | 
            +
                            sample_scheduler = FlowUniPCMultistepScheduler(
         | 
| 289 | 
            +
                                num_train_timesteps=self.num_train_timesteps,
         | 
| 290 | 
            +
                                shift=1,
         | 
| 291 | 
            +
                                use_dynamic_shifting=False)
         | 
| 292 | 
            +
                            sample_scheduler.set_timesteps(
         | 
| 293 | 
            +
                                sampling_steps, device=self.device, shift=shift)
         | 
| 294 | 
            +
                            timesteps = sample_scheduler.timesteps
         | 
| 295 | 
            +
                        elif sample_solver == 'dpm++':
         | 
| 296 | 
            +
                            sample_scheduler = FlowDPMSolverMultistepScheduler(
         | 
| 297 | 
            +
                                num_train_timesteps=self.num_train_timesteps,
         | 
| 298 | 
            +
                                shift=1,
         | 
| 299 | 
            +
                                use_dynamic_shifting=False)
         | 
| 300 | 
            +
                            sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
         | 
| 301 | 
            +
                            timesteps, _ = retrieve_timesteps(
         | 
| 302 | 
            +
                                sample_scheduler,
         | 
| 303 | 
            +
                                device=self.device,
         | 
| 304 | 
            +
                                sigmas=sampling_sigmas)
         | 
| 305 | 
            +
                        else:
         | 
| 306 | 
            +
                            raise NotImplementedError("Unsupported solver.")
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                        # sample videos
         | 
| 309 | 
            +
                        latent = noise
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                        arg_c = {
         | 
| 312 | 
            +
                            'context': [context[0]],
         | 
| 313 | 
            +
                            'clip_fea': clip_context,
         | 
| 314 | 
            +
                            'seq_len': max_seq_len,
         | 
| 315 | 
            +
                            'y': [y],
         | 
| 316 | 
            +
                        }
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                        arg_null = {
         | 
| 319 | 
            +
                            'context': context_null,
         | 
| 320 | 
            +
                            'clip_fea': clip_context,
         | 
| 321 | 
            +
                            'seq_len': max_seq_len,
         | 
| 322 | 
            +
                            'y': [y],
         | 
| 323 | 
            +
                        }
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                        if offload_model:
         | 
| 326 | 
            +
                            torch.cuda.empty_cache()
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                        self.model.to(self.device)
         | 
| 329 | 
            +
                        for _, t in enumerate(tqdm(timesteps)):
         | 
| 330 | 
            +
                            latent_model_input = [latent.to(self.device)]
         | 
| 331 | 
            +
                            timestep = [t]
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                            timestep = torch.stack(timestep).to(self.device)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                            noise_pred_cond = self.model(
         | 
| 336 | 
            +
                                latent_model_input, t=timestep, **arg_c)[0].to(
         | 
| 337 | 
            +
                                    torch.device('cpu') if offload_model else self.device)
         | 
| 338 | 
            +
                            if offload_model:
         | 
| 339 | 
            +
                                torch.cuda.empty_cache()
         | 
| 340 | 
            +
                            noise_pred_uncond = self.model(
         | 
| 341 | 
            +
                                latent_model_input, t=timestep, **arg_null)[0].to(
         | 
| 342 | 
            +
                                    torch.device('cpu') if offload_model else self.device)
         | 
| 343 | 
            +
                            if offload_model:
         | 
| 344 | 
            +
                                torch.cuda.empty_cache()
         | 
| 345 | 
            +
                            noise_pred = noise_pred_uncond + guide_scale * (
         | 
| 346 | 
            +
                                noise_pred_cond - noise_pred_uncond)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                            latent = latent.to(
         | 
| 349 | 
            +
                                torch.device('cpu') if offload_model else self.device)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                            temp_x0 = sample_scheduler.step(
         | 
| 352 | 
            +
                                noise_pred.unsqueeze(0),
         | 
| 353 | 
            +
                                t,
         | 
| 354 | 
            +
                                latent.unsqueeze(0),
         | 
| 355 | 
            +
                                return_dict=False,
         | 
| 356 | 
            +
                                generator=seed_g)[0]
         | 
| 357 | 
            +
                            latent = temp_x0.squeeze(0)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                            x0 = [latent.to(self.device)]
         | 
| 360 | 
            +
                            del latent_model_input, timestep
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                        if offload_model:
         | 
| 363 | 
            +
                            self.model.cpu()
         | 
| 364 | 
            +
                            torch.cuda.empty_cache()
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                        if self.rank == 0:
         | 
| 367 | 
            +
                            videos = self.vae.decode(x0)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                    del noise, latent
         | 
| 370 | 
            +
                    del sample_scheduler
         | 
| 371 | 
            +
                    if offload_model:
         | 
| 372 | 
            +
                        gc.collect()
         | 
| 373 | 
            +
                        torch.cuda.synchronize()
         | 
| 374 | 
            +
                    if dist.is_initialized():
         | 
| 375 | 
            +
                        dist.barrier()
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                    return videos[0] if self.rank == 0 else None
         | 
    	
        wan/image2video.py
    ADDED
    
    | @@ -0,0 +1,350 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import gc
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import sys
         | 
| 8 | 
            +
            import types
         | 
| 9 | 
            +
            from contextlib import contextmanager
         | 
| 10 | 
            +
            from functools import partial
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import numpy as np
         | 
| 13 | 
            +
            import torch
         | 
| 14 | 
            +
            import torch.cuda.amp as amp
         | 
| 15 | 
            +
            import torch.distributed as dist
         | 
| 16 | 
            +
            import torchvision.transforms.functional as TF
         | 
| 17 | 
            +
            from tqdm import tqdm
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from .distributed.fsdp import shard_model
         | 
| 20 | 
            +
            from .modules.clip import CLIPModel
         | 
| 21 | 
            +
            from .modules.model import WanModel
         | 
| 22 | 
            +
            from .modules.t5 import T5EncoderModel
         | 
| 23 | 
            +
            from .modules.vae import WanVAE
         | 
| 24 | 
            +
            from .utils.fm_solvers import (
         | 
| 25 | 
            +
                FlowDPMSolverMultistepScheduler,
         | 
| 26 | 
            +
                get_sampling_sigmas,
         | 
| 27 | 
            +
                retrieve_timesteps,
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
            from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
         | 
| 30 | 
            +
             | 
| 31 | 
            +
             | 
| 32 | 
            +
            class WanI2V:
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(
         | 
| 35 | 
            +
                    self,
         | 
| 36 | 
            +
                    config,
         | 
| 37 | 
            +
                    checkpoint_dir,
         | 
| 38 | 
            +
                    device_id=0,
         | 
| 39 | 
            +
                    rank=0,
         | 
| 40 | 
            +
                    t5_fsdp=False,
         | 
| 41 | 
            +
                    dit_fsdp=False,
         | 
| 42 | 
            +
                    use_usp=False,
         | 
| 43 | 
            +
                    t5_cpu=False,
         | 
| 44 | 
            +
                    init_on_cpu=True,
         | 
| 45 | 
            +
                ):
         | 
| 46 | 
            +
                    r"""
         | 
| 47 | 
            +
                    Initializes the image-to-video generation model components.
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    Args:
         | 
| 50 | 
            +
                        config (EasyDict):
         | 
| 51 | 
            +
                            Object containing model parameters initialized from config.py
         | 
| 52 | 
            +
                        checkpoint_dir (`str`):
         | 
| 53 | 
            +
                            Path to directory containing model checkpoints
         | 
| 54 | 
            +
                        device_id (`int`,  *optional*, defaults to 0):
         | 
| 55 | 
            +
                            Id of target GPU device
         | 
| 56 | 
            +
                        rank (`int`,  *optional*, defaults to 0):
         | 
| 57 | 
            +
                            Process rank for distributed training
         | 
| 58 | 
            +
                        t5_fsdp (`bool`, *optional*, defaults to False):
         | 
| 59 | 
            +
                            Enable FSDP sharding for T5 model
         | 
| 60 | 
            +
                        dit_fsdp (`bool`, *optional*, defaults to False):
         | 
| 61 | 
            +
                            Enable FSDP sharding for DiT model
         | 
| 62 | 
            +
                        use_usp (`bool`, *optional*, defaults to False):
         | 
| 63 | 
            +
                            Enable distribution strategy of USP.
         | 
| 64 | 
            +
                        t5_cpu (`bool`, *optional*, defaults to False):
         | 
| 65 | 
            +
                            Whether to place T5 model on CPU. Only works without t5_fsdp.
         | 
| 66 | 
            +
                        init_on_cpu (`bool`, *optional*, defaults to True):
         | 
| 67 | 
            +
                            Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
         | 
| 68 | 
            +
                    """
         | 
| 69 | 
            +
                    self.device = torch.device(f"cuda:{device_id}")
         | 
| 70 | 
            +
                    self.config = config
         | 
| 71 | 
            +
                    self.rank = rank
         | 
| 72 | 
            +
                    self.use_usp = use_usp
         | 
| 73 | 
            +
                    self.t5_cpu = t5_cpu
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    self.num_train_timesteps = config.num_train_timesteps
         | 
| 76 | 
            +
                    self.param_dtype = config.param_dtype
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    shard_fn = partial(shard_model, device_id=device_id)
         | 
| 79 | 
            +
                    self.text_encoder = T5EncoderModel(
         | 
| 80 | 
            +
                        text_len=config.text_len,
         | 
| 81 | 
            +
                        dtype=config.t5_dtype,
         | 
| 82 | 
            +
                        device=torch.device('cpu'),
         | 
| 83 | 
            +
                        checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
         | 
| 84 | 
            +
                        tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
         | 
| 85 | 
            +
                        shard_fn=shard_fn if t5_fsdp else None,
         | 
| 86 | 
            +
                    )
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    self.vae_stride = config.vae_stride
         | 
| 89 | 
            +
                    self.patch_size = config.patch_size
         | 
| 90 | 
            +
                    self.vae = WanVAE(
         | 
| 91 | 
            +
                        vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
         | 
| 92 | 
            +
                        device=self.device)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    self.clip = CLIPModel(
         | 
| 95 | 
            +
                        dtype=config.clip_dtype,
         | 
| 96 | 
            +
                        device=self.device,
         | 
| 97 | 
            +
                        checkpoint_path=os.path.join(checkpoint_dir,
         | 
| 98 | 
            +
                                                     config.clip_checkpoint),
         | 
| 99 | 
            +
                        tokenizer_path=os.path.join(checkpoint_dir, config.clip_tokenizer))
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    logging.info(f"Creating WanModel from {checkpoint_dir}")
         | 
| 102 | 
            +
                    self.model = WanModel.from_pretrained(checkpoint_dir)
         | 
| 103 | 
            +
                    self.model.eval().requires_grad_(False)
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    if t5_fsdp or dit_fsdp or use_usp:
         | 
| 106 | 
            +
                        init_on_cpu = False
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    if use_usp:
         | 
| 109 | 
            +
                        from xfuser.core.distributed import get_sequence_parallel_world_size
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                        from .distributed.xdit_context_parallel import (
         | 
| 112 | 
            +
                            usp_attn_forward,
         | 
| 113 | 
            +
                            usp_dit_forward,
         | 
| 114 | 
            +
                        )
         | 
| 115 | 
            +
                        for block in self.model.blocks:
         | 
| 116 | 
            +
                            block.self_attn.forward = types.MethodType(
         | 
| 117 | 
            +
                                usp_attn_forward, block.self_attn)
         | 
| 118 | 
            +
                        self.model.forward = types.MethodType(usp_dit_forward, self.model)
         | 
| 119 | 
            +
                        self.sp_size = get_sequence_parallel_world_size()
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        self.sp_size = 1
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if dist.is_initialized():
         | 
| 124 | 
            +
                        dist.barrier()
         | 
| 125 | 
            +
                    if dit_fsdp:
         | 
| 126 | 
            +
                        self.model = shard_fn(self.model)
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        if not init_on_cpu:
         | 
| 129 | 
            +
                            self.model.to(self.device)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    self.sample_neg_prompt = config.sample_neg_prompt
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def generate(self,
         | 
| 134 | 
            +
                             input_prompt,
         | 
| 135 | 
            +
                             img,
         | 
| 136 | 
            +
                             max_area=720 * 1280,
         | 
| 137 | 
            +
                             frame_num=81,
         | 
| 138 | 
            +
                             shift=5.0,
         | 
| 139 | 
            +
                             sample_solver='unipc',
         | 
| 140 | 
            +
                             sampling_steps=40,
         | 
| 141 | 
            +
                             guide_scale=5.0,
         | 
| 142 | 
            +
                             n_prompt="",
         | 
| 143 | 
            +
                             seed=-1,
         | 
| 144 | 
            +
                             offload_model=True):
         | 
| 145 | 
            +
                    r"""
         | 
| 146 | 
            +
                    Generates video frames from input image and text prompt using diffusion process.
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    Args:
         | 
| 149 | 
            +
                        input_prompt (`str`):
         | 
| 150 | 
            +
                            Text prompt for content generation.
         | 
| 151 | 
            +
                        img (PIL.Image.Image):
         | 
| 152 | 
            +
                            Input image tensor. Shape: [3, H, W]
         | 
| 153 | 
            +
                        max_area (`int`, *optional*, defaults to 720*1280):
         | 
| 154 | 
            +
                            Maximum pixel area for latent space calculation. Controls video resolution scaling
         | 
| 155 | 
            +
                        frame_num (`int`, *optional*, defaults to 81):
         | 
| 156 | 
            +
                            How many frames to sample from a video. The number should be 4n+1
         | 
| 157 | 
            +
                        shift (`float`, *optional*, defaults to 5.0):
         | 
| 158 | 
            +
                            Noise schedule shift parameter. Affects temporal dynamics
         | 
| 159 | 
            +
                            [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
         | 
| 160 | 
            +
                        sample_solver (`str`, *optional*, defaults to 'unipc'):
         | 
| 161 | 
            +
                            Solver used to sample the video.
         | 
| 162 | 
            +
                        sampling_steps (`int`, *optional*, defaults to 40):
         | 
| 163 | 
            +
                            Number of diffusion sampling steps. Higher values improve quality but slow generation
         | 
| 164 | 
            +
                        guide_scale (`float`, *optional*, defaults 5.0):
         | 
| 165 | 
            +
                            Classifier-free guidance scale. Controls prompt adherence vs. creativity
         | 
| 166 | 
            +
                        n_prompt (`str`, *optional*, defaults to ""):
         | 
| 167 | 
            +
                            Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
         | 
| 168 | 
            +
                        seed (`int`, *optional*, defaults to -1):
         | 
| 169 | 
            +
                            Random seed for noise generation. If -1, use random seed
         | 
| 170 | 
            +
                        offload_model (`bool`, *optional*, defaults to True):
         | 
| 171 | 
            +
                            If True, offloads models to CPU during generation to save VRAM
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    Returns:
         | 
| 174 | 
            +
                        torch.Tensor:
         | 
| 175 | 
            +
                            Generated video frames tensor. Dimensions: (C, N H, W) where:
         | 
| 176 | 
            +
                            - C: Color channels (3 for RGB)
         | 
| 177 | 
            +
                            - N: Number of frames (81)
         | 
| 178 | 
            +
                            - H: Frame height (from max_area)
         | 
| 179 | 
            +
                            - W: Frame width from max_area)
         | 
| 180 | 
            +
                    """
         | 
| 181 | 
            +
                    img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    F = frame_num
         | 
| 184 | 
            +
                    h, w = img.shape[1:]
         | 
| 185 | 
            +
                    aspect_ratio = h / w
         | 
| 186 | 
            +
                    lat_h = round(
         | 
| 187 | 
            +
                        np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
         | 
| 188 | 
            +
                        self.patch_size[1] * self.patch_size[1])
         | 
| 189 | 
            +
                    lat_w = round(
         | 
| 190 | 
            +
                        np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
         | 
| 191 | 
            +
                        self.patch_size[2] * self.patch_size[2])
         | 
| 192 | 
            +
                    h = lat_h * self.vae_stride[1]
         | 
| 193 | 
            +
                    w = lat_w * self.vae_stride[2]
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
         | 
| 196 | 
            +
                        self.patch_size[1] * self.patch_size[2])
         | 
| 197 | 
            +
                    max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
         | 
| 200 | 
            +
                    seed_g = torch.Generator(device=self.device)
         | 
| 201 | 
            +
                    seed_g.manual_seed(seed)
         | 
| 202 | 
            +
                    noise = torch.randn(
         | 
| 203 | 
            +
                        16, (F - 1) // 4 + 1,
         | 
| 204 | 
            +
                        lat_h,
         | 
| 205 | 
            +
                        lat_w,
         | 
| 206 | 
            +
                        dtype=torch.float32,
         | 
| 207 | 
            +
                        generator=seed_g,
         | 
| 208 | 
            +
                        device=self.device)
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
         | 
| 211 | 
            +
                    msk[:, 1:] = 0
         | 
| 212 | 
            +
                    msk = torch.concat([
         | 
| 213 | 
            +
                        torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
         | 
| 214 | 
            +
                    ],
         | 
| 215 | 
            +
                                       dim=1)
         | 
| 216 | 
            +
                    msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
         | 
| 217 | 
            +
                    msk = msk.transpose(1, 2)[0]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                    if n_prompt == "":
         | 
| 220 | 
            +
                        n_prompt = self.sample_neg_prompt
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    # preprocess
         | 
| 223 | 
            +
                    if not self.t5_cpu:
         | 
| 224 | 
            +
                        self.text_encoder.model.to(self.device)
         | 
| 225 | 
            +
                        context = self.text_encoder([input_prompt], self.device)
         | 
| 226 | 
            +
                        context_null = self.text_encoder([n_prompt], self.device)
         | 
| 227 | 
            +
                        if offload_model:
         | 
| 228 | 
            +
                            self.text_encoder.model.cpu()
         | 
| 229 | 
            +
                    else:
         | 
| 230 | 
            +
                        context = self.text_encoder([input_prompt], torch.device('cpu'))
         | 
| 231 | 
            +
                        context_null = self.text_encoder([n_prompt], torch.device('cpu'))
         | 
| 232 | 
            +
                        context = [t.to(self.device) for t in context]
         | 
| 233 | 
            +
                        context_null = [t.to(self.device) for t in context_null]
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    self.clip.model.to(self.device)
         | 
| 236 | 
            +
                    clip_context = self.clip.visual([img[:, None, :, :]])
         | 
| 237 | 
            +
                    if offload_model:
         | 
| 238 | 
            +
                        self.clip.model.cpu()
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                    y = self.vae.encode([
         | 
| 241 | 
            +
                        torch.concat([
         | 
| 242 | 
            +
                            torch.nn.functional.interpolate(
         | 
| 243 | 
            +
                                img[None].cpu(), size=(h, w), mode='bicubic').transpose(
         | 
| 244 | 
            +
                                    0, 1),
         | 
| 245 | 
            +
                            torch.zeros(3, F - 1, h, w)
         | 
| 246 | 
            +
                        ],
         | 
| 247 | 
            +
                                     dim=1).to(self.device)
         | 
| 248 | 
            +
                    ])[0]
         | 
| 249 | 
            +
                    y = torch.concat([msk, y])
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    @contextmanager
         | 
| 252 | 
            +
                    def noop_no_sync():
         | 
| 253 | 
            +
                        yield
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                    no_sync = getattr(self.model, 'no_sync', noop_no_sync)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    # evaluation mode
         | 
| 258 | 
            +
                    with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync():
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                        if sample_solver == 'unipc':
         | 
| 261 | 
            +
                            sample_scheduler = FlowUniPCMultistepScheduler(
         | 
| 262 | 
            +
                                num_train_timesteps=self.num_train_timesteps,
         | 
| 263 | 
            +
                                shift=1,
         | 
| 264 | 
            +
                                use_dynamic_shifting=False)
         | 
| 265 | 
            +
                            sample_scheduler.set_timesteps(
         | 
| 266 | 
            +
                                sampling_steps, device=self.device, shift=shift)
         | 
| 267 | 
            +
                            timesteps = sample_scheduler.timesteps
         | 
| 268 | 
            +
                        elif sample_solver == 'dpm++':
         | 
| 269 | 
            +
                            sample_scheduler = FlowDPMSolverMultistepScheduler(
         | 
| 270 | 
            +
                                num_train_timesteps=self.num_train_timesteps,
         | 
| 271 | 
            +
                                shift=1,
         | 
| 272 | 
            +
                                use_dynamic_shifting=False)
         | 
| 273 | 
            +
                            sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
         | 
| 274 | 
            +
                            timesteps, _ = retrieve_timesteps(
         | 
| 275 | 
            +
                                sample_scheduler,
         | 
| 276 | 
            +
                                device=self.device,
         | 
| 277 | 
            +
                                sigmas=sampling_sigmas)
         | 
| 278 | 
            +
                        else:
         | 
| 279 | 
            +
                            raise NotImplementedError("Unsupported solver.")
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                        # sample videos
         | 
| 282 | 
            +
                        latent = noise
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                        arg_c = {
         | 
| 285 | 
            +
                            'context': [context[0]],
         | 
| 286 | 
            +
                            'clip_fea': clip_context,
         | 
| 287 | 
            +
                            'seq_len': max_seq_len,
         | 
| 288 | 
            +
                            'y': [y],
         | 
| 289 | 
            +
                        }
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                        arg_null = {
         | 
| 292 | 
            +
                            'context': context_null,
         | 
| 293 | 
            +
                            'clip_fea': clip_context,
         | 
| 294 | 
            +
                            'seq_len': max_seq_len,
         | 
| 295 | 
            +
                            'y': [y],
         | 
| 296 | 
            +
                        }
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                        if offload_model:
         | 
| 299 | 
            +
                            torch.cuda.empty_cache()
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                        self.model.to(self.device)
         | 
| 302 | 
            +
                        for _, t in enumerate(tqdm(timesteps)):
         | 
| 303 | 
            +
                            latent_model_input = [latent.to(self.device)]
         | 
| 304 | 
            +
                            timestep = [t]
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                            timestep = torch.stack(timestep).to(self.device)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                            noise_pred_cond = self.model(
         | 
| 309 | 
            +
                                latent_model_input, t=timestep, **arg_c)[0].to(
         | 
| 310 | 
            +
                                    torch.device('cpu') if offload_model else self.device)
         | 
| 311 | 
            +
                            if offload_model:
         | 
| 312 | 
            +
                                torch.cuda.empty_cache()
         | 
| 313 | 
            +
                            noise_pred_uncond = self.model(
         | 
| 314 | 
            +
                                latent_model_input, t=timestep, **arg_null)[0].to(
         | 
| 315 | 
            +
                                    torch.device('cpu') if offload_model else self.device)
         | 
| 316 | 
            +
                            if offload_model:
         | 
| 317 | 
            +
                                torch.cuda.empty_cache()
         | 
| 318 | 
            +
                            noise_pred = noise_pred_uncond + guide_scale * (
         | 
| 319 | 
            +
                                noise_pred_cond - noise_pred_uncond)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                            latent = latent.to(
         | 
| 322 | 
            +
                                torch.device('cpu') if offload_model else self.device)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
                            temp_x0 = sample_scheduler.step(
         | 
| 325 | 
            +
                                noise_pred.unsqueeze(0),
         | 
| 326 | 
            +
                                t,
         | 
| 327 | 
            +
                                latent.unsqueeze(0),
         | 
| 328 | 
            +
                                return_dict=False,
         | 
| 329 | 
            +
                                generator=seed_g)[0]
         | 
| 330 | 
            +
                            latent = temp_x0.squeeze(0)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                            x0 = [latent.to(self.device)]
         | 
| 333 | 
            +
                            del latent_model_input, timestep
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                        if offload_model:
         | 
| 336 | 
            +
                            self.model.cpu()
         | 
| 337 | 
            +
                            torch.cuda.empty_cache()
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        if self.rank == 0:
         | 
| 340 | 
            +
                            videos = self.vae.decode(x0)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    del noise, latent
         | 
| 343 | 
            +
                    del sample_scheduler
         | 
| 344 | 
            +
                    if offload_model:
         | 
| 345 | 
            +
                        gc.collect()
         | 
| 346 | 
            +
                        torch.cuda.synchronize()
         | 
| 347 | 
            +
                    if dist.is_initialized():
         | 
| 348 | 
            +
                        dist.barrier()
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                    return videos[0] if self.rank == 0 else None
         | 
    	
        wan/modules/__init__.py
    ADDED
    
    | @@ -0,0 +1,18 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .attention import flash_attention
         | 
| 2 | 
            +
            from .model import WanModel
         | 
| 3 | 
            +
            from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
         | 
| 4 | 
            +
            from .tokenizers import HuggingfaceTokenizer
         | 
| 5 | 
            +
            from .vace_model import VaceWanModel
         | 
| 6 | 
            +
            from .vae import WanVAE
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            __all__ = [
         | 
| 9 | 
            +
                'WanVAE',
         | 
| 10 | 
            +
                'WanModel',
         | 
| 11 | 
            +
                'VaceWanModel',
         | 
| 12 | 
            +
                'T5Model',
         | 
| 13 | 
            +
                'T5Encoder',
         | 
| 14 | 
            +
                'T5Decoder',
         | 
| 15 | 
            +
                'T5EncoderModel',
         | 
| 16 | 
            +
                'HuggingfaceTokenizer',
         | 
| 17 | 
            +
                'flash_attention',
         | 
| 18 | 
            +
            ]
         | 
    	
        wan/modules/attention.py
    ADDED
    
    | @@ -0,0 +1,393 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.nn as nn
         | 
| 4 | 
            +
            from einops import rearrange, repeat
         | 
| 5 | 
            +
            from ..utils.multitalk_utils import RotaryPositionalEmbedding1D, normalize_and_scale, split_token_counts_and_frame_ids
         | 
| 6 | 
            +
            from xfuser.core.distributed import (
         | 
| 7 | 
            +
                get_sequence_parallel_rank,
         | 
| 8 | 
            +
                get_sequence_parallel_world_size,
         | 
| 9 | 
            +
                get_sp_group,
         | 
| 10 | 
            +
            )
         | 
| 11 | 
            +
            import xformers.ops
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            try:
         | 
| 14 | 
            +
                import flash_attn_interface
         | 
| 15 | 
            +
                FLASH_ATTN_3_AVAILABLE = True
         | 
| 16 | 
            +
            except ModuleNotFoundError:
         | 
| 17 | 
            +
                FLASH_ATTN_3_AVAILABLE = False
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            try:
         | 
| 20 | 
            +
                import flash_attn
         | 
| 21 | 
            +
                FLASH_ATTN_2_AVAILABLE = True
         | 
| 22 | 
            +
            except ModuleNotFoundError:
         | 
| 23 | 
            +
                FLASH_ATTN_2_AVAILABLE = False
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            import warnings
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            __all__ = [
         | 
| 28 | 
            +
                'flash_attention',
         | 
| 29 | 
            +
                'attention',
         | 
| 30 | 
            +
            ]
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            def flash_attention(
         | 
| 34 | 
            +
                q,
         | 
| 35 | 
            +
                k,
         | 
| 36 | 
            +
                v,
         | 
| 37 | 
            +
                q_lens=None,
         | 
| 38 | 
            +
                k_lens=None,
         | 
| 39 | 
            +
                dropout_p=0.,
         | 
| 40 | 
            +
                softmax_scale=None,
         | 
| 41 | 
            +
                q_scale=None,
         | 
| 42 | 
            +
                causal=False,
         | 
| 43 | 
            +
                window_size=(-1, -1),
         | 
| 44 | 
            +
                deterministic=False,
         | 
| 45 | 
            +
                dtype=torch.bfloat16,
         | 
| 46 | 
            +
                version=None,
         | 
| 47 | 
            +
            ):
         | 
| 48 | 
            +
                """
         | 
| 49 | 
            +
                q:              [B, Lq, Nq, C1].
         | 
| 50 | 
            +
                k:              [B, Lk, Nk, C1].
         | 
| 51 | 
            +
                v:              [B, Lk, Nk, C2]. Nq must be divisible by Nk.
         | 
| 52 | 
            +
                q_lens:         [B].
         | 
| 53 | 
            +
                k_lens:         [B].
         | 
| 54 | 
            +
                dropout_p:      float. Dropout probability.
         | 
| 55 | 
            +
                softmax_scale:  float. The scaling of QK^T before applying softmax.
         | 
| 56 | 
            +
                causal:         bool. Whether to apply causal attention mask.
         | 
| 57 | 
            +
                window_size:    (left right). If not (-1, -1), apply sliding window local attention.
         | 
| 58 | 
            +
                deterministic:  bool. If True, slightly slower and uses more memory.
         | 
| 59 | 
            +
                dtype:          torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
         | 
| 60 | 
            +
                """
         | 
| 61 | 
            +
                half_dtypes = (torch.float16, torch.bfloat16)
         | 
| 62 | 
            +
                assert dtype in half_dtypes
         | 
| 63 | 
            +
                assert q.device.type == 'cuda' and q.size(-1) <= 256
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                # params
         | 
| 66 | 
            +
                b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def half(x):
         | 
| 69 | 
            +
                    return x if x.dtype in half_dtypes else x.to(dtype)
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                # preprocess query
         | 
| 72 | 
            +
                if q_lens is None:
         | 
| 73 | 
            +
                    q = half(q.flatten(0, 1))
         | 
| 74 | 
            +
                    q_lens = torch.tensor(
         | 
| 75 | 
            +
                        [lq] * b, dtype=torch.int32).to(
         | 
| 76 | 
            +
                            device=q.device, non_blocking=True)
         | 
| 77 | 
            +
                else:
         | 
| 78 | 
            +
                    q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                # preprocess key, value
         | 
| 81 | 
            +
                if k_lens is None:
         | 
| 82 | 
            +
                    k = half(k.flatten(0, 1))
         | 
| 83 | 
            +
                    v = half(v.flatten(0, 1))
         | 
| 84 | 
            +
                    k_lens = torch.tensor(
         | 
| 85 | 
            +
                        [lk] * b, dtype=torch.int32).to(
         | 
| 86 | 
            +
                            device=k.device, non_blocking=True)
         | 
| 87 | 
            +
                else:
         | 
| 88 | 
            +
                    k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
         | 
| 89 | 
            +
                    v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                q = q.to(v.dtype)
         | 
| 92 | 
            +
                k = k.to(v.dtype)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                if q_scale is not None:
         | 
| 95 | 
            +
                    q = q * q_scale
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
         | 
| 98 | 
            +
                    warnings.warn(
         | 
| 99 | 
            +
                        'Flash attention 3 is not available, use flash attention 2 instead.'
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                # apply attention
         | 
| 103 | 
            +
                if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
         | 
| 104 | 
            +
                    # Note: dropout_p, window_size are not supported in FA3 now.
         | 
| 105 | 
            +
                    x = flash_attn_interface.flash_attn_varlen_func(
         | 
| 106 | 
            +
                        q=q,
         | 
| 107 | 
            +
                        k=k,
         | 
| 108 | 
            +
                        v=v,
         | 
| 109 | 
            +
                        cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
         | 
| 110 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 111 | 
            +
                        cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
         | 
| 112 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 113 | 
            +
                        seqused_q=None,
         | 
| 114 | 
            +
                        seqused_k=None,
         | 
| 115 | 
            +
                        max_seqlen_q=lq,
         | 
| 116 | 
            +
                        max_seqlen_k=lk,
         | 
| 117 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 118 | 
            +
                        causal=causal,
         | 
| 119 | 
            +
                        deterministic=deterministic)[0].unflatten(0, (b, lq))
         | 
| 120 | 
            +
                else:
         | 
| 121 | 
            +
                    assert FLASH_ATTN_2_AVAILABLE
         | 
| 122 | 
            +
                    x = flash_attn.flash_attn_varlen_func(
         | 
| 123 | 
            +
                        q=q,
         | 
| 124 | 
            +
                        k=k,
         | 
| 125 | 
            +
                        v=v,
         | 
| 126 | 
            +
                        cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
         | 
| 127 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 128 | 
            +
                        cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
         | 
| 129 | 
            +
                            0, dtype=torch.int32).to(q.device, non_blocking=True),
         | 
| 130 | 
            +
                        max_seqlen_q=lq,
         | 
| 131 | 
            +
                        max_seqlen_k=lk,
         | 
| 132 | 
            +
                        dropout_p=dropout_p,
         | 
| 133 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 134 | 
            +
                        causal=causal,
         | 
| 135 | 
            +
                        window_size=window_size,
         | 
| 136 | 
            +
                        deterministic=deterministic).unflatten(0, (b, lq))
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                # output
         | 
| 139 | 
            +
                return x.type(out_dtype)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            def attention(
         | 
| 143 | 
            +
                q,
         | 
| 144 | 
            +
                k,
         | 
| 145 | 
            +
                v,
         | 
| 146 | 
            +
                q_lens=None,
         | 
| 147 | 
            +
                k_lens=None,
         | 
| 148 | 
            +
                dropout_p=0.,
         | 
| 149 | 
            +
                softmax_scale=None,
         | 
| 150 | 
            +
                q_scale=None,
         | 
| 151 | 
            +
                causal=False,
         | 
| 152 | 
            +
                window_size=(-1, -1),
         | 
| 153 | 
            +
                deterministic=False,
         | 
| 154 | 
            +
                dtype=torch.bfloat16,
         | 
| 155 | 
            +
                fa_version=None,
         | 
| 156 | 
            +
            ):
         | 
| 157 | 
            +
                if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
         | 
| 158 | 
            +
                    return flash_attention(
         | 
| 159 | 
            +
                        q=q,
         | 
| 160 | 
            +
                        k=k,
         | 
| 161 | 
            +
                        v=v,
         | 
| 162 | 
            +
                        q_lens=q_lens,
         | 
| 163 | 
            +
                        k_lens=k_lens,
         | 
| 164 | 
            +
                        dropout_p=dropout_p,
         | 
| 165 | 
            +
                        softmax_scale=softmax_scale,
         | 
| 166 | 
            +
                        q_scale=q_scale,
         | 
| 167 | 
            +
                        causal=causal,
         | 
| 168 | 
            +
                        window_size=window_size,
         | 
| 169 | 
            +
                        deterministic=deterministic,
         | 
| 170 | 
            +
                        dtype=dtype,
         | 
| 171 | 
            +
                        version=fa_version,
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
                else:
         | 
| 174 | 
            +
                    if q_lens is not None or k_lens is not None:
         | 
| 175 | 
            +
                        warnings.warn(
         | 
| 176 | 
            +
                            'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
         | 
| 177 | 
            +
                        )
         | 
| 178 | 
            +
                    attn_mask = None
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    q = q.transpose(1, 2).to(dtype)
         | 
| 181 | 
            +
                    k = k.transpose(1, 2).to(dtype)
         | 
| 182 | 
            +
                    v = v.transpose(1, 2).to(dtype)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    out = torch.nn.functional.scaled_dot_product_attention(
         | 
| 185 | 
            +
                        q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    out = out.transpose(1, 2).contiguous()
         | 
| 188 | 
            +
                    return out
         | 
| 189 | 
            +
                
         | 
| 190 | 
            +
             | 
| 191 | 
            +
            class SingleStreamAttention(nn.Module):
         | 
| 192 | 
            +
                def __init__(
         | 
| 193 | 
            +
                    self,
         | 
| 194 | 
            +
                    dim: int,
         | 
| 195 | 
            +
                    encoder_hidden_states_dim: int,
         | 
| 196 | 
            +
                    num_heads: int,
         | 
| 197 | 
            +
                    qkv_bias: bool,
         | 
| 198 | 
            +
                    qk_norm: bool,
         | 
| 199 | 
            +
                    norm_layer: nn.Module,
         | 
| 200 | 
            +
                    attn_drop: float = 0.0,
         | 
| 201 | 
            +
                    proj_drop: float = 0.0,
         | 
| 202 | 
            +
                    eps: float = 1e-6,
         | 
| 203 | 
            +
                ) -> None:
         | 
| 204 | 
            +
                    super().__init__()
         | 
| 205 | 
            +
                    assert dim % num_heads == 0, "dim should be divisible by num_heads"
         | 
| 206 | 
            +
                    self.dim = dim
         | 
| 207 | 
            +
                    self.encoder_hidden_states_dim = encoder_hidden_states_dim
         | 
| 208 | 
            +
                    self.num_heads = num_heads
         | 
| 209 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 210 | 
            +
                    self.scale = self.head_dim**-0.5
         | 
| 211 | 
            +
                    self.qk_norm = qk_norm
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    self.q_linear = nn.Linear(dim, dim, bias=qkv_bias)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    self.q_norm = norm_layer(self.head_dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 216 | 
            +
                    self.k_norm = norm_layer(self.head_dim,eps=eps) if qk_norm else nn.Identity()
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    self.attn_drop = nn.Dropout(attn_drop)
         | 
| 219 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 220 | 
            +
                    self.proj_drop = nn.Dropout(proj_drop)
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    self.kv_linear = nn.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
         | 
| 225 | 
            +
                    self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
         | 
| 228 | 
            +
                   
         | 
| 229 | 
            +
                    N_t, N_h, N_w = shape
         | 
| 230 | 
            +
                    if not enable_sp:
         | 
| 231 | 
            +
                        x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # get q for hidden_state
         | 
| 234 | 
            +
                    B, N, C = x.shape
         | 
| 235 | 
            +
                    q = self.q_linear(x)
         | 
| 236 | 
            +
                    q_shape = (B, N, self.num_heads, self.head_dim)
         | 
| 237 | 
            +
                    q = q.view(q_shape).permute((0, 2, 1, 3))
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                    if self.qk_norm:
         | 
| 240 | 
            +
                        q = self.q_norm(q)
         | 
| 241 | 
            +
                    
         | 
| 242 | 
            +
                    # get kv from encoder_hidden_states
         | 
| 243 | 
            +
                    _, N_a, _ = encoder_hidden_states.shape
         | 
| 244 | 
            +
                    encoder_kv = self.kv_linear(encoder_hidden_states)
         | 
| 245 | 
            +
                    encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
         | 
| 246 | 
            +
                    encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
         | 
| 247 | 
            +
                    encoder_k, encoder_v = encoder_kv.unbind(0)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    if self.qk_norm:
         | 
| 250 | 
            +
                        encoder_k = self.add_k_norm(encoder_k)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
             | 
| 253 | 
            +
                    q = rearrange(q, "B H M K -> B M H K")
         | 
| 254 | 
            +
                    encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
         | 
| 255 | 
            +
                    encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    if enable_sp:
         | 
| 258 | 
            +
                        # context parallel
         | 
| 259 | 
            +
                        sp_size = get_sequence_parallel_world_size()
         | 
| 260 | 
            +
                        sp_rank = get_sequence_parallel_rank()
         | 
| 261 | 
            +
                        visual_seqlen, _ = split_token_counts_and_frame_ids(N_t, N_h * N_w, sp_size, sp_rank)
         | 
| 262 | 
            +
                        assert kv_seq is not None, f"kv_seq should not be None."
         | 
| 263 | 
            +
                        attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(visual_seqlen, kv_seq)
         | 
| 264 | 
            +
                    else:
         | 
| 265 | 
            +
                        attn_bias = None
         | 
| 266 | 
            +
                    x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
         | 
| 267 | 
            +
                    x = rearrange(x, "B M H K -> B H M K") 
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    # linear transform
         | 
| 270 | 
            +
                    x_output_shape = (B, N, C)
         | 
| 271 | 
            +
                    x = x.transpose(1, 2) 
         | 
| 272 | 
            +
                    x = x.reshape(x_output_shape) 
         | 
| 273 | 
            +
                    x = self.proj(x)
         | 
| 274 | 
            +
                    x = self.proj_drop(x)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    if not enable_sp:
         | 
| 277 | 
            +
                        # reshape x to origin shape
         | 
| 278 | 
            +
                        x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    return x
         | 
| 281 | 
            +
             | 
| 282 | 
            +
            class SingleStreamMutiAttention(SingleStreamAttention):
         | 
| 283 | 
            +
                def __init__(
         | 
| 284 | 
            +
                    self,
         | 
| 285 | 
            +
                    dim: int,
         | 
| 286 | 
            +
                    encoder_hidden_states_dim: int,
         | 
| 287 | 
            +
                    num_heads: int,
         | 
| 288 | 
            +
                    qkv_bias: bool,
         | 
| 289 | 
            +
                    qk_norm: bool,
         | 
| 290 | 
            +
                    norm_layer: nn.Module,
         | 
| 291 | 
            +
                    attn_drop: float = 0.0,
         | 
| 292 | 
            +
                    proj_drop: float = 0.0,
         | 
| 293 | 
            +
                    eps: float = 1e-6,
         | 
| 294 | 
            +
                    class_range: int = 24,
         | 
| 295 | 
            +
                    class_interval: int = 4,
         | 
| 296 | 
            +
                ) -> None:
         | 
| 297 | 
            +
                    super().__init__(
         | 
| 298 | 
            +
                        dim=dim,
         | 
| 299 | 
            +
                        encoder_hidden_states_dim=encoder_hidden_states_dim,
         | 
| 300 | 
            +
                        num_heads=num_heads,
         | 
| 301 | 
            +
                        qkv_bias=qkv_bias,
         | 
| 302 | 
            +
                        qk_norm=qk_norm,
         | 
| 303 | 
            +
                        norm_layer=norm_layer,
         | 
| 304 | 
            +
                        attn_drop=attn_drop,
         | 
| 305 | 
            +
                        proj_drop=proj_drop,
         | 
| 306 | 
            +
                        eps=eps,
         | 
| 307 | 
            +
                    )
         | 
| 308 | 
            +
                    self.class_interval = class_interval
         | 
| 309 | 
            +
                    self.class_range = class_range
         | 
| 310 | 
            +
                    self.rope_h1  = (0, self.class_interval)
         | 
| 311 | 
            +
                    self.rope_h2  = (self.class_range - self.class_interval, self.class_range)
         | 
| 312 | 
            +
                    self.rope_bak = int(self.class_range // 2)
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                    self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                def forward(self, 
         | 
| 317 | 
            +
                            x: torch.Tensor, 
         | 
| 318 | 
            +
                            encoder_hidden_states: torch.Tensor, 
         | 
| 319 | 
            +
                            shape=None, 
         | 
| 320 | 
            +
                            x_ref_attn_map=None,
         | 
| 321 | 
            +
                            human_num=None) -> torch.Tensor:
         | 
| 322 | 
            +
                    
         | 
| 323 | 
            +
                    encoder_hidden_states = encoder_hidden_states.squeeze(0)
         | 
| 324 | 
            +
                    if human_num == 1:
         | 
| 325 | 
            +
                        return super().forward(x, encoder_hidden_states, shape)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    N_t, _, _ = shape 
         | 
| 328 | 
            +
                    x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) 
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # get q for hidden_state
         | 
| 331 | 
            +
                    B, N, C = x.shape
         | 
| 332 | 
            +
                    q = self.q_linear(x) 
         | 
| 333 | 
            +
                    q_shape = (B, N, self.num_heads, self.head_dim) 
         | 
| 334 | 
            +
                    q = q.view(q_shape).permute((0, 2, 1, 3))
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    if self.qk_norm:
         | 
| 337 | 
            +
                        q = self.q_norm(q)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
              
         | 
| 340 | 
            +
                    max_values = x_ref_attn_map.max(1).values[:, None, None] 
         | 
| 341 | 
            +
                    min_values = x_ref_attn_map.min(1).values[:, None, None] 
         | 
| 342 | 
            +
                    max_min_values = torch.cat([max_values, min_values], dim=2)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                    human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
         | 
| 345 | 
            +
                    human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), (self.rope_h1[0], self.rope_h1[1]))
         | 
| 348 | 
            +
                    human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), (self.rope_h2[0], self.rope_h2[1]))
         | 
| 349 | 
            +
                    back   = torch.full((x_ref_attn_map.size(1),), self.rope_bak, dtype=human1.dtype).to(human1.device)
         | 
| 350 | 
            +
                    max_indices = x_ref_attn_map.argmax(dim=0)
         | 
| 351 | 
            +
                    normalized_map = torch.stack([human1, human2, back], dim=1)
         | 
| 352 | 
            +
                    normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N 
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
         | 
| 355 | 
            +
                    q = self.rope_1d(q, normalized_pos)
         | 
| 356 | 
            +
                    q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                    _, N_a, _ = encoder_hidden_states.shape 
         | 
| 359 | 
            +
                    encoder_kv = self.kv_linear(encoder_hidden_states) 
         | 
| 360 | 
            +
                    encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
         | 
| 361 | 
            +
                    encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) 
         | 
| 362 | 
            +
                    encoder_k, encoder_v = encoder_kv.unbind(0) 
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    if self.qk_norm:
         | 
| 365 | 
            +
                        encoder_k = self.add_k_norm(encoder_k)
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    
         | 
| 368 | 
            +
                    per_frame = torch.zeros(N_a, dtype=encoder_k.dtype).to(encoder_k.device)
         | 
| 369 | 
            +
                    per_frame[:per_frame.size(0)//2] = (self.rope_h1[0] + self.rope_h1[1]) / 2
         | 
| 370 | 
            +
                    per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
         | 
| 371 | 
            +
                    encoder_pos = torch.concat([per_frame]*N_t, dim=0)
         | 
| 372 | 
            +
                    encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
         | 
| 373 | 
            +
                    encoder_k = self.rope_1d(encoder_k, encoder_pos)
         | 
| 374 | 
            +
                    encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
             
         | 
| 377 | 
            +
                    q = rearrange(q, "B H M K -> B M H K")
         | 
| 378 | 
            +
                    encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
         | 
| 379 | 
            +
                    encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
         | 
| 380 | 
            +
                    x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
         | 
| 381 | 
            +
                    x = rearrange(x, "B M H K -> B H M K")
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    # linear transform
         | 
| 384 | 
            +
                    x_output_shape = (B, N, C)
         | 
| 385 | 
            +
                    x = x.transpose(1, 2) 
         | 
| 386 | 
            +
                    x = x.reshape(x_output_shape) 
         | 
| 387 | 
            +
                    x = self.proj(x) 
         | 
| 388 | 
            +
                    x = self.proj_drop(x)
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # reshape x to origin shape
         | 
| 391 | 
            +
                    x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) 
         | 
| 392 | 
            +
             | 
| 393 | 
            +
                    return x
         | 
    	
        wan/modules/clip.py
    ADDED
    
    | @@ -0,0 +1,542 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
         | 
| 2 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            import torchvision.transforms as T
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from .attention import flash_attention
         | 
| 12 | 
            +
            from .tokenizers import HuggingfaceTokenizer
         | 
| 13 | 
            +
            from .xlm_roberta import XLMRoberta
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            __all__ = [
         | 
| 16 | 
            +
                'XLMRobertaCLIP',
         | 
| 17 | 
            +
                'clip_xlm_roberta_vit_h_14',
         | 
| 18 | 
            +
                'CLIPModel',
         | 
| 19 | 
            +
            ]
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            def pos_interpolate(pos, seq_len):
         | 
| 23 | 
            +
                if pos.size(1) == seq_len:
         | 
| 24 | 
            +
                    return pos
         | 
| 25 | 
            +
                else:
         | 
| 26 | 
            +
                    src_grid = int(math.sqrt(pos.size(1)))
         | 
| 27 | 
            +
                    tar_grid = int(math.sqrt(seq_len))
         | 
| 28 | 
            +
                    n = pos.size(1) - src_grid * src_grid
         | 
| 29 | 
            +
                    return torch.cat([
         | 
| 30 | 
            +
                        pos[:, :n],
         | 
| 31 | 
            +
                        F.interpolate(
         | 
| 32 | 
            +
                            pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
         | 
| 33 | 
            +
                                0, 3, 1, 2),
         | 
| 34 | 
            +
                            size=(tar_grid, tar_grid),
         | 
| 35 | 
            +
                            mode='bicubic',
         | 
| 36 | 
            +
                            align_corners=False).flatten(2).transpose(1, 2)
         | 
| 37 | 
            +
                    ],
         | 
| 38 | 
            +
                                     dim=1)
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class QuickGELU(nn.Module):
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def forward(self, x):
         | 
| 44 | 
            +
                    return x * torch.sigmoid(1.702 * x)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class LayerNorm(nn.LayerNorm):
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward(self, x):
         | 
| 50 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class SelfAttention(nn.Module):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self,
         | 
| 56 | 
            +
                             dim,
         | 
| 57 | 
            +
                             num_heads,
         | 
| 58 | 
            +
                             causal=False,
         | 
| 59 | 
            +
                             attn_dropout=0.0,
         | 
| 60 | 
            +
                             proj_dropout=0.0):
         | 
| 61 | 
            +
                    assert dim % num_heads == 0
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
                    self.dim = dim
         | 
| 64 | 
            +
                    self.num_heads = num_heads
         | 
| 65 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 66 | 
            +
                    self.causal = causal
         | 
| 67 | 
            +
                    self.attn_dropout = attn_dropout
         | 
| 68 | 
            +
                    self.proj_dropout = proj_dropout
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    # layers
         | 
| 71 | 
            +
                    self.to_qkv = nn.Linear(dim, dim * 3)
         | 
| 72 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def forward(self, x):
         | 
| 75 | 
            +
                    """
         | 
| 76 | 
            +
                    x:   [B, L, C].
         | 
| 77 | 
            +
                    """
         | 
| 78 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    # compute query, key, value
         | 
| 81 | 
            +
                    q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # compute attention
         | 
| 84 | 
            +
                    p = self.attn_dropout if self.training else 0.0
         | 
| 85 | 
            +
                    x = flash_attention(q, k, v, dropout_p=p, causal=self.causal, version=2)
         | 
| 86 | 
            +
                    x = x.reshape(b, s, c)
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                    # output
         | 
| 89 | 
            +
                    x = self.proj(x)
         | 
| 90 | 
            +
                    x = F.dropout(x, self.proj_dropout, self.training)
         | 
| 91 | 
            +
                    return x
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            class SwiGLU(nn.Module):
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __init__(self, dim, mid_dim):
         | 
| 97 | 
            +
                    super().__init__()
         | 
| 98 | 
            +
                    self.dim = dim
         | 
| 99 | 
            +
                    self.mid_dim = mid_dim
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # layers
         | 
| 102 | 
            +
                    self.fc1 = nn.Linear(dim, mid_dim)
         | 
| 103 | 
            +
                    self.fc2 = nn.Linear(dim, mid_dim)
         | 
| 104 | 
            +
                    self.fc3 = nn.Linear(mid_dim, dim)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def forward(self, x):
         | 
| 107 | 
            +
                    x = F.silu(self.fc1(x)) * self.fc2(x)
         | 
| 108 | 
            +
                    x = self.fc3(x)
         | 
| 109 | 
            +
                    return x
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def __init__(self,
         | 
| 115 | 
            +
                             dim,
         | 
| 116 | 
            +
                             mlp_ratio,
         | 
| 117 | 
            +
                             num_heads,
         | 
| 118 | 
            +
                             post_norm=False,
         | 
| 119 | 
            +
                             causal=False,
         | 
| 120 | 
            +
                             activation='quick_gelu',
         | 
| 121 | 
            +
                             attn_dropout=0.0,
         | 
| 122 | 
            +
                             proj_dropout=0.0,
         | 
| 123 | 
            +
                             norm_eps=1e-5):
         | 
| 124 | 
            +
                    assert activation in ['quick_gelu', 'gelu', 'swi_glu']
         | 
| 125 | 
            +
                    super().__init__()
         | 
| 126 | 
            +
                    self.dim = dim
         | 
| 127 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 128 | 
            +
                    self.num_heads = num_heads
         | 
| 129 | 
            +
                    self.post_norm = post_norm
         | 
| 130 | 
            +
                    self.causal = causal
         | 
| 131 | 
            +
                    self.norm_eps = norm_eps
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # layers
         | 
| 134 | 
            +
                    self.norm1 = LayerNorm(dim, eps=norm_eps)
         | 
| 135 | 
            +
                    self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
         | 
| 136 | 
            +
                                              proj_dropout)
         | 
| 137 | 
            +
                    self.norm2 = LayerNorm(dim, eps=norm_eps)
         | 
| 138 | 
            +
                    if activation == 'swi_glu':
         | 
| 139 | 
            +
                        self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
         | 
| 140 | 
            +
                    else:
         | 
| 141 | 
            +
                        self.mlp = nn.Sequential(
         | 
| 142 | 
            +
                            nn.Linear(dim, int(dim * mlp_ratio)),
         | 
| 143 | 
            +
                            QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
         | 
| 144 | 
            +
                            nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def forward(self, x):
         | 
| 147 | 
            +
                    if self.post_norm:
         | 
| 148 | 
            +
                        x = x + self.norm1(self.attn(x))
         | 
| 149 | 
            +
                        x = x + self.norm2(self.mlp(x))
         | 
| 150 | 
            +
                    else:
         | 
| 151 | 
            +
                        x = x + self.attn(self.norm1(x))
         | 
| 152 | 
            +
                        x = x + self.mlp(self.norm2(x))
         | 
| 153 | 
            +
                    return x
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
            class AttentionPool(nn.Module):
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def __init__(self,
         | 
| 159 | 
            +
                             dim,
         | 
| 160 | 
            +
                             mlp_ratio,
         | 
| 161 | 
            +
                             num_heads,
         | 
| 162 | 
            +
                             activation='gelu',
         | 
| 163 | 
            +
                             proj_dropout=0.0,
         | 
| 164 | 
            +
                             norm_eps=1e-5):
         | 
| 165 | 
            +
                    assert dim % num_heads == 0
         | 
| 166 | 
            +
                    super().__init__()
         | 
| 167 | 
            +
                    self.dim = dim
         | 
| 168 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 169 | 
            +
                    self.num_heads = num_heads
         | 
| 170 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 171 | 
            +
                    self.proj_dropout = proj_dropout
         | 
| 172 | 
            +
                    self.norm_eps = norm_eps
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                    # layers
         | 
| 175 | 
            +
                    gain = 1.0 / math.sqrt(dim)
         | 
| 176 | 
            +
                    self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
         | 
| 177 | 
            +
                    self.to_q = nn.Linear(dim, dim)
         | 
| 178 | 
            +
                    self.to_kv = nn.Linear(dim, dim * 2)
         | 
| 179 | 
            +
                    self.proj = nn.Linear(dim, dim)
         | 
| 180 | 
            +
                    self.norm = LayerNorm(dim, eps=norm_eps)
         | 
| 181 | 
            +
                    self.mlp = nn.Sequential(
         | 
| 182 | 
            +
                        nn.Linear(dim, int(dim * mlp_ratio)),
         | 
| 183 | 
            +
                        QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
         | 
| 184 | 
            +
                        nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def forward(self, x):
         | 
| 187 | 
            +
                    """
         | 
| 188 | 
            +
                    x:  [B, L, C].
         | 
| 189 | 
            +
                    """
         | 
| 190 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                    # compute query, key, value
         | 
| 193 | 
            +
                    q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
         | 
| 194 | 
            +
                    k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # compute attention
         | 
| 197 | 
            +
                    x = flash_attention(q, k, v, version=2)
         | 
| 198 | 
            +
                    x = x.reshape(b, 1, c)
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                    # output
         | 
| 201 | 
            +
                    x = self.proj(x)
         | 
| 202 | 
            +
                    x = F.dropout(x, self.proj_dropout, self.training)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    # mlp
         | 
| 205 | 
            +
                    x = x + self.mlp(self.norm(x))
         | 
| 206 | 
            +
                    return x[:, 0]
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
            class VisionTransformer(nn.Module):
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __init__(self,
         | 
| 212 | 
            +
                             image_size=224,
         | 
| 213 | 
            +
                             patch_size=16,
         | 
| 214 | 
            +
                             dim=768,
         | 
| 215 | 
            +
                             mlp_ratio=4,
         | 
| 216 | 
            +
                             out_dim=512,
         | 
| 217 | 
            +
                             num_heads=12,
         | 
| 218 | 
            +
                             num_layers=12,
         | 
| 219 | 
            +
                             pool_type='token',
         | 
| 220 | 
            +
                             pre_norm=True,
         | 
| 221 | 
            +
                             post_norm=False,
         | 
| 222 | 
            +
                             activation='quick_gelu',
         | 
| 223 | 
            +
                             attn_dropout=0.0,
         | 
| 224 | 
            +
                             proj_dropout=0.0,
         | 
| 225 | 
            +
                             embedding_dropout=0.0,
         | 
| 226 | 
            +
                             norm_eps=1e-5):
         | 
| 227 | 
            +
                    if image_size % patch_size != 0:
         | 
| 228 | 
            +
                        print(
         | 
| 229 | 
            +
                            '[WARNING] image_size is not divisible by patch_size',
         | 
| 230 | 
            +
                            flush=True)
         | 
| 231 | 
            +
                    assert pool_type in ('token', 'token_fc', 'attn_pool')
         | 
| 232 | 
            +
                    out_dim = out_dim or dim
         | 
| 233 | 
            +
                    super().__init__()
         | 
| 234 | 
            +
                    self.image_size = image_size
         | 
| 235 | 
            +
                    self.patch_size = patch_size
         | 
| 236 | 
            +
                    self.num_patches = (image_size // patch_size)**2
         | 
| 237 | 
            +
                    self.dim = dim
         | 
| 238 | 
            +
                    self.mlp_ratio = mlp_ratio
         | 
| 239 | 
            +
                    self.out_dim = out_dim
         | 
| 240 | 
            +
                    self.num_heads = num_heads
         | 
| 241 | 
            +
                    self.num_layers = num_layers
         | 
| 242 | 
            +
                    self.pool_type = pool_type
         | 
| 243 | 
            +
                    self.post_norm = post_norm
         | 
| 244 | 
            +
                    self.norm_eps = norm_eps
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    # embeddings
         | 
| 247 | 
            +
                    gain = 1.0 / math.sqrt(dim)
         | 
| 248 | 
            +
                    self.patch_embedding = nn.Conv2d(
         | 
| 249 | 
            +
                        3,
         | 
| 250 | 
            +
                        dim,
         | 
| 251 | 
            +
                        kernel_size=patch_size,
         | 
| 252 | 
            +
                        stride=patch_size,
         | 
| 253 | 
            +
                        bias=not pre_norm)
         | 
| 254 | 
            +
                    if pool_type in ('token', 'token_fc'):
         | 
| 255 | 
            +
                        self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
         | 
| 256 | 
            +
                    self.pos_embedding = nn.Parameter(gain * torch.randn(
         | 
| 257 | 
            +
                        1, self.num_patches +
         | 
| 258 | 
            +
                        (1 if pool_type in ('token', 'token_fc') else 0), dim))
         | 
| 259 | 
            +
                    self.dropout = nn.Dropout(embedding_dropout)
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                    # transformer
         | 
| 262 | 
            +
                    self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
         | 
| 263 | 
            +
                    self.transformer = nn.Sequential(*[
         | 
| 264 | 
            +
                        AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
         | 
| 265 | 
            +
                                       activation, attn_dropout, proj_dropout, norm_eps)
         | 
| 266 | 
            +
                        for _ in range(num_layers)
         | 
| 267 | 
            +
                    ])
         | 
| 268 | 
            +
                    self.post_norm = LayerNorm(dim, eps=norm_eps)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # head
         | 
| 271 | 
            +
                    if pool_type == 'token':
         | 
| 272 | 
            +
                        self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
         | 
| 273 | 
            +
                    elif pool_type == 'token_fc':
         | 
| 274 | 
            +
                        self.head = nn.Linear(dim, out_dim)
         | 
| 275 | 
            +
                    elif pool_type == 'attn_pool':
         | 
| 276 | 
            +
                        self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
         | 
| 277 | 
            +
                                                  proj_dropout, norm_eps)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def forward(self, x, interpolation=False, use_31_block=False):
         | 
| 280 | 
            +
                    b = x.size(0)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                    # embeddings
         | 
| 283 | 
            +
                    x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
         | 
| 284 | 
            +
                    if self.pool_type in ('token', 'token_fc'):
         | 
| 285 | 
            +
                        x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
         | 
| 286 | 
            +
                    if interpolation:
         | 
| 287 | 
            +
                        e = pos_interpolate(self.pos_embedding, x.size(1))
         | 
| 288 | 
            +
                    else:
         | 
| 289 | 
            +
                        e = self.pos_embedding
         | 
| 290 | 
            +
                    x = self.dropout(x + e)
         | 
| 291 | 
            +
                    if self.pre_norm is not None:
         | 
| 292 | 
            +
                        x = self.pre_norm(x)
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                    # transformer
         | 
| 295 | 
            +
                    if use_31_block:
         | 
| 296 | 
            +
                        x = self.transformer[:-1](x)
         | 
| 297 | 
            +
                        return x
         | 
| 298 | 
            +
                    else:
         | 
| 299 | 
            +
                        x = self.transformer(x)
         | 
| 300 | 
            +
                        return x
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
            class XLMRobertaWithHead(XLMRoberta):
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def __init__(self, **kwargs):
         | 
| 306 | 
            +
                    self.out_dim = kwargs.pop('out_dim')
         | 
| 307 | 
            +
                    super().__init__(**kwargs)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    # head
         | 
| 310 | 
            +
                    mid_dim = (self.dim + self.out_dim) // 2
         | 
| 311 | 
            +
                    self.head = nn.Sequential(
         | 
| 312 | 
            +
                        nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
         | 
| 313 | 
            +
                        nn.Linear(mid_dim, self.out_dim, bias=False))
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def forward(self, ids):
         | 
| 316 | 
            +
                    # xlm-roberta
         | 
| 317 | 
            +
                    x = super().forward(ids)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    # average pooling
         | 
| 320 | 
            +
                    mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
         | 
| 321 | 
            +
                    x = (x * mask).sum(dim=1) / mask.sum(dim=1)
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    # head
         | 
| 324 | 
            +
                    x = self.head(x)
         | 
| 325 | 
            +
                    return x
         | 
| 326 | 
            +
             | 
| 327 | 
            +
             | 
| 328 | 
            +
            class XLMRobertaCLIP(nn.Module):
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def __init__(self,
         | 
| 331 | 
            +
                             embed_dim=1024,
         | 
| 332 | 
            +
                             image_size=224,
         | 
| 333 | 
            +
                             patch_size=14,
         | 
| 334 | 
            +
                             vision_dim=1280,
         | 
| 335 | 
            +
                             vision_mlp_ratio=4,
         | 
| 336 | 
            +
                             vision_heads=16,
         | 
| 337 | 
            +
                             vision_layers=32,
         | 
| 338 | 
            +
                             vision_pool='token',
         | 
| 339 | 
            +
                             vision_pre_norm=True,
         | 
| 340 | 
            +
                             vision_post_norm=False,
         | 
| 341 | 
            +
                             activation='gelu',
         | 
| 342 | 
            +
                             vocab_size=250002,
         | 
| 343 | 
            +
                             max_text_len=514,
         | 
| 344 | 
            +
                             type_size=1,
         | 
| 345 | 
            +
                             pad_id=1,
         | 
| 346 | 
            +
                             text_dim=1024,
         | 
| 347 | 
            +
                             text_heads=16,
         | 
| 348 | 
            +
                             text_layers=24,
         | 
| 349 | 
            +
                             text_post_norm=True,
         | 
| 350 | 
            +
                             text_dropout=0.1,
         | 
| 351 | 
            +
                             attn_dropout=0.0,
         | 
| 352 | 
            +
                             proj_dropout=0.0,
         | 
| 353 | 
            +
                             embedding_dropout=0.0,
         | 
| 354 | 
            +
                             norm_eps=1e-5):
         | 
| 355 | 
            +
                    super().__init__()
         | 
| 356 | 
            +
                    self.embed_dim = embed_dim
         | 
| 357 | 
            +
                    self.image_size = image_size
         | 
| 358 | 
            +
                    self.patch_size = patch_size
         | 
| 359 | 
            +
                    self.vision_dim = vision_dim
         | 
| 360 | 
            +
                    self.vision_mlp_ratio = vision_mlp_ratio
         | 
| 361 | 
            +
                    self.vision_heads = vision_heads
         | 
| 362 | 
            +
                    self.vision_layers = vision_layers
         | 
| 363 | 
            +
                    self.vision_pre_norm = vision_pre_norm
         | 
| 364 | 
            +
                    self.vision_post_norm = vision_post_norm
         | 
| 365 | 
            +
                    self.activation = activation
         | 
| 366 | 
            +
                    self.vocab_size = vocab_size
         | 
| 367 | 
            +
                    self.max_text_len = max_text_len
         | 
| 368 | 
            +
                    self.type_size = type_size
         | 
| 369 | 
            +
                    self.pad_id = pad_id
         | 
| 370 | 
            +
                    self.text_dim = text_dim
         | 
| 371 | 
            +
                    self.text_heads = text_heads
         | 
| 372 | 
            +
                    self.text_layers = text_layers
         | 
| 373 | 
            +
                    self.text_post_norm = text_post_norm
         | 
| 374 | 
            +
                    self.norm_eps = norm_eps
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    # models
         | 
| 377 | 
            +
                    self.visual = VisionTransformer(
         | 
| 378 | 
            +
                        image_size=image_size,
         | 
| 379 | 
            +
                        patch_size=patch_size,
         | 
| 380 | 
            +
                        dim=vision_dim,
         | 
| 381 | 
            +
                        mlp_ratio=vision_mlp_ratio,
         | 
| 382 | 
            +
                        out_dim=embed_dim,
         | 
| 383 | 
            +
                        num_heads=vision_heads,
         | 
| 384 | 
            +
                        num_layers=vision_layers,
         | 
| 385 | 
            +
                        pool_type=vision_pool,
         | 
| 386 | 
            +
                        pre_norm=vision_pre_norm,
         | 
| 387 | 
            +
                        post_norm=vision_post_norm,
         | 
| 388 | 
            +
                        activation=activation,
         | 
| 389 | 
            +
                        attn_dropout=attn_dropout,
         | 
| 390 | 
            +
                        proj_dropout=proj_dropout,
         | 
| 391 | 
            +
                        embedding_dropout=embedding_dropout,
         | 
| 392 | 
            +
                        norm_eps=norm_eps)
         | 
| 393 | 
            +
                    self.textual = XLMRobertaWithHead(
         | 
| 394 | 
            +
                        vocab_size=vocab_size,
         | 
| 395 | 
            +
                        max_seq_len=max_text_len,
         | 
| 396 | 
            +
                        type_size=type_size,
         | 
| 397 | 
            +
                        pad_id=pad_id,
         | 
| 398 | 
            +
                        dim=text_dim,
         | 
| 399 | 
            +
                        out_dim=embed_dim,
         | 
| 400 | 
            +
                        num_heads=text_heads,
         | 
| 401 | 
            +
                        num_layers=text_layers,
         | 
| 402 | 
            +
                        post_norm=text_post_norm,
         | 
| 403 | 
            +
                        dropout=text_dropout)
         | 
| 404 | 
            +
                    self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                def forward(self, imgs, txt_ids):
         | 
| 407 | 
            +
                    """
         | 
| 408 | 
            +
                    imgs:       [B, 3, H, W] of torch.float32.
         | 
| 409 | 
            +
                    - mean:     [0.48145466, 0.4578275, 0.40821073]
         | 
| 410 | 
            +
                    - std:      [0.26862954, 0.26130258, 0.27577711]
         | 
| 411 | 
            +
                    txt_ids:    [B, L] of torch.long.
         | 
| 412 | 
            +
                                Encoded by data.CLIPTokenizer.
         | 
| 413 | 
            +
                    """
         | 
| 414 | 
            +
                    xi = self.visual(imgs)
         | 
| 415 | 
            +
                    xt = self.textual(txt_ids)
         | 
| 416 | 
            +
                    return xi, xt
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                def param_groups(self):
         | 
| 419 | 
            +
                    groups = [{
         | 
| 420 | 
            +
                        'params': [
         | 
| 421 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 422 | 
            +
                            if 'norm' in n or n.endswith('bias')
         | 
| 423 | 
            +
                        ],
         | 
| 424 | 
            +
                        'weight_decay': 0.0
         | 
| 425 | 
            +
                    }, {
         | 
| 426 | 
            +
                        'params': [
         | 
| 427 | 
            +
                            p for n, p in self.named_parameters()
         | 
| 428 | 
            +
                            if not ('norm' in n or n.endswith('bias'))
         | 
| 429 | 
            +
                        ]
         | 
| 430 | 
            +
                    }]
         | 
| 431 | 
            +
                    return groups
         | 
| 432 | 
            +
             | 
| 433 | 
            +
             | 
| 434 | 
            +
            def _clip(pretrained=False,
         | 
| 435 | 
            +
                      pretrained_name=None,
         | 
| 436 | 
            +
                      model_cls=XLMRobertaCLIP,
         | 
| 437 | 
            +
                      return_transforms=False,
         | 
| 438 | 
            +
                      return_tokenizer=False,
         | 
| 439 | 
            +
                      tokenizer_padding='eos',
         | 
| 440 | 
            +
                      dtype=torch.float32,
         | 
| 441 | 
            +
                      device='cpu',
         | 
| 442 | 
            +
                      **kwargs):
         | 
| 443 | 
            +
                # init a model on device
         | 
| 444 | 
            +
                with torch.device(device):
         | 
| 445 | 
            +
                    model = model_cls(**kwargs)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                # set device
         | 
| 448 | 
            +
                model = model.to(dtype=dtype, device=device)
         | 
| 449 | 
            +
                output = (model,)
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                # init transforms
         | 
| 452 | 
            +
                if return_transforms:
         | 
| 453 | 
            +
                    # mean and std
         | 
| 454 | 
            +
                    if 'siglip' in pretrained_name.lower():
         | 
| 455 | 
            +
                        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
         | 
| 456 | 
            +
                    else:
         | 
| 457 | 
            +
                        mean = [0.48145466, 0.4578275, 0.40821073]
         | 
| 458 | 
            +
                        std = [0.26862954, 0.26130258, 0.27577711]
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    # transforms
         | 
| 461 | 
            +
                    transforms = T.Compose([
         | 
| 462 | 
            +
                        T.Resize((model.image_size, model.image_size),
         | 
| 463 | 
            +
                                 interpolation=T.InterpolationMode.BICUBIC),
         | 
| 464 | 
            +
                        T.ToTensor(),
         | 
| 465 | 
            +
                        T.Normalize(mean=mean, std=std)
         | 
| 466 | 
            +
                    ])
         | 
| 467 | 
            +
                    output += (transforms,)
         | 
| 468 | 
            +
                return output[0] if len(output) == 1 else output
         | 
| 469 | 
            +
             | 
| 470 | 
            +
             | 
| 471 | 
            +
            def clip_xlm_roberta_vit_h_14(
         | 
| 472 | 
            +
                    pretrained=False,
         | 
| 473 | 
            +
                    pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
         | 
| 474 | 
            +
                    **kwargs):
         | 
| 475 | 
            +
                cfg = dict(
         | 
| 476 | 
            +
                    embed_dim=1024,
         | 
| 477 | 
            +
                    image_size=224,
         | 
| 478 | 
            +
                    patch_size=14,
         | 
| 479 | 
            +
                    vision_dim=1280,
         | 
| 480 | 
            +
                    vision_mlp_ratio=4,
         | 
| 481 | 
            +
                    vision_heads=16,
         | 
| 482 | 
            +
                    vision_layers=32,
         | 
| 483 | 
            +
                    vision_pool='token',
         | 
| 484 | 
            +
                    activation='gelu',
         | 
| 485 | 
            +
                    vocab_size=250002,
         | 
| 486 | 
            +
                    max_text_len=514,
         | 
| 487 | 
            +
                    type_size=1,
         | 
| 488 | 
            +
                    pad_id=1,
         | 
| 489 | 
            +
                    text_dim=1024,
         | 
| 490 | 
            +
                    text_heads=16,
         | 
| 491 | 
            +
                    text_layers=24,
         | 
| 492 | 
            +
                    text_post_norm=True,
         | 
| 493 | 
            +
                    text_dropout=0.1,
         | 
| 494 | 
            +
                    attn_dropout=0.0,
         | 
| 495 | 
            +
                    proj_dropout=0.0,
         | 
| 496 | 
            +
                    embedding_dropout=0.0)
         | 
| 497 | 
            +
                cfg.update(**kwargs)
         | 
| 498 | 
            +
                return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
         | 
| 499 | 
            +
             | 
| 500 | 
            +
             | 
| 501 | 
            +
            class CLIPModel:
         | 
| 502 | 
            +
             | 
| 503 | 
            +
                def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
         | 
| 504 | 
            +
                    self.dtype = dtype
         | 
| 505 | 
            +
                    self.device = device
         | 
| 506 | 
            +
                    self.checkpoint_path = checkpoint_path
         | 
| 507 | 
            +
                    self.tokenizer_path = tokenizer_path
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    # init model
         | 
| 510 | 
            +
                    self.model, self.transforms = clip_xlm_roberta_vit_h_14(
         | 
| 511 | 
            +
                        pretrained=False,
         | 
| 512 | 
            +
                        return_transforms=True,
         | 
| 513 | 
            +
                        return_tokenizer=False,
         | 
| 514 | 
            +
                        dtype=dtype,
         | 
| 515 | 
            +
                        device=device)
         | 
| 516 | 
            +
                    self.model = self.model.eval().requires_grad_(False)
         | 
| 517 | 
            +
                    logging.info(f'loading {checkpoint_path}')
         | 
| 518 | 
            +
                    self.model.load_state_dict(
         | 
| 519 | 
            +
                        torch.load(checkpoint_path, map_location='cpu'))
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    # init tokenizer
         | 
| 522 | 
            +
                    self.tokenizer = HuggingfaceTokenizer(
         | 
| 523 | 
            +
                        name=tokenizer_path,
         | 
| 524 | 
            +
                        seq_len=self.model.max_text_len - 2,
         | 
| 525 | 
            +
                        clean='whitespace')
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                def visual(self, videos):
         | 
| 528 | 
            +
                    # preprocess
         | 
| 529 | 
            +
                    size = (self.model.image_size,) * 2
         | 
| 530 | 
            +
                    videos = torch.cat([
         | 
| 531 | 
            +
                        F.interpolate(
         | 
| 532 | 
            +
                            u.transpose(0, 1),
         | 
| 533 | 
            +
                            size=size,
         | 
| 534 | 
            +
                            mode='bicubic',
         | 
| 535 | 
            +
                            align_corners=False) for u in videos
         | 
| 536 | 
            +
                    ])
         | 
| 537 | 
            +
                    videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # forward
         | 
| 540 | 
            +
                    with torch.cuda.amp.autocast(dtype=self.dtype):
         | 
| 541 | 
            +
                        out = self.model.visual(videos, use_31_block=True)
         | 
| 542 | 
            +
                        return out
         | 
    	
        wan/modules/model.py
    ADDED
    
    | @@ -0,0 +1,631 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.cuda.amp as amp
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 8 | 
            +
            from diffusers.models.modeling_utils import ModelMixin
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .attention import flash_attention
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            __all__ = ['WanModel']
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            T5_CONTEXT_TOKEN_NUMBER = 512
         | 
| 15 | 
            +
            FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER = 257 * 2
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def sinusoidal_embedding_1d(dim, position):
         | 
| 19 | 
            +
                # preprocess
         | 
| 20 | 
            +
                assert dim % 2 == 0
         | 
| 21 | 
            +
                half = dim // 2
         | 
| 22 | 
            +
                position = position.type(torch.float64)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                # calculation
         | 
| 25 | 
            +
                sinusoid = torch.outer(
         | 
| 26 | 
            +
                    position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
         | 
| 27 | 
            +
                x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
         | 
| 28 | 
            +
                return x
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            @amp.autocast(enabled=False)
         | 
| 32 | 
            +
            def rope_params(max_seq_len, dim, theta=10000):
         | 
| 33 | 
            +
                assert dim % 2 == 0
         | 
| 34 | 
            +
                freqs = torch.outer(
         | 
| 35 | 
            +
                    torch.arange(max_seq_len),
         | 
| 36 | 
            +
                    1.0 / torch.pow(theta,
         | 
| 37 | 
            +
                                    torch.arange(0, dim, 2).to(torch.float64).div(dim)))
         | 
| 38 | 
            +
                freqs = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 39 | 
            +
                return freqs
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            @amp.autocast(enabled=False)
         | 
| 43 | 
            +
            def rope_apply(x, grid_sizes, freqs):
         | 
| 44 | 
            +
                n, c = x.size(2), x.size(3) // 2
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                # split freqs
         | 
| 47 | 
            +
                freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                # loop over samples
         | 
| 50 | 
            +
                output = []
         | 
| 51 | 
            +
                for i, (f, h, w) in enumerate(grid_sizes.tolist()):
         | 
| 52 | 
            +
                    seq_len = f * h * w
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                    # precompute multipliers
         | 
| 55 | 
            +
                    x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
         | 
| 56 | 
            +
                        seq_len, n, -1, 2))
         | 
| 57 | 
            +
                    freqs_i = torch.cat([
         | 
| 58 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 59 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 60 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 61 | 
            +
                    ],
         | 
| 62 | 
            +
                                        dim=-1).reshape(seq_len, 1, -1)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # apply rotary embedding
         | 
| 65 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
         | 
| 66 | 
            +
                    x_i = torch.cat([x_i, x[i, seq_len:]])
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    # append to collection
         | 
| 69 | 
            +
                    output.append(x_i)
         | 
| 70 | 
            +
                return torch.stack(output).float()
         | 
| 71 | 
            +
             | 
| 72 | 
            +
             | 
| 73 | 
            +
            class WanRMSNorm(nn.Module):
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 76 | 
            +
                    super().__init__()
         | 
| 77 | 
            +
                    self.dim = dim
         | 
| 78 | 
            +
                    self.eps = eps
         | 
| 79 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def forward(self, x):
         | 
| 82 | 
            +
                    r"""
         | 
| 83 | 
            +
                    Args:
         | 
| 84 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 85 | 
            +
                    """
         | 
| 86 | 
            +
                    return self._norm(x.float()).type_as(x) * self.weight
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def _norm(self, x):
         | 
| 89 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class WanLayerNorm(nn.LayerNorm):
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __init__(self, dim, eps=1e-6, elementwise_affine=False):
         | 
| 95 | 
            +
                    super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def forward(self, x):
         | 
| 98 | 
            +
                    r"""
         | 
| 99 | 
            +
                    Args:
         | 
| 100 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 101 | 
            +
                    """
         | 
| 102 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            class WanSelfAttention(nn.Module):
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def __init__(self,
         | 
| 108 | 
            +
                             dim,
         | 
| 109 | 
            +
                             num_heads,
         | 
| 110 | 
            +
                             window_size=(-1, -1),
         | 
| 111 | 
            +
                             qk_norm=True,
         | 
| 112 | 
            +
                             eps=1e-6):
         | 
| 113 | 
            +
                    assert dim % num_heads == 0
         | 
| 114 | 
            +
                    super().__init__()
         | 
| 115 | 
            +
                    self.dim = dim
         | 
| 116 | 
            +
                    self.num_heads = num_heads
         | 
| 117 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 118 | 
            +
                    self.window_size = window_size
         | 
| 119 | 
            +
                    self.qk_norm = qk_norm
         | 
| 120 | 
            +
                    self.eps = eps
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    # layers
         | 
| 123 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 124 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 125 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 126 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 127 | 
            +
                    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 128 | 
            +
                    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs):
         | 
| 131 | 
            +
                    r"""
         | 
| 132 | 
            +
                    Args:
         | 
| 133 | 
            +
                        x(Tensor): Shape [B, L, num_heads, C / num_heads]
         | 
| 134 | 
            +
                        seq_lens(Tensor): Shape [B]
         | 
| 135 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
         | 
| 136 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 137 | 
            +
                    """
         | 
| 138 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # query, key, value function
         | 
| 141 | 
            +
                    def qkv_fn(x):
         | 
| 142 | 
            +
                        q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 143 | 
            +
                        k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 144 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 145 | 
            +
                        return q, k, v
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    x = flash_attention(
         | 
| 150 | 
            +
                        q=rope_apply(q, grid_sizes, freqs),
         | 
| 151 | 
            +
                        k=rope_apply(k, grid_sizes, freqs),
         | 
| 152 | 
            +
                        v=v,
         | 
| 153 | 
            +
                        k_lens=seq_lens,
         | 
| 154 | 
            +
                        window_size=self.window_size)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # output
         | 
| 157 | 
            +
                    x = x.flatten(2)
         | 
| 158 | 
            +
                    x = self.o(x)
         | 
| 159 | 
            +
                    return x
         | 
| 160 | 
            +
             | 
| 161 | 
            +
             | 
| 162 | 
            +
            class WanT2VCrossAttention(WanSelfAttention):
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 165 | 
            +
                    r"""
         | 
| 166 | 
            +
                    Args:
         | 
| 167 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 168 | 
            +
                        context(Tensor): Shape [B, L2, C]
         | 
| 169 | 
            +
                        context_lens(Tensor): Shape [B]
         | 
| 170 | 
            +
                    """
         | 
| 171 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    # compute query, key, value
         | 
| 174 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 175 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 176 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    # compute attention
         | 
| 179 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    # output
         | 
| 182 | 
            +
                    x = x.flatten(2)
         | 
| 183 | 
            +
                    x = self.o(x)
         | 
| 184 | 
            +
                    return x
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
| 187 | 
            +
            class WanI2VCrossAttention(WanSelfAttention):
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def __init__(self,
         | 
| 190 | 
            +
                             dim,
         | 
| 191 | 
            +
                             num_heads,
         | 
| 192 | 
            +
                             window_size=(-1, -1),
         | 
| 193 | 
            +
                             qk_norm=True,
         | 
| 194 | 
            +
                             eps=1e-6):
         | 
| 195 | 
            +
                    super().__init__(dim, num_heads, window_size, qk_norm, eps)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    self.k_img = nn.Linear(dim, dim)
         | 
| 198 | 
            +
                    self.v_img = nn.Linear(dim, dim)
         | 
| 199 | 
            +
                    # self.alpha = nn.Parameter(torch.zeros((1, )))
         | 
| 200 | 
            +
                    self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 203 | 
            +
                    r"""
         | 
| 204 | 
            +
                    Args:
         | 
| 205 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 206 | 
            +
                        context(Tensor): Shape [B, L2, C]
         | 
| 207 | 
            +
                        context_lens(Tensor): Shape [B]
         | 
| 208 | 
            +
                    """
         | 
| 209 | 
            +
                    image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
         | 
| 210 | 
            +
                    context_img = context[:, :image_context_length]
         | 
| 211 | 
            +
                    context = context[:, image_context_length:]
         | 
| 212 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # compute query, key, value
         | 
| 215 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 216 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 217 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 218 | 
            +
                    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
         | 
| 219 | 
            +
                    v_img = self.v_img(context_img).view(b, -1, n, d)
         | 
| 220 | 
            +
                    img_x = flash_attention(q, k_img, v_img, k_lens=None)
         | 
| 221 | 
            +
                    # compute attention
         | 
| 222 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # output
         | 
| 225 | 
            +
                    x = x.flatten(2)
         | 
| 226 | 
            +
                    img_x = img_x.flatten(2)
         | 
| 227 | 
            +
                    x = x + img_x
         | 
| 228 | 
            +
                    x = self.o(x)
         | 
| 229 | 
            +
                    return x
         | 
| 230 | 
            +
             | 
| 231 | 
            +
             | 
| 232 | 
            +
            WAN_CROSSATTENTION_CLASSES = {
         | 
| 233 | 
            +
                't2v_cross_attn': WanT2VCrossAttention,
         | 
| 234 | 
            +
                'i2v_cross_attn': WanI2VCrossAttention,
         | 
| 235 | 
            +
            }
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
            class WanAttentionBlock(nn.Module):
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def __init__(self,
         | 
| 241 | 
            +
                             cross_attn_type,
         | 
| 242 | 
            +
                             dim,
         | 
| 243 | 
            +
                             ffn_dim,
         | 
| 244 | 
            +
                             num_heads,
         | 
| 245 | 
            +
                             window_size=(-1, -1),
         | 
| 246 | 
            +
                             qk_norm=True,
         | 
| 247 | 
            +
                             cross_attn_norm=False,
         | 
| 248 | 
            +
                             eps=1e-6):
         | 
| 249 | 
            +
                    super().__init__()
         | 
| 250 | 
            +
                    self.dim = dim
         | 
| 251 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 252 | 
            +
                    self.num_heads = num_heads
         | 
| 253 | 
            +
                    self.window_size = window_size
         | 
| 254 | 
            +
                    self.qk_norm = qk_norm
         | 
| 255 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 256 | 
            +
                    self.eps = eps
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    # layers
         | 
| 259 | 
            +
                    self.norm1 = WanLayerNorm(dim, eps)
         | 
| 260 | 
            +
                    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
         | 
| 261 | 
            +
                                                      eps)
         | 
| 262 | 
            +
                    self.norm3 = WanLayerNorm(
         | 
| 263 | 
            +
                        dim, eps,
         | 
| 264 | 
            +
                        elementwise_affine=True) if cross_attn_norm else nn.Identity()
         | 
| 265 | 
            +
                    self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
         | 
| 266 | 
            +
                                                                                  num_heads,
         | 
| 267 | 
            +
                                                                                  (-1, -1),
         | 
| 268 | 
            +
                                                                                  qk_norm,
         | 
| 269 | 
            +
                                                                                  eps)
         | 
| 270 | 
            +
                    self.norm2 = WanLayerNorm(dim, eps)
         | 
| 271 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 272 | 
            +
                        nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
         | 
| 273 | 
            +
                        nn.Linear(ffn_dim, dim))
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    # modulation
         | 
| 276 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                def forward(
         | 
| 279 | 
            +
                    self,
         | 
| 280 | 
            +
                    x,
         | 
| 281 | 
            +
                    e,
         | 
| 282 | 
            +
                    seq_lens,
         | 
| 283 | 
            +
                    grid_sizes,
         | 
| 284 | 
            +
                    freqs,
         | 
| 285 | 
            +
                    context,
         | 
| 286 | 
            +
                    context_lens,
         | 
| 287 | 
            +
                ):
         | 
| 288 | 
            +
                    r"""
         | 
| 289 | 
            +
                    Args:
         | 
| 290 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 291 | 
            +
                        e(Tensor): Shape [B, 6, C]
         | 
| 292 | 
            +
                        seq_lens(Tensor): Shape [B], length of each sequence in batch
         | 
| 293 | 
            +
                        grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
         | 
| 294 | 
            +
                        freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
         | 
| 295 | 
            +
                    """
         | 
| 296 | 
            +
                    assert e.dtype == torch.float32
         | 
| 297 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 298 | 
            +
                        e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
         | 
| 299 | 
            +
                    assert e[0].dtype == torch.float32
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    # self-attention
         | 
| 302 | 
            +
                    y = self.self_attn(
         | 
| 303 | 
            +
                        self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
         | 
| 304 | 
            +
                        freqs)
         | 
| 305 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 306 | 
            +
                        x = x + y * e[2]
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # cross-attention & ffn function
         | 
| 309 | 
            +
                    def cross_attn_ffn(x, context, context_lens, e):
         | 
| 310 | 
            +
                        x = x + self.cross_attn(self.norm3(x), context, context_lens)
         | 
| 311 | 
            +
                        y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
         | 
| 312 | 
            +
                        with amp.autocast(dtype=torch.float32):
         | 
| 313 | 
            +
                            x = x + y * e[5]
         | 
| 314 | 
            +
                        return x
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    x = cross_attn_ffn(x, context, context_lens, e)
         | 
| 317 | 
            +
                    return x
         | 
| 318 | 
            +
             | 
| 319 | 
            +
             | 
| 320 | 
            +
            class Head(nn.Module):
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                def __init__(self, dim, out_dim, patch_size, eps=1e-6):
         | 
| 323 | 
            +
                    super().__init__()
         | 
| 324 | 
            +
                    self.dim = dim
         | 
| 325 | 
            +
                    self.out_dim = out_dim
         | 
| 326 | 
            +
                    self.patch_size = patch_size
         | 
| 327 | 
            +
                    self.eps = eps
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    # layers
         | 
| 330 | 
            +
                    out_dim = math.prod(patch_size) * out_dim
         | 
| 331 | 
            +
                    self.norm = WanLayerNorm(dim, eps)
         | 
| 332 | 
            +
                    self.head = nn.Linear(dim, out_dim)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    # modulation
         | 
| 335 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                def forward(self, x, e):
         | 
| 338 | 
            +
                    r"""
         | 
| 339 | 
            +
                    Args:
         | 
| 340 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 341 | 
            +
                        e(Tensor): Shape [B, C]
         | 
| 342 | 
            +
                    """
         | 
| 343 | 
            +
                    assert e.dtype == torch.float32
         | 
| 344 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 345 | 
            +
                        e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
         | 
| 346 | 
            +
                        x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
         | 
| 347 | 
            +
                    return x
         | 
| 348 | 
            +
             | 
| 349 | 
            +
             | 
| 350 | 
            +
            class MLPProj(torch.nn.Module):
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def __init__(self, in_dim, out_dim, flf_pos_emb=False):
         | 
| 353 | 
            +
                    super().__init__()
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    self.proj = torch.nn.Sequential(
         | 
| 356 | 
            +
                        torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
         | 
| 357 | 
            +
                        torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
         | 
| 358 | 
            +
                        torch.nn.LayerNorm(out_dim))
         | 
| 359 | 
            +
                    if flf_pos_emb:  # NOTE: we only use this for `flf2v`
         | 
| 360 | 
            +
                        self.emb_pos = nn.Parameter(
         | 
| 361 | 
            +
                            torch.zeros(1, FIRST_LAST_FRAME_CONTEXT_TOKEN_NUMBER, 1280))
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                def forward(self, image_embeds):
         | 
| 364 | 
            +
                    if hasattr(self, 'emb_pos'):
         | 
| 365 | 
            +
                        bs, n, d = image_embeds.shape
         | 
| 366 | 
            +
                        image_embeds = image_embeds.view(-1, 2 * n, d)
         | 
| 367 | 
            +
                        image_embeds = image_embeds + self.emb_pos
         | 
| 368 | 
            +
                    clip_extra_context_tokens = self.proj(image_embeds)
         | 
| 369 | 
            +
                    return clip_extra_context_tokens
         | 
| 370 | 
            +
             | 
| 371 | 
            +
             | 
| 372 | 
            +
            class WanModel(ModelMixin, ConfigMixin):
         | 
| 373 | 
            +
                r"""
         | 
| 374 | 
            +
                Wan diffusion backbone supporting both text-to-video and image-to-video.
         | 
| 375 | 
            +
                """
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                ignore_for_config = [
         | 
| 378 | 
            +
                    'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
         | 
| 379 | 
            +
                ]
         | 
| 380 | 
            +
                _no_split_modules = ['WanAttentionBlock']
         | 
| 381 | 
            +
             | 
| 382 | 
            +
                @register_to_config
         | 
| 383 | 
            +
                def __init__(self,
         | 
| 384 | 
            +
                             model_type='t2v',
         | 
| 385 | 
            +
                             patch_size=(1, 2, 2),
         | 
| 386 | 
            +
                             text_len=512,
         | 
| 387 | 
            +
                             in_dim=16,
         | 
| 388 | 
            +
                             dim=2048,
         | 
| 389 | 
            +
                             ffn_dim=8192,
         | 
| 390 | 
            +
                             freq_dim=256,
         | 
| 391 | 
            +
                             text_dim=4096,
         | 
| 392 | 
            +
                             out_dim=16,
         | 
| 393 | 
            +
                             num_heads=16,
         | 
| 394 | 
            +
                             num_layers=32,
         | 
| 395 | 
            +
                             window_size=(-1, -1),
         | 
| 396 | 
            +
                             qk_norm=True,
         | 
| 397 | 
            +
                             cross_attn_norm=True,
         | 
| 398 | 
            +
                             eps=1e-6):
         | 
| 399 | 
            +
                    r"""
         | 
| 400 | 
            +
                    Initialize the diffusion model backbone.
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    Args:
         | 
| 403 | 
            +
                        model_type (`str`, *optional*, defaults to 't2v'):
         | 
| 404 | 
            +
                            Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) or 'flf2v' (first-last-frame-to-video) or 'vace'
         | 
| 405 | 
            +
                        patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
         | 
| 406 | 
            +
                            3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
         | 
| 407 | 
            +
                        text_len (`int`, *optional*, defaults to 512):
         | 
| 408 | 
            +
                            Fixed length for text embeddings
         | 
| 409 | 
            +
                        in_dim (`int`, *optional*, defaults to 16):
         | 
| 410 | 
            +
                            Input video channels (C_in)
         | 
| 411 | 
            +
                        dim (`int`, *optional*, defaults to 2048):
         | 
| 412 | 
            +
                            Hidden dimension of the transformer
         | 
| 413 | 
            +
                        ffn_dim (`int`, *optional*, defaults to 8192):
         | 
| 414 | 
            +
                            Intermediate dimension in feed-forward network
         | 
| 415 | 
            +
                        freq_dim (`int`, *optional*, defaults to 256):
         | 
| 416 | 
            +
                            Dimension for sinusoidal time embeddings
         | 
| 417 | 
            +
                        text_dim (`int`, *optional*, defaults to 4096):
         | 
| 418 | 
            +
                            Input dimension for text embeddings
         | 
| 419 | 
            +
                        out_dim (`int`, *optional*, defaults to 16):
         | 
| 420 | 
            +
                            Output video channels (C_out)
         | 
| 421 | 
            +
                        num_heads (`int`, *optional*, defaults to 16):
         | 
| 422 | 
            +
                            Number of attention heads
         | 
| 423 | 
            +
                        num_layers (`int`, *optional*, defaults to 32):
         | 
| 424 | 
            +
                            Number of transformer blocks
         | 
| 425 | 
            +
                        window_size (`tuple`, *optional*, defaults to (-1, -1)):
         | 
| 426 | 
            +
                            Window size for local attention (-1 indicates global attention)
         | 
| 427 | 
            +
                        qk_norm (`bool`, *optional*, defaults to True):
         | 
| 428 | 
            +
                            Enable query/key normalization
         | 
| 429 | 
            +
                        cross_attn_norm (`bool`, *optional*, defaults to False):
         | 
| 430 | 
            +
                            Enable cross-attention normalization
         | 
| 431 | 
            +
                        eps (`float`, *optional*, defaults to 1e-6):
         | 
| 432 | 
            +
                            Epsilon value for normalization layers
         | 
| 433 | 
            +
                    """
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    super().__init__()
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    assert model_type in ['t2v', 'i2v', 'flf2v', 'vace']
         | 
| 438 | 
            +
                    self.model_type = model_type
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    self.patch_size = patch_size
         | 
| 441 | 
            +
                    self.text_len = text_len
         | 
| 442 | 
            +
                    self.in_dim = in_dim
         | 
| 443 | 
            +
                    self.dim = dim
         | 
| 444 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 445 | 
            +
                    self.freq_dim = freq_dim
         | 
| 446 | 
            +
                    self.text_dim = text_dim
         | 
| 447 | 
            +
                    self.out_dim = out_dim
         | 
| 448 | 
            +
                    self.num_heads = num_heads
         | 
| 449 | 
            +
                    self.num_layers = num_layers
         | 
| 450 | 
            +
                    self.window_size = window_size
         | 
| 451 | 
            +
                    self.qk_norm = qk_norm
         | 
| 452 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 453 | 
            +
                    self.eps = eps
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    # embeddings
         | 
| 456 | 
            +
                    self.patch_embedding = nn.Conv3d(
         | 
| 457 | 
            +
                        in_dim, dim, kernel_size=patch_size, stride=patch_size)
         | 
| 458 | 
            +
                    self.text_embedding = nn.Sequential(
         | 
| 459 | 
            +
                        nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
         | 
| 460 | 
            +
                        nn.Linear(dim, dim))
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    self.time_embedding = nn.Sequential(
         | 
| 463 | 
            +
                        nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
         | 
| 464 | 
            +
                    self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                    # blocks
         | 
| 467 | 
            +
                    cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
         | 
| 468 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 469 | 
            +
                        WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
         | 
| 470 | 
            +
                                          window_size, qk_norm, cross_attn_norm, eps)
         | 
| 471 | 
            +
                        for _ in range(num_layers)
         | 
| 472 | 
            +
                    ])
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    # head
         | 
| 475 | 
            +
                    self.head = Head(dim, out_dim, patch_size, eps)
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    # buffers (don't use register_buffer otherwise dtype will be changed in to())
         | 
| 478 | 
            +
                    assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
         | 
| 479 | 
            +
                    d = dim // num_heads
         | 
| 480 | 
            +
                    self.freqs = torch.cat([
         | 
| 481 | 
            +
                        rope_params(1024, d - 4 * (d // 6)),
         | 
| 482 | 
            +
                        rope_params(1024, 2 * (d // 6)),
         | 
| 483 | 
            +
                        rope_params(1024, 2 * (d // 6))
         | 
| 484 | 
            +
                    ],
         | 
| 485 | 
            +
                                           dim=1)
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    if model_type == 'i2v' or model_type == 'flf2v':
         | 
| 488 | 
            +
                        self.img_emb = MLPProj(1280, dim, flf_pos_emb=model_type == 'flf2v')
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    # initialize weights
         | 
| 491 | 
            +
                    self.init_weights()
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                def forward(
         | 
| 494 | 
            +
                    self,
         | 
| 495 | 
            +
                    x,
         | 
| 496 | 
            +
                    t,
         | 
| 497 | 
            +
                    context,
         | 
| 498 | 
            +
                    seq_len,
         | 
| 499 | 
            +
                    clip_fea=None,
         | 
| 500 | 
            +
                    y=None,
         | 
| 501 | 
            +
                ):
         | 
| 502 | 
            +
                    r"""
         | 
| 503 | 
            +
                    Forward pass through the diffusion model
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    Args:
         | 
| 506 | 
            +
                        x (List[Tensor]):
         | 
| 507 | 
            +
                            List of input video tensors, each with shape [C_in, F, H, W]
         | 
| 508 | 
            +
                        t (Tensor):
         | 
| 509 | 
            +
                            Diffusion timesteps tensor of shape [B]
         | 
| 510 | 
            +
                        context (List[Tensor]):
         | 
| 511 | 
            +
                            List of text embeddings each with shape [L, C]
         | 
| 512 | 
            +
                        seq_len (`int`):
         | 
| 513 | 
            +
                            Maximum sequence length for positional encoding
         | 
| 514 | 
            +
                        clip_fea (Tensor, *optional*):
         | 
| 515 | 
            +
                            CLIP image features for image-to-video mode or first-last-frame-to-video mode
         | 
| 516 | 
            +
                        y (List[Tensor], *optional*):
         | 
| 517 | 
            +
                            Conditional video inputs for image-to-video mode, same shape as x
         | 
| 518 | 
            +
             | 
| 519 | 
            +
                    Returns:
         | 
| 520 | 
            +
                        List[Tensor]:
         | 
| 521 | 
            +
                            List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
         | 
| 522 | 
            +
                    """
         | 
| 523 | 
            +
                    if self.model_type == 'i2v' or self.model_type == 'flf2v':
         | 
| 524 | 
            +
                        assert clip_fea is not None and y is not None
         | 
| 525 | 
            +
                    # params
         | 
| 526 | 
            +
                    device = self.patch_embedding.weight.device
         | 
| 527 | 
            +
                    if self.freqs.device != device:
         | 
| 528 | 
            +
                        self.freqs = self.freqs.to(device)
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    if y is not None:
         | 
| 531 | 
            +
                        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                    # embeddings
         | 
| 534 | 
            +
                    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 535 | 
            +
                    grid_sizes = torch.stack(
         | 
| 536 | 
            +
                        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 537 | 
            +
                    x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 538 | 
            +
                    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 539 | 
            +
                    assert seq_lens.max() <= seq_len
         | 
| 540 | 
            +
                    x = torch.cat([
         | 
| 541 | 
            +
                        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
         | 
| 542 | 
            +
                                  dim=1) for u in x
         | 
| 543 | 
            +
                    ])
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    # time embeddings
         | 
| 546 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 547 | 
            +
                        e = self.time_embedding(
         | 
| 548 | 
            +
                            sinusoidal_embedding_1d(self.freq_dim, t).float())
         | 
| 549 | 
            +
                        e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 550 | 
            +
                        assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    # context
         | 
| 553 | 
            +
                    context_lens = None
         | 
| 554 | 
            +
                    context = self.text_embedding(
         | 
| 555 | 
            +
                        torch.stack([
         | 
| 556 | 
            +
                            torch.cat(
         | 
| 557 | 
            +
                                [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 558 | 
            +
                            for u in context
         | 
| 559 | 
            +
                        ]))
         | 
| 560 | 
            +
             | 
| 561 | 
            +
                    if clip_fea is not None:
         | 
| 562 | 
            +
                        context_clip = self.img_emb(clip_fea)  # bs x 257 (x2) x dim
         | 
| 563 | 
            +
                        context = torch.concat([context_clip, context], dim=1)
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                    # arguments
         | 
| 566 | 
            +
                    kwargs = dict(
         | 
| 567 | 
            +
                        e=e0,
         | 
| 568 | 
            +
                        seq_lens=seq_lens,
         | 
| 569 | 
            +
                        grid_sizes=grid_sizes,
         | 
| 570 | 
            +
                        freqs=self.freqs,
         | 
| 571 | 
            +
                        context=context,
         | 
| 572 | 
            +
                        context_lens=context_lens)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    for block in self.blocks:
         | 
| 575 | 
            +
                        x = block(x, **kwargs)
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                    # head
         | 
| 578 | 
            +
                    x = self.head(x, e)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
                    # unpatchify
         | 
| 581 | 
            +
                    x = self.unpatchify(x, grid_sizes)
         | 
| 582 | 
            +
                    return [u.float() for u in x]
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                def unpatchify(self, x, grid_sizes):
         | 
| 585 | 
            +
                    r"""
         | 
| 586 | 
            +
                    Reconstruct video tensors from patch embeddings.
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    Args:
         | 
| 589 | 
            +
                        x (List[Tensor]):
         | 
| 590 | 
            +
                            List of patchified features, each with shape [L, C_out * prod(patch_size)]
         | 
| 591 | 
            +
                        grid_sizes (Tensor):
         | 
| 592 | 
            +
                            Original spatial-temporal grid dimensions before patching,
         | 
| 593 | 
            +
                                shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    Returns:
         | 
| 596 | 
            +
                        List[Tensor]:
         | 
| 597 | 
            +
                            Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
         | 
| 598 | 
            +
                    """
         | 
| 599 | 
            +
             | 
| 600 | 
            +
                    c = self.out_dim
         | 
| 601 | 
            +
                    out = []
         | 
| 602 | 
            +
                    for u, v in zip(x, grid_sizes.tolist()):
         | 
| 603 | 
            +
                        u = u[:math.prod(v)].view(*v, *self.patch_size, c)
         | 
| 604 | 
            +
                        u = torch.einsum('fhwpqrc->cfphqwr', u)
         | 
| 605 | 
            +
                        u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
         | 
| 606 | 
            +
                        out.append(u)
         | 
| 607 | 
            +
                    return out
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                def init_weights(self):
         | 
| 610 | 
            +
                    r"""
         | 
| 611 | 
            +
                    Initialize model parameters using Xavier initialization.
         | 
| 612 | 
            +
                    """
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    # basic init
         | 
| 615 | 
            +
                    for m in self.modules():
         | 
| 616 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 617 | 
            +
                            nn.init.xavier_uniform_(m.weight)
         | 
| 618 | 
            +
                            if m.bias is not None:
         | 
| 619 | 
            +
                                nn.init.zeros_(m.bias)
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                    # init embeddings
         | 
| 622 | 
            +
                    nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
         | 
| 623 | 
            +
                    for m in self.text_embedding.modules():
         | 
| 624 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 625 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 626 | 
            +
                    for m in self.time_embedding.modules():
         | 
| 627 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 628 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                    # init output layer
         | 
| 631 | 
            +
                    nn.init.zeros_(self.head.head.weight)
         | 
    	
        wan/modules/multitalk_model.py
    ADDED
    
    | @@ -0,0 +1,799 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            import os
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            import torch.cuda.amp as amp
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from einops import rearrange
         | 
| 11 | 
            +
            from diffusers import ModelMixin
         | 
| 12 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from .attention import flash_attention, SingleStreamMutiAttention
         | 
| 15 | 
            +
            from ..utils.multitalk_utils import get_attn_map_with_target
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            __all__ = ['WanModel']
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def sinusoidal_embedding_1d(dim, position):
         | 
| 22 | 
            +
                # preprocess
         | 
| 23 | 
            +
                assert dim % 2 == 0
         | 
| 24 | 
            +
                half = dim // 2
         | 
| 25 | 
            +
                position = position.type(torch.float64)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                # calculation
         | 
| 28 | 
            +
                sinusoid = torch.outer(
         | 
| 29 | 
            +
                    position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
         | 
| 30 | 
            +
                x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
         | 
| 31 | 
            +
                return x
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            @amp.autocast(enabled=False)
         | 
| 35 | 
            +
            def rope_params(max_seq_len, dim, theta=10000):
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                assert dim % 2 == 0
         | 
| 38 | 
            +
                freqs = torch.outer(
         | 
| 39 | 
            +
                    torch.arange(max_seq_len),
         | 
| 40 | 
            +
                    1.0 / torch.pow(theta,
         | 
| 41 | 
            +
                                    torch.arange(0, dim, 2).to(torch.float64).div(dim)))
         | 
| 42 | 
            +
                freqs = torch.polar(torch.ones_like(freqs), freqs)
         | 
| 43 | 
            +
                return freqs
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            @amp.autocast(enabled=False)
         | 
| 47 | 
            +
            def rope_apply(x, grid_sizes, freqs):
         | 
| 48 | 
            +
                s, n, c = x.size(1), x.size(2), x.size(3) // 2
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                output = []
         | 
| 53 | 
            +
                for i, (f, h, w) in enumerate(grid_sizes.tolist()):
         | 
| 54 | 
            +
                    seq_len = f * h * w
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
         | 
| 57 | 
            +
                        s, n, -1, 2))
         | 
| 58 | 
            +
                    freqs_i = torch.cat([
         | 
| 59 | 
            +
                        freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
         | 
| 60 | 
            +
                        freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
         | 
| 61 | 
            +
                        freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
         | 
| 62 | 
            +
                    ],
         | 
| 63 | 
            +
                                        dim=-1).reshape(seq_len, 1, -1)
         | 
| 64 | 
            +
                    freqs_i = freqs_i.to(device=x_i.device)
         | 
| 65 | 
            +
                    x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
         | 
| 66 | 
            +
                    x_i = torch.cat([x_i, x[i, seq_len:]])
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    output.append(x_i)
         | 
| 69 | 
            +
                return torch.stack(output).float()
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class WanRMSNorm(nn.Module):
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def __init__(self, dim, eps=1e-5):
         | 
| 75 | 
            +
                    super().__init__()
         | 
| 76 | 
            +
                    self.dim = dim
         | 
| 77 | 
            +
                    self.eps = eps
         | 
| 78 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def forward(self, x):
         | 
| 81 | 
            +
                    r"""
         | 
| 82 | 
            +
                    Args:
         | 
| 83 | 
            +
                        x(Tensor): Shape [B, L, C]
         | 
| 84 | 
            +
                    """
         | 
| 85 | 
            +
                    return self._norm(x.float()).type_as(x) * self.weight
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                def _norm(self, x):
         | 
| 88 | 
            +
                    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            class WanLayerNorm(nn.LayerNorm):
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                def __init__(self, dim, eps=1e-6, elementwise_affine=False):
         | 
| 94 | 
            +
                    super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         | 
| 97 | 
            +
                    origin_dtype = inputs.dtype
         | 
| 98 | 
            +
                    out = F.layer_norm(
         | 
| 99 | 
            +
                        inputs.float(), 
         | 
| 100 | 
            +
                        self.normalized_shape, 
         | 
| 101 | 
            +
                        None if self.weight is None else self.weight.float(), 
         | 
| 102 | 
            +
                        None if self.bias is None else self.bias.float() ,
         | 
| 103 | 
            +
                        self.eps
         | 
| 104 | 
            +
                    ).to(origin_dtype)
         | 
| 105 | 
            +
                    return out
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            class WanSelfAttention(nn.Module):
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def __init__(self,
         | 
| 111 | 
            +
                             dim,
         | 
| 112 | 
            +
                             num_heads,
         | 
| 113 | 
            +
                             window_size=(-1, -1),
         | 
| 114 | 
            +
                             qk_norm=True,
         | 
| 115 | 
            +
                             eps=1e-6):
         | 
| 116 | 
            +
                    assert dim % num_heads == 0
         | 
| 117 | 
            +
                    super().__init__()
         | 
| 118 | 
            +
                    self.dim = dim
         | 
| 119 | 
            +
                    self.num_heads = num_heads
         | 
| 120 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 121 | 
            +
                    self.window_size = window_size
         | 
| 122 | 
            +
                    self.qk_norm = qk_norm
         | 
| 123 | 
            +
                    self.eps = eps
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # layers
         | 
| 126 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 127 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 128 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 129 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 130 | 
            +
                    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 131 | 
            +
                    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                def forward(self, x, seq_lens, grid_sizes, freqs, ref_target_masks=None):
         | 
| 134 | 
            +
                    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    # query, key, value function
         | 
| 137 | 
            +
                    def qkv_fn(x):
         | 
| 138 | 
            +
                        q = self.norm_q(self.q(x)).view(b, s, n, d)
         | 
| 139 | 
            +
                        k = self.norm_k(self.k(x)).view(b, s, n, d)
         | 
| 140 | 
            +
                        v = self.v(x).view(b, s, n, d)
         | 
| 141 | 
            +
                        return q, k, v
         | 
| 142 | 
            +
                    q, k, v = qkv_fn(x)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    q = rope_apply(q, grid_sizes, freqs)
         | 
| 145 | 
            +
                    k = rope_apply(k, grid_sizes, freqs)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    
         | 
| 148 | 
            +
                    x = flash_attention(
         | 
| 149 | 
            +
                        q=q,
         | 
| 150 | 
            +
                        k=k,
         | 
| 151 | 
            +
                        v=v,
         | 
| 152 | 
            +
                        k_lens=seq_lens,
         | 
| 153 | 
            +
                        window_size=self.window_size
         | 
| 154 | 
            +
                    ).type_as(x)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    # output
         | 
| 157 | 
            +
                    x = x.flatten(2)
         | 
| 158 | 
            +
                    x = self.o(x)
         | 
| 159 | 
            +
                    with torch.no_grad():
         | 
| 160 | 
            +
                        x_ref_attn_map = get_attn_map_with_target(q.type_as(x), k.type_as(x), grid_sizes[0], 
         | 
| 161 | 
            +
                                                                ref_target_masks=ref_target_masks)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    return x, x_ref_attn_map
         | 
| 164 | 
            +
             | 
| 165 | 
            +
             | 
| 166 | 
            +
            class WanI2VCrossAttention(WanSelfAttention):
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                def __init__(self,
         | 
| 169 | 
            +
                             dim,
         | 
| 170 | 
            +
                             num_heads,
         | 
| 171 | 
            +
                             window_size=(-1, -1),
         | 
| 172 | 
            +
                             qk_norm=True,
         | 
| 173 | 
            +
                             eps=1e-6):
         | 
| 174 | 
            +
                    super().__init__(dim, num_heads, window_size, qk_norm, eps)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    self.k_img = nn.Linear(dim, dim)
         | 
| 177 | 
            +
                    self.v_img = nn.Linear(dim, dim)
         | 
| 178 | 
            +
                    self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def forward(self, x, context, context_lens):
         | 
| 181 | 
            +
                    context_img = context[:, :257]
         | 
| 182 | 
            +
                    context = context[:, 257:]
         | 
| 183 | 
            +
                    b, n, d = x.size(0), self.num_heads, self.head_dim
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    # compute query, key, value
         | 
| 186 | 
            +
                    q = self.norm_q(self.q(x)).view(b, -1, n, d)
         | 
| 187 | 
            +
                    k = self.norm_k(self.k(context)).view(b, -1, n, d)
         | 
| 188 | 
            +
                    v = self.v(context).view(b, -1, n, d)
         | 
| 189 | 
            +
                    k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
         | 
| 190 | 
            +
                    v_img = self.v_img(context_img).view(b, -1, n, d)
         | 
| 191 | 
            +
                    img_x = flash_attention(q, k_img, v_img, k_lens=None)
         | 
| 192 | 
            +
                    # compute attention
         | 
| 193 | 
            +
                    x = flash_attention(q, k, v, k_lens=context_lens)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # output
         | 
| 196 | 
            +
                    x = x.flatten(2)
         | 
| 197 | 
            +
                    img_x = img_x.flatten(2)
         | 
| 198 | 
            +
                    x = x + img_x
         | 
| 199 | 
            +
                    x = self.o(x)
         | 
| 200 | 
            +
                    return x
         | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            class WanAttentionBlock(nn.Module):
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                def __init__(self,
         | 
| 206 | 
            +
                             cross_attn_type,
         | 
| 207 | 
            +
                             dim,
         | 
| 208 | 
            +
                             ffn_dim,
         | 
| 209 | 
            +
                             num_heads,
         | 
| 210 | 
            +
                             window_size=(-1, -1),
         | 
| 211 | 
            +
                             qk_norm=True,
         | 
| 212 | 
            +
                             cross_attn_norm=False,
         | 
| 213 | 
            +
                             eps=1e-6,
         | 
| 214 | 
            +
                             output_dim=768,
         | 
| 215 | 
            +
                             norm_input_visual=True,
         | 
| 216 | 
            +
                             class_range=24,
         | 
| 217 | 
            +
                             class_interval=4):
         | 
| 218 | 
            +
                    super().__init__()
         | 
| 219 | 
            +
                    self.dim = dim
         | 
| 220 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 221 | 
            +
                    self.num_heads = num_heads
         | 
| 222 | 
            +
                    self.window_size = window_size
         | 
| 223 | 
            +
                    self.qk_norm = qk_norm
         | 
| 224 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 225 | 
            +
                    self.eps = eps
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # layers
         | 
| 228 | 
            +
                    self.norm1 = WanLayerNorm(dim, eps)
         | 
| 229 | 
            +
                    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
         | 
| 230 | 
            +
                    self.norm3 = WanLayerNorm(
         | 
| 231 | 
            +
                        dim, eps,
         | 
| 232 | 
            +
                        elementwise_affine=True) if cross_attn_norm else nn.Identity()
         | 
| 233 | 
            +
                    self.cross_attn = WanI2VCrossAttention(dim,
         | 
| 234 | 
            +
                                                            num_heads,
         | 
| 235 | 
            +
                                                            (-1, -1),
         | 
| 236 | 
            +
                                                            qk_norm,
         | 
| 237 | 
            +
                                                            eps)
         | 
| 238 | 
            +
                    self.norm2 = WanLayerNorm(dim, eps)
         | 
| 239 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 240 | 
            +
                        nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
         | 
| 241 | 
            +
                        nn.Linear(ffn_dim, dim))
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    # modulation
         | 
| 244 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    # init audio module
         | 
| 247 | 
            +
                    self.audio_cross_attn = SingleStreamMutiAttention(
         | 
| 248 | 
            +
                            dim=dim,
         | 
| 249 | 
            +
                            encoder_hidden_states_dim=output_dim,
         | 
| 250 | 
            +
                            num_heads=num_heads,
         | 
| 251 | 
            +
                            qk_norm=False,
         | 
| 252 | 
            +
                            qkv_bias=True,
         | 
| 253 | 
            +
                            eps=eps,
         | 
| 254 | 
            +
                            norm_layer=WanRMSNorm,
         | 
| 255 | 
            +
                            class_range=class_range,
         | 
| 256 | 
            +
                            class_interval=class_interval
         | 
| 257 | 
            +
                        )
         | 
| 258 | 
            +
                    self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True)  if norm_input_visual else nn.Identity()
         | 
| 259 | 
            +
                    
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def forward(
         | 
| 262 | 
            +
                    self,
         | 
| 263 | 
            +
                    x,
         | 
| 264 | 
            +
                    e,
         | 
| 265 | 
            +
                    seq_lens,
         | 
| 266 | 
            +
                    grid_sizes,
         | 
| 267 | 
            +
                    freqs,
         | 
| 268 | 
            +
                    context,
         | 
| 269 | 
            +
                    context_lens,
         | 
| 270 | 
            +
                    audio_embedding=None,
         | 
| 271 | 
            +
                    ref_target_masks=None,
         | 
| 272 | 
            +
                    human_num=None,
         | 
| 273 | 
            +
                ):
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                    dtype = x.dtype
         | 
| 276 | 
            +
                    assert e.dtype == torch.float32
         | 
| 277 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 278 | 
            +
                        e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
         | 
| 279 | 
            +
                    assert e[0].dtype == torch.float32
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    # self-attention
         | 
| 282 | 
            +
                    y, x_ref_attn_map = self.self_attn(
         | 
| 283 | 
            +
                        (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
         | 
| 284 | 
            +
                        freqs, ref_target_masks=ref_target_masks)
         | 
| 285 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 286 | 
            +
                        x = x + y * e[2]
         | 
| 287 | 
            +
                    
         | 
| 288 | 
            +
                    x = x.to(dtype)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    # cross-attention of text
         | 
| 291 | 
            +
                    x = x + self.cross_attn(self.norm3(x), context, context_lens)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    # cross attn of audio
         | 
| 294 | 
            +
                    x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
         | 
| 295 | 
            +
                                                    shape=grid_sizes[0], x_ref_attn_map=x_ref_attn_map, human_num=human_num)
         | 
| 296 | 
            +
                    x = x + x_a
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
         | 
| 299 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 300 | 
            +
                        x = x + y * e[5]
         | 
| 301 | 
            +
             | 
| 302 | 
            +
             | 
| 303 | 
            +
                    x = x.to(dtype)
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    return x
         | 
| 306 | 
            +
             | 
| 307 | 
            +
             | 
| 308 | 
            +
            class Head(nn.Module):
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                def __init__(self, dim, out_dim, patch_size, eps=1e-6):
         | 
| 311 | 
            +
                    super().__init__()
         | 
| 312 | 
            +
                    self.dim = dim
         | 
| 313 | 
            +
                    self.out_dim = out_dim
         | 
| 314 | 
            +
                    self.patch_size = patch_size
         | 
| 315 | 
            +
                    self.eps = eps
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                    # layers
         | 
| 318 | 
            +
                    out_dim = math.prod(patch_size) * out_dim
         | 
| 319 | 
            +
                    self.norm = WanLayerNorm(dim, eps)
         | 
| 320 | 
            +
                    self.head = nn.Linear(dim, out_dim)
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    # modulation
         | 
| 323 | 
            +
                    self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                def forward(self, x, e):
         | 
| 326 | 
            +
                    r"""
         | 
| 327 | 
            +
                    Args:
         | 
| 328 | 
            +
                        x(Tensor): Shape [B, L1, C]
         | 
| 329 | 
            +
                        e(Tensor): Shape [B, C]
         | 
| 330 | 
            +
                    """
         | 
| 331 | 
            +
                    assert e.dtype == torch.float32
         | 
| 332 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 333 | 
            +
                        e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
         | 
| 334 | 
            +
                        x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
         | 
| 335 | 
            +
                    return x
         | 
| 336 | 
            +
             | 
| 337 | 
            +
             | 
| 338 | 
            +
            class MLPProj(torch.nn.Module):
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                def __init__(self, in_dim, out_dim):
         | 
| 341 | 
            +
                    super().__init__()
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    self.proj = torch.nn.Sequential(
         | 
| 344 | 
            +
                        torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
         | 
| 345 | 
            +
                        torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
         | 
| 346 | 
            +
                        torch.nn.LayerNorm(out_dim))
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def forward(self, image_embeds):
         | 
| 349 | 
            +
                    clip_extra_context_tokens = self.proj(image_embeds)
         | 
| 350 | 
            +
                    return clip_extra_context_tokens
         | 
| 351 | 
            +
             | 
| 352 | 
            +
             | 
| 353 | 
            +
            class AudioProjModel(ModelMixin, ConfigMixin):
         | 
| 354 | 
            +
                def __init__(
         | 
| 355 | 
            +
                    self,
         | 
| 356 | 
            +
                    seq_len=5,
         | 
| 357 | 
            +
                    seq_len_vf=12,
         | 
| 358 | 
            +
                    blocks=12,  
         | 
| 359 | 
            +
                    channels=768, 
         | 
| 360 | 
            +
                    intermediate_dim=512,
         | 
| 361 | 
            +
                    output_dim=768,
         | 
| 362 | 
            +
                    context_tokens=32,
         | 
| 363 | 
            +
                    norm_output_audio=False,
         | 
| 364 | 
            +
                ):
         | 
| 365 | 
            +
                    super().__init__()
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    self.seq_len = seq_len
         | 
| 368 | 
            +
                    self.blocks = blocks
         | 
| 369 | 
            +
                    self.channels = channels
         | 
| 370 | 
            +
                    self.input_dim = seq_len * blocks * channels  
         | 
| 371 | 
            +
                    self.input_dim_vf = seq_len_vf * blocks * channels
         | 
| 372 | 
            +
                    self.intermediate_dim = intermediate_dim
         | 
| 373 | 
            +
                    self.context_tokens = context_tokens
         | 
| 374 | 
            +
                    self.output_dim = output_dim
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    # define multiple linear layers
         | 
| 377 | 
            +
                    self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
         | 
| 378 | 
            +
                    self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
         | 
| 379 | 
            +
                    self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
         | 
| 380 | 
            +
                    self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
         | 
| 381 | 
            +
                    self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                def forward(self, audio_embeds, audio_embeds_vf):
         | 
| 384 | 
            +
                    video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
         | 
| 385 | 
            +
                    B, _, _, S, C = audio_embeds.shape
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    # process audio of first frame
         | 
| 388 | 
            +
                    audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
         | 
| 389 | 
            +
                    batch_size, window_size, blocks, channels = audio_embeds.shape
         | 
| 390 | 
            +
                    audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    # process audio of latter frame
         | 
| 393 | 
            +
                    audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
         | 
| 394 | 
            +
                    batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
         | 
| 395 | 
            +
                    audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    # first projection
         | 
| 398 | 
            +
                    audio_embeds = torch.relu(self.proj1(audio_embeds)) 
         | 
| 399 | 
            +
                    audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) 
         | 
| 400 | 
            +
                    audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
         | 
| 401 | 
            +
                    audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
         | 
| 402 | 
            +
                    audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) 
         | 
| 403 | 
            +
                    batch_size_c, N_t, C_a = audio_embeds_c.shape
         | 
| 404 | 
            +
                    audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    # second projection
         | 
| 407 | 
            +
                    audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                    # normalization and reshape
         | 
| 412 | 
            +
                    context_tokens = self.norm(context_tokens)
         | 
| 413 | 
            +
                    context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    return context_tokens
         | 
| 416 | 
            +
             | 
| 417 | 
            +
             | 
| 418 | 
            +
            class WanModel(ModelMixin, ConfigMixin):
         | 
| 419 | 
            +
                r"""
         | 
| 420 | 
            +
                Wan diffusion backbone supporting both text-to-video and image-to-video.
         | 
| 421 | 
            +
                """
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                ignore_for_config = [
         | 
| 424 | 
            +
                    'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
         | 
| 425 | 
            +
                ]
         | 
| 426 | 
            +
                _no_split_modules = ['WanAttentionBlock']
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                @register_to_config
         | 
| 429 | 
            +
                def __init__(self,
         | 
| 430 | 
            +
                             model_type='i2v',
         | 
| 431 | 
            +
                             patch_size=(1, 2, 2),
         | 
| 432 | 
            +
                             text_len=512,
         | 
| 433 | 
            +
                             in_dim=16,
         | 
| 434 | 
            +
                             dim=2048,
         | 
| 435 | 
            +
                             ffn_dim=8192,
         | 
| 436 | 
            +
                             freq_dim=256,
         | 
| 437 | 
            +
                             text_dim=4096,
         | 
| 438 | 
            +
                             out_dim=16,
         | 
| 439 | 
            +
                             num_heads=16,
         | 
| 440 | 
            +
                             num_layers=32,
         | 
| 441 | 
            +
                             window_size=(-1, -1),
         | 
| 442 | 
            +
                             qk_norm=True,
         | 
| 443 | 
            +
                             cross_attn_norm=True,
         | 
| 444 | 
            +
                             eps=1e-6,
         | 
| 445 | 
            +
                             # audio params
         | 
| 446 | 
            +
                             audio_window=5,
         | 
| 447 | 
            +
                             intermediate_dim=512,
         | 
| 448 | 
            +
                             output_dim=768,
         | 
| 449 | 
            +
                             context_tokens=32,
         | 
| 450 | 
            +
                             vae_scale=4, # vae timedownsample scale
         | 
| 451 | 
            +
             | 
| 452 | 
            +
                             norm_input_visual=True,
         | 
| 453 | 
            +
                             norm_output_audio=True):
         | 
| 454 | 
            +
                    super().__init__()
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
         | 
| 457 | 
            +
                    self.model_type = model_type
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    self.patch_size = patch_size
         | 
| 460 | 
            +
                    self.text_len = text_len
         | 
| 461 | 
            +
                    self.in_dim = in_dim
         | 
| 462 | 
            +
                    self.dim = dim
         | 
| 463 | 
            +
                    self.ffn_dim = ffn_dim
         | 
| 464 | 
            +
                    self.freq_dim = freq_dim
         | 
| 465 | 
            +
                    self.text_dim = text_dim
         | 
| 466 | 
            +
                    self.out_dim = out_dim
         | 
| 467 | 
            +
                    self.num_heads = num_heads
         | 
| 468 | 
            +
                    self.num_layers = num_layers
         | 
| 469 | 
            +
                    self.window_size = window_size
         | 
| 470 | 
            +
                    self.qk_norm = qk_norm
         | 
| 471 | 
            +
                    self.cross_attn_norm = cross_attn_norm
         | 
| 472 | 
            +
                    self.eps = eps
         | 
| 473 | 
            +
             | 
| 474 | 
            +
             | 
| 475 | 
            +
                    self.norm_output_audio = norm_output_audio
         | 
| 476 | 
            +
                    self.audio_window = audio_window
         | 
| 477 | 
            +
                    self.intermediate_dim = intermediate_dim
         | 
| 478 | 
            +
                    self.vae_scale = vae_scale
         | 
| 479 | 
            +
                    
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    # embeddings
         | 
| 482 | 
            +
                    self.patch_embedding = nn.Conv3d(
         | 
| 483 | 
            +
                        in_dim, dim, kernel_size=patch_size, stride=patch_size)
         | 
| 484 | 
            +
                    self.text_embedding = nn.Sequential(
         | 
| 485 | 
            +
                        nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
         | 
| 486 | 
            +
                        nn.Linear(dim, dim))
         | 
| 487 | 
            +
             | 
| 488 | 
            +
                    self.time_embedding = nn.Sequential(
         | 
| 489 | 
            +
                        nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
         | 
| 490 | 
            +
                    self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                    # blocks
         | 
| 493 | 
            +
                    cross_attn_type = 'i2v_cross_attn'
         | 
| 494 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 495 | 
            +
                        WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
         | 
| 496 | 
            +
                                          window_size, qk_norm, cross_attn_norm, eps, 
         | 
| 497 | 
            +
                                          output_dim=output_dim, norm_input_visual=norm_input_visual)
         | 
| 498 | 
            +
                        for _ in range(num_layers)
         | 
| 499 | 
            +
                    ])
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                    # head
         | 
| 502 | 
            +
                    self.head = Head(dim, out_dim, patch_size, eps)
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                    assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
         | 
| 505 | 
            +
                    d = dim // num_heads
         | 
| 506 | 
            +
                    self.freqs = torch.cat([
         | 
| 507 | 
            +
                        rope_params(1024, d - 4 * (d // 6)),
         | 
| 508 | 
            +
                        rope_params(1024, 2 * (d // 6)),
         | 
| 509 | 
            +
                        rope_params(1024, 2 * (d // 6))
         | 
| 510 | 
            +
                    ],
         | 
| 511 | 
            +
                                           dim=1)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    if model_type == 'i2v':
         | 
| 514 | 
            +
                        self.img_emb = MLPProj(1280, dim)
         | 
| 515 | 
            +
                    else:
         | 
| 516 | 
            +
                        raise NotImplementedError('Not supported model type.')
         | 
| 517 | 
            +
                    
         | 
| 518 | 
            +
                    # init audio adapter
         | 
| 519 | 
            +
                    self.audio_proj = AudioProjModel(
         | 
| 520 | 
            +
                                seq_len=audio_window,
         | 
| 521 | 
            +
                                seq_len_vf=audio_window+vae_scale-1,
         | 
| 522 | 
            +
                                intermediate_dim=intermediate_dim,
         | 
| 523 | 
            +
                                output_dim=output_dim,
         | 
| 524 | 
            +
                                context_tokens=context_tokens,
         | 
| 525 | 
            +
                                norm_output_audio=norm_output_audio,
         | 
| 526 | 
            +
                            )
         | 
| 527 | 
            +
             | 
| 528 | 
            +
             | 
| 529 | 
            +
                    # initialize weights
         | 
| 530 | 
            +
                    self.init_weights()
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                def teacache_init(
         | 
| 533 | 
            +
                    self,
         | 
| 534 | 
            +
                    use_ret_steps=True,
         | 
| 535 | 
            +
                    teacache_thresh=0.2,
         | 
| 536 | 
            +
                    sample_steps=40,
         | 
| 537 | 
            +
                    model_scale='multitalk-480',
         | 
| 538 | 
            +
                ):
         | 
| 539 | 
            +
                    print("teacache_init")
         | 
| 540 | 
            +
                    self.enable_teacache = True
         | 
| 541 | 
            +
                    
         | 
| 542 | 
            +
                    self.__class__.cnt = 0
         | 
| 543 | 
            +
                    self.__class__.num_steps = sample_steps*3
         | 
| 544 | 
            +
                    self.__class__.teacache_thresh = teacache_thresh
         | 
| 545 | 
            +
                    self.__class__.accumulated_rel_l1_distance_even = 0
         | 
| 546 | 
            +
                    self.__class__.accumulated_rel_l1_distance_odd = 0
         | 
| 547 | 
            +
                    self.__class__.previous_e0_even = None
         | 
| 548 | 
            +
                    self.__class__.previous_e0_odd = None
         | 
| 549 | 
            +
                    self.__class__.previous_residual_even = None
         | 
| 550 | 
            +
                    self.__class__.previous_residual_odd = None
         | 
| 551 | 
            +
                    self.__class__.use_ret_steps = use_ret_steps
         | 
| 552 | 
            +
             | 
| 553 | 
            +
                    if use_ret_steps:
         | 
| 554 | 
            +
                        if model_scale == 'multitalk-480':
         | 
| 555 | 
            +
                            self.__class__.coefficients = [ 2.57151496e+05, -3.54229917e+04,  1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
         | 
| 556 | 
            +
                        if model_scale == 'multitalk-720':
         | 
| 557 | 
            +
                            self.__class__.coefficients = [ 8.10705460e+03,  2.13393892e+03, -3.72934672e+02,  1.66203073e+01, -4.17769401e-02]
         | 
| 558 | 
            +
                        self.__class__.ret_steps = 5*3
         | 
| 559 | 
            +
                        self.__class__.cutoff_steps = sample_steps*3
         | 
| 560 | 
            +
                    else:
         | 
| 561 | 
            +
                        if model_scale == 'multitalk-480':
         | 
| 562 | 
            +
                            self.__class__.coefficients = [-3.02331670e+02,  2.23948934e+02, -5.25463970e+01,  5.87348440e+00, -2.01973289e-01]
         | 
| 563 | 
            +
                    
         | 
| 564 | 
            +
                        if model_scale == 'multitalk-720':
         | 
| 565 | 
            +
                            self.__class__.coefficients = [-114.36346466,   65.26524496,  -18.82220707,    4.91518089,   -0.23412683]
         | 
| 566 | 
            +
                        self.__class__.ret_steps = 1*3
         | 
| 567 | 
            +
                        self.__class__.cutoff_steps = sample_steps*3 - 3
         | 
| 568 | 
            +
                    print("teacache_init done")
         | 
| 569 | 
            +
                
         | 
| 570 | 
            +
                def disable_teacache(self):
         | 
| 571 | 
            +
                    self.enable_teacache = False
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                def forward(
         | 
| 574 | 
            +
                        self,
         | 
| 575 | 
            +
                        x,
         | 
| 576 | 
            +
                        t,
         | 
| 577 | 
            +
                        context,
         | 
| 578 | 
            +
                        seq_len,
         | 
| 579 | 
            +
                        clip_fea=None,
         | 
| 580 | 
            +
                        y=None,
         | 
| 581 | 
            +
                        audio=None,
         | 
| 582 | 
            +
                        ref_target_masks=None,
         | 
| 583 | 
            +
                    ):
         | 
| 584 | 
            +
                    assert clip_fea is not None and y is not None
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                    _, T, H, W = x[0].shape
         | 
| 587 | 
            +
                    N_t = T // self.patch_size[0]
         | 
| 588 | 
            +
                    N_h = H // self.patch_size[1]
         | 
| 589 | 
            +
                    N_w = W // self.patch_size[2]
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    if y is not None:
         | 
| 592 | 
            +
                        x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 593 | 
            +
                    x[0] = x[0].to(context[0].dtype)
         | 
| 594 | 
            +
             | 
| 595 | 
            +
                    # embeddings
         | 
| 596 | 
            +
                    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 597 | 
            +
                    grid_sizes = torch.stack(
         | 
| 598 | 
            +
                        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 599 | 
            +
                    x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 600 | 
            +
                    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 601 | 
            +
                    assert seq_lens.max() <= seq_len
         | 
| 602 | 
            +
                    x = torch.cat([
         | 
| 603 | 
            +
                        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
         | 
| 604 | 
            +
                                  dim=1) for u in x
         | 
| 605 | 
            +
                    ])
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    # time embeddings
         | 
| 608 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 609 | 
            +
                        e = self.time_embedding(
         | 
| 610 | 
            +
                            sinusoidal_embedding_1d(self.freq_dim, t).float())
         | 
| 611 | 
            +
                        e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 612 | 
            +
                        assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 613 | 
            +
             | 
| 614 | 
            +
                    # text embedding
         | 
| 615 | 
            +
                    context_lens = None
         | 
| 616 | 
            +
                    context = self.text_embedding(
         | 
| 617 | 
            +
                        torch.stack([
         | 
| 618 | 
            +
                            torch.cat(
         | 
| 619 | 
            +
                                [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 620 | 
            +
                            for u in context
         | 
| 621 | 
            +
                        ]))
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                    # clip embedding
         | 
| 624 | 
            +
                    if clip_fea is not None:
         | 
| 625 | 
            +
                        context_clip = self.img_emb(clip_fea) 
         | 
| 626 | 
            +
                        context = torch.concat([context_clip, context], dim=1).to(x.dtype)
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                    
         | 
| 629 | 
            +
                    audio_cond = audio.to(device=x.device, dtype=x.dtype)
         | 
| 630 | 
            +
                    first_frame_audio_emb_s = audio_cond[:, :1, ...] 
         | 
| 631 | 
            +
                    latter_frame_audio_emb = audio_cond[:, 1:, ...] 
         | 
| 632 | 
            +
                    latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale) 
         | 
| 633 | 
            +
                    middle_index = self.audio_window // 2
         | 
| 634 | 
            +
                    latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] 
         | 
| 635 | 
            +
                    latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
         | 
| 636 | 
            +
                    latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] 
         | 
| 637 | 
            +
                    latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
         | 
| 638 | 
            +
                    latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] 
         | 
| 639 | 
            +
                    latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") 
         | 
| 640 | 
            +
                    latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) 
         | 
| 641 | 
            +
                    audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) 
         | 
| 642 | 
            +
                    human_num = len(audio_embedding)
         | 
| 643 | 
            +
                    audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)
         | 
| 644 | 
            +
             | 
| 645 | 
            +
             | 
| 646 | 
            +
                    # convert ref_target_masks to token_ref_target_masks
         | 
| 647 | 
            +
                    if ref_target_masks is not None:
         | 
| 648 | 
            +
                        ref_target_masks = ref_target_masks.unsqueeze(0).to(torch.float32) 
         | 
| 649 | 
            +
                        token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest') 
         | 
| 650 | 
            +
                        token_ref_target_masks = token_ref_target_masks.squeeze(0)
         | 
| 651 | 
            +
                        token_ref_target_masks = (token_ref_target_masks > 0)
         | 
| 652 | 
            +
                        token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) 
         | 
| 653 | 
            +
                        token_ref_target_masks = token_ref_target_masks.to(x.dtype)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    # teacache
         | 
| 656 | 
            +
                    if self.enable_teacache:
         | 
| 657 | 
            +
                        modulated_inp = e0 if self.use_ret_steps else e
         | 
| 658 | 
            +
                        if self.cnt%3==0: # cond
         | 
| 659 | 
            +
                            if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
         | 
| 660 | 
            +
                                should_calc_cond = True
         | 
| 661 | 
            +
                                self.accumulated_rel_l1_distance_cond = 0
         | 
| 662 | 
            +
                            else:
         | 
| 663 | 
            +
                                rescale_func = np.poly1d(self.coefficients)
         | 
| 664 | 
            +
                                self.accumulated_rel_l1_distance_cond += rescale_func(((modulated_inp-self.previous_e0_cond).abs().mean() / self.previous_e0_cond.abs().mean()).cpu().item())
         | 
| 665 | 
            +
                                if self.accumulated_rel_l1_distance_cond < self.teacache_thresh:
         | 
| 666 | 
            +
                                    should_calc_cond = False
         | 
| 667 | 
            +
                                else:
         | 
| 668 | 
            +
                                    should_calc_cond = True
         | 
| 669 | 
            +
                                    self.accumulated_rel_l1_distance_cond = 0
         | 
| 670 | 
            +
                            self.previous_e0_cond = modulated_inp.clone()
         | 
| 671 | 
            +
                        elif self.cnt%3==1: # drop_text
         | 
| 672 | 
            +
                            if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
         | 
| 673 | 
            +
                                should_calc_drop_text = True
         | 
| 674 | 
            +
                                self.accumulated_rel_l1_distance_drop_text = 0
         | 
| 675 | 
            +
                            else:
         | 
| 676 | 
            +
                                rescale_func = np.poly1d(self.coefficients)
         | 
| 677 | 
            +
                                self.accumulated_rel_l1_distance_drop_text += rescale_func(((modulated_inp-self.previous_e0_drop_text).abs().mean() / self.previous_e0_drop_text.abs().mean()).cpu().item())
         | 
| 678 | 
            +
                                if self.accumulated_rel_l1_distance_drop_text < self.teacache_thresh:
         | 
| 679 | 
            +
                                    should_calc_drop_text = False
         | 
| 680 | 
            +
                                else:
         | 
| 681 | 
            +
                                    should_calc_drop_text = True
         | 
| 682 | 
            +
                                    self.accumulated_rel_l1_distance_drop_text = 0
         | 
| 683 | 
            +
                            self.previous_e0_drop_text = modulated_inp.clone()
         | 
| 684 | 
            +
                        else: # uncond
         | 
| 685 | 
            +
                            if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
         | 
| 686 | 
            +
                                should_calc_uncond = True
         | 
| 687 | 
            +
                                self.accumulated_rel_l1_distance_uncond = 0
         | 
| 688 | 
            +
                            else:
         | 
| 689 | 
            +
                                rescale_func = np.poly1d(self.coefficients)
         | 
| 690 | 
            +
                                self.accumulated_rel_l1_distance_uncond += rescale_func(((modulated_inp-self.previous_e0_uncond).abs().mean() / self.previous_e0_uncond.abs().mean()).cpu().item())
         | 
| 691 | 
            +
                                if self.accumulated_rel_l1_distance_uncond < self.teacache_thresh:
         | 
| 692 | 
            +
                                    should_calc_uncond = False
         | 
| 693 | 
            +
                                else:
         | 
| 694 | 
            +
                                    should_calc_uncond = True
         | 
| 695 | 
            +
                                    self.accumulated_rel_l1_distance_uncond = 0
         | 
| 696 | 
            +
                            self.previous_e0_uncond = modulated_inp.clone()
         | 
| 697 | 
            +
             | 
| 698 | 
            +
                    # arguments
         | 
| 699 | 
            +
                    kwargs = dict(
         | 
| 700 | 
            +
                        e=e0,
         | 
| 701 | 
            +
                        seq_lens=seq_lens,
         | 
| 702 | 
            +
                        grid_sizes=grid_sizes,
         | 
| 703 | 
            +
                        freqs=self.freqs,
         | 
| 704 | 
            +
                        context=context,
         | 
| 705 | 
            +
                        context_lens=context_lens,
         | 
| 706 | 
            +
                        audio_embedding=audio_embedding,
         | 
| 707 | 
            +
                        ref_target_masks=token_ref_target_masks,
         | 
| 708 | 
            +
                        human_num=human_num,
         | 
| 709 | 
            +
                        )
         | 
| 710 | 
            +
                    if self.enable_teacache:
         | 
| 711 | 
            +
                        if self.cnt%3==0:
         | 
| 712 | 
            +
                            if not should_calc_cond:
         | 
| 713 | 
            +
                                x +=  self.previous_residual_cond
         | 
| 714 | 
            +
                            else:
         | 
| 715 | 
            +
                                ori_x = x.clone()
         | 
| 716 | 
            +
                                for block in self.blocks:
         | 
| 717 | 
            +
                                    x = block(x, **kwargs)
         | 
| 718 | 
            +
                                self.previous_residual_cond = x - ori_x
         | 
| 719 | 
            +
                        elif self.cnt%3==1:
         | 
| 720 | 
            +
                            if not should_calc_drop_text:
         | 
| 721 | 
            +
                                x +=  self.previous_residual_drop_text
         | 
| 722 | 
            +
                            else:
         | 
| 723 | 
            +
                                ori_x = x.clone()
         | 
| 724 | 
            +
                                for block in self.blocks:
         | 
| 725 | 
            +
                                    x = block(x, **kwargs)
         | 
| 726 | 
            +
                                self.previous_residual_drop_text = x - ori_x
         | 
| 727 | 
            +
                        else:
         | 
| 728 | 
            +
                            if not should_calc_uncond:
         | 
| 729 | 
            +
                                x +=  self.previous_residual_uncond
         | 
| 730 | 
            +
                            else:
         | 
| 731 | 
            +
                                ori_x = x.clone()
         | 
| 732 | 
            +
                                for block in self.blocks:
         | 
| 733 | 
            +
                                    x = block(x, **kwargs)
         | 
| 734 | 
            +
                                self.previous_residual_uncond = x - ori_x
         | 
| 735 | 
            +
                    else:
         | 
| 736 | 
            +
                        for block in self.blocks:
         | 
| 737 | 
            +
                            x = block(x, **kwargs)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    # head
         | 
| 740 | 
            +
                    x = self.head(x, e)
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                    # unpatchify
         | 
| 743 | 
            +
                    x = self.unpatchify(x, grid_sizes)
         | 
| 744 | 
            +
                    if self.enable_teacache:
         | 
| 745 | 
            +
                        self.cnt += 1
         | 
| 746 | 
            +
                        if self.cnt >= self.num_steps:
         | 
| 747 | 
            +
                            self.cnt = 0
         | 
| 748 | 
            +
             | 
| 749 | 
            +
                    return torch.stack(x).float()
         | 
| 750 | 
            +
             | 
| 751 | 
            +
             | 
| 752 | 
            +
                def unpatchify(self, x, grid_sizes):
         | 
| 753 | 
            +
                    r"""
         | 
| 754 | 
            +
                    Reconstruct video tensors from patch embeddings.
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                    Args:
         | 
| 757 | 
            +
                        x (List[Tensor]):
         | 
| 758 | 
            +
                            List of patchified features, each with shape [L, C_out * prod(patch_size)]
         | 
| 759 | 
            +
                        grid_sizes (Tensor):
         | 
| 760 | 
            +
                            Original spatial-temporal grid dimensions before patching,
         | 
| 761 | 
            +
                                shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                    Returns:
         | 
| 764 | 
            +
                        List[Tensor]:
         | 
| 765 | 
            +
                            Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
         | 
| 766 | 
            +
                    """
         | 
| 767 | 
            +
             | 
| 768 | 
            +
                    c = self.out_dim
         | 
| 769 | 
            +
                    out = []
         | 
| 770 | 
            +
                    for u, v in zip(x, grid_sizes.tolist()):
         | 
| 771 | 
            +
                        u = u[:math.prod(v)].view(*v, *self.patch_size, c)
         | 
| 772 | 
            +
                        u = torch.einsum('fhwpqrc->cfphqwr', u)
         | 
| 773 | 
            +
                        u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
         | 
| 774 | 
            +
                        out.append(u)
         | 
| 775 | 
            +
                    return out
         | 
| 776 | 
            +
             | 
| 777 | 
            +
                def init_weights(self):
         | 
| 778 | 
            +
                    r"""
         | 
| 779 | 
            +
                    Initialize model parameters using Xavier initialization.
         | 
| 780 | 
            +
                    """
         | 
| 781 | 
            +
             | 
| 782 | 
            +
                    # basic init
         | 
| 783 | 
            +
                    for m in self.modules():
         | 
| 784 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 785 | 
            +
                            nn.init.xavier_uniform_(m.weight)
         | 
| 786 | 
            +
                            if m.bias is not None:
         | 
| 787 | 
            +
                                nn.init.zeros_(m.bias)
         | 
| 788 | 
            +
             | 
| 789 | 
            +
                    # init embeddings
         | 
| 790 | 
            +
                    nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
         | 
| 791 | 
            +
                    for m in self.text_embedding.modules():
         | 
| 792 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 793 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 794 | 
            +
                    for m in self.time_embedding.modules():
         | 
| 795 | 
            +
                        if isinstance(m, nn.Linear):
         | 
| 796 | 
            +
                            nn.init.normal_(m.weight, std=.02)
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                    # init output layer
         | 
| 799 | 
            +
                    nn.init.zeros_(self.head.head.weight)
         | 
    	
        wan/modules/t5.py
    ADDED
    
    | @@ -0,0 +1,513 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from transformers.models.t5.modeling_t5
         | 
| 2 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 3 | 
            +
            import logging
         | 
| 4 | 
            +
            import math
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torch.nn as nn
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from .tokenizers import HuggingfaceTokenizer
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            __all__ = [
         | 
| 13 | 
            +
                'T5Model',
         | 
| 14 | 
            +
                'T5Encoder',
         | 
| 15 | 
            +
                'T5Decoder',
         | 
| 16 | 
            +
                'T5EncoderModel',
         | 
| 17 | 
            +
            ]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def fp16_clamp(x):
         | 
| 21 | 
            +
                if x.dtype == torch.float16 and torch.isinf(x).any():
         | 
| 22 | 
            +
                    clamp = torch.finfo(x.dtype).max - 1000
         | 
| 23 | 
            +
                    x = torch.clamp(x, min=-clamp, max=clamp)
         | 
| 24 | 
            +
                return x
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            def init_weights(m):
         | 
| 28 | 
            +
                if isinstance(m, T5LayerNorm):
         | 
| 29 | 
            +
                    nn.init.ones_(m.weight)
         | 
| 30 | 
            +
                elif isinstance(m, T5Model):
         | 
| 31 | 
            +
                    nn.init.normal_(m.token_embedding.weight, std=1.0)
         | 
| 32 | 
            +
                elif isinstance(m, T5FeedForward):
         | 
| 33 | 
            +
                    nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
         | 
| 34 | 
            +
                    nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
         | 
| 35 | 
            +
                    nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
         | 
| 36 | 
            +
                elif isinstance(m, T5Attention):
         | 
| 37 | 
            +
                    nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
         | 
| 38 | 
            +
                    nn.init.normal_(m.k.weight, std=m.dim**-0.5)
         | 
| 39 | 
            +
                    nn.init.normal_(m.v.weight, std=m.dim**-0.5)
         | 
| 40 | 
            +
                    nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
         | 
| 41 | 
            +
                elif isinstance(m, T5RelativeEmbedding):
         | 
| 42 | 
            +
                    nn.init.normal_(
         | 
| 43 | 
            +
                        m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
             | 
| 46 | 
            +
            class GELU(nn.Module):
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def forward(self, x):
         | 
| 49 | 
            +
                    return 0.5 * x * (1.0 + torch.tanh(
         | 
| 50 | 
            +
                        math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            class T5LayerNorm(nn.Module):
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __init__(self, dim, eps=1e-6):
         | 
| 56 | 
            +
                    super(T5LayerNorm, self).__init__()
         | 
| 57 | 
            +
                    self.dim = dim
         | 
| 58 | 
            +
                    self.eps = eps
         | 
| 59 | 
            +
                    self.weight = nn.Parameter(torch.ones(dim))
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def forward(self, x):
         | 
| 62 | 
            +
                    x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
         | 
| 63 | 
            +
                                        self.eps)
         | 
| 64 | 
            +
                    if self.weight.dtype in [torch.float16, torch.bfloat16]:
         | 
| 65 | 
            +
                        x = x.type_as(self.weight)
         | 
| 66 | 
            +
                    return self.weight * x
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class T5Attention(nn.Module):
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
         | 
| 72 | 
            +
                    assert dim_attn % num_heads == 0
         | 
| 73 | 
            +
                    super(T5Attention, self).__init__()
         | 
| 74 | 
            +
                    self.dim = dim
         | 
| 75 | 
            +
                    self.dim_attn = dim_attn
         | 
| 76 | 
            +
                    self.num_heads = num_heads
         | 
| 77 | 
            +
                    self.head_dim = dim_attn // num_heads
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    # layers
         | 
| 80 | 
            +
                    self.q = nn.Linear(dim, dim_attn, bias=False)
         | 
| 81 | 
            +
                    self.k = nn.Linear(dim, dim_attn, bias=False)
         | 
| 82 | 
            +
                    self.v = nn.Linear(dim, dim_attn, bias=False)
         | 
| 83 | 
            +
                    self.o = nn.Linear(dim_attn, dim, bias=False)
         | 
| 84 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def forward(self, x, context=None, mask=None, pos_bias=None):
         | 
| 87 | 
            +
                    """
         | 
| 88 | 
            +
                    x:          [B, L1, C].
         | 
| 89 | 
            +
                    context:    [B, L2, C] or None.
         | 
| 90 | 
            +
                    mask:       [B, L2] or [B, L1, L2] or None.
         | 
| 91 | 
            +
                    """
         | 
| 92 | 
            +
                    # check inputs
         | 
| 93 | 
            +
                    context = x if context is None else context
         | 
| 94 | 
            +
                    b, n, c = x.size(0), self.num_heads, self.head_dim
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    # compute query, key, value
         | 
| 97 | 
            +
                    q = self.q(x).view(b, -1, n, c)
         | 
| 98 | 
            +
                    k = self.k(context).view(b, -1, n, c)
         | 
| 99 | 
            +
                    v = self.v(context).view(b, -1, n, c)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    # attention bias
         | 
| 102 | 
            +
                    attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
         | 
| 103 | 
            +
                    if pos_bias is not None:
         | 
| 104 | 
            +
                        attn_bias += pos_bias
         | 
| 105 | 
            +
                    if mask is not None:
         | 
| 106 | 
            +
                        assert mask.ndim in [2, 3]
         | 
| 107 | 
            +
                        mask = mask.view(b, 1, 1,
         | 
| 108 | 
            +
                                         -1) if mask.ndim == 2 else mask.unsqueeze(1)
         | 
| 109 | 
            +
                        attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    # compute attention (T5 does not use scaling)
         | 
| 112 | 
            +
                    attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
         | 
| 113 | 
            +
                    attn = F.softmax(attn.float(), dim=-1).type_as(attn)
         | 
| 114 | 
            +
                    x = torch.einsum('bnij,bjnc->binc', attn, v)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    # output
         | 
| 117 | 
            +
                    x = x.reshape(b, -1, n * c)
         | 
| 118 | 
            +
                    x = self.o(x)
         | 
| 119 | 
            +
                    x = self.dropout(x)
         | 
| 120 | 
            +
                    return x
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            class T5FeedForward(nn.Module):
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def __init__(self, dim, dim_ffn, dropout=0.1):
         | 
| 126 | 
            +
                    super(T5FeedForward, self).__init__()
         | 
| 127 | 
            +
                    self.dim = dim
         | 
| 128 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # layers
         | 
| 131 | 
            +
                    self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
         | 
| 132 | 
            +
                    self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
         | 
| 133 | 
            +
                    self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
         | 
| 134 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def forward(self, x):
         | 
| 137 | 
            +
                    x = self.fc1(x) * self.gate(x)
         | 
| 138 | 
            +
                    x = self.dropout(x)
         | 
| 139 | 
            +
                    x = self.fc2(x)
         | 
| 140 | 
            +
                    x = self.dropout(x)
         | 
| 141 | 
            +
                    return x
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            class T5SelfAttention(nn.Module):
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def __init__(self,
         | 
| 147 | 
            +
                             dim,
         | 
| 148 | 
            +
                             dim_attn,
         | 
| 149 | 
            +
                             dim_ffn,
         | 
| 150 | 
            +
                             num_heads,
         | 
| 151 | 
            +
                             num_buckets,
         | 
| 152 | 
            +
                             shared_pos=True,
         | 
| 153 | 
            +
                             dropout=0.1):
         | 
| 154 | 
            +
                    super(T5SelfAttention, self).__init__()
         | 
| 155 | 
            +
                    self.dim = dim
         | 
| 156 | 
            +
                    self.dim_attn = dim_attn
         | 
| 157 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 158 | 
            +
                    self.num_heads = num_heads
         | 
| 159 | 
            +
                    self.num_buckets = num_buckets
         | 
| 160 | 
            +
                    self.shared_pos = shared_pos
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    # layers
         | 
| 163 | 
            +
                    self.norm1 = T5LayerNorm(dim)
         | 
| 164 | 
            +
                    self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 165 | 
            +
                    self.norm2 = T5LayerNorm(dim)
         | 
| 166 | 
            +
                    self.ffn = T5FeedForward(dim, dim_ffn, dropout)
         | 
| 167 | 
            +
                    self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
         | 
| 168 | 
            +
                        num_buckets, num_heads, bidirectional=True)
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                def forward(self, x, mask=None, pos_bias=None):
         | 
| 171 | 
            +
                    e = pos_bias if self.shared_pos else self.pos_embedding(
         | 
| 172 | 
            +
                        x.size(1), x.size(1))
         | 
| 173 | 
            +
                    x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
         | 
| 174 | 
            +
                    x = fp16_clamp(x + self.ffn(self.norm2(x)))
         | 
| 175 | 
            +
                    return x
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class T5CrossAttention(nn.Module):
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def __init__(self,
         | 
| 181 | 
            +
                             dim,
         | 
| 182 | 
            +
                             dim_attn,
         | 
| 183 | 
            +
                             dim_ffn,
         | 
| 184 | 
            +
                             num_heads,
         | 
| 185 | 
            +
                             num_buckets,
         | 
| 186 | 
            +
                             shared_pos=True,
         | 
| 187 | 
            +
                             dropout=0.1):
         | 
| 188 | 
            +
                    super(T5CrossAttention, self).__init__()
         | 
| 189 | 
            +
                    self.dim = dim
         | 
| 190 | 
            +
                    self.dim_attn = dim_attn
         | 
| 191 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 192 | 
            +
                    self.num_heads = num_heads
         | 
| 193 | 
            +
                    self.num_buckets = num_buckets
         | 
| 194 | 
            +
                    self.shared_pos = shared_pos
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    # layers
         | 
| 197 | 
            +
                    self.norm1 = T5LayerNorm(dim)
         | 
| 198 | 
            +
                    self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 199 | 
            +
                    self.norm2 = T5LayerNorm(dim)
         | 
| 200 | 
            +
                    self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
         | 
| 201 | 
            +
                    self.norm3 = T5LayerNorm(dim)
         | 
| 202 | 
            +
                    self.ffn = T5FeedForward(dim, dim_ffn, dropout)
         | 
| 203 | 
            +
                    self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
         | 
| 204 | 
            +
                        num_buckets, num_heads, bidirectional=False)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                def forward(self,
         | 
| 207 | 
            +
                            x,
         | 
| 208 | 
            +
                            mask=None,
         | 
| 209 | 
            +
                            encoder_states=None,
         | 
| 210 | 
            +
                            encoder_mask=None,
         | 
| 211 | 
            +
                            pos_bias=None):
         | 
| 212 | 
            +
                    e = pos_bias if self.shared_pos else self.pos_embedding(
         | 
| 213 | 
            +
                        x.size(1), x.size(1))
         | 
| 214 | 
            +
                    x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
         | 
| 215 | 
            +
                    x = fp16_clamp(x + self.cross_attn(
         | 
| 216 | 
            +
                        self.norm2(x), context=encoder_states, mask=encoder_mask))
         | 
| 217 | 
            +
                    x = fp16_clamp(x + self.ffn(self.norm3(x)))
         | 
| 218 | 
            +
                    return x
         | 
| 219 | 
            +
             | 
| 220 | 
            +
             | 
| 221 | 
            +
            class T5RelativeEmbedding(nn.Module):
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
         | 
| 224 | 
            +
                    super(T5RelativeEmbedding, self).__init__()
         | 
| 225 | 
            +
                    self.num_buckets = num_buckets
         | 
| 226 | 
            +
                    self.num_heads = num_heads
         | 
| 227 | 
            +
                    self.bidirectional = bidirectional
         | 
| 228 | 
            +
                    self.max_dist = max_dist
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # layers
         | 
| 231 | 
            +
                    self.embedding = nn.Embedding(num_buckets, num_heads)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                def forward(self, lq, lk):
         | 
| 234 | 
            +
                    device = self.embedding.weight.device
         | 
| 235 | 
            +
                    # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
         | 
| 236 | 
            +
                    #     torch.arange(lq).unsqueeze(1).to(device)
         | 
| 237 | 
            +
                    rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
         | 
| 238 | 
            +
                        torch.arange(lq, device=device).unsqueeze(1)
         | 
| 239 | 
            +
                    rel_pos = self._relative_position_bucket(rel_pos)
         | 
| 240 | 
            +
                    rel_pos_embeds = self.embedding(rel_pos)
         | 
| 241 | 
            +
                    rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
         | 
| 242 | 
            +
                        0)  # [1, N, Lq, Lk]
         | 
| 243 | 
            +
                    return rel_pos_embeds.contiguous()
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def _relative_position_bucket(self, rel_pos):
         | 
| 246 | 
            +
                    # preprocess
         | 
| 247 | 
            +
                    if self.bidirectional:
         | 
| 248 | 
            +
                        num_buckets = self.num_buckets // 2
         | 
| 249 | 
            +
                        rel_buckets = (rel_pos > 0).long() * num_buckets
         | 
| 250 | 
            +
                        rel_pos = torch.abs(rel_pos)
         | 
| 251 | 
            +
                    else:
         | 
| 252 | 
            +
                        num_buckets = self.num_buckets
         | 
| 253 | 
            +
                        rel_buckets = 0
         | 
| 254 | 
            +
                        rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    # embeddings for small and large positions
         | 
| 257 | 
            +
                    max_exact = num_buckets // 2
         | 
| 258 | 
            +
                    rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
         | 
| 259 | 
            +
                                                 math.log(self.max_dist / max_exact) *
         | 
| 260 | 
            +
                                                 (num_buckets - max_exact)).long()
         | 
| 261 | 
            +
                    rel_pos_large = torch.min(
         | 
| 262 | 
            +
                        rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
         | 
| 263 | 
            +
                    rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
         | 
| 264 | 
            +
                    return rel_buckets
         | 
| 265 | 
            +
             | 
| 266 | 
            +
             | 
| 267 | 
            +
            class T5Encoder(nn.Module):
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def __init__(self,
         | 
| 270 | 
            +
                             vocab,
         | 
| 271 | 
            +
                             dim,
         | 
| 272 | 
            +
                             dim_attn,
         | 
| 273 | 
            +
                             dim_ffn,
         | 
| 274 | 
            +
                             num_heads,
         | 
| 275 | 
            +
                             num_layers,
         | 
| 276 | 
            +
                             num_buckets,
         | 
| 277 | 
            +
                             shared_pos=True,
         | 
| 278 | 
            +
                             dropout=0.1):
         | 
| 279 | 
            +
                    super(T5Encoder, self).__init__()
         | 
| 280 | 
            +
                    self.dim = dim
         | 
| 281 | 
            +
                    self.dim_attn = dim_attn
         | 
| 282 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 283 | 
            +
                    self.num_heads = num_heads
         | 
| 284 | 
            +
                    self.num_layers = num_layers
         | 
| 285 | 
            +
                    self.num_buckets = num_buckets
         | 
| 286 | 
            +
                    self.shared_pos = shared_pos
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    # layers
         | 
| 289 | 
            +
                    self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
         | 
| 290 | 
            +
                        else nn.Embedding(vocab, dim)
         | 
| 291 | 
            +
                    self.pos_embedding = T5RelativeEmbedding(
         | 
| 292 | 
            +
                        num_buckets, num_heads, bidirectional=True) if shared_pos else None
         | 
| 293 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 294 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 295 | 
            +
                        T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
         | 
| 296 | 
            +
                                        shared_pos, dropout) for _ in range(num_layers)
         | 
| 297 | 
            +
                    ])
         | 
| 298 | 
            +
                    self.norm = T5LayerNorm(dim)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    # initialize weights
         | 
| 301 | 
            +
                    self.apply(init_weights)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                def forward(self, ids, mask=None):
         | 
| 304 | 
            +
                    x = self.token_embedding(ids)
         | 
| 305 | 
            +
                    x = self.dropout(x)
         | 
| 306 | 
            +
                    e = self.pos_embedding(x.size(1),
         | 
| 307 | 
            +
                                           x.size(1)) if self.shared_pos else None
         | 
| 308 | 
            +
                    for block in self.blocks:
         | 
| 309 | 
            +
                        x = block(x, mask, pos_bias=e)
         | 
| 310 | 
            +
                    x = self.norm(x)
         | 
| 311 | 
            +
                    x = self.dropout(x)
         | 
| 312 | 
            +
                    return x
         | 
| 313 | 
            +
             | 
| 314 | 
            +
             | 
| 315 | 
            +
            class T5Decoder(nn.Module):
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def __init__(self,
         | 
| 318 | 
            +
                             vocab,
         | 
| 319 | 
            +
                             dim,
         | 
| 320 | 
            +
                             dim_attn,
         | 
| 321 | 
            +
                             dim_ffn,
         | 
| 322 | 
            +
                             num_heads,
         | 
| 323 | 
            +
                             num_layers,
         | 
| 324 | 
            +
                             num_buckets,
         | 
| 325 | 
            +
                             shared_pos=True,
         | 
| 326 | 
            +
                             dropout=0.1):
         | 
| 327 | 
            +
                    super(T5Decoder, self).__init__()
         | 
| 328 | 
            +
                    self.dim = dim
         | 
| 329 | 
            +
                    self.dim_attn = dim_attn
         | 
| 330 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 331 | 
            +
                    self.num_heads = num_heads
         | 
| 332 | 
            +
                    self.num_layers = num_layers
         | 
| 333 | 
            +
                    self.num_buckets = num_buckets
         | 
| 334 | 
            +
                    self.shared_pos = shared_pos
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    # layers
         | 
| 337 | 
            +
                    self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
         | 
| 338 | 
            +
                        else nn.Embedding(vocab, dim)
         | 
| 339 | 
            +
                    self.pos_embedding = T5RelativeEmbedding(
         | 
| 340 | 
            +
                        num_buckets, num_heads, bidirectional=False) if shared_pos else None
         | 
| 341 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 342 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 343 | 
            +
                        T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
         | 
| 344 | 
            +
                                         shared_pos, dropout) for _ in range(num_layers)
         | 
| 345 | 
            +
                    ])
         | 
| 346 | 
            +
                    self.norm = T5LayerNorm(dim)
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    # initialize weights
         | 
| 349 | 
            +
                    self.apply(init_weights)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
         | 
| 352 | 
            +
                    b, s = ids.size()
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    # causal mask
         | 
| 355 | 
            +
                    if mask is None:
         | 
| 356 | 
            +
                        mask = torch.tril(torch.ones(1, s, s).to(ids.device))
         | 
| 357 | 
            +
                    elif mask.ndim == 2:
         | 
| 358 | 
            +
                        mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    # layers
         | 
| 361 | 
            +
                    x = self.token_embedding(ids)
         | 
| 362 | 
            +
                    x = self.dropout(x)
         | 
| 363 | 
            +
                    e = self.pos_embedding(x.size(1),
         | 
| 364 | 
            +
                                           x.size(1)) if self.shared_pos else None
         | 
| 365 | 
            +
                    for block in self.blocks:
         | 
| 366 | 
            +
                        x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
         | 
| 367 | 
            +
                    x = self.norm(x)
         | 
| 368 | 
            +
                    x = self.dropout(x)
         | 
| 369 | 
            +
                    return x
         | 
| 370 | 
            +
             | 
| 371 | 
            +
             | 
| 372 | 
            +
            class T5Model(nn.Module):
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                def __init__(self,
         | 
| 375 | 
            +
                             vocab_size,
         | 
| 376 | 
            +
                             dim,
         | 
| 377 | 
            +
                             dim_attn,
         | 
| 378 | 
            +
                             dim_ffn,
         | 
| 379 | 
            +
                             num_heads,
         | 
| 380 | 
            +
                             encoder_layers,
         | 
| 381 | 
            +
                             decoder_layers,
         | 
| 382 | 
            +
                             num_buckets,
         | 
| 383 | 
            +
                             shared_pos=True,
         | 
| 384 | 
            +
                             dropout=0.1):
         | 
| 385 | 
            +
                    super(T5Model, self).__init__()
         | 
| 386 | 
            +
                    self.vocab_size = vocab_size
         | 
| 387 | 
            +
                    self.dim = dim
         | 
| 388 | 
            +
                    self.dim_attn = dim_attn
         | 
| 389 | 
            +
                    self.dim_ffn = dim_ffn
         | 
| 390 | 
            +
                    self.num_heads = num_heads
         | 
| 391 | 
            +
                    self.encoder_layers = encoder_layers
         | 
| 392 | 
            +
                    self.decoder_layers = decoder_layers
         | 
| 393 | 
            +
                    self.num_buckets = num_buckets
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    # layers
         | 
| 396 | 
            +
                    self.token_embedding = nn.Embedding(vocab_size, dim)
         | 
| 397 | 
            +
                    self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
         | 
| 398 | 
            +
                                             num_heads, encoder_layers, num_buckets,
         | 
| 399 | 
            +
                                             shared_pos, dropout)
         | 
| 400 | 
            +
                    self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
         | 
| 401 | 
            +
                                             num_heads, decoder_layers, num_buckets,
         | 
| 402 | 
            +
                                             shared_pos, dropout)
         | 
| 403 | 
            +
                    self.head = nn.Linear(dim, vocab_size, bias=False)
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    # initialize weights
         | 
| 406 | 
            +
                    self.apply(init_weights)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
         | 
| 409 | 
            +
                    x = self.encoder(encoder_ids, encoder_mask)
         | 
| 410 | 
            +
                    x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
         | 
| 411 | 
            +
                    x = self.head(x)
         | 
| 412 | 
            +
                    return x
         | 
| 413 | 
            +
             | 
| 414 | 
            +
             | 
| 415 | 
            +
            def _t5(name,
         | 
| 416 | 
            +
                    encoder_only=False,
         | 
| 417 | 
            +
                    decoder_only=False,
         | 
| 418 | 
            +
                    return_tokenizer=False,
         | 
| 419 | 
            +
                    tokenizer_kwargs={},
         | 
| 420 | 
            +
                    dtype=torch.float32,
         | 
| 421 | 
            +
                    device='cpu',
         | 
| 422 | 
            +
                    **kwargs):
         | 
| 423 | 
            +
                # sanity check
         | 
| 424 | 
            +
                assert not (encoder_only and decoder_only)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                # params
         | 
| 427 | 
            +
                if encoder_only:
         | 
| 428 | 
            +
                    model_cls = T5Encoder
         | 
| 429 | 
            +
                    kwargs['vocab'] = kwargs.pop('vocab_size')
         | 
| 430 | 
            +
                    kwargs['num_layers'] = kwargs.pop('encoder_layers')
         | 
| 431 | 
            +
                    _ = kwargs.pop('decoder_layers')
         | 
| 432 | 
            +
                elif decoder_only:
         | 
| 433 | 
            +
                    model_cls = T5Decoder
         | 
| 434 | 
            +
                    kwargs['vocab'] = kwargs.pop('vocab_size')
         | 
| 435 | 
            +
                    kwargs['num_layers'] = kwargs.pop('decoder_layers')
         | 
| 436 | 
            +
                    _ = kwargs.pop('encoder_layers')
         | 
| 437 | 
            +
                else:
         | 
| 438 | 
            +
                    model_cls = T5Model
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                # init model
         | 
| 441 | 
            +
                with torch.device(device):
         | 
| 442 | 
            +
                    model = model_cls(**kwargs)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                # set device
         | 
| 445 | 
            +
                model = model.to(dtype=dtype, device=device)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
                # init tokenizer
         | 
| 448 | 
            +
                if return_tokenizer:
         | 
| 449 | 
            +
                    from .tokenizers import HuggingfaceTokenizer
         | 
| 450 | 
            +
                    tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
         | 
| 451 | 
            +
                    return model, tokenizer
         | 
| 452 | 
            +
                else:
         | 
| 453 | 
            +
                    return model
         | 
| 454 | 
            +
             | 
| 455 | 
            +
             | 
| 456 | 
            +
            def umt5_xxl(**kwargs):
         | 
| 457 | 
            +
                cfg = dict(
         | 
| 458 | 
            +
                    vocab_size=256384,
         | 
| 459 | 
            +
                    dim=4096,
         | 
| 460 | 
            +
                    dim_attn=4096,
         | 
| 461 | 
            +
                    dim_ffn=10240,
         | 
| 462 | 
            +
                    num_heads=64,
         | 
| 463 | 
            +
                    encoder_layers=24,
         | 
| 464 | 
            +
                    decoder_layers=24,
         | 
| 465 | 
            +
                    num_buckets=32,
         | 
| 466 | 
            +
                    shared_pos=False,
         | 
| 467 | 
            +
                    dropout=0.1)
         | 
| 468 | 
            +
                cfg.update(**kwargs)
         | 
| 469 | 
            +
                return _t5('umt5-xxl', **cfg)
         | 
| 470 | 
            +
             | 
| 471 | 
            +
             | 
| 472 | 
            +
            class T5EncoderModel:
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def __init__(
         | 
| 475 | 
            +
                    self,
         | 
| 476 | 
            +
                    text_len,
         | 
| 477 | 
            +
                    dtype=torch.bfloat16,
         | 
| 478 | 
            +
                    device=torch.cuda.current_device(),
         | 
| 479 | 
            +
                    checkpoint_path=None,
         | 
| 480 | 
            +
                    tokenizer_path=None,
         | 
| 481 | 
            +
                    shard_fn=None,
         | 
| 482 | 
            +
                ):
         | 
| 483 | 
            +
                    self.text_len = text_len
         | 
| 484 | 
            +
                    self.dtype = dtype
         | 
| 485 | 
            +
                    self.device = device
         | 
| 486 | 
            +
                    self.checkpoint_path = checkpoint_path
         | 
| 487 | 
            +
                    self.tokenizer_path = tokenizer_path
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    # init model
         | 
| 490 | 
            +
                    model = umt5_xxl(
         | 
| 491 | 
            +
                        encoder_only=True,
         | 
| 492 | 
            +
                        return_tokenizer=False,
         | 
| 493 | 
            +
                        dtype=dtype,
         | 
| 494 | 
            +
                        device=device).eval().requires_grad_(False)
         | 
| 495 | 
            +
                    logging.info(f'loading {checkpoint_path}')
         | 
| 496 | 
            +
                    model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
         | 
| 497 | 
            +
                    self.model = model
         | 
| 498 | 
            +
                    if shard_fn is not None:
         | 
| 499 | 
            +
                        self.model = shard_fn(self.model, sync_module_states=False)
         | 
| 500 | 
            +
                    else:
         | 
| 501 | 
            +
                        self.model.to(self.device)
         | 
| 502 | 
            +
                    # init tokenizer
         | 
| 503 | 
            +
                    self.tokenizer = HuggingfaceTokenizer(
         | 
| 504 | 
            +
                        name=tokenizer_path, seq_len=text_len, clean='whitespace')
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                def __call__(self, texts, device):
         | 
| 507 | 
            +
                    ids, mask = self.tokenizer(
         | 
| 508 | 
            +
                        texts, return_mask=True, add_special_tokens=True)
         | 
| 509 | 
            +
                    ids = ids.to(device)
         | 
| 510 | 
            +
                    mask = mask.to(device)
         | 
| 511 | 
            +
                    seq_lens = mask.gt(0).sum(dim=1).long()
         | 
| 512 | 
            +
                    context = self.model(ids, mask)
         | 
| 513 | 
            +
                    return [u[:v] for u, v in zip(context, seq_lens)]
         | 
    	
        wan/modules/tokenizers.py
    ADDED
    
    | @@ -0,0 +1,82 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import html
         | 
| 3 | 
            +
            import string
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import ftfy
         | 
| 6 | 
            +
            import regex as re
         | 
| 7 | 
            +
            from transformers import AutoTokenizer
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            __all__ = ['HuggingfaceTokenizer']
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            def basic_clean(text):
         | 
| 13 | 
            +
                text = ftfy.fix_text(text)
         | 
| 14 | 
            +
                text = html.unescape(html.unescape(text))
         | 
| 15 | 
            +
                return text.strip()
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def whitespace_clean(text):
         | 
| 19 | 
            +
                text = re.sub(r'\s+', ' ', text)
         | 
| 20 | 
            +
                text = text.strip()
         | 
| 21 | 
            +
                return text
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            def canonicalize(text, keep_punctuation_exact_string=None):
         | 
| 25 | 
            +
                text = text.replace('_', ' ')
         | 
| 26 | 
            +
                if keep_punctuation_exact_string:
         | 
| 27 | 
            +
                    text = keep_punctuation_exact_string.join(
         | 
| 28 | 
            +
                        part.translate(str.maketrans('', '', string.punctuation))
         | 
| 29 | 
            +
                        for part in text.split(keep_punctuation_exact_string))
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    text = text.translate(str.maketrans('', '', string.punctuation))
         | 
| 32 | 
            +
                text = text.lower()
         | 
| 33 | 
            +
                text = re.sub(r'\s+', ' ', text)
         | 
| 34 | 
            +
                return text.strip()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
             | 
| 37 | 
            +
            class HuggingfaceTokenizer:
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                def __init__(self, name, seq_len=None, clean=None, **kwargs):
         | 
| 40 | 
            +
                    assert clean in (None, 'whitespace', 'lower', 'canonicalize')
         | 
| 41 | 
            +
                    self.name = name
         | 
| 42 | 
            +
                    self.seq_len = seq_len
         | 
| 43 | 
            +
                    self.clean = clean
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    # init tokenizer
         | 
| 46 | 
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
         | 
| 47 | 
            +
                    self.vocab_size = self.tokenizer.vocab_size
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __call__(self, sequence, **kwargs):
         | 
| 50 | 
            +
                    return_mask = kwargs.pop('return_mask', False)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    # arguments
         | 
| 53 | 
            +
                    _kwargs = {'return_tensors': 'pt'}
         | 
| 54 | 
            +
                    if self.seq_len is not None:
         | 
| 55 | 
            +
                        _kwargs.update({
         | 
| 56 | 
            +
                            'padding': 'max_length',
         | 
| 57 | 
            +
                            'truncation': True,
         | 
| 58 | 
            +
                            'max_length': self.seq_len
         | 
| 59 | 
            +
                        })
         | 
| 60 | 
            +
                    _kwargs.update(**kwargs)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    # tokenization
         | 
| 63 | 
            +
                    if isinstance(sequence, str):
         | 
| 64 | 
            +
                        sequence = [sequence]
         | 
| 65 | 
            +
                    if self.clean:
         | 
| 66 | 
            +
                        sequence = [self._clean(u) for u in sequence]
         | 
| 67 | 
            +
                    ids = self.tokenizer(sequence, **_kwargs)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # output
         | 
| 70 | 
            +
                    if return_mask:
         | 
| 71 | 
            +
                        return ids.input_ids, ids.attention_mask
         | 
| 72 | 
            +
                    else:
         | 
| 73 | 
            +
                        return ids.input_ids
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def _clean(self, text):
         | 
| 76 | 
            +
                    if self.clean == 'whitespace':
         | 
| 77 | 
            +
                        text = whitespace_clean(basic_clean(text))
         | 
| 78 | 
            +
                    elif self.clean == 'lower':
         | 
| 79 | 
            +
                        text = whitespace_clean(basic_clean(text)).lower()
         | 
| 80 | 
            +
                    elif self.clean == 'canonicalize':
         | 
| 81 | 
            +
                        text = canonicalize(basic_clean(text))
         | 
| 82 | 
            +
                    return text
         | 
    	
        wan/modules/vace_model.py
    ADDED
    
    | @@ -0,0 +1,250 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import torch.cuda.amp as amp
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            from diffusers.configuration_utils import register_to_config
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from .model import WanAttentionBlock, WanModel, sinusoidal_embedding_1d
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class VaceWanAttentionBlock(WanAttentionBlock):
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __init__(self,
         | 
| 13 | 
            +
                             cross_attn_type,
         | 
| 14 | 
            +
                             dim,
         | 
| 15 | 
            +
                             ffn_dim,
         | 
| 16 | 
            +
                             num_heads,
         | 
| 17 | 
            +
                             window_size=(-1, -1),
         | 
| 18 | 
            +
                             qk_norm=True,
         | 
| 19 | 
            +
                             cross_attn_norm=False,
         | 
| 20 | 
            +
                             eps=1e-6,
         | 
| 21 | 
            +
                             block_id=0):
         | 
| 22 | 
            +
                    super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
         | 
| 23 | 
            +
                                     qk_norm, cross_attn_norm, eps)
         | 
| 24 | 
            +
                    self.block_id = block_id
         | 
| 25 | 
            +
                    if block_id == 0:
         | 
| 26 | 
            +
                        self.before_proj = nn.Linear(self.dim, self.dim)
         | 
| 27 | 
            +
                        nn.init.zeros_(self.before_proj.weight)
         | 
| 28 | 
            +
                        nn.init.zeros_(self.before_proj.bias)
         | 
| 29 | 
            +
                    self.after_proj = nn.Linear(self.dim, self.dim)
         | 
| 30 | 
            +
                    nn.init.zeros_(self.after_proj.weight)
         | 
| 31 | 
            +
                    nn.init.zeros_(self.after_proj.bias)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def forward(self, c, x, **kwargs):
         | 
| 34 | 
            +
                    if self.block_id == 0:
         | 
| 35 | 
            +
                        c = self.before_proj(c) + x
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    c = super().forward(c, **kwargs)
         | 
| 38 | 
            +
                    c_skip = self.after_proj(c)
         | 
| 39 | 
            +
                    return c, c_skip
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            class BaseWanAttentionBlock(WanAttentionBlock):
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def __init__(self,
         | 
| 45 | 
            +
                             cross_attn_type,
         | 
| 46 | 
            +
                             dim,
         | 
| 47 | 
            +
                             ffn_dim,
         | 
| 48 | 
            +
                             num_heads,
         | 
| 49 | 
            +
                             window_size=(-1, -1),
         | 
| 50 | 
            +
                             qk_norm=True,
         | 
| 51 | 
            +
                             cross_attn_norm=False,
         | 
| 52 | 
            +
                             eps=1e-6,
         | 
| 53 | 
            +
                             block_id=None):
         | 
| 54 | 
            +
                    super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size,
         | 
| 55 | 
            +
                                     qk_norm, cross_attn_norm, eps)
         | 
| 56 | 
            +
                    self.block_id = block_id
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def forward(self, x, hints, context_scale=1.0, **kwargs):
         | 
| 59 | 
            +
                    x = super().forward(x, **kwargs)
         | 
| 60 | 
            +
                    if self.block_id is not None:
         | 
| 61 | 
            +
                        x = x + hints[self.block_id] * context_scale
         | 
| 62 | 
            +
                    return x
         | 
| 63 | 
            +
             | 
| 64 | 
            +
             | 
| 65 | 
            +
            class VaceWanModel(WanModel):
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                @register_to_config
         | 
| 68 | 
            +
                def __init__(self,
         | 
| 69 | 
            +
                             vace_layers=None,
         | 
| 70 | 
            +
                             vace_in_dim=None,
         | 
| 71 | 
            +
                             model_type='vace',
         | 
| 72 | 
            +
                             patch_size=(1, 2, 2),
         | 
| 73 | 
            +
                             text_len=512,
         | 
| 74 | 
            +
                             in_dim=16,
         | 
| 75 | 
            +
                             dim=2048,
         | 
| 76 | 
            +
                             ffn_dim=8192,
         | 
| 77 | 
            +
                             freq_dim=256,
         | 
| 78 | 
            +
                             text_dim=4096,
         | 
| 79 | 
            +
                             out_dim=16,
         | 
| 80 | 
            +
                             num_heads=16,
         | 
| 81 | 
            +
                             num_layers=32,
         | 
| 82 | 
            +
                             window_size=(-1, -1),
         | 
| 83 | 
            +
                             qk_norm=True,
         | 
| 84 | 
            +
                             cross_attn_norm=True,
         | 
| 85 | 
            +
                             eps=1e-6):
         | 
| 86 | 
            +
                    super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim,
         | 
| 87 | 
            +
                                     freq_dim, text_dim, out_dim, num_heads, num_layers,
         | 
| 88 | 
            +
                                     window_size, qk_norm, cross_attn_norm, eps)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    self.vace_layers = [i for i in range(0, self.num_layers, 2)
         | 
| 91 | 
            +
                                       ] if vace_layers is None else vace_layers
         | 
| 92 | 
            +
                    self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    assert 0 in self.vace_layers
         | 
| 95 | 
            +
                    self.vace_layers_mapping = {
         | 
| 96 | 
            +
                        i: n for n, i in enumerate(self.vace_layers)
         | 
| 97 | 
            +
                    }
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    # blocks
         | 
| 100 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 101 | 
            +
                        BaseWanAttentionBlock(
         | 
| 102 | 
            +
                            't2v_cross_attn',
         | 
| 103 | 
            +
                            self.dim,
         | 
| 104 | 
            +
                            self.ffn_dim,
         | 
| 105 | 
            +
                            self.num_heads,
         | 
| 106 | 
            +
                            self.window_size,
         | 
| 107 | 
            +
                            self.qk_norm,
         | 
| 108 | 
            +
                            self.cross_attn_norm,
         | 
| 109 | 
            +
                            self.eps,
         | 
| 110 | 
            +
                            block_id=self.vace_layers_mapping[i]
         | 
| 111 | 
            +
                            if i in self.vace_layers else None)
         | 
| 112 | 
            +
                        for i in range(self.num_layers)
         | 
| 113 | 
            +
                    ])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # vace blocks
         | 
| 116 | 
            +
                    self.vace_blocks = nn.ModuleList([
         | 
| 117 | 
            +
                        VaceWanAttentionBlock(
         | 
| 118 | 
            +
                            't2v_cross_attn',
         | 
| 119 | 
            +
                            self.dim,
         | 
| 120 | 
            +
                            self.ffn_dim,
         | 
| 121 | 
            +
                            self.num_heads,
         | 
| 122 | 
            +
                            self.window_size,
         | 
| 123 | 
            +
                            self.qk_norm,
         | 
| 124 | 
            +
                            self.cross_attn_norm,
         | 
| 125 | 
            +
                            self.eps,
         | 
| 126 | 
            +
                            block_id=i) for i in self.vace_layers
         | 
| 127 | 
            +
                    ])
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    # vace patch embeddings
         | 
| 130 | 
            +
                    self.vace_patch_embedding = nn.Conv3d(
         | 
| 131 | 
            +
                        self.vace_in_dim,
         | 
| 132 | 
            +
                        self.dim,
         | 
| 133 | 
            +
                        kernel_size=self.patch_size,
         | 
| 134 | 
            +
                        stride=self.patch_size)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def forward_vace(self, x, vace_context, seq_len, kwargs):
         | 
| 137 | 
            +
                    # embeddings
         | 
| 138 | 
            +
                    c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
         | 
| 139 | 
            +
                    c = [u.flatten(2).transpose(1, 2) for u in c]
         | 
| 140 | 
            +
                    c = torch.cat([
         | 
| 141 | 
            +
                        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
         | 
| 142 | 
            +
                                  dim=1) for u in c
         | 
| 143 | 
            +
                    ])
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # arguments
         | 
| 146 | 
            +
                    new_kwargs = dict(x=x)
         | 
| 147 | 
            +
                    new_kwargs.update(kwargs)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    hints = []
         | 
| 150 | 
            +
                    for block in self.vace_blocks:
         | 
| 151 | 
            +
                        c, c_skip = block(c, **new_kwargs)
         | 
| 152 | 
            +
                        hints.append(c_skip)
         | 
| 153 | 
            +
                    return hints
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                def forward(
         | 
| 156 | 
            +
                    self,
         | 
| 157 | 
            +
                    x,
         | 
| 158 | 
            +
                    t,
         | 
| 159 | 
            +
                    vace_context,
         | 
| 160 | 
            +
                    context,
         | 
| 161 | 
            +
                    seq_len,
         | 
| 162 | 
            +
                    vace_context_scale=1.0,
         | 
| 163 | 
            +
                    clip_fea=None,
         | 
| 164 | 
            +
                    y=None,
         | 
| 165 | 
            +
                ):
         | 
| 166 | 
            +
                    r"""
         | 
| 167 | 
            +
                    Forward pass through the diffusion model
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                    Args:
         | 
| 170 | 
            +
                        x (List[Tensor]):
         | 
| 171 | 
            +
                            List of input video tensors, each with shape [C_in, F, H, W]
         | 
| 172 | 
            +
                        t (Tensor):
         | 
| 173 | 
            +
                            Diffusion timesteps tensor of shape [B]
         | 
| 174 | 
            +
                        context (List[Tensor]):
         | 
| 175 | 
            +
                            List of text embeddings each with shape [L, C]
         | 
| 176 | 
            +
                        seq_len (`int`):
         | 
| 177 | 
            +
                            Maximum sequence length for positional encoding
         | 
| 178 | 
            +
                        clip_fea (Tensor, *optional*):
         | 
| 179 | 
            +
                            CLIP image features for image-to-video mode
         | 
| 180 | 
            +
                        y (List[Tensor], *optional*):
         | 
| 181 | 
            +
                            Conditional video inputs for image-to-video mode, same shape as x
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                    Returns:
         | 
| 184 | 
            +
                        List[Tensor]:
         | 
| 185 | 
            +
                            List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
         | 
| 186 | 
            +
                    """
         | 
| 187 | 
            +
                    # if self.model_type == 'i2v':
         | 
| 188 | 
            +
                    #     assert clip_fea is not None and y is not None
         | 
| 189 | 
            +
                    # params
         | 
| 190 | 
            +
                    device = self.patch_embedding.weight.device
         | 
| 191 | 
            +
                    if self.freqs.device != device:
         | 
| 192 | 
            +
                        self.freqs = self.freqs.to(device)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    # if y is not None:
         | 
| 195 | 
            +
                    #     x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    # embeddings
         | 
| 198 | 
            +
                    x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
         | 
| 199 | 
            +
                    grid_sizes = torch.stack(
         | 
| 200 | 
            +
                        [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
         | 
| 201 | 
            +
                    x = [u.flatten(2).transpose(1, 2) for u in x]
         | 
| 202 | 
            +
                    seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
         | 
| 203 | 
            +
                    assert seq_lens.max() <= seq_len
         | 
| 204 | 
            +
                    x = torch.cat([
         | 
| 205 | 
            +
                        torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
         | 
| 206 | 
            +
                                  dim=1) for u in x
         | 
| 207 | 
            +
                    ])
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    # time embeddings
         | 
| 210 | 
            +
                    with amp.autocast(dtype=torch.float32):
         | 
| 211 | 
            +
                        e = self.time_embedding(
         | 
| 212 | 
            +
                            sinusoidal_embedding_1d(self.freq_dim, t).float())
         | 
| 213 | 
            +
                        e0 = self.time_projection(e).unflatten(1, (6, self.dim))
         | 
| 214 | 
            +
                        assert e.dtype == torch.float32 and e0.dtype == torch.float32
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                    # context
         | 
| 217 | 
            +
                    context_lens = None
         | 
| 218 | 
            +
                    context = self.text_embedding(
         | 
| 219 | 
            +
                        torch.stack([
         | 
| 220 | 
            +
                            torch.cat(
         | 
| 221 | 
            +
                                [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
         | 
| 222 | 
            +
                            for u in context
         | 
| 223 | 
            +
                        ]))
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # if clip_fea is not None:
         | 
| 226 | 
            +
                    #     context_clip = self.img_emb(clip_fea)  # bs x 257 x dim
         | 
| 227 | 
            +
                    #     context = torch.concat([context_clip, context], dim=1)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    # arguments
         | 
| 230 | 
            +
                    kwargs = dict(
         | 
| 231 | 
            +
                        e=e0,
         | 
| 232 | 
            +
                        seq_lens=seq_lens,
         | 
| 233 | 
            +
                        grid_sizes=grid_sizes,
         | 
| 234 | 
            +
                        freqs=self.freqs,
         | 
| 235 | 
            +
                        context=context,
         | 
| 236 | 
            +
                        context_lens=context_lens)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    hints = self.forward_vace(x, vace_context, seq_len, kwargs)
         | 
| 239 | 
            +
                    kwargs['hints'] = hints
         | 
| 240 | 
            +
                    kwargs['context_scale'] = vace_context_scale
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    for block in self.blocks:
         | 
| 243 | 
            +
                        x = block(x, **kwargs)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    # head
         | 
| 246 | 
            +
                    x = self.head(x, e)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # unpatchify
         | 
| 249 | 
            +
                    x = self.unpatchify(x, grid_sizes)
         | 
| 250 | 
            +
                    return [u.float() for u in x]
         | 
    	
        wan/modules/vae.py
    ADDED
    
    | @@ -0,0 +1,663 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.cuda.amp as amp
         | 
| 6 | 
            +
            import torch.nn as nn
         | 
| 7 | 
            +
            import torch.nn.functional as F
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            __all__ = [
         | 
| 11 | 
            +
                'WanVAE',
         | 
| 12 | 
            +
            ]
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            CACHE_T = 2
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class CausalConv3d(nn.Conv3d):
         | 
| 18 | 
            +
                """
         | 
| 19 | 
            +
                Causal 3d convolusion.
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 23 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 24 | 
            +
                    self._padding = (self.padding[2], self.padding[2], self.padding[1],
         | 
| 25 | 
            +
                                     self.padding[1], 2 * self.padding[0], 0)
         | 
| 26 | 
            +
                    self.padding = (0, 0, 0)
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                def forward(self, x, cache_x=None):
         | 
| 29 | 
            +
                    padding = list(self._padding)
         | 
| 30 | 
            +
                    if cache_x is not None and self._padding[4] > 0:
         | 
| 31 | 
            +
                        cache_x = cache_x.to(x.device)
         | 
| 32 | 
            +
                        x = torch.cat([cache_x, x], dim=2)
         | 
| 33 | 
            +
                        padding[4] -= cache_x.shape[2]
         | 
| 34 | 
            +
                    x = F.pad(x, padding)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    return super().forward(x)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class RMS_norm(nn.Module):
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __init__(self, dim, channel_first=True, images=True, bias=False):
         | 
| 42 | 
            +
                    super().__init__()
         | 
| 43 | 
            +
                    broadcastable_dims = (1, 1, 1) if not images else (1, 1)
         | 
| 44 | 
            +
                    shape = (dim, *broadcastable_dims) if channel_first else (dim,)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    self.channel_first = channel_first
         | 
| 47 | 
            +
                    self.scale = dim**0.5
         | 
| 48 | 
            +
                    self.gamma = nn.Parameter(torch.ones(shape))
         | 
| 49 | 
            +
                    self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def forward(self, x):
         | 
| 52 | 
            +
                    return F.normalize(
         | 
| 53 | 
            +
                        x, dim=(1 if self.channel_first else
         | 
| 54 | 
            +
                                -1)) * self.scale * self.gamma + self.bias
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class Upsample(nn.Upsample):
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def forward(self, x):
         | 
| 60 | 
            +
                    """
         | 
| 61 | 
            +
                    Fix bfloat16 support for nearest neighbor interpolation.
         | 
| 62 | 
            +
                    """
         | 
| 63 | 
            +
                    return super().forward(x.float()).type_as(x)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            class Resample(nn.Module):
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                def __init__(self, dim, mode):
         | 
| 69 | 
            +
                    assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
         | 
| 70 | 
            +
                                    'downsample3d')
         | 
| 71 | 
            +
                    super().__init__()
         | 
| 72 | 
            +
                    self.dim = dim
         | 
| 73 | 
            +
                    self.mode = mode
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # layers
         | 
| 76 | 
            +
                    if mode == 'upsample2d':
         | 
| 77 | 
            +
                        self.resample = nn.Sequential(
         | 
| 78 | 
            +
                            Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
         | 
| 79 | 
            +
                            nn.Conv2d(dim, dim // 2, 3, padding=1))
         | 
| 80 | 
            +
                    elif mode == 'upsample3d':
         | 
| 81 | 
            +
                        self.resample = nn.Sequential(
         | 
| 82 | 
            +
                            Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
         | 
| 83 | 
            +
                            nn.Conv2d(dim, dim // 2, 3, padding=1))
         | 
| 84 | 
            +
                        self.time_conv = CausalConv3d(
         | 
| 85 | 
            +
                            dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    elif mode == 'downsample2d':
         | 
| 88 | 
            +
                        self.resample = nn.Sequential(
         | 
| 89 | 
            +
                            nn.ZeroPad2d((0, 1, 0, 1)),
         | 
| 90 | 
            +
                            nn.Conv2d(dim, dim, 3, stride=(2, 2)))
         | 
| 91 | 
            +
                    elif mode == 'downsample3d':
         | 
| 92 | 
            +
                        self.resample = nn.Sequential(
         | 
| 93 | 
            +
                            nn.ZeroPad2d((0, 1, 0, 1)),
         | 
| 94 | 
            +
                            nn.Conv2d(dim, dim, 3, stride=(2, 2)))
         | 
| 95 | 
            +
                        self.time_conv = CausalConv3d(
         | 
| 96 | 
            +
                            dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    else:
         | 
| 99 | 
            +
                        self.resample = nn.Identity()
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 102 | 
            +
                    b, c, t, h, w = x.size()
         | 
| 103 | 
            +
                    if self.mode == 'upsample3d':
         | 
| 104 | 
            +
                        if feat_cache is not None:
         | 
| 105 | 
            +
                            idx = feat_idx[0]
         | 
| 106 | 
            +
                            if feat_cache[idx] is None:
         | 
| 107 | 
            +
                                feat_cache[idx] = 'Rep'
         | 
| 108 | 
            +
                                feat_idx[0] += 1
         | 
| 109 | 
            +
                            else:
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                                cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 112 | 
            +
                                if cache_x.shape[2] < 2 and feat_cache[
         | 
| 113 | 
            +
                                        idx] is not None and feat_cache[idx] != 'Rep':
         | 
| 114 | 
            +
                                    # cache last frame of last two chunk
         | 
| 115 | 
            +
                                    cache_x = torch.cat([
         | 
| 116 | 
            +
                                        feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 117 | 
            +
                                            cache_x.device), cache_x
         | 
| 118 | 
            +
                                    ],
         | 
| 119 | 
            +
                                                        dim=2)
         | 
| 120 | 
            +
                                if cache_x.shape[2] < 2 and feat_cache[
         | 
| 121 | 
            +
                                        idx] is not None and feat_cache[idx] == 'Rep':
         | 
| 122 | 
            +
                                    cache_x = torch.cat([
         | 
| 123 | 
            +
                                        torch.zeros_like(cache_x).to(cache_x.device),
         | 
| 124 | 
            +
                                        cache_x
         | 
| 125 | 
            +
                                    ],
         | 
| 126 | 
            +
                                                        dim=2)
         | 
| 127 | 
            +
                                if feat_cache[idx] == 'Rep':
         | 
| 128 | 
            +
                                    x = self.time_conv(x)
         | 
| 129 | 
            +
                                else:
         | 
| 130 | 
            +
                                    x = self.time_conv(x, feat_cache[idx])
         | 
| 131 | 
            +
                                feat_cache[idx] = cache_x
         | 
| 132 | 
            +
                                feat_idx[0] += 1
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                                x = x.reshape(b, 2, c, t, h, w)
         | 
| 135 | 
            +
                                x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
         | 
| 136 | 
            +
                                                3)
         | 
| 137 | 
            +
                                x = x.reshape(b, c, t * 2, h, w)
         | 
| 138 | 
            +
                    t = x.shape[2]
         | 
| 139 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 140 | 
            +
                    x = self.resample(x)
         | 
| 141 | 
            +
                    x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                    if self.mode == 'downsample3d':
         | 
| 144 | 
            +
                        if feat_cache is not None:
         | 
| 145 | 
            +
                            idx = feat_idx[0]
         | 
| 146 | 
            +
                            if feat_cache[idx] is None:
         | 
| 147 | 
            +
                                feat_cache[idx] = x.clone()
         | 
| 148 | 
            +
                                feat_idx[0] += 1
         | 
| 149 | 
            +
                            else:
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                                cache_x = x[:, :, -1:, :, :].clone()
         | 
| 152 | 
            +
                                # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
         | 
| 153 | 
            +
                                #     # cache last frame of last two chunk
         | 
| 154 | 
            +
                                #     cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                                x = self.time_conv(
         | 
| 157 | 
            +
                                    torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
         | 
| 158 | 
            +
                                feat_cache[idx] = cache_x
         | 
| 159 | 
            +
                                feat_idx[0] += 1
         | 
| 160 | 
            +
                    return x
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def init_weight(self, conv):
         | 
| 163 | 
            +
                    conv_weight = conv.weight
         | 
| 164 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 165 | 
            +
                    c1, c2, t, h, w = conv_weight.size()
         | 
| 166 | 
            +
                    one_matrix = torch.eye(c1, c2)
         | 
| 167 | 
            +
                    init_matrix = one_matrix
         | 
| 168 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 169 | 
            +
                    #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
         | 
| 170 | 
            +
                    conv_weight.data[:, :, 1, 0, 0] = init_matrix  #* 0.5
         | 
| 171 | 
            +
                    conv.weight.data.copy_(conv_weight)
         | 
| 172 | 
            +
                    nn.init.zeros_(conv.bias.data)
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def init_weight2(self, conv):
         | 
| 175 | 
            +
                    conv_weight = conv.weight.data
         | 
| 176 | 
            +
                    nn.init.zeros_(conv_weight)
         | 
| 177 | 
            +
                    c1, c2, t, h, w = conv_weight.size()
         | 
| 178 | 
            +
                    init_matrix = torch.eye(c1 // 2, c2)
         | 
| 179 | 
            +
                    #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
         | 
| 180 | 
            +
                    conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
         | 
| 181 | 
            +
                    conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
         | 
| 182 | 
            +
                    conv.weight.data.copy_(conv_weight)
         | 
| 183 | 
            +
                    nn.init.zeros_(conv.bias.data)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            class ResidualBlock(nn.Module):
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                def __init__(self, in_dim, out_dim, dropout=0.0):
         | 
| 189 | 
            +
                    super().__init__()
         | 
| 190 | 
            +
                    self.in_dim = in_dim
         | 
| 191 | 
            +
                    self.out_dim = out_dim
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # layers
         | 
| 194 | 
            +
                    self.residual = nn.Sequential(
         | 
| 195 | 
            +
                        RMS_norm(in_dim, images=False), nn.SiLU(),
         | 
| 196 | 
            +
                        CausalConv3d(in_dim, out_dim, 3, padding=1),
         | 
| 197 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
         | 
| 198 | 
            +
                        CausalConv3d(out_dim, out_dim, 3, padding=1))
         | 
| 199 | 
            +
                    self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
         | 
| 200 | 
            +
                        if in_dim != out_dim else nn.Identity()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 203 | 
            +
                    h = self.shortcut(x)
         | 
| 204 | 
            +
                    for layer in self.residual:
         | 
| 205 | 
            +
                        if isinstance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 206 | 
            +
                            idx = feat_idx[0]
         | 
| 207 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 208 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 209 | 
            +
                                # cache last frame of last two chunk
         | 
| 210 | 
            +
                                cache_x = torch.cat([
         | 
| 211 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 212 | 
            +
                                        cache_x.device), cache_x
         | 
| 213 | 
            +
                                ],
         | 
| 214 | 
            +
                                                    dim=2)
         | 
| 215 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 216 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 217 | 
            +
                            feat_idx[0] += 1
         | 
| 218 | 
            +
                        else:
         | 
| 219 | 
            +
                            x = layer(x)
         | 
| 220 | 
            +
                    return x + h
         | 
| 221 | 
            +
             | 
| 222 | 
            +
             | 
| 223 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 224 | 
            +
                """
         | 
| 225 | 
            +
                Causal self-attention with a single head.
         | 
| 226 | 
            +
                """
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                def __init__(self, dim):
         | 
| 229 | 
            +
                    super().__init__()
         | 
| 230 | 
            +
                    self.dim = dim
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                    # layers
         | 
| 233 | 
            +
                    self.norm = RMS_norm(dim)
         | 
| 234 | 
            +
                    self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
         | 
| 235 | 
            +
                    self.proj = nn.Conv2d(dim, dim, 1)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                    # zero out the last layer params
         | 
| 238 | 
            +
                    nn.init.zeros_(self.proj.weight)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                def forward(self, x):
         | 
| 241 | 
            +
                    identity = x
         | 
| 242 | 
            +
                    b, c, t, h, w = x.size()
         | 
| 243 | 
            +
                    x = rearrange(x, 'b c t h w -> (b t) c h w')
         | 
| 244 | 
            +
                    x = self.norm(x)
         | 
| 245 | 
            +
                    # compute query, key, value
         | 
| 246 | 
            +
                    q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
         | 
| 247 | 
            +
                                                     -1).permute(0, 1, 3,
         | 
| 248 | 
            +
                                                                 2).contiguous().chunk(
         | 
| 249 | 
            +
                                                                     3, dim=-1)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    # apply attention
         | 
| 252 | 
            +
                    x = F.scaled_dot_product_attention(
         | 
| 253 | 
            +
                        q,
         | 
| 254 | 
            +
                        k,
         | 
| 255 | 
            +
                        v,
         | 
| 256 | 
            +
                    )
         | 
| 257 | 
            +
                    x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # output
         | 
| 260 | 
            +
                    x = self.proj(x)
         | 
| 261 | 
            +
                    x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
         | 
| 262 | 
            +
                    return x + identity
         | 
| 263 | 
            +
             | 
| 264 | 
            +
             | 
| 265 | 
            +
            class Encoder3d(nn.Module):
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                def __init__(self,
         | 
| 268 | 
            +
                             dim=128,
         | 
| 269 | 
            +
                             z_dim=4,
         | 
| 270 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 271 | 
            +
                             num_res_blocks=2,
         | 
| 272 | 
            +
                             attn_scales=[],
         | 
| 273 | 
            +
                             temperal_downsample=[True, True, False],
         | 
| 274 | 
            +
                             dropout=0.0):
         | 
| 275 | 
            +
                    super().__init__()
         | 
| 276 | 
            +
                    self.dim = dim
         | 
| 277 | 
            +
                    self.z_dim = z_dim
         | 
| 278 | 
            +
                    self.dim_mult = dim_mult
         | 
| 279 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 280 | 
            +
                    self.attn_scales = attn_scales
         | 
| 281 | 
            +
                    self.temperal_downsample = temperal_downsample
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    # dimensions
         | 
| 284 | 
            +
                    dims = [dim * u for u in [1] + dim_mult]
         | 
| 285 | 
            +
                    scale = 1.0
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    # init block
         | 
| 288 | 
            +
                    self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    # downsample blocks
         | 
| 291 | 
            +
                    downsamples = []
         | 
| 292 | 
            +
                    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
         | 
| 293 | 
            +
                        # residual (+attention) blocks
         | 
| 294 | 
            +
                        for _ in range(num_res_blocks):
         | 
| 295 | 
            +
                            downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
         | 
| 296 | 
            +
                            if scale in attn_scales:
         | 
| 297 | 
            +
                                downsamples.append(AttentionBlock(out_dim))
         | 
| 298 | 
            +
                            in_dim = out_dim
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                        # downsample block
         | 
| 301 | 
            +
                        if i != len(dim_mult) - 1:
         | 
| 302 | 
            +
                            mode = 'downsample3d' if temperal_downsample[
         | 
| 303 | 
            +
                                i] else 'downsample2d'
         | 
| 304 | 
            +
                            downsamples.append(Resample(out_dim, mode=mode))
         | 
| 305 | 
            +
                            scale /= 2.0
         | 
| 306 | 
            +
                    self.downsamples = nn.Sequential(*downsamples)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    # middle blocks
         | 
| 309 | 
            +
                    self.middle = nn.Sequential(
         | 
| 310 | 
            +
                        ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
         | 
| 311 | 
            +
                        ResidualBlock(out_dim, out_dim, dropout))
         | 
| 312 | 
            +
             | 
| 313 | 
            +
                    # output blocks
         | 
| 314 | 
            +
                    self.head = nn.Sequential(
         | 
| 315 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(),
         | 
| 316 | 
            +
                        CausalConv3d(out_dim, z_dim, 3, padding=1))
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 319 | 
            +
                    if feat_cache is not None:
         | 
| 320 | 
            +
                        idx = feat_idx[0]
         | 
| 321 | 
            +
                        cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 322 | 
            +
                        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 323 | 
            +
                            # cache last frame of last two chunk
         | 
| 324 | 
            +
                            cache_x = torch.cat([
         | 
| 325 | 
            +
                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 326 | 
            +
                                    cache_x.device), cache_x
         | 
| 327 | 
            +
                            ],
         | 
| 328 | 
            +
                                                dim=2)
         | 
| 329 | 
            +
                        x = self.conv1(x, feat_cache[idx])
         | 
| 330 | 
            +
                        feat_cache[idx] = cache_x
         | 
| 331 | 
            +
                        feat_idx[0] += 1
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        x = self.conv1(x)
         | 
| 334 | 
            +
             | 
| 335 | 
            +
                    ## downsamples
         | 
| 336 | 
            +
                    for layer in self.downsamples:
         | 
| 337 | 
            +
                        if feat_cache is not None:
         | 
| 338 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 339 | 
            +
                        else:
         | 
| 340 | 
            +
                            x = layer(x)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    ## middle
         | 
| 343 | 
            +
                    for layer in self.middle:
         | 
| 344 | 
            +
                        if isinstance(layer, ResidualBlock) and feat_cache is not None:
         | 
| 345 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 346 | 
            +
                        else:
         | 
| 347 | 
            +
                            x = layer(x)
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                    ## head
         | 
| 350 | 
            +
                    for layer in self.head:
         | 
| 351 | 
            +
                        if isinstance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 352 | 
            +
                            idx = feat_idx[0]
         | 
| 353 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 354 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 355 | 
            +
                                # cache last frame of last two chunk
         | 
| 356 | 
            +
                                cache_x = torch.cat([
         | 
| 357 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 358 | 
            +
                                        cache_x.device), cache_x
         | 
| 359 | 
            +
                                ],
         | 
| 360 | 
            +
                                                    dim=2)
         | 
| 361 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 362 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 363 | 
            +
                            feat_idx[0] += 1
         | 
| 364 | 
            +
                        else:
         | 
| 365 | 
            +
                            x = layer(x)
         | 
| 366 | 
            +
                    return x
         | 
| 367 | 
            +
             | 
| 368 | 
            +
             | 
| 369 | 
            +
            class Decoder3d(nn.Module):
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                def __init__(self,
         | 
| 372 | 
            +
                             dim=128,
         | 
| 373 | 
            +
                             z_dim=4,
         | 
| 374 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 375 | 
            +
                             num_res_blocks=2,
         | 
| 376 | 
            +
                             attn_scales=[],
         | 
| 377 | 
            +
                             temperal_upsample=[False, True, True],
         | 
| 378 | 
            +
                             dropout=0.0):
         | 
| 379 | 
            +
                    super().__init__()
         | 
| 380 | 
            +
                    self.dim = dim
         | 
| 381 | 
            +
                    self.z_dim = z_dim
         | 
| 382 | 
            +
                    self.dim_mult = dim_mult
         | 
| 383 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 384 | 
            +
                    self.attn_scales = attn_scales
         | 
| 385 | 
            +
                    self.temperal_upsample = temperal_upsample
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    # dimensions
         | 
| 388 | 
            +
                    dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
         | 
| 389 | 
            +
                    scale = 1.0 / 2**(len(dim_mult) - 2)
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    # init block
         | 
| 392 | 
            +
                    self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                    # middle blocks
         | 
| 395 | 
            +
                    self.middle = nn.Sequential(
         | 
| 396 | 
            +
                        ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
         | 
| 397 | 
            +
                        ResidualBlock(dims[0], dims[0], dropout))
         | 
| 398 | 
            +
             | 
| 399 | 
            +
                    # upsample blocks
         | 
| 400 | 
            +
                    upsamples = []
         | 
| 401 | 
            +
                    for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
         | 
| 402 | 
            +
                        # residual (+attention) blocks
         | 
| 403 | 
            +
                        if i == 1 or i == 2 or i == 3:
         | 
| 404 | 
            +
                            in_dim = in_dim // 2
         | 
| 405 | 
            +
                        for _ in range(num_res_blocks + 1):
         | 
| 406 | 
            +
                            upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
         | 
| 407 | 
            +
                            if scale in attn_scales:
         | 
| 408 | 
            +
                                upsamples.append(AttentionBlock(out_dim))
         | 
| 409 | 
            +
                            in_dim = out_dim
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                        # upsample block
         | 
| 412 | 
            +
                        if i != len(dim_mult) - 1:
         | 
| 413 | 
            +
                            mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
         | 
| 414 | 
            +
                            upsamples.append(Resample(out_dim, mode=mode))
         | 
| 415 | 
            +
                            scale *= 2.0
         | 
| 416 | 
            +
                    self.upsamples = nn.Sequential(*upsamples)
         | 
| 417 | 
            +
             | 
| 418 | 
            +
                    # output blocks
         | 
| 419 | 
            +
                    self.head = nn.Sequential(
         | 
| 420 | 
            +
                        RMS_norm(out_dim, images=False), nn.SiLU(),
         | 
| 421 | 
            +
                        CausalConv3d(out_dim, 3, 3, padding=1))
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                def forward(self, x, feat_cache=None, feat_idx=[0]):
         | 
| 424 | 
            +
                    ## conv1
         | 
| 425 | 
            +
                    if feat_cache is not None:
         | 
| 426 | 
            +
                        idx = feat_idx[0]
         | 
| 427 | 
            +
                        cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 428 | 
            +
                        if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 429 | 
            +
                            # cache last frame of last two chunk
         | 
| 430 | 
            +
                            cache_x = torch.cat([
         | 
| 431 | 
            +
                                feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 432 | 
            +
                                    cache_x.device), cache_x
         | 
| 433 | 
            +
                            ],
         | 
| 434 | 
            +
                                                dim=2)
         | 
| 435 | 
            +
                        x = self.conv1(x, feat_cache[idx])
         | 
| 436 | 
            +
                        feat_cache[idx] = cache_x
         | 
| 437 | 
            +
                        feat_idx[0] += 1
         | 
| 438 | 
            +
                    else:
         | 
| 439 | 
            +
                        x = self.conv1(x)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    ## middle
         | 
| 442 | 
            +
                    for layer in self.middle:
         | 
| 443 | 
            +
                        if isinstance(layer, ResidualBlock) and feat_cache is not None:
         | 
| 444 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 445 | 
            +
                        else:
         | 
| 446 | 
            +
                            x = layer(x)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    ## upsamples
         | 
| 449 | 
            +
                    for layer in self.upsamples:
         | 
| 450 | 
            +
                        if feat_cache is not None:
         | 
| 451 | 
            +
                            x = layer(x, feat_cache, feat_idx)
         | 
| 452 | 
            +
                        else:
         | 
| 453 | 
            +
                            x = layer(x)
         | 
| 454 | 
            +
             | 
| 455 | 
            +
                    ## head
         | 
| 456 | 
            +
                    for layer in self.head:
         | 
| 457 | 
            +
                        if isinstance(layer, CausalConv3d) and feat_cache is not None:
         | 
| 458 | 
            +
                            idx = feat_idx[0]
         | 
| 459 | 
            +
                            cache_x = x[:, :, -CACHE_T:, :, :].clone()
         | 
| 460 | 
            +
                            if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
         | 
| 461 | 
            +
                                # cache last frame of last two chunk
         | 
| 462 | 
            +
                                cache_x = torch.cat([
         | 
| 463 | 
            +
                                    feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
         | 
| 464 | 
            +
                                        cache_x.device), cache_x
         | 
| 465 | 
            +
                                ],
         | 
| 466 | 
            +
                                                    dim=2)
         | 
| 467 | 
            +
                            x = layer(x, feat_cache[idx])
         | 
| 468 | 
            +
                            feat_cache[idx] = cache_x
         | 
| 469 | 
            +
                            feat_idx[0] += 1
         | 
| 470 | 
            +
                        else:
         | 
| 471 | 
            +
                            x = layer(x)
         | 
| 472 | 
            +
                    return x
         | 
| 473 | 
            +
             | 
| 474 | 
            +
             | 
| 475 | 
            +
            def count_conv3d(model):
         | 
| 476 | 
            +
                count = 0
         | 
| 477 | 
            +
                for m in model.modules():
         | 
| 478 | 
            +
                    if isinstance(m, CausalConv3d):
         | 
| 479 | 
            +
                        count += 1
         | 
| 480 | 
            +
                return count
         | 
| 481 | 
            +
             | 
| 482 | 
            +
             | 
| 483 | 
            +
            class WanVAE_(nn.Module):
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                def __init__(self,
         | 
| 486 | 
            +
                             dim=128,
         | 
| 487 | 
            +
                             z_dim=4,
         | 
| 488 | 
            +
                             dim_mult=[1, 2, 4, 4],
         | 
| 489 | 
            +
                             num_res_blocks=2,
         | 
| 490 | 
            +
                             attn_scales=[],
         | 
| 491 | 
            +
                             temperal_downsample=[True, True, False],
         | 
| 492 | 
            +
                             dropout=0.0):
         | 
| 493 | 
            +
                    super().__init__()
         | 
| 494 | 
            +
                    self.dim = dim
         | 
| 495 | 
            +
                    self.z_dim = z_dim
         | 
| 496 | 
            +
                    self.dim_mult = dim_mult
         | 
| 497 | 
            +
                    self.num_res_blocks = num_res_blocks
         | 
| 498 | 
            +
                    self.attn_scales = attn_scales
         | 
| 499 | 
            +
                    self.temperal_downsample = temperal_downsample
         | 
| 500 | 
            +
                    self.temperal_upsample = temperal_downsample[::-1]
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                    # modules
         | 
| 503 | 
            +
                    self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
         | 
| 504 | 
            +
                                             attn_scales, self.temperal_downsample, dropout)
         | 
| 505 | 
            +
                    self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
         | 
| 506 | 
            +
                    self.conv2 = CausalConv3d(z_dim, z_dim, 1)
         | 
| 507 | 
            +
                    self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
         | 
| 508 | 
            +
                                             attn_scales, self.temperal_upsample, dropout)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                def forward(self, x):
         | 
| 511 | 
            +
                    mu, log_var = self.encode(x)
         | 
| 512 | 
            +
                    z = self.reparameterize(mu, log_var)
         | 
| 513 | 
            +
                    x_recon = self.decode(z)
         | 
| 514 | 
            +
                    return x_recon, mu, log_var
         | 
| 515 | 
            +
             | 
| 516 | 
            +
                def encode(self, x, scale):
         | 
| 517 | 
            +
                    self.clear_cache()
         | 
| 518 | 
            +
                    ## cache
         | 
| 519 | 
            +
                    t = x.shape[2]
         | 
| 520 | 
            +
                    iter_ = 1 + (t - 1) // 4
         | 
| 521 | 
            +
                    ## 对encode输入的x,按时间拆分为1、4、4、4....
         | 
| 522 | 
            +
                    for i in range(iter_):
         | 
| 523 | 
            +
                        self._enc_conv_idx = [0]
         | 
| 524 | 
            +
                        if i == 0:
         | 
| 525 | 
            +
                            out = self.encoder(
         | 
| 526 | 
            +
                                x[:, :, :1, :, :],
         | 
| 527 | 
            +
                                feat_cache=self._enc_feat_map,
         | 
| 528 | 
            +
                                feat_idx=self._enc_conv_idx)
         | 
| 529 | 
            +
                        else:
         | 
| 530 | 
            +
                            out_ = self.encoder(
         | 
| 531 | 
            +
                                x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
         | 
| 532 | 
            +
                                feat_cache=self._enc_feat_map,
         | 
| 533 | 
            +
                                feat_idx=self._enc_conv_idx)
         | 
| 534 | 
            +
                            out = torch.cat([out, out_], 2)
         | 
| 535 | 
            +
                    mu, log_var = self.conv1(out).chunk(2, dim=1)
         | 
| 536 | 
            +
                    if isinstance(scale[0], torch.Tensor):
         | 
| 537 | 
            +
                        mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
         | 
| 538 | 
            +
                            1, self.z_dim, 1, 1, 1)
         | 
| 539 | 
            +
                    else:
         | 
| 540 | 
            +
                        mu = (mu - scale[0]) * scale[1]
         | 
| 541 | 
            +
                    self.clear_cache()
         | 
| 542 | 
            +
                    return mu
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                def decode(self, z, scale):
         | 
| 545 | 
            +
                    self.clear_cache()
         | 
| 546 | 
            +
                    # z: [b,c,t,h,w]
         | 
| 547 | 
            +
                    if isinstance(scale[0], torch.Tensor):
         | 
| 548 | 
            +
                        z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
         | 
| 549 | 
            +
                            1, self.z_dim, 1, 1, 1)
         | 
| 550 | 
            +
                    else:
         | 
| 551 | 
            +
                        z = z / scale[1] + scale[0]
         | 
| 552 | 
            +
                    iter_ = z.shape[2]
         | 
| 553 | 
            +
                    x = self.conv2(z)
         | 
| 554 | 
            +
                    for i in range(iter_):
         | 
| 555 | 
            +
                        self._conv_idx = [0]
         | 
| 556 | 
            +
                        if i == 0:
         | 
| 557 | 
            +
                            out = self.decoder(
         | 
| 558 | 
            +
                                x[:, :, i:i + 1, :, :],
         | 
| 559 | 
            +
                                feat_cache=self._feat_map,
         | 
| 560 | 
            +
                                feat_idx=self._conv_idx)
         | 
| 561 | 
            +
                        else:
         | 
| 562 | 
            +
                            out_ = self.decoder(
         | 
| 563 | 
            +
                                x[:, :, i:i + 1, :, :],
         | 
| 564 | 
            +
                                feat_cache=self._feat_map,
         | 
| 565 | 
            +
                                feat_idx=self._conv_idx)
         | 
| 566 | 
            +
                            out = torch.cat([out, out_], 2)
         | 
| 567 | 
            +
                    self.clear_cache()
         | 
| 568 | 
            +
                    return out
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                def reparameterize(self, mu, log_var):
         | 
| 571 | 
            +
                    std = torch.exp(0.5 * log_var)
         | 
| 572 | 
            +
                    eps = torch.randn_like(std)
         | 
| 573 | 
            +
                    return eps * std + mu
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                def sample(self, imgs, deterministic=False):
         | 
| 576 | 
            +
                    mu, log_var = self.encode(imgs)
         | 
| 577 | 
            +
                    if deterministic:
         | 
| 578 | 
            +
                        return mu
         | 
| 579 | 
            +
                    std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
         | 
| 580 | 
            +
                    return mu + std * torch.randn_like(std)
         | 
| 581 | 
            +
             | 
| 582 | 
            +
                def clear_cache(self):
         | 
| 583 | 
            +
                    self._conv_num = count_conv3d(self.decoder)
         | 
| 584 | 
            +
                    self._conv_idx = [0]
         | 
| 585 | 
            +
                    self._feat_map = [None] * self._conv_num
         | 
| 586 | 
            +
                    #cache encode
         | 
| 587 | 
            +
                    self._enc_conv_num = count_conv3d(self.encoder)
         | 
| 588 | 
            +
                    self._enc_conv_idx = [0]
         | 
| 589 | 
            +
                    self._enc_feat_map = [None] * self._enc_conv_num
         | 
| 590 | 
            +
             | 
| 591 | 
            +
             | 
| 592 | 
            +
            def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
         | 
| 593 | 
            +
                """
         | 
| 594 | 
            +
                Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
         | 
| 595 | 
            +
                """
         | 
| 596 | 
            +
                # params
         | 
| 597 | 
            +
                cfg = dict(
         | 
| 598 | 
            +
                    dim=96,
         | 
| 599 | 
            +
                    z_dim=z_dim,
         | 
| 600 | 
            +
                    dim_mult=[1, 2, 4, 4],
         | 
| 601 | 
            +
                    num_res_blocks=2,
         | 
| 602 | 
            +
                    attn_scales=[],
         | 
| 603 | 
            +
                    temperal_downsample=[False, True, True],
         | 
| 604 | 
            +
                    dropout=0.0)
         | 
| 605 | 
            +
                cfg.update(**kwargs)
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                # init model
         | 
| 608 | 
            +
                with torch.device('meta'):
         | 
| 609 | 
            +
                    model = WanVAE_(**cfg)
         | 
| 610 | 
            +
             | 
| 611 | 
            +
                # load checkpoint
         | 
| 612 | 
            +
                logging.info(f'loading {pretrained_path}')
         | 
| 613 | 
            +
                model.load_state_dict(
         | 
| 614 | 
            +
                    torch.load(pretrained_path, map_location=device), assign=True)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                return model
         | 
| 617 | 
            +
             | 
| 618 | 
            +
             | 
| 619 | 
            +
            class WanVAE:
         | 
| 620 | 
            +
             | 
| 621 | 
            +
                def __init__(self,
         | 
| 622 | 
            +
                             z_dim=16,
         | 
| 623 | 
            +
                             vae_pth='cache/vae_step_411000.pth',
         | 
| 624 | 
            +
                             dtype=torch.float,
         | 
| 625 | 
            +
                             device="cuda"):
         | 
| 626 | 
            +
                    self.dtype = dtype
         | 
| 627 | 
            +
                    self.device = device
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    mean = [
         | 
| 630 | 
            +
                        -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
         | 
| 631 | 
            +
                        0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
         | 
| 632 | 
            +
                    ]
         | 
| 633 | 
            +
                    std = [
         | 
| 634 | 
            +
                        2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
         | 
| 635 | 
            +
                        3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
         | 
| 636 | 
            +
                    ]
         | 
| 637 | 
            +
                    self.mean = torch.tensor(mean, dtype=dtype, device=device)
         | 
| 638 | 
            +
                    self.std = torch.tensor(std, dtype=dtype, device=device)
         | 
| 639 | 
            +
                    self.scale = [self.mean, 1.0 / self.std]
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                    # init model
         | 
| 642 | 
            +
                    self.model = _video_vae(
         | 
| 643 | 
            +
                        pretrained_path=vae_pth,
         | 
| 644 | 
            +
                        z_dim=z_dim,
         | 
| 645 | 
            +
                    ).eval().requires_grad_(False).to(device)
         | 
| 646 | 
            +
             | 
| 647 | 
            +
                def encode(self, videos):
         | 
| 648 | 
            +
                    """
         | 
| 649 | 
            +
                    videos: A list of videos each with shape [C, T, H, W].
         | 
| 650 | 
            +
                    """
         | 
| 651 | 
            +
                    with amp.autocast(dtype=self.dtype):
         | 
| 652 | 
            +
                        return [
         | 
| 653 | 
            +
                            self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
         | 
| 654 | 
            +
                            for u in videos
         | 
| 655 | 
            +
                        ]
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                def decode(self, zs):
         | 
| 658 | 
            +
                    with amp.autocast(dtype=self.dtype):
         | 
| 659 | 
            +
                        return [
         | 
| 660 | 
            +
                            self.model.decode(u.unsqueeze(0),
         | 
| 661 | 
            +
                                              self.scale).float().clamp_(-1, 1).squeeze(0)
         | 
| 662 | 
            +
                            for u in zs
         | 
| 663 | 
            +
                        ]
         | 
    	
        wan/modules/xlm_roberta.py
    ADDED
    
    | @@ -0,0 +1,170 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
         | 
| 2 | 
            +
            # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            __all__ = ['XLMRoberta', 'xlm_roberta_large']
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class SelfAttention(nn.Module):
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
         | 
| 13 | 
            +
                    assert dim % num_heads == 0
         | 
| 14 | 
            +
                    super().__init__()
         | 
| 15 | 
            +
                    self.dim = dim
         | 
| 16 | 
            +
                    self.num_heads = num_heads
         | 
| 17 | 
            +
                    self.head_dim = dim // num_heads
         | 
| 18 | 
            +
                    self.eps = eps
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    # layers
         | 
| 21 | 
            +
                    self.q = nn.Linear(dim, dim)
         | 
| 22 | 
            +
                    self.k = nn.Linear(dim, dim)
         | 
| 23 | 
            +
                    self.v = nn.Linear(dim, dim)
         | 
| 24 | 
            +
                    self.o = nn.Linear(dim, dim)
         | 
| 25 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def forward(self, x, mask):
         | 
| 28 | 
            +
                    """
         | 
| 29 | 
            +
                    x:   [B, L, C].
         | 
| 30 | 
            +
                    """
         | 
| 31 | 
            +
                    b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    # compute query, key, value
         | 
| 34 | 
            +
                    q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 35 | 
            +
                    k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 36 | 
            +
                    v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    # compute attention
         | 
| 39 | 
            +
                    p = self.dropout.p if self.training else 0.0
         | 
| 40 | 
            +
                    x = F.scaled_dot_product_attention(q, k, v, mask, p)
         | 
| 41 | 
            +
                    x = x.permute(0, 2, 1, 3).reshape(b, s, c)
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                    # output
         | 
| 44 | 
            +
                    x = self.o(x)
         | 
| 45 | 
            +
                    x = self.dropout(x)
         | 
| 46 | 
            +
                    return x
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class AttentionBlock(nn.Module):
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
         | 
| 52 | 
            +
                    super().__init__()
         | 
| 53 | 
            +
                    self.dim = dim
         | 
| 54 | 
            +
                    self.num_heads = num_heads
         | 
| 55 | 
            +
                    self.post_norm = post_norm
         | 
| 56 | 
            +
                    self.eps = eps
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    # layers
         | 
| 59 | 
            +
                    self.attn = SelfAttention(dim, num_heads, dropout, eps)
         | 
| 60 | 
            +
                    self.norm1 = nn.LayerNorm(dim, eps=eps)
         | 
| 61 | 
            +
                    self.ffn = nn.Sequential(
         | 
| 62 | 
            +
                        nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
         | 
| 63 | 
            +
                        nn.Dropout(dropout))
         | 
| 64 | 
            +
                    self.norm2 = nn.LayerNorm(dim, eps=eps)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                def forward(self, x, mask):
         | 
| 67 | 
            +
                    if self.post_norm:
         | 
| 68 | 
            +
                        x = self.norm1(x + self.attn(x, mask))
         | 
| 69 | 
            +
                        x = self.norm2(x + self.ffn(x))
         | 
| 70 | 
            +
                    else:
         | 
| 71 | 
            +
                        x = x + self.attn(self.norm1(x), mask)
         | 
| 72 | 
            +
                        x = x + self.ffn(self.norm2(x))
         | 
| 73 | 
            +
                    return x
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            class XLMRoberta(nn.Module):
         | 
| 77 | 
            +
                """
         | 
| 78 | 
            +
                XLMRobertaModel with no pooler and no LM head.
         | 
| 79 | 
            +
                """
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def __init__(self,
         | 
| 82 | 
            +
                             vocab_size=250002,
         | 
| 83 | 
            +
                             max_seq_len=514,
         | 
| 84 | 
            +
                             type_size=1,
         | 
| 85 | 
            +
                             pad_id=1,
         | 
| 86 | 
            +
                             dim=1024,
         | 
| 87 | 
            +
                             num_heads=16,
         | 
| 88 | 
            +
                             num_layers=24,
         | 
| 89 | 
            +
                             post_norm=True,
         | 
| 90 | 
            +
                             dropout=0.1,
         | 
| 91 | 
            +
                             eps=1e-5):
         | 
| 92 | 
            +
                    super().__init__()
         | 
| 93 | 
            +
                    self.vocab_size = vocab_size
         | 
| 94 | 
            +
                    self.max_seq_len = max_seq_len
         | 
| 95 | 
            +
                    self.type_size = type_size
         | 
| 96 | 
            +
                    self.pad_id = pad_id
         | 
| 97 | 
            +
                    self.dim = dim
         | 
| 98 | 
            +
                    self.num_heads = num_heads
         | 
| 99 | 
            +
                    self.num_layers = num_layers
         | 
| 100 | 
            +
                    self.post_norm = post_norm
         | 
| 101 | 
            +
                    self.eps = eps
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                    # embeddings
         | 
| 104 | 
            +
                    self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
         | 
| 105 | 
            +
                    self.type_embedding = nn.Embedding(type_size, dim)
         | 
| 106 | 
            +
                    self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
         | 
| 107 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    # blocks
         | 
| 110 | 
            +
                    self.blocks = nn.ModuleList([
         | 
| 111 | 
            +
                        AttentionBlock(dim, num_heads, post_norm, dropout, eps)
         | 
| 112 | 
            +
                        for _ in range(num_layers)
         | 
| 113 | 
            +
                    ])
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    # norm layer
         | 
| 116 | 
            +
                    self.norm = nn.LayerNorm(dim, eps=eps)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def forward(self, ids):
         | 
| 119 | 
            +
                    """
         | 
| 120 | 
            +
                    ids: [B, L] of torch.LongTensor.
         | 
| 121 | 
            +
                    """
         | 
| 122 | 
            +
                    b, s = ids.shape
         | 
| 123 | 
            +
                    mask = ids.ne(self.pad_id).long()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # embeddings
         | 
| 126 | 
            +
                    x = self.token_embedding(ids) + \
         | 
| 127 | 
            +
                        self.type_embedding(torch.zeros_like(ids)) + \
         | 
| 128 | 
            +
                        self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
         | 
| 129 | 
            +
                    if self.post_norm:
         | 
| 130 | 
            +
                        x = self.norm(x)
         | 
| 131 | 
            +
                    x = self.dropout(x)
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    # blocks
         | 
| 134 | 
            +
                    mask = torch.where(
         | 
| 135 | 
            +
                        mask.view(b, 1, 1, s).gt(0), 0.0,
         | 
| 136 | 
            +
                        torch.finfo(x.dtype).min)
         | 
| 137 | 
            +
                    for block in self.blocks:
         | 
| 138 | 
            +
                        x = block(x, mask)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # output
         | 
| 141 | 
            +
                    if not self.post_norm:
         | 
| 142 | 
            +
                        x = self.norm(x)
         | 
| 143 | 
            +
                    return x
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            def xlm_roberta_large(pretrained=False,
         | 
| 147 | 
            +
                                  return_tokenizer=False,
         | 
| 148 | 
            +
                                  device='cpu',
         | 
| 149 | 
            +
                                  **kwargs):
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                XLMRobertaLarge adapted from Huggingface.
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                # params
         | 
| 154 | 
            +
                cfg = dict(
         | 
| 155 | 
            +
                    vocab_size=250002,
         | 
| 156 | 
            +
                    max_seq_len=514,
         | 
| 157 | 
            +
                    type_size=1,
         | 
| 158 | 
            +
                    pad_id=1,
         | 
| 159 | 
            +
                    dim=1024,
         | 
| 160 | 
            +
                    num_heads=16,
         | 
| 161 | 
            +
                    num_layers=24,
         | 
| 162 | 
            +
                    post_norm=True,
         | 
| 163 | 
            +
                    dropout=0.1,
         | 
| 164 | 
            +
                    eps=1e-5)
         | 
| 165 | 
            +
                cfg.update(**kwargs)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                # init a model on device
         | 
| 168 | 
            +
                with torch.device(device):
         | 
| 169 | 
            +
                    model = XLMRoberta(**cfg)
         | 
| 170 | 
            +
                return model
         | 
