multimodalart HF staff commited on
Commit
38e20ed
·
verified ·
1 Parent(s): e3cc724

Upload 83 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. LICENSE.txt +38 -0
  3. README.md +194 -14
  4. assets/.DS_Store +0 -0
  5. assets/demo.gif +3 -0
  6. assets/driving_audio/1.wav +3 -0
  7. assets/driving_audio/2.wav +3 -0
  8. assets/driving_audio/3.wav +3 -0
  9. assets/driving_audio/4.wav +3 -0
  10. assets/driving_audio/5.wav +3 -0
  11. assets/driving_audio/6.wav +3 -0
  12. assets/driving_video/.DS_Store +0 -0
  13. assets/driving_video/1.mp4 +3 -0
  14. assets/driving_video/2.mp4 +3 -0
  15. assets/driving_video/3.mp4 +3 -0
  16. assets/driving_video/4.mp4 +3 -0
  17. assets/driving_video/5.mp4 +3 -0
  18. assets/driving_video/6.mp4 +3 -0
  19. assets/driving_video/7.mp4 +3 -0
  20. assets/driving_video/8.mp4 +3 -0
  21. assets/logo.png +0 -0
  22. assets/ref_images/1.png +3 -0
  23. assets/ref_images/10.png +3 -0
  24. assets/ref_images/11.png +3 -0
  25. assets/ref_images/12.png +3 -0
  26. assets/ref_images/13.png +3 -0
  27. assets/ref_images/14.png +3 -0
  28. assets/ref_images/15.png +3 -0
  29. assets/ref_images/16.png +3 -0
  30. assets/ref_images/17.png +3 -0
  31. assets/ref_images/18.png +3 -0
  32. assets/ref_images/19.png +3 -0
  33. assets/ref_images/2.png +3 -0
  34. assets/ref_images/20.png +0 -0
  35. assets/ref_images/3.png +3 -0
  36. assets/ref_images/4.png +3 -0
  37. assets/ref_images/5.png +3 -0
  38. assets/ref_images/6.png +3 -0
  39. assets/ref_images/7.png +3 -0
  40. assets/ref_images/8.png +3 -0
  41. diffposetalk/common.py +46 -0
  42. diffposetalk/diff_talking_head.py +536 -0
  43. diffposetalk/diffposetalk.py +228 -0
  44. diffposetalk/hubert.py +51 -0
  45. diffposetalk/utils/__init__.py +1 -0
  46. diffposetalk/utils/common.py +378 -0
  47. diffposetalk/utils/media.py +35 -0
  48. diffposetalk/utils/renderer.py +147 -0
  49. diffposetalk/utils/rotation_conversions.py +569 -0
  50. diffposetalk/wav2vec2.py +119 -0
.gitattributes CHANGED
@@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/driving_audio/1.wav filter=lfs diff=lfs merge=lfs -text
38
+ assets/driving_audio/2.wav filter=lfs diff=lfs merge=lfs -text
39
+ assets/driving_audio/3.wav filter=lfs diff=lfs merge=lfs -text
40
+ assets/driving_audio/4.wav filter=lfs diff=lfs merge=lfs -text
41
+ assets/driving_audio/5.wav filter=lfs diff=lfs merge=lfs -text
42
+ assets/driving_audio/6.wav filter=lfs diff=lfs merge=lfs -text
43
+ assets/driving_video/1.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ assets/driving_video/2.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ assets/driving_video/3.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ assets/driving_video/4.mp4 filter=lfs diff=lfs merge=lfs -text
47
+ assets/driving_video/5.mp4 filter=lfs diff=lfs merge=lfs -text
48
+ assets/driving_video/6.mp4 filter=lfs diff=lfs merge=lfs -text
49
+ assets/driving_video/7.mp4 filter=lfs diff=lfs merge=lfs -text
50
+ assets/driving_video/8.mp4 filter=lfs diff=lfs merge=lfs -text
51
+ assets/ref_images/1.png filter=lfs diff=lfs merge=lfs -text
52
+ assets/ref_images/10.png filter=lfs diff=lfs merge=lfs -text
53
+ assets/ref_images/11.png filter=lfs diff=lfs merge=lfs -text
54
+ assets/ref_images/12.png filter=lfs diff=lfs merge=lfs -text
55
+ assets/ref_images/13.png filter=lfs diff=lfs merge=lfs -text
56
+ assets/ref_images/14.png filter=lfs diff=lfs merge=lfs -text
57
+ assets/ref_images/15.png filter=lfs diff=lfs merge=lfs -text
58
+ assets/ref_images/16.png filter=lfs diff=lfs merge=lfs -text
59
+ assets/ref_images/17.png filter=lfs diff=lfs merge=lfs -text
60
+ assets/ref_images/18.png filter=lfs diff=lfs merge=lfs -text
61
+ assets/ref_images/19.png filter=lfs diff=lfs merge=lfs -text
62
+ assets/ref_images/2.png filter=lfs diff=lfs merge=lfs -text
63
+ assets/ref_images/3.png filter=lfs diff=lfs merge=lfs -text
64
+ assets/ref_images/4.png filter=lfs diff=lfs merge=lfs -text
65
+ assets/ref_images/5.png filter=lfs diff=lfs merge=lfs -text
66
+ assets/ref_images/6.png filter=lfs diff=lfs merge=lfs -text
67
+ assets/ref_images/7.png filter=lfs diff=lfs merge=lfs -text
68
+ assets/ref_images/8.png filter=lfs diff=lfs merge=lfs -text
69
+ skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text
70
+ skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - zh
5
+ license: other
6
+ tasks:
7
+ - text-generation
8
+
9
+ ---
10
+
11
+ <!-- markdownlint-disable first-line-h1 -->
12
+ <!-- markdownlint-disable html -->
13
+
14
+ # <span id="Terms">声明与协议/Terms and Conditions</span>
15
+
16
+ ## 声明
17
+
18
+ 我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
19
+
20
+ 我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
21
+
22
+ We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
23
+
24
+ We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
25
+
26
+ ## 协议
27
+
28
+ 社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
29
+
30
+
31
+ The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
32
+
33
+
34
+
35
+ [《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
36
+
37
+
38
README.md CHANGED
@@ -1,14 +1,194 @@
1
- ---
2
- title: Skyreels Talking Head
3
- emoji: 😻
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.20.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: audio to talking face
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="assets/logo.png" alt="Skyreels Logo" width="50%">
3
+ </p>
4
+
5
+
6
+ <h1 align="center">SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers</h1>
7
+ <div align='center'>
8
+ <a href='https://scholar.google.com/citations?user=6D_nzucAAAAJ&hl=en' target='_blank'>Di Qiu</a>&emsp;
9
+ <a href='https://scholar.google.com/citations?user=_43YnBcAAAAJ&hl=zh-CN' target='_blank'>Zhengcong Fei</a>&emsp;
10
+ <a href='' target='_blank'>Rui Wang</a>&emsp;
11
+ <a href='' target='_blank'>Jialin Bai</a>&emsp;
12
+ <a href='https://scholar.google.com/citations?user=Hv-vj2sAAAAJ&hl=en' target='_blank'>Changqian Yu</a>&emsp;
13
+ </div>
14
+
15
+ <div align='center'>
16
+ <a href='https://scholar.google.com.au/citations?user=ePIeVuUAAAAJ&hl=en' target='_blank'>Mingyuan Fan</a>&emsp;
17
+ <a href='https://scholar.google.com/citations?user=HukWSw4AAAAJ&hl=en' target='_blank'>Guibin Chen</a>&emsp;
18
+ <a href='https://scholar.google.com.tw/citations?user=RvAuMk0AAAAJ&hl=zh-CN' target='_blank'>Xiang Wen</a>&emsp;
19
+ </div>
20
+
21
+ <div align='center'>
22
+ <small><strong>Skywork AI</strong></small>
23
+ </div>
24
+
25
+ <br>
26
+
27
+ <div align="center">
28
+ <!-- <a href='LICENSE'><img src='https://img.shields.io/badge/license-MIT-yellow'></a> -->
29
+ <a href='https://arxiv.org/abs/2502.10841'><img src='https://img.shields.io/badge/arXiv-SkyReels A1-red'></a>
30
+ <a href='https://skyworkai.github.io/skyreels-a1.github.io/'><img src='https://img.shields.io/badge/Project-SkyReels A1-green'></a>
31
+ <a href='https://huggingface.co/Skywork/SkyReels-A1'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue'></a>
32
+ <a href='https://www.skyreels.ai/home?utm_campaign=github_A1'><img src='https://img.shields.io/badge/Playground-Spaces-yellow'></a>
33
+ <br>
34
+ </div>
35
+ <br>
36
+
37
+
38
+ <p align="center">
39
+ <img src="./assets/demo.gif" alt="showcase">
40
+ <br>
41
+ 🔥 For more results, visit our <a href="https://skyworkai.github.io/skyreels-a1.github.io/"><strong>homepage</strong></a> 🔥
42
+ </p>
43
+
44
+ <p align="center">
45
+ 👋 Join our <a href="https://discord.gg/PwM6NYtccQ" target="_blank"><strong>Discord</strong></a>
46
+ </p>
47
+
48
+
49
+ This repo, named **SkyReels-A1**, contains the official PyTorch implementation of our paper [SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers](https://arxiv.org).
50
+
51
+
52
+ ## 🔥🔥🔥 News!!
53
+ * Mar 4, 2025: 🔥 We release audio-driven portrait image animation pipeline.
54
+ * Feb 18, 2025: 👋 We release the inference code and model weights of SkyReels-A1. [Download](https://huggingface.co/Skywork/SkyReels-A1)
55
+ * Feb 18, 2025: 🎉 We have made our technical report available as open source. [Read](https://skyworkai.github.io/skyreels-a1.github.io/report.pdf)
56
+ * Feb 18, 2025: 🔥 Our online demo of LipSync is available on SkyReels now! Try out [LipSync](https://www.skyreels.ai/home/tools/lip-sync?refer=navbar).
57
+ * Feb 18, 2025: 🔥 We have open-sourced I2V video generation model [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
58
+
59
+ ## 📑 TODO List
60
+ - [x] Checkpoints
61
+ - [x] Inference Code
62
+ - [x] Web Demo (Gradio)
63
+ - [x] Audio-driven Portrait Image Animation Pipeline
64
+ - [ ] Inference Code for Long Videos
65
+ - [ ] User-Level GPU Inference on RTX4090
66
+ - [ ] ComfyUI
67
+
68
+
69
+ ## Getting Started 🏁
70
+
71
+ ### 1. Clone the code and prepare the environment 🛠️
72
+ First git clone the repository with code:
73
+ ```bash
74
+ git clone https://github.com/SkyworkAI/SkyReels-A1.git
75
+ cd SkyReels-A1
76
+
77
+ # create env using conda
78
+ conda create -n skyreels-a1 python=3.10
79
+ conda activate skyreels-a1
80
+ ```
81
+ Then, install the remaining dependencies:
82
+ ```bash
83
+ pip install -r requirements.txt
84
+ ```
85
+
86
+
87
+ ### 2. Download pretrained weights 📥
88
+ You can download the pretrained weights is from HuggingFace:
89
+ ```bash
90
+ # !pip install -U "huggingface_hub[cli]"
91
+ huggingface-cli download SkyReels-A1 --local-dir local_path --exclude "*.git*" "README.md" "docs"
92
+ ```
93
+
94
+ The FLAME, mediapipe, and smirk models are located in the SkyReels-A1/extra_models folder.
95
+
96
+ The directory structure of our SkyReels-A1 code is formulated as:
97
+ ```text
98
+ pretrained_models
99
+ ├── FLAME
100
+ ├── SkyReels-A1-5B
101
+ │ ├── pose_guider
102
+ │ ├── scheduler
103
+ │ ├── tokenizer
104
+ │ ├── siglip-so400m-patch14-384
105
+ │ ├── transformer
106
+ │ ├── vae
107
+ │ └── text_encoder
108
+ ├── mediapipe
109
+ └── smirk
110
+
111
+ ```
112
+
113
+ #### Download DiffposeTalk assets and pretrained weights (For Audio-driven)
114
+
115
+ - We use [diffposetalk](https://github.com/DiffPoseTalk/DiffPoseTalk/tree/main) to generate flame coefficients from audio, thereby constructing motion signals.
116
+
117
+ - Download the diffposetalk code and follow its README to download the weights and related data.
118
+
119
+ - Then place them in the specified directory.
120
+
121
+ ```bash
122
+ cp -r ${diffposetalk_root}/style pretrained_models/diffposetalk
123
+ cp ${diffposetalk_root}/experiments/DPT/head-SA-hubert-WM/checkpoints/iter_0110000.pt pretrained_models/diffposetalk
124
+ cp ${diffposetalk_root}/datasets/HDTF_TFHP/lmdb/stats_train.npz pretrained_models/diffposetalk
125
+ ```
126
+
127
+ ```text
128
+ pretrained_models
129
+ ├── FLAME
130
+ ├── SkyReels-A1-5B
131
+ ├── mediapipe
132
+ ├── diffposetalk
133
+ │ ├── style
134
+ │ ├── iter_0110000.pt
135
+ │ ├── states_train.npz
136
+ └── smirk
137
+
138
+ ```
139
+
140
+
141
+ ### 3. Inference 🚀
142
+ You can simply run the inference scripts as:
143
+ ```bash
144
+ python inference.py
145
+
146
+ # inference audio to video
147
+ python inference_audio.py
148
+ ```
149
+
150
+ If the script runs successfully, you will get an output mp4 file. This file includes the following results: driving video, input image or video, and generated result.
151
+
152
+
153
+ ## Gradio Interface 🤗
154
+
155
+ We provide a [Gradio](https://huggingface.co/docs/hub/spaces-sdks-gradio) interface for a better experience, just run by:
156
+
157
+ ```bash
158
+ python app.py
159
+ ```
160
+
161
+ The graphical interactive interface is shown as below:
162
+
163
+ ![gradio](https://github.com/user-attachments/assets/ed56f08c-f31c-4fbe-ac1d-c4d4e87a8719)
164
+
165
+
166
+ ## Metric Evaluation 👓
167
+
168
+ We also provide all scripts for automatically calculating the metrics, including SimFace, FID, and L1 distance between expression and motion, reported in the paper.
169
+
170
+ All codes can be found in the ```eval``` folder. After setting the video result path, run the following commands in sequence:
171
+
172
+ ```bash
173
+ python arc_score.py
174
+ python expression_score.py
175
+ python pose_score.py
176
+ ```
177
+
178
+
179
+ ## Acknowledgements 💐
180
+ We would like to thank the contributors of [CogvideoX](https://github.com/THUDM/CogVideo), [finetrainers](https://github.com/a-r-r-o-w/finetrainers) and [DiffPoseTalk](https://github.com/DiffPoseTalk/DiffPoseTalk)repositories, for their open research and contributions.
181
+
182
+ ## Citation 💖
183
+ If you find SkyReels-A1 useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
184
+ ```bibtex
185
+ @article{qiu2025skyreels,
186
+ title={SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers},
187
+ author={Qiu, Di and Fei, Zhengcong and Wang, Rui and Bai, Jialin and Yu, Changqian and Fan, Mingyuan and Chen, Guibin and Wen, Xiang},
188
+ journal={arXiv preprint arXiv:2502.10841},
189
+ year={2025}
190
+ }
191
+ ```
192
+
193
+
194
+
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/demo.gif ADDED

Git LFS Details

  • SHA256: 1b8c13b7c718a9e2645dd4490dfe645a880781121b7d207ed43cc7cd3d0a35e4
  • Pointer size: 132 Bytes
  • Size of remote file: 3.08 MB
assets/driving_audio/1.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38dab65a002455f4c186d4b0bde848c964415441d9636d94a48d5b32f23b0f6f
3
+ size 575850
assets/driving_audio/2.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:15b998fd3fbabf22e9fde210f93df4f7b1647e19fe81b2d33b2b74470fea32b5
3
+ size 3891278
assets/driving_audio/3.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:448576f545e18c4cb33cb934a00c0bc3f331896eba4d1cb6a077c1b9382d0628
3
+ size 910770
assets/driving_audio/4.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a3e0d37fc6e235b5a09eb4b7e3b0b5d5f566e2204c5c815c23ca2215dcbf9c93
3
+ size 553038
assets/driving_audio/5.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95275c9299919e38e52789cb3af17ddc4691b7afea82f26c7edc640addce057d
3
+ size 856142
assets/driving_audio/6.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7558a976b8b64d214f33c65e503c77271d3be0cd116a00ddadcb2b2fc53a6396
3
+ size 2641742
assets/driving_video/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/driving_video/1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b7da4f10cf9e692ba8c75848bacceb3c4d30ee8d3b07719435560c44a8da6544
3
+ size 306996
assets/driving_video/2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e795b7be655c4b5ae8cac0733a32e8d321ccebd13f2cac07cc15dfc8f61a547
3
+ size 2875843
assets/driving_video/3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02f5ee85c1028c9673c70682b533a4f22e203173eddd40de42bad0cb57f18abb
3
+ size 1020948
assets/driving_video/4.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f7ddbb17b198a580f658d57f4d83bee7489aa4d8a677f2c45b76b1ec01ae461
3
+ size 215144
assets/driving_video/5.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9637fea5ef83b494a0aa8b7c526ae1efc6ec94d79dfa94381de8d6f38eec238e
3
+ size 556047
assets/driving_video/6.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac7ee3c2419046f11dc230b6db33c2391a98334eba2b1d773e7eb9627992622f
3
+ size 1064930
assets/driving_video/7.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dc94c1fec7ef7dc831c8a49f0e1788ae568812cb68e62f6875d9070f573d02a
3
+ size 187263
assets/driving_video/8.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3047ba66296d96b8a4584e412e61493d7bc0fa5149c77b130e7feea375e698bd
3
+ size 232859
assets/logo.png ADDED
assets/ref_images/1.png ADDED

Git LFS Details

  • SHA256: 93429c6e7408723b04f3681cc06ac98072f8ce4fd69476ee612466a335ca152c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
assets/ref_images/10.png ADDED

Git LFS Details

  • SHA256: ef7456fd5eb3b31584f0933d1b71c25f92a8a9cb428466c1c4daf4eede2db9d3
  • Pointer size: 131 Bytes
  • Size of remote file: 508 kB
assets/ref_images/11.png ADDED

Git LFS Details

  • SHA256: 99bfeeecbefa2bf408d0f15688ae89fa4c71f881d78baced1591bef128367efc
  • Pointer size: 131 Bytes
  • Size of remote file: 634 kB
assets/ref_images/12.png ADDED

Git LFS Details

  • SHA256: c258cba0979585f3fac4d63d9ca0fc3e51604afde62a272f069146ae43d1a996
  • Pointer size: 131 Bytes
  • Size of remote file: 793 kB
assets/ref_images/13.png ADDED

Git LFS Details

  • SHA256: c31191bc70144def9c0de388483d0a9257b0e4eb72128474232bbaa234f5a0a5
  • Pointer size: 131 Bytes
  • Size of remote file: 633 kB
assets/ref_images/14.png ADDED

Git LFS Details

  • SHA256: 8058fc784284c59f1954269638f1ad937ac35cf58563b935736d3f34e6355045
  • Pointer size: 131 Bytes
  • Size of remote file: 517 kB
assets/ref_images/15.png ADDED

Git LFS Details

  • SHA256: 4c3e49512a2253b2a7291ad6b1636521e66b10050dba37a0b9d47c9a5666fb61
  • Pointer size: 131 Bytes
  • Size of remote file: 641 kB
assets/ref_images/16.png ADDED

Git LFS Details

  • SHA256: 5e65a2f40f5f971b0e91023e774ce8aff56a1da723c1f8ffdfc5ec616690cde2
  • Pointer size: 131 Bytes
  • Size of remote file: 392 kB
assets/ref_images/17.png ADDED

Git LFS Details

  • SHA256: 202b2a66e87de425c55e223554942da71a4b0a27757bc2f90ec4c8d51133934b
  • Pointer size: 131 Bytes
  • Size of remote file: 750 kB
assets/ref_images/18.png ADDED

Git LFS Details

  • SHA256: 06a756c3e0a0b5d786428b0968126281c292e1df2c286cb683bac059821c0122
  • Pointer size: 131 Bytes
  • Size of remote file: 183 kB
assets/ref_images/19.png ADDED

Git LFS Details

  • SHA256: 277fc3ecf3c0299f87cb59d056c5d484feb2fa7897c9d0f80ee0854eba2c3487
  • Pointer size: 131 Bytes
  • Size of remote file: 283 kB
assets/ref_images/2.png ADDED

Git LFS Details

  • SHA256: 5c972790d52fc6adf7e5bcb4611720570e260f56c52f063acfea5e4d2f52c07f
  • Pointer size: 131 Bytes
  • Size of remote file: 762 kB
assets/ref_images/20.png ADDED
assets/ref_images/3.png ADDED

Git LFS Details

  • SHA256: bce73675d41349d0792e9903d08ad12280d0e1b3af21e686720a7dac5dcaa649
  • Pointer size: 131 Bytes
  • Size of remote file: 737 kB
assets/ref_images/4.png ADDED

Git LFS Details

  • SHA256: 03ff23c5be3ff225969ddd97a26971bab40af4cc6012f0f859971a12cd8e9003
  • Pointer size: 131 Bytes
  • Size of remote file: 348 kB
assets/ref_images/5.png ADDED

Git LFS Details

  • SHA256: 6b9c2279c99ef4f354fa9e2ea8f1751e8f35ed2ed937e5a2b0b3c918fb49f947
  • Pointer size: 131 Bytes
  • Size of remote file: 375 kB
assets/ref_images/6.png ADDED

Git LFS Details

  • SHA256: d127961dece864d4000351c1c14a71d3c1bc54c51c2cce6d9dd1c74bdea0ec4c
  • Pointer size: 131 Bytes
  • Size of remote file: 370 kB
assets/ref_images/7.png ADDED

Git LFS Details

  • SHA256: c1e2c11b7f9832b2acbf454065b2beebf95f6817f623ee1fe56ff2fafc0caf1d
  • Pointer size: 131 Bytes
  • Size of remote file: 542 kB
assets/ref_images/8.png ADDED

Git LFS Details

  • SHA256: 8c8aa92c1bea3f5f0b1b3859b35ed801fc4022f064b3ebba09e621157a2ac4c6
  • Pointer size: 131 Bytes
  • Size of remote file: 358 kB
diffposetalk/common.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class PositionalEncoding(nn.Module):
9
+ def __init__(self, d_model, dropout=0.1, max_len=600):
10
+ super().__init__()
11
+ self.dropout = nn.Dropout(p=dropout)
12
+ # vanilla sinusoidal encoding
13
+ pe = torch.zeros(max_len, d_model)
14
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
15
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
16
+ pe[:, 0::2] = torch.sin(position * div_term)
17
+ pe[:, 1::2] = torch.cos(position * div_term)
18
+ pe = pe.unsqueeze(0)
19
+ self.register_buffer('pe', pe)
20
+
21
+ def forward(self, x):
22
+ x = x + self.pe[:, x.shape[1], :]
23
+ return self.dropout(x)
24
+
25
+
26
+ def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
27
+ mask = torch.ones(T, S)
28
+ for i in range(T):
29
+ mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
30
+ return (mask == 1).to(device=device)
31
+
32
+
33
+ def pad_audio(audio, audio_unit=320, pad_threshold=80):
34
+ batch_size, audio_len = audio.shape
35
+ n_units = audio_len // audio_unit
36
+ side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
37
+ if side_len >= 0:
38
+ reflect_len = side_len // 2
39
+ replicate_len = side_len % 2
40
+ if reflect_len > 0:
41
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
42
+ audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
43
+ if replicate_len > 0:
44
+ audio = F.pad(audio, (1, 1), mode='replicate')
45
+
46
+ return audio
diffposetalk/diff_talking_head.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .common import PositionalEncoding, enc_dec_mask, pad_audio
6
+
7
+
8
+ class DiffusionSchedule(nn.Module):
9
+ def __init__(self, num_steps, mode='linear', beta_1=1e-4, beta_T=0.02, s=0.008):
10
+ super().__init__()
11
+
12
+ if mode == 'linear':
13
+ betas = torch.linspace(beta_1, beta_T, num_steps)
14
+ elif mode == 'quadratic':
15
+ betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
16
+ elif mode == 'sigmoid':
17
+ betas = torch.sigmoid(torch.linspace(-5, 5, num_steps)) * (beta_T - beta_1) + beta_1
18
+ elif mode == 'cosine':
19
+ steps = num_steps + 1
20
+ x = torch.linspace(0, num_steps, steps)
21
+ alpha_bars = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
22
+ alpha_bars = alpha_bars / alpha_bars[0]
23
+ betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
24
+ betas = torch.clip(betas, 0.0001, 0.999)
25
+ else:
26
+ raise ValueError(f'Unknown diffusion schedule {mode}!')
27
+ betas = torch.cat([torch.zeros(1), betas], dim=0) # Padding beta_0 = 0
28
+
29
+ alphas = 1 - betas
30
+ log_alphas = torch.log(alphas)
31
+ for i in range(1, log_alphas.shape[0]): # 1 to T
32
+ log_alphas[i] += log_alphas[i - 1]
33
+ alpha_bars = log_alphas.exp()
34
+
35
+ sigmas_flex = torch.sqrt(betas)
36
+ sigmas_inflex = torch.zeros_like(sigmas_flex)
37
+ for i in range(1, sigmas_flex.shape[0]):
38
+ sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[i]
39
+ sigmas_inflex = torch.sqrt(sigmas_inflex)
40
+
41
+ self.num_steps = num_steps
42
+ self.register_buffer('betas', betas)
43
+ self.register_buffer('alphas', alphas)
44
+ self.register_buffer('alpha_bars', alpha_bars)
45
+ self.register_buffer('sigmas_flex', sigmas_flex)
46
+ self.register_buffer('sigmas_inflex', sigmas_inflex)
47
+
48
+ def uniform_sample_t(self, batch_size):
49
+ ts = torch.randint(1, self.num_steps + 1, (batch_size,))
50
+ return ts.tolist()
51
+
52
+ def get_sigmas(self, t, flexibility=0):
53
+ assert 0 <= flexibility <= 1
54
+ sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
55
+ return sigmas
56
+
57
+
58
+ class DiffTalkingHead(nn.Module):
59
+ def __init__(self, args, device='cuda'):
60
+ super().__init__()
61
+
62
+ # Model parameters
63
+ self.target = args.target
64
+ self.architecture = args.architecture
65
+ self.use_style = args.style_enc_ckpt is not None
66
+
67
+ self.motion_feat_dim = 50
68
+ if args.rot_repr == 'aa':
69
+ self.motion_feat_dim += 1 if args.no_head_pose else 4
70
+ else:
71
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
72
+
73
+ self.fps = args.fps
74
+ self.n_motions = args.n_motions
75
+ self.n_prev_motions = args.n_prev_motions
76
+ if self.use_style:
77
+ self.style_feat_dim = args.d_style
78
+
79
+ # Audio encoder
80
+ self.audio_model = args.audio_model
81
+ if self.audio_model == 'wav2vec2':
82
+ from .wav2vec2 import Wav2Vec2Model
83
+ self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
84
+ # wav2vec 2.0 weights initialization
85
+ self.audio_encoder.feature_extractor._freeze_parameters()
86
+ elif self.audio_model == 'hubert':
87
+ from .hubert import HubertModel
88
+ self.audio_encoder = HubertModel.from_pretrained('facebook/hubert-base-ls960')
89
+ self.audio_encoder.feature_extractor._freeze_parameters()
90
+
91
+ frozen_layers = [0, 1]
92
+ for name, param in self.audio_encoder.named_parameters():
93
+ if name.startswith("feature_projection"):
94
+ param.requires_grad = False
95
+ if name.startswith("encoder.layers"):
96
+ layer = int(name.split(".")[2])
97
+ if layer in frozen_layers:
98
+ param.requires_grad = False
99
+ else:
100
+ raise ValueError(f'Unknown audio model {self.audio_model}!')
101
+
102
+ if args.architecture == 'decoder':
103
+ self.audio_feature_map = nn.Linear(768, args.feature_dim)
104
+ self.start_audio_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, args.feature_dim))
105
+ else:
106
+ raise ValueError(f'Unknown architecture {args.architecture}!')
107
+
108
+ self.start_motion_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, self.motion_feat_dim))
109
+
110
+ # Diffusion model
111
+ self.denoising_net = DenoisingNetwork(args, device)
112
+ # diffusion schedule
113
+ self.diffusion_sched = DiffusionSchedule(args.n_diff_steps, args.diff_schedule)
114
+
115
+ # Classifier-free settings
116
+ self.cfg_mode = args.cfg_mode
117
+ guiding_conditions = args.guiding_conditions.split(',') if args.guiding_conditions else []
118
+ self.guiding_conditions = [cond for cond in guiding_conditions if cond in ['style', 'audio']]
119
+ if 'style' in self.guiding_conditions:
120
+ if not self.use_style:
121
+ raise ValueError('Cannot use style guiding without enabling it!')
122
+ self.null_style_feat = nn.Parameter(torch.randn(1, 1, self.style_feat_dim))
123
+ if 'audio' in self.guiding_conditions:
124
+ audio_feat_dim = args.feature_dim
125
+ self.null_audio_feat = nn.Parameter(torch.randn(1, 1, audio_feat_dim))
126
+
127
+ self.to(device)
128
+
129
+ @property
130
+ def device(self):
131
+ return next(self.parameters()).device
132
+
133
+ def forward(self, motion_feat, audio_or_feat, shape_feat, style_feat=None,
134
+ prev_motion_feat=None, prev_audio_feat=None, time_step=None, indicator=None):
135
+ """
136
+ Args:
137
+ motion_feat: (N, L, d_coef) motion coefficients or features
138
+ audio_or_feat: (N, L_audio) raw audio or audio feature
139
+ shape_feat: (N, d_shape) or (N, 1, d_shape)
140
+ style_feat: (N, d_style)
141
+ prev_motion_feat: (N, n_prev_motions, d_motion) previous motion coefficients or feature
142
+ prev_audio_feat: (N, n_prev_motions, d_audio) previous audio features
143
+ time_step: (N,)
144
+ indicator: (N, L) 0/1 indicator of real (unpadded) motion coefficients
145
+
146
+ Returns:
147
+ motion_feat_noise: (N, L, d_motion)
148
+ """
149
+ if self.use_style:
150
+ assert style_feat is not None, 'Missing style features!'
151
+
152
+ batch_size = motion_feat.shape[0]
153
+
154
+ if audio_or_feat.ndim == 2:
155
+ # Extract audio features
156
+ assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
157
+ f'Incorrect audio length {audio_or_feat.shape[1]}'
158
+ audio_feat_saved = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
159
+ elif audio_or_feat.ndim == 3:
160
+ assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
161
+ audio_feat_saved = audio_or_feat
162
+ else:
163
+ raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
164
+ audio_feat = audio_feat_saved.clone()
165
+
166
+ if shape_feat.ndim == 2:
167
+ shape_feat = shape_feat.unsqueeze(1) # (N, 1, d_shape)
168
+ if style_feat is not None and style_feat.ndim == 2:
169
+ style_feat = style_feat.unsqueeze(1) # (N, 1, d_style)
170
+
171
+ if prev_motion_feat is None:
172
+ prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
173
+ if prev_audio_feat is None:
174
+ # (N, n_prev_motions, feature_dim)
175
+ prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
176
+
177
+ # Classifier-free guidance
178
+ if len(self.guiding_conditions) > 0:
179
+ assert len(self.guiding_conditions) <= 2, 'Only support 1 or 2 CFG conditions!'
180
+ if len(self.guiding_conditions) == 1 or self.cfg_mode == 'independent':
181
+ null_cond_prob = 0.5 if len(self.guiding_conditions) >= 2 else 0.1
182
+ if 'style' in self.guiding_conditions:
183
+ mask_style = torch.rand(batch_size, device=self.device) < null_cond_prob
184
+ style_feat = torch.where(mask_style.view(-1, 1, 1),
185
+ self.null_style_feat.expand(batch_size, -1, -1),
186
+ style_feat)
187
+ if 'audio' in self.guiding_conditions:
188
+ mask_audio = torch.rand(batch_size, device=self.device) < null_cond_prob
189
+ audio_feat = torch.where(mask_audio.view(-1, 1, 1),
190
+ self.null_audio_feat.expand(batch_size, self.n_motions, -1),
191
+ audio_feat)
192
+ else:
193
+ # len(self.guiding_conditions) > 1 and self.cfg_mode == 'incremental'
194
+ # full (0.45), w/o style (0.45), w/o style or audio (0.1)
195
+ mask_flag = torch.rand(batch_size, device=self.device)
196
+ if 'style' in self.guiding_conditions:
197
+ mask_style = mask_flag > 0.55
198
+ style_feat = torch.where(mask_style.view(-1, 1, 1),
199
+ self.null_style_feat.expand(batch_size, -1, -1),
200
+ style_feat)
201
+ if 'audio' in self.guiding_conditions:
202
+ mask_audio = mask_flag > 0.9
203
+ audio_feat = torch.where(mask_audio.view(-1, 1, 1),
204
+ self.null_audio_feat.expand(batch_size, self.n_motions, -1),
205
+ audio_feat)
206
+
207
+ if style_feat is None:
208
+ # The model only accepts audio and shape features, i.e., self.use_style = False
209
+ person_feat = shape_feat
210
+ else:
211
+ person_feat = torch.cat([shape_feat, style_feat], dim=-1)
212
+
213
+ if time_step is None:
214
+ # Sample time step
215
+ time_step = self.diffusion_sched.uniform_sample_t(batch_size) # (N,)
216
+
217
+ # The forward diffusion process
218
+ alpha_bar = self.diffusion_sched.alpha_bars[time_step] # (N,)
219
+ c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (N, 1, 1)
220
+ c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (N, 1, 1)
221
+
222
+ eps = torch.randn_like(motion_feat) # (N, L, d_motion)
223
+ motion_feat_noisy = c0 * motion_feat + c1 * eps
224
+
225
+ # The reverse diffusion process
226
+ motion_feat_target = self.denoising_net(motion_feat_noisy, audio_feat, person_feat,
227
+ prev_motion_feat, prev_audio_feat, time_step, indicator)
228
+
229
+ return eps, motion_feat_target, motion_feat.detach(), audio_feat_saved.detach()
230
+
231
+ def extract_audio_feature(self, audio, frame_num=None):
232
+ frame_num = frame_num or self.n_motions
233
+
234
+ # # Strategy 1: resample during audio feature extraction
235
+ # hidden_states = self.audio_encoder(pad_audio(audio), self.fps, frame_num=frame_num).last_hidden_state # (N, L, 768)
236
+
237
+ # Strategy 2: resample after audio feature extraction (BackResample)
238
+ hidden_states = self.audio_encoder(pad_audio(audio), self.fps,
239
+ frame_num=frame_num * 2).last_hidden_state # (N, 2L, 768)
240
+ hidden_states = hidden_states.transpose(1, 2) # (N, 768, 2L)
241
+ hidden_states = F.interpolate(hidden_states, size=frame_num, align_corners=False, mode='linear') # (N, 768, L)
242
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, 768)
243
+
244
+ audio_feat = self.audio_feature_map(hidden_states) # (N, L, feature_dim)
245
+ return audio_feat
246
+
247
+ @torch.no_grad()
248
+ def sample(self, audio_or_feat, shape_feat, style_feat=None, prev_motion_feat=None, prev_audio_feat=None,
249
+ motion_at_T=None, indicator=None, cfg_mode=None, cfg_cond=None, cfg_scale=1.15, flexibility=0,
250
+ dynamic_threshold=None, ret_traj=False):
251
+ # Check and convert inputs
252
+ batch_size = audio_or_feat.shape[0]
253
+
254
+ # Check CFG conditions
255
+ if cfg_mode is None: # Use default CFG mode
256
+ cfg_mode = self.cfg_mode
257
+ if cfg_cond is None: # Use default CFG conditions
258
+ cfg_cond = self.guiding_conditions
259
+ cfg_cond = [c for c in cfg_cond if c in ['audio', 'style']]
260
+
261
+ if not isinstance(cfg_scale, list):
262
+ cfg_scale = [cfg_scale] * len(cfg_cond)
263
+
264
+ # sort cfg_cond and cfg_scale
265
+ if len(cfg_cond) > 0:
266
+ cfg_cond, cfg_scale = zip(*sorted(zip(cfg_cond, cfg_scale), key=lambda x: ['audio', 'style'].index(x[0])))
267
+ else:
268
+ cfg_cond, cfg_scale = [], []
269
+
270
+ if 'style' in cfg_cond:
271
+ assert self.use_style and style_feat is not None
272
+
273
+ if self.use_style:
274
+ if style_feat is None: # use null style feature
275
+ style_feat = self.null_style_feat.expand(batch_size, -1, -1)
276
+ else:
277
+ assert style_feat is None, 'This model does not support style feature input!'
278
+
279
+ if audio_or_feat.ndim == 2:
280
+ # Extract audio features
281
+ assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
282
+ f'Incorrect audio length {audio_or_feat.shape[1]}'
283
+ audio_feat = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
284
+ elif audio_or_feat.ndim == 3:
285
+ assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
286
+ audio_feat = audio_or_feat
287
+ else:
288
+ raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
289
+
290
+ if shape_feat.ndim == 2:
291
+ shape_feat = shape_feat.unsqueeze(1) # (N, 1, d_shape)
292
+ if style_feat is not None and style_feat.ndim == 2:
293
+ style_feat = style_feat.unsqueeze(1) # (N, 1, d_style)
294
+
295
+ if prev_motion_feat is None:
296
+ prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
297
+ if prev_audio_feat is None:
298
+ # (N, n_prev_motions, feature_dim)
299
+ prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
300
+
301
+ if motion_at_T is None:
302
+ motion_at_T = torch.randn((batch_size, self.n_motions, self.motion_feat_dim)).to(self.device)
303
+
304
+ # Prepare input for the reverse diffusion process (including optional classifier-free guidance)
305
+ if 'audio' in cfg_cond:
306
+ audio_feat_null = self.null_audio_feat.expand(batch_size, self.n_motions, -1)
307
+ else:
308
+ audio_feat_null = audio_feat
309
+
310
+ if 'style' in cfg_cond:
311
+ person_feat_null = torch.cat([shape_feat, self.null_style_feat.expand(batch_size, -1, -1)], dim=-1)
312
+ else:
313
+ if self.use_style:
314
+ person_feat_null = torch.cat([shape_feat, style_feat], dim=-1)
315
+ else:
316
+ person_feat_null = shape_feat
317
+
318
+ audio_feat_in = [audio_feat_null]
319
+ person_feat_in = [person_feat_null]
320
+ for cond in cfg_cond:
321
+ if cond == 'audio':
322
+ audio_feat_in.append(audio_feat)
323
+ person_feat_in.append(person_feat_null)
324
+ elif cond == 'style':
325
+ if cfg_mode == 'independent':
326
+ audio_feat_in.append(audio_feat_null)
327
+ elif cfg_mode == 'incremental':
328
+ audio_feat_in.append(audio_feat)
329
+ else:
330
+ raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
331
+ person_feat_in.append(torch.cat([shape_feat, style_feat], dim=-1))
332
+
333
+ n_entries = len(audio_feat_in)
334
+ audio_feat_in = torch.cat(audio_feat_in, dim=0)
335
+ person_feat_in = torch.cat(person_feat_in, dim=0)
336
+ prev_motion_feat_in = torch.cat([prev_motion_feat] * n_entries, dim=0)
337
+ prev_audio_feat_in = torch.cat([prev_audio_feat] * n_entries, dim=0)
338
+ indicator_in = torch.cat([indicator] * n_entries, dim=0) if indicator is not None else None
339
+
340
+ traj = {self.diffusion_sched.num_steps: motion_at_T}
341
+ for t in range(self.diffusion_sched.num_steps, 0, -1):
342
+ if t > 1:
343
+ z = torch.randn_like(motion_at_T)
344
+ else:
345
+ z = torch.zeros_like(motion_at_T)
346
+
347
+ alpha = self.diffusion_sched.alphas[t]
348
+ alpha_bar = self.diffusion_sched.alpha_bars[t]
349
+ alpha_bar_prev = self.diffusion_sched.alpha_bars[t - 1]
350
+ sigma = self.diffusion_sched.get_sigmas(t, flexibility)
351
+
352
+ motion_at_t = traj[t]
353
+ motion_in = torch.cat([motion_at_t] * n_entries, dim=0)
354
+ step_in = torch.tensor([t] * batch_size, device=self.device)
355
+ step_in = torch.cat([step_in] * n_entries, dim=0)
356
+
357
+ results = self.denoising_net(motion_in, audio_feat_in, person_feat_in, prev_motion_feat_in,
358
+ prev_audio_feat_in, step_in, indicator_in)
359
+
360
+ # Apply thresholding if specified
361
+ if dynamic_threshold:
362
+ dt_ratio, dt_min, dt_max = dynamic_threshold
363
+ abs_results = results[:, -self.n_motions:].reshape(batch_size * n_entries, -1).abs()
364
+ s = torch.quantile(abs_results, dt_ratio, dim=1)
365
+ s = torch.clamp(s, min=dt_min, max=dt_max)
366
+ s = s[..., None, None]
367
+ results = torch.clamp(results, min=-s, max=s)
368
+
369
+ results = results.chunk(n_entries)
370
+
371
+ # Unconditional target (CFG) or the conditional target (non-CFG)
372
+ target_theta = results[0][:, -self.n_motions:]
373
+ # Classifier-free Guidance (optional)
374
+ for i in range(0, n_entries - 1):
375
+ if cfg_mode == 'independent':
376
+ target_theta += cfg_scale[i] * (
377
+ results[i + 1][:, -self.n_motions:] - results[0][:, -self.n_motions:])
378
+ elif cfg_mode == 'incremental':
379
+ target_theta += cfg_scale[i] * (
380
+ results[i + 1][:, -self.n_motions:] - results[i][:, -self.n_motions:])
381
+ else:
382
+ raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
383
+
384
+ if self.target == 'noise':
385
+ c0 = 1 / torch.sqrt(alpha)
386
+ c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
387
+ motion_next = c0 * (motion_at_t - c1 * target_theta) + sigma * z
388
+ elif self.target == 'sample':
389
+ c0 = (1 - alpha_bar_prev) * torch.sqrt(alpha) / (1 - alpha_bar)
390
+ c1 = (1 - alpha) * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar)
391
+ motion_next = c0 * motion_at_t + c1 * target_theta + sigma * z
392
+ else:
393
+ raise ValueError('Unknown target type: {}'.format(self.target))
394
+
395
+ traj[t - 1] = motion_next.detach() # Stop gradient and save trajectory.
396
+ traj[t] = traj[t].cpu() # Move previous output to CPU memory.
397
+ if not ret_traj:
398
+ del traj[t]
399
+
400
+ if ret_traj:
401
+ return traj, motion_at_T, audio_feat
402
+ else:
403
+ return traj[0], motion_at_T, audio_feat
404
+
405
+
406
+ class DenoisingNetwork(nn.Module):
407
+ def __init__(self, args, device='cuda'):
408
+ super().__init__()
409
+
410
+ # Model parameters
411
+ self.use_style = args.style_enc_ckpt is not None
412
+ self.motion_feat_dim = 50
413
+ if args.rot_repr == 'aa':
414
+ self.motion_feat_dim += 1 if args.no_head_pose else 4
415
+ else:
416
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
417
+ self.shape_feat_dim = 100
418
+ if self.use_style:
419
+ self.style_feat_dim = args.d_style
420
+ self.person_feat_dim = self.shape_feat_dim + self.style_feat_dim
421
+ else:
422
+ self.person_feat_dim = self.shape_feat_dim
423
+ self.use_indicator = args.use_indicator
424
+
425
+ # Transformer
426
+ self.architecture = args.architecture
427
+ self.feature_dim = args.feature_dim
428
+ self.n_heads = args.n_heads
429
+ self.n_layers = args.n_layers
430
+ self.mlp_ratio = args.mlp_ratio
431
+ self.align_mask_width = args.align_mask_width
432
+ self.use_learnable_pe = not args.no_use_learnable_pe
433
+ # sequence length
434
+ self.n_prev_motions = args.n_prev_motions
435
+ self.n_motions = args.n_motions
436
+
437
+ # Temporal embedding for the diffusion time step
438
+ self.TE = PositionalEncoding(self.feature_dim, max_len=args.n_diff_steps + 1)
439
+ self.diff_step_map = nn.Sequential(
440
+ nn.Linear(self.feature_dim, self.feature_dim),
441
+ nn.GELU(),
442
+ nn.Linear(self.feature_dim, self.feature_dim)
443
+ )
444
+
445
+ if self.use_learnable_pe:
446
+ # Learnable positional encoding
447
+ self.PE = nn.Parameter(torch.randn(1, 1 + self.n_prev_motions + self.n_motions, self.feature_dim))
448
+ else:
449
+ self.PE = PositionalEncoding(self.feature_dim)
450
+
451
+ self.person_proj = nn.Linear(self.person_feat_dim, self.feature_dim)
452
+
453
+ # Transformer decoder
454
+ if self.architecture == 'decoder':
455
+ self.feature_proj = nn.Linear(self.motion_feat_dim + (1 if self.use_indicator else 0),
456
+ self.feature_dim)
457
+ decoder_layer = nn.TransformerDecoderLayer(
458
+ d_model=self.feature_dim, nhead=self.n_heads, dim_feedforward=self.mlp_ratio * self.feature_dim,
459
+ activation='gelu', batch_first=True
460
+ )
461
+ self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=self.n_layers)
462
+ if self.align_mask_width > 0:
463
+ motion_len = self.n_prev_motions + self.n_motions
464
+ alignment_mask = enc_dec_mask(motion_len, motion_len, 1, self.align_mask_width - 1)
465
+ alignment_mask = F.pad(alignment_mask, (0, 0, 1, 0), value=False)
466
+ self.register_buffer('alignment_mask', alignment_mask)
467
+ else:
468
+ self.alignment_mask = None
469
+ else:
470
+ raise ValueError(f'Unknown architecture: {self.architecture}')
471
+
472
+ # Motion decoder
473
+ self.motion_dec = nn.Sequential(
474
+ nn.Linear(self.feature_dim, self.feature_dim // 2),
475
+ nn.GELU(),
476
+ nn.Linear(self.feature_dim // 2, self.motion_feat_dim)
477
+ )
478
+
479
+ self.to(device)
480
+
481
+ @property
482
+ def device(self):
483
+ return next(self.parameters()).device
484
+
485
+ def forward(self, motion_feat, audio_feat, person_feat, prev_motion_feat, prev_audio_feat, step, indicator=None):
486
+ """
487
+ Args:
488
+ motion_feat: (N, L, d_motion). Noisy motion feature
489
+ audio_feat: (N, L, feature_dim)
490
+ person_feat: (N, 1, d_person)
491
+ prev_motion_feat: (N, L_p, d_motion). Padded previous motion coefficients or feature
492
+ prev_audio_feat: (N, L_p, d_audio). Padded previous motion coefficients or feature
493
+ step: (N,)
494
+ indicator: (N, L). 0/1 indicator for the real (unpadded) motion feature
495
+
496
+ Returns:
497
+ motion_feat_target: (N, L_p + L, d_motion)
498
+ """
499
+ # Diffusion time step embedding
500
+ diff_step_embedding = self.diff_step_map(self.TE.pe[0, step]).unsqueeze(1) # (N, 1, diff_step_dim)
501
+
502
+ person_feat = self.person_proj(person_feat) # (N, 1, feature_dim)
503
+ person_feat = person_feat + diff_step_embedding
504
+
505
+ if indicator is not None:
506
+ indicator = torch.cat([torch.zeros((indicator.shape[0], self.n_prev_motions), device=indicator.device),
507
+ indicator], dim=1) # (N, L_p + L)
508
+ indicator = indicator.unsqueeze(-1) # (N, L_p + L, 1)
509
+
510
+ # Concat features and embeddings
511
+ if self.architecture == 'decoder':
512
+ feats_in = torch.cat([prev_motion_feat, motion_feat], dim=1) # (N, L_p + L, d_motion)
513
+ else:
514
+ raise ValueError(f'Unknown architecture: {self.architecture}')
515
+ if self.use_indicator:
516
+ feats_in = torch.cat([feats_in, indicator], dim=-1) # (N, L_p + L, d_motion + d_audio + 1)
517
+
518
+ feats_in = self.feature_proj(feats_in) # (N, L_p + L, feature_dim)
519
+ feats_in = torch.cat([person_feat, feats_in], dim=1) # (N, 1 + L_p + L, feature_dim)
520
+
521
+ if self.use_learnable_pe:
522
+ feats_in = feats_in + self.PE
523
+ else:
524
+ feats_in = self.PE(feats_in)
525
+
526
+ # Transformer
527
+ if self.architecture == 'decoder':
528
+ audio_feat_in = torch.cat([prev_audio_feat, audio_feat], dim=1) # (N, L_p + L, d_audio)
529
+ feat_out = self.transformer(feats_in, audio_feat_in, memory_mask=self.alignment_mask)
530
+ else:
531
+ raise ValueError(f'Unknown architecture: {self.architecture}')
532
+
533
+ # Decode predicted motion feature noise / sample
534
+ motion_feat_target = self.motion_dec(feat_out[:, 1:]) # (N, L_p + L, d_motion)
535
+
536
+ return motion_feat_target
diffposetalk/diffposetalk.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import tempfile
3
+ import warnings
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import librosa
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from tqdm import tqdm
12
+ from pydantic import BaseModel
13
+
14
+ from .diff_talking_head import DiffTalkingHead
15
+ from .utils import NullableArgs, coef_dict_to_vertices, get_coef_dict
16
+ from .utils.media import combine_video_and_audio, convert_video, reencode_audio
17
+
18
+ warnings.filterwarnings('ignore', message='PySoundFile failed. Trying audioread instead.')
19
+
20
+ class DiffPoseTalkConfig(BaseModel):
21
+ no_context_audio_feat: bool = False
22
+ model_path: str = "pretrained_models/diffposetalk/iter_0110000.pt" # DPT/head-SA-hubert-WM
23
+ coef_stats: str = "pretrained_models/diffposetalk/stats_train.npz"
24
+ style_path: str = "pretrained_models/diffposetalk/style/L4H4-T0.1-BS32/iter_0034000/normal.npy"
25
+ dynamic_threshold_ratio: float = 0.99
26
+ dynamic_threshold_min: float = 1.0
27
+ dynamic_threshold_max: float = 4.0
28
+ scale_audio: float = 1.15
29
+ scale_style: float = 3.0
30
+
31
+ class DiffPoseTalk:
32
+ def __init__(self, config: DiffPoseTalkConfig = DiffPoseTalkConfig(), device="cuda"):
33
+ self.cfg = config
34
+ self.device = device
35
+
36
+ self.no_context_audio_feat = self.cfg.no_context_audio_feat
37
+ model_data = torch.load(self.cfg.model_path, map_location=self.device)
38
+
39
+ self.model_args = NullableArgs(model_data['args'])
40
+ self.model = DiffTalkingHead(self.model_args, self.device)
41
+ model_data['model'].pop('denoising_net.TE.pe')
42
+ self.model.load_state_dict(model_data['model'], strict=False)
43
+ self.model.to(self.device)
44
+ self.model.eval()
45
+
46
+ self.use_indicator = self.model_args.use_indicator
47
+ self.rot_repr = self.model_args.rot_repr
48
+ self.predict_head_pose = not self.model_args.no_head_pose
49
+ if self.model.use_style:
50
+ style_dir = Path(self.model_args.style_enc_ckpt)
51
+ style_dir = Path(*style_dir.with_suffix('').parts[-3::2])
52
+ self.style_dir = style_dir
53
+
54
+ # sequence
55
+ self.n_motions = self.model_args.n_motions
56
+ self.n_prev_motions = self.model_args.n_prev_motions
57
+ self.fps = self.model_args.fps
58
+ self.audio_unit = 16000. / self.fps # num of samples per frame
59
+ self.n_audio_samples = round(self.audio_unit * self.n_motions)
60
+ self.pad_mode = self.model_args.pad_mode
61
+
62
+ self.coef_stats = dict(np.load(self.cfg.coef_stats))
63
+ self.coef_stats = {k: torch.from_numpy(v).to(self.device) for k, v in self.coef_stats.items()}
64
+
65
+ if self.cfg.dynamic_threshold_ratio > 0:
66
+ self.dynamic_threshold = (self.cfg.dynamic_threshold_ratio, self.cfg.dynamic_threshold_min,
67
+ self.cfg.dynamic_threshold_max)
68
+ else:
69
+ self.dynamic_threshold = None
70
+
71
+
72
+ def infer_from_file(self, audio_path, shape_coef):
73
+ n_repetitions = 1
74
+ cfg_mode = None
75
+ cfg_cond = self.model.guiding_conditions
76
+ cfg_scale = []
77
+ for cond in cfg_cond:
78
+ if cond == 'audio':
79
+ cfg_scale.append(self.cfg.scale_audio)
80
+ elif cond == 'style':
81
+ cfg_scale.append(self.cfg.scale_style)
82
+
83
+ coef_dict = self.infer_coeffs(audio_path, shape_coef, self.cfg.style_path, n_repetitions,
84
+ cfg_mode, cfg_cond, cfg_scale, include_shape=True)
85
+ return coef_dict
86
+
87
+ @torch.no_grad()
88
+ def infer_coeffs(self, audio, shape_coef, style_feat=None, n_repetitions=1,
89
+ cfg_mode=None, cfg_cond=None, cfg_scale=1.15, include_shape=False):
90
+ # Returns dict[str, (n_repetitions, L, *)]
91
+ # Step 1: Preprocessing
92
+ # Preprocess audio
93
+ if isinstance(audio, (str, Path)):
94
+ audio, _ = librosa.load(audio, sr=16000, mono=True)
95
+ if isinstance(audio, np.ndarray):
96
+ audio = torch.from_numpy(audio).to(self.device)
97
+ assert audio.ndim == 1, 'Audio must be 1D tensor.'
98
+ audio_mean, audio_std = torch.mean(audio), torch.std(audio)
99
+ audio = (audio - audio_mean) / (audio_std + 1e-5)
100
+
101
+ # Preprocess shape coefficient
102
+ if isinstance(shape_coef, (str, Path)):
103
+ shape_coef = np.load(shape_coef)
104
+ if not isinstance(shape_coef, np.ndarray):
105
+ shape_coef = shape_coef['shape']
106
+ if isinstance(shape_coef, np.ndarray):
107
+ shape_coef = torch.from_numpy(shape_coef).float().to(self.device)
108
+ assert shape_coef.ndim <= 2, 'Shape coefficient must be 1D or 2D tensor.'
109
+ if shape_coef.ndim > 1:
110
+ # use the first frame as the shape coefficient
111
+ shape_coef = shape_coef[0]
112
+ original_shape_coef = shape_coef.clone()
113
+ if self.coef_stats is not None:
114
+ shape_coef = (shape_coef - self.coef_stats['shape_mean']) / self.coef_stats['shape_std']
115
+ shape_coef = shape_coef.unsqueeze(0).expand(n_repetitions, -1)
116
+
117
+ # Preprocess style feature if given
118
+ if style_feat is not None:
119
+ assert self.model.use_style
120
+ if isinstance(style_feat, (str, Path)):
121
+ style_feat = Path(style_feat)
122
+ if not style_feat.exists() and not style_feat.is_absolute():
123
+ style_feat = style_feat.parent / self.style_dir / style_feat.name
124
+ style_feat = np.load(style_feat)
125
+ if not isinstance(style_feat, np.ndarray):
126
+ style_feat = style_feat['style']
127
+ if isinstance(style_feat, np.ndarray):
128
+ style_feat = torch.from_numpy(style_feat).float().to(self.device)
129
+ assert style_feat.ndim == 1, 'Style feature must be 1D tensor.'
130
+ style_feat = style_feat.unsqueeze(0).expand(n_repetitions, -1)
131
+
132
+ # Step 2: Predict motion coef
133
+ # divide into synthesize units and do synthesize
134
+ clip_len = int(len(audio) / 16000 * self.fps)
135
+ stride = self.n_motions
136
+ if clip_len <= self.n_motions:
137
+ n_subdivision = 1
138
+ else:
139
+ n_subdivision = math.ceil(clip_len / stride)
140
+
141
+ # Prepare audio input
142
+ n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
143
+ n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
144
+ if n_padding_audio_samples > 0:
145
+ if self.pad_mode == 'zero':
146
+ padding_value = 0
147
+ elif self.pad_mode == 'replicate':
148
+ padding_value = audio[-1]
149
+ else:
150
+ raise ValueError(f'Unknown pad mode: {self.pad_mode}')
151
+ audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
152
+
153
+ if not self.no_context_audio_feat:
154
+ audio_feat = self.model.extract_audio_feature(audio.unsqueeze(0), self.n_motions * n_subdivision)
155
+
156
+ # Generate `self.n_motions` new frames at one time, and use the last `self.n_prev_motions` frames
157
+ # from the previous generation as the initial motion condition
158
+ coef_list = []
159
+ for i in range(0, n_subdivision):
160
+ start_idx = i * stride
161
+ end_idx = start_idx + self.n_motions
162
+ indicator = torch.ones((n_repetitions, self.n_motions)).to(self.device) if self.use_indicator else None
163
+ if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
164
+ indicator[:, -n_padding_frames:] = 0
165
+ if not self.no_context_audio_feat:
166
+ audio_in = audio_feat[:, start_idx:end_idx].expand(n_repetitions, -1, -1)
167
+ else:
168
+ audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
169
+
170
+ # generate motion coefficients
171
+ if i == 0:
172
+ # -> (N, L, d_motion=n_code_per_frame * code_dim)
173
+ motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
174
+ indicator=indicator, cfg_mode=cfg_mode,
175
+ cfg_cond=cfg_cond, cfg_scale=cfg_scale,
176
+ dynamic_threshold=self.dynamic_threshold)
177
+ else:
178
+ motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
179
+ prev_motion_feat, prev_audio_feat, noise,
180
+ indicator=indicator, cfg_mode=cfg_mode,
181
+ cfg_cond=cfg_cond, cfg_scale=cfg_scale,
182
+ dynamic_threshold=self.dynamic_threshold)
183
+ prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
184
+ prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
185
+
186
+ motion_coef = motion_feat
187
+ if i == n_subdivision - 1 and n_padding_frames > 0:
188
+ motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
189
+ coef_list.append(motion_coef)
190
+
191
+ motion_coef = torch.cat(coef_list, dim=1)
192
+
193
+ # Step 3: restore to coef dict
194
+ coef_dict = get_coef_dict(motion_coef, None, self.coef_stats, self.predict_head_pose, self.rot_repr)
195
+ if include_shape:
196
+ coef_dict['shape'] = original_shape_coef[None, None].expand(n_repetitions, motion_coef.shape[1], -1)
197
+ return self.coef_to_a1_format(coef_dict)
198
+
199
+ def coef_to_a1_format(self, coef_dict):
200
+ n_frames = coef_dict['exp'].shape[1]
201
+ new_coef_dict = []
202
+ for i in range(n_frames):
203
+
204
+ new_coef_dict.append({
205
+ "expression_params": coef_dict["exp"][0, i:i+1],
206
+ "jaw_params": coef_dict["pose"][0, i:i+1, 3:],
207
+ "eye_pose_params": torch.zeros(1, 6).type_as(coef_dict["pose"]),
208
+ "pose_params": coef_dict["pose"][0, i:i+1, :3],
209
+ "eyelid_params": None
210
+ })
211
+ return new_coef_dict
212
+
213
+
214
+
215
+
216
+
217
+ @staticmethod
218
+ def _pad_coef(coef, n_frames, elem_ndim=1):
219
+ if coef.ndim == elem_ndim:
220
+ coef = coef[None]
221
+ elem_shape = coef.shape[1:]
222
+ if coef.shape[0] >= n_frames:
223
+ new_coef = coef[:n_frames]
224
+ else:
225
+ # repeat the last coef frame
226
+ new_coef = torch.cat([coef, coef[[-1]].expand(n_frames - coef.shape[0], *elem_shape)], dim=0)
227
+ return new_coef # (n_frames, *elem_shape)
228
+
diffposetalk/hubert.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import HubertModel
2
+ from transformers.modeling_outputs import BaseModelOutput
3
+
4
+ from .wav2vec2 import linear_interpolation
5
+
6
+ _CONFIG_FOR_DOC = 'HubertConfig'
7
+
8
+
9
+ class HubertModel(HubertModel):
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+
13
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
14
+ output_hidden_states=None, return_dict=None, frame_num=None):
15
+ self.config.output_attentions = True
16
+
17
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
18
+ output_hidden_states = (
19
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
20
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
21
+
22
+ extract_features = self.feature_extractor(input_values) # (N, C, L)
23
+ # Resample the audio feature @ 50 fps to `output_fps`.
24
+ if frame_num is not None:
25
+ extract_features_len = round(frame_num * 50 / output_fps)
26
+ extract_features = extract_features[:, :, :extract_features_len]
27
+ extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
28
+ extract_features = extract_features.transpose(1, 2) # (N, L, C)
29
+
30
+ if attention_mask is not None:
31
+ # compute reduced attention_mask corresponding to feature vectors
32
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
33
+
34
+ hidden_states = self.feature_projection(extract_features)
35
+ hidden_states = self._mask_hidden_states(hidden_states)
36
+
37
+ encoder_outputs = self.encoder(
38
+ hidden_states,
39
+ attention_mask=attention_mask,
40
+ output_attentions=output_attentions,
41
+ output_hidden_states=output_hidden_states,
42
+ return_dict=return_dict,
43
+ )
44
+
45
+ hidden_states = encoder_outputs[0]
46
+
47
+ if not return_dict:
48
+ return (hidden_states,) + encoder_outputs[1:]
49
+
50
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
51
+ attentions=encoder_outputs.attentions, )
diffposetalk/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .common import *
diffposetalk/utils/common.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class NullableArgs:
9
+ def __init__(self, namespace):
10
+ for key, value in namespace.__dict__.items():
11
+ setattr(self, key, value)
12
+
13
+ def __getattr__(self, key):
14
+ # when an attribute lookup has not found the attribute
15
+ if key == 'align_mask_width':
16
+ if 'use_alignment_mask' in self.__dict__:
17
+ return 1 if self.use_alignment_mask else 0
18
+ else:
19
+ return 0
20
+ if key == 'no_head_pose':
21
+ return not self.predict_head_pose
22
+ if key == 'no_use_learnable_pe':
23
+ return not self.use_learnable_pe
24
+
25
+ return None
26
+
27
+
28
+ def count_parameters(model):
29
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
30
+
31
+
32
+ def get_option_text(args, parser):
33
+ message = ''
34
+ for k, v in sorted(vars(args).items()):
35
+ comment = ''
36
+ default = parser.get_default(k)
37
+ if v != default:
38
+ comment = f'\t[default: {str(default)}]'
39
+ message += f'{str(k):>30}: {str(v):<30}{comment}\n'
40
+ return message
41
+
42
+
43
+ def get_model_path(exp_name, iteration, model_type='DPT'):
44
+ exp_root_dir = Path(__file__).parent.parent / 'experiments' / model_type
45
+ exp_dir = exp_root_dir / exp_name
46
+ if not exp_dir.exists():
47
+ exp_dir = next(exp_root_dir.glob(f'{exp_name}*'))
48
+ model_path = exp_dir / f'checkpoints/iter_{iteration:07}.pt'
49
+ return model_path, exp_dir.relative_to(exp_root_dir)
50
+
51
+
52
+ def get_pose_input(coef_dict, rot_repr, with_global_pose):
53
+ if rot_repr == 'aa':
54
+ pose_input = coef_dict['pose'] if with_global_pose else coef_dict['pose'][..., -3:]
55
+ # Remove mouth rotation round y, z axis
56
+ pose_input = pose_input[..., :-2]
57
+ else:
58
+ raise ValueError(f'Unknown rotation representation: {rot_repr}')
59
+ return pose_input
60
+
61
+
62
+ def get_motion_coef(coef_dict, rot_repr, with_global_pose=False, norm_stats=None):
63
+ if norm_stats is not None:
64
+ if rot_repr == 'aa':
65
+ keys = ['exp', 'pose']
66
+ else:
67
+ raise ValueError(f'Unknown rotation representation {rot_repr}!')
68
+
69
+ coef_dict = {k: (coef_dict[k] - norm_stats[f'{k}_mean']) / norm_stats[f'{k}_std'] for k in keys}
70
+ pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose)
71
+ return torch.cat([coef_dict['exp'], pose_coef], dim=-1)
72
+
73
+
74
+ def get_coef_dict(motion_coef, shape_coef=None, denorm_stats=None, with_global_pose=False, rot_repr='aa'):
75
+ coef_dict = {
76
+ 'exp': motion_coef[..., :50]
77
+ }
78
+ if rot_repr == 'aa':
79
+ if with_global_pose:
80
+ coef_dict['pose'] = motion_coef[..., 50:]
81
+ else:
82
+ placeholder = torch.zeros_like(motion_coef[..., :3])
83
+ coef_dict['pose'] = torch.cat([placeholder, motion_coef[..., -1:]], dim=-1)
84
+ # Add back rotation around y, z axis
85
+ coef_dict['pose'] = torch.cat([coef_dict['pose'], torch.zeros_like(motion_coef[..., :2])], dim=-1)
86
+ else:
87
+ raise ValueError(f'Unknown rotation representation {rot_repr}!')
88
+
89
+ if shape_coef is not None:
90
+ if motion_coef.ndim == 3:
91
+ if shape_coef.ndim == 2:
92
+ shape_coef = shape_coef.unsqueeze(1)
93
+ if shape_coef.shape[1] == 1:
94
+ shape_coef = shape_coef.expand(-1, motion_coef.shape[1], -1)
95
+
96
+ coef_dict['shape'] = shape_coef
97
+
98
+ if denorm_stats is not None:
99
+ coef_dict = {k: coef_dict[k] * denorm_stats[f'{k}_std'] + denorm_stats[f'{k}_mean'] for k in coef_dict}
100
+
101
+ if not with_global_pose:
102
+ if rot_repr == 'aa':
103
+ coef_dict['pose'][..., :3] = 0
104
+ else:
105
+ raise ValueError(f'Unknown rotation representation {rot_repr}!')
106
+
107
+ return coef_dict
108
+
109
+
110
+ def coef_dict_to_vertices(coef_dict, flame, rot_repr='aa', ignore_global_rot=False, flame_batch_size=512):
111
+ shape = coef_dict['exp'].shape[:-1]
112
+ coef_dict = {k: v.view(-1, v.shape[-1]) for k, v in coef_dict.items()}
113
+ n_samples = reduce(lambda x, y: x * y, shape, 1)
114
+
115
+ # Convert to vertices
116
+ vert_list = []
117
+ for i in range(0, n_samples, flame_batch_size):
118
+ batch_coef_dict = {k: v[i:i + flame_batch_size] for k, v in coef_dict.items()}
119
+ if rot_repr == 'aa':
120
+ vert, _, _ = flame(
121
+ batch_coef_dict['shape'], batch_coef_dict['exp'], batch_coef_dict['pose'],
122
+ pose2rot=True, ignore_global_rot=ignore_global_rot, return_lm2d=False, return_lm3d=False)
123
+ else:
124
+ raise ValueError(f'Unknown rot_repr: {rot_repr}')
125
+ vert_list.append(vert)
126
+
127
+ vert_list = torch.cat(vert_list, dim=0) # (n_samples, 5023, 3)
128
+ vert_list = vert_list.view(*shape, -1, 3) # (..., 5023, 3)
129
+
130
+ return vert_list
131
+
132
+
133
+ def compute_loss(args, is_starting_sample, shape_coef, motion_coef_gt, noise, target, prev_motion_coef, coef_stats,
134
+ flame, end_idx=None):
135
+ if args.criterion.lower() == 'l2':
136
+ criterion_func = F.mse_loss
137
+ elif args.criterion.lower() == 'l1':
138
+ criterion_func = F.l1_loss
139
+ else:
140
+ raise NotImplementedError(f'Criterion {args.criterion} not implemented.')
141
+
142
+ loss_vert = None
143
+ loss_vel = None
144
+ loss_smooth = None
145
+ loss_head_angle = None
146
+ loss_head_vel = None
147
+ loss_head_smooth = None
148
+ loss_head_trans_vel = None
149
+ loss_head_trans_accel = None
150
+ loss_head_trans = None
151
+ if args.target == 'noise':
152
+ loss_noise = criterion_func(noise, target[:, args.n_prev_motions:], reduction='none')
153
+ elif args.target == 'sample':
154
+ if is_starting_sample:
155
+ target = target[:, args.n_prev_motions:]
156
+ else:
157
+ motion_coef_gt = torch.cat([prev_motion_coef, motion_coef_gt], dim=1)
158
+ if args.no_constrain_prev:
159
+ target = torch.cat([prev_motion_coef, target[:, args.n_prev_motions:]], dim=1)
160
+
161
+ loss_noise = criterion_func(motion_coef_gt, target, reduction='none')
162
+
163
+ if args.l_vert > 0 or args.l_vel > 0:
164
+ coef_gt = get_coef_dict(motion_coef_gt, shape_coef, coef_stats, with_global_pose=False,
165
+ rot_repr=args.rot_repr)
166
+ coef_pred = get_coef_dict(target, shape_coef, coef_stats, with_global_pose=False,
167
+ rot_repr=args.rot_repr)
168
+ seq_len = target.shape[1]
169
+
170
+ if args.rot_repr == 'aa':
171
+ verts_gt, _, _ = flame(coef_gt['shape'].view(-1, 100), coef_gt['exp'].view(-1, 50),
172
+ coef_gt['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
173
+ verts_pred, _, _ = flame(coef_pred['shape'].view(-1, 100), coef_pred['exp'].view(-1, 50),
174
+ coef_pred['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
175
+ else:
176
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
177
+ verts_gt = verts_gt.view(-1, seq_len, 5023, 3)
178
+ verts_pred = verts_pred.view(-1, seq_len, 5023, 3)
179
+
180
+ if args.l_vert > 0:
181
+ loss_vert = criterion_func(verts_gt, verts_pred, reduction='none')
182
+
183
+ if args.l_vel > 0:
184
+ vel_gt = verts_gt[:, 1:] - verts_gt[:, :-1]
185
+ vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
186
+ loss_vel = criterion_func(vel_gt, vel_pred, reduction='none')
187
+
188
+ if args.l_smooth > 0:
189
+ vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
190
+ loss_smooth = criterion_func(vel_pred[:, 1:], vel_pred[:, :-1], reduction='none')
191
+
192
+ # head pose
193
+ if not args.no_head_pose:
194
+ if args.rot_repr == 'aa':
195
+ head_pose_gt = motion_coef_gt[:, :, 50:53]
196
+ head_pose_pred = target[:, :, 50:53]
197
+ else:
198
+ raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
199
+
200
+ if args.l_head_angle > 0:
201
+ loss_head_angle = criterion_func(head_pose_gt, head_pose_pred, reduction='none')
202
+
203
+ if args.l_head_vel > 0:
204
+ head_vel_gt = head_pose_gt[:, 1:] - head_pose_gt[:, :-1]
205
+ head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
206
+ loss_head_vel = criterion_func(head_vel_gt, head_vel_pred, reduction='none')
207
+
208
+ if args.l_head_smooth > 0:
209
+ head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
210
+ loss_head_smooth = criterion_func(head_vel_pred[:, 1:], head_vel_pred[:, :-1], reduction='none')
211
+
212
+ if not is_starting_sample and args.l_head_trans > 0:
213
+ # # version 1: constrain both the predicted previous and current motions (x_{-3} ~ x_{2})
214
+ # head_pose_trans = head_pose_pred[:, args.n_prev_motions - 3:args.n_prev_motions + 3]
215
+ # head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
216
+ # head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]
217
+
218
+ # version 2: constrain only the predicted current motions (x_{0} ~ x_{2})
219
+ head_pose_trans = torch.cat([head_pose_gt[:, args.n_prev_motions - 3:args.n_prev_motions],
220
+ head_pose_pred[:, args.n_prev_motions:args.n_prev_motions + 3]], dim=1)
221
+ head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
222
+ head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]
223
+
224
+ # will constrain x_{-2|0} ~ x_{1}
225
+ loss_head_trans_vel = criterion_func(head_vel_pred[:, 2:4], head_vel_pred[:, 1:3], reduction='none')
226
+ # will constrain x_{-3|0} ~ x_{2}
227
+ loss_head_trans_accel = criterion_func(head_accel_pred[:, 1:], head_accel_pred[:, :-1],
228
+ reduction='none')
229
+ else:
230
+ raise ValueError(f'Unknown diffusion target: {args.target}')
231
+
232
+ if end_idx is None:
233
+ mask = torch.ones((target.shape[0], args.n_motions), dtype=torch.bool, device=target.device)
234
+ else:
235
+ mask = torch.arange(args.n_motions, device=target.device).expand(target.shape[0], -1) < end_idx.unsqueeze(1)
236
+
237
+ if args.target == 'sample' and not is_starting_sample:
238
+ if args.no_constrain_prev:
239
+ # Warning: this option will be deprecated in the future
240
+ mask = torch.cat([torch.zeros_like(mask[:, :args.n_prev_motions]), mask], dim=1)
241
+ else:
242
+ mask = torch.cat([torch.ones_like(mask[:, :args.n_prev_motions]), mask], dim=1)
243
+
244
+ loss_noise = loss_noise[mask].mean()
245
+ if loss_vert is not None:
246
+ loss_vert = loss_vert[mask].mean()
247
+ if loss_vel is not None:
248
+ loss_vel = loss_vel[mask[:, 1:]]
249
+ loss_vel = loss_vel.mean() if torch.numel(loss_vel) > 0 else None
250
+ if loss_smooth is not None:
251
+ loss_smooth = loss_smooth[mask[:, 2:]]
252
+ loss_smooth = loss_smooth.mean() if torch.numel(loss_smooth) > 0 else None
253
+ if loss_head_angle is not None:
254
+ loss_head_angle = loss_head_angle[mask].mean()
255
+ if loss_head_vel is not None:
256
+ loss_head_vel = loss_head_vel[mask[:, 1:]]
257
+ loss_head_vel = loss_head_vel.mean() if torch.numel(loss_head_vel) > 0 else None
258
+ if loss_head_smooth is not None:
259
+ loss_head_smooth = loss_head_smooth[mask[:, 2:]]
260
+ loss_head_smooth = loss_head_smooth.mean() if torch.numel(loss_head_smooth) > 0 else None
261
+ if loss_head_trans_vel is not None:
262
+ vel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 2]
263
+ accel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 3]
264
+ loss_head_trans_vel = loss_head_trans_vel[vel_mask].mean()
265
+ loss_head_trans_accel = loss_head_trans_accel[accel_mask].mean()
266
+ loss_head_trans = loss_head_trans_vel + loss_head_trans_accel
267
+
268
+ return loss_noise, loss_vert, loss_vel, loss_smooth, loss_head_angle, loss_head_vel, loss_head_smooth, \
269
+ loss_head_trans
270
+
271
+
272
+ def _truncate_audio(audio, end_idx, pad_mode='zero'):
273
+ batch_size = audio.shape[0]
274
+ audio_trunc = audio.clone()
275
+ if pad_mode == 'replicate':
276
+ for i in range(batch_size):
277
+ audio_trunc[i, end_idx[i]:] = audio_trunc[i, end_idx[i] - 1]
278
+ elif pad_mode == 'zero':
279
+ for i in range(batch_size):
280
+ audio_trunc[i, end_idx[i]:] = 0
281
+ else:
282
+ raise ValueError(f'Unknown pad mode {pad_mode}!')
283
+
284
+ return audio_trunc
285
+
286
+
287
+ def _truncate_coef_dict(coef_dict, end_idx, pad_mode='zero'):
288
+ batch_size = coef_dict['exp'].shape[0]
289
+ coef_dict_trunc = {k: v.clone() for k, v in coef_dict.items()}
290
+ if pad_mode == 'replicate':
291
+ for i in range(batch_size):
292
+ for k in coef_dict_trunc:
293
+ coef_dict_trunc[k][i, end_idx[i]:] = coef_dict_trunc[k][i, end_idx[i] - 1]
294
+ elif pad_mode == 'zero':
295
+ for i in range(batch_size):
296
+ for k in coef_dict:
297
+ coef_dict_trunc[k][i, end_idx[i]:] = 0
298
+ else:
299
+ raise ValueError(f'Unknown pad mode: {pad_mode}!')
300
+
301
+ return coef_dict_trunc
302
+
303
+
304
+ def truncate_coef_dict_and_audio(audio, coef_dict, n_motions, audio_unit=640, pad_mode='zero'):
305
+ batch_size = audio.shape[0]
306
+ end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
307
+ audio_end_idx = (end_idx * audio_unit).long()
308
+ # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)
309
+
310
+ # truncate audio
311
+ audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)
312
+
313
+ # truncate coef dict
314
+ coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
315
+
316
+ return audio_trunc, coef_dict_trunc, end_idx
317
+
318
+
319
+ def truncate_motion_coef_and_audio(audio, motion_coef, n_motions, audio_unit=640, pad_mode='zero'):
320
+ batch_size = audio.shape[0]
321
+ end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
322
+ audio_end_idx = (end_idx * audio_unit).long()
323
+ # mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)
324
+
325
+ # truncate audio
326
+ audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)
327
+
328
+ # prepare coef dict and stats
329
+ coef_dict = {'exp': motion_coef[..., :50], 'pose_any': motion_coef[..., 50:]}
330
+
331
+ # truncate coef dict
332
+ coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
333
+ motion_coef_trunc = torch.cat([coef_dict_trunc['exp'], coef_dict_trunc['pose_any']], dim=-1)
334
+
335
+ return audio_trunc, motion_coef_trunc, end_idx
336
+
337
+
338
+ def nt_xent_loss(feature_a, feature_b, temperature):
339
+ """
340
+ Normalized temperature-scaled cross entropy loss.
341
+
342
+ (Adapted from https://github.com/sthalles/SimCLR/blob/master/simclr.py)
343
+
344
+ Args:
345
+ feature_a (torch.Tensor): shape (batch_size, feature_dim)
346
+ feature_b (torch.Tensor): shape (batch_size, feature_dim)
347
+ temperature (float): temperature scaling factor
348
+
349
+ Returns:
350
+ torch.Tensor: scalar
351
+ """
352
+ batch_size = feature_a.shape[0]
353
+ device = feature_a.device
354
+
355
+ features = torch.cat([feature_a, feature_b], dim=0)
356
+
357
+ labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0)
358
+ labels = (labels.unsqueeze(0) == labels.unsqueeze(1))
359
+ labels = labels.to(device)
360
+
361
+ features = F.normalize(features, dim=1)
362
+ similarity_matrix = torch.matmul(features, features.T)
363
+
364
+ # discard the main diagonal from both: labels and similarities matrix
365
+ mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
366
+ labels = labels[~mask].view(labels.shape[0], -1)
367
+ similarity_matrix = similarity_matrix[~mask].view(labels.shape[0], -1)
368
+
369
+ # select the positives and negatives
370
+ positives = similarity_matrix[labels].view(labels.shape[0], -1)
371
+ negatives = similarity_matrix[~labels].view(labels.shape[0], -1)
372
+
373
+ logits = torch.cat([positives, negatives], dim=1)
374
+ logits = logits / temperature
375
+ labels = torch.zeros(labels.shape[0], dtype=torch.long).to(device)
376
+
377
+ loss = F.cross_entropy(logits, labels)
378
+ return loss
diffposetalk/utils/media.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shlex
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+
6
+ def combine_video_and_audio(video_file, audio_file, output, quality=17, copy_audio=True):
7
+ audio_codec = '-c:a copy' if copy_audio else ''
8
+ cmd = f'ffmpeg -i {video_file} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
9
+ f'{audio_codec} -fflags +shortest -y -hide_banner -loglevel error {output}'
10
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
11
+
12
+
13
+ def combine_frames_and_audio(frame_files, audio_file, fps, output, quality=17):
14
+ cmd = f'ffmpeg -framerate {fps} -i {frame_files} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
15
+ f'-c:a copy -fflags +shortest -y -hide_banner -loglevel error {output}'
16
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
17
+
18
+
19
+ def convert_video(video_file, output, quality=17):
20
+ cmd = f'ffmpeg -i {video_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
21
+ f'-fflags +shortest -y -hide_banner -loglevel error {output}'
22
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
23
+
24
+
25
+ def reencode_audio(audio_file, output):
26
+ cmd = f'ffmpeg -i {audio_file} -y -hide_banner -loglevel error {output}'
27
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
28
+
29
+
30
+ def extract_frames(filename, output_dir, quality=1):
31
+ output_dir = Path(output_dir)
32
+ output_dir.mkdir(parents=True, exist_ok=True)
33
+ cmd = f'ffmpeg -i {filename} -qmin 1 -qscale:v {quality} -y -start_number 0 -hide_banner -loglevel error ' \
34
+ f'{output_dir / "%06d.jpg"}'
35
+ assert subprocess.run(shlex.split(cmd)).returncode == 0
diffposetalk/utils/renderer.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+
4
+ import cv2
5
+ import kiui.mesh
6
+ import numpy as np
7
+
8
+ # os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # osmesa or egl
9
+ os.environ['PYOPENGL_PLATFORM'] = 'egl'
10
+ import pyrender
11
+ import trimesh
12
+ # from psbody.mesh import Mesh
13
+
14
+
15
+ class MeshRenderer:
16
+ def __init__(self, size, fov=16 / 180 * np.pi, camera_pose=None, light_pose=None, black_bg=False):
17
+ # Camera
18
+ self.frustum = {'near': 0.01, 'far': 3.0}
19
+ self.camera = pyrender.PerspectiveCamera(yfov=fov, znear=self.frustum['near'],
20
+ zfar=self.frustum['far'], aspectRatio=1.0)
21
+
22
+ # Material
23
+ self.primitive_material = pyrender.material.MetallicRoughnessMaterial(
24
+ alphaMode='BLEND',
25
+ baseColorFactor=[0.3, 0.3, 0.3, 1.0],
26
+ metallicFactor=0.8,
27
+ roughnessFactor=0.8
28
+ )
29
+
30
+ # Lighting
31
+ light_color = np.array([1., 1., 1.])
32
+ self.light = pyrender.DirectionalLight(color=light_color, intensity=2)
33
+ self.light_angle = np.pi / 6.0
34
+
35
+ # Scene
36
+ self.scene = None
37
+ self._init_scene(black_bg)
38
+
39
+ # add camera and lighting
40
+ self._init_camera(camera_pose)
41
+ self._init_lighting(light_pose)
42
+
43
+ # Renderer
44
+ self.renderer = pyrender.OffscreenRenderer(*size, point_size=1.0)
45
+
46
+ def _init_scene(self, black_bg=False):
47
+ if black_bg:
48
+ self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
49
+ else:
50
+ self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
51
+
52
+ def _init_camera(self, camera_pose=None):
53
+ if camera_pose is None:
54
+ camera_pose = np.eye(4)
55
+ camera_pose[:3, 3] = np.array([0, 0, 1])
56
+ self.camera_pose = camera_pose.copy()
57
+ self.camera_node = self.scene.add(self.camera, pose=camera_pose)
58
+
59
+ def _init_lighting(self, light_pose=None):
60
+ if light_pose is None:
61
+ light_pose = np.eye(4)
62
+ light_pose[:3, 3] = np.array([0, 0, 1])
63
+ self.light_pose = light_pose.copy()
64
+
65
+ light_poses = self._get_light_poses(self.light_angle, light_pose)
66
+ self.light_nodes = [self.scene.add(self.light, pose=light_pose) for light_pose in light_poses]
67
+
68
+ def set_camera_pose(self, camera_pose):
69
+ self.camera_pose = camera_pose.copy()
70
+ self.scene.set_pose(self.camera_node, pose=camera_pose)
71
+
72
+ def set_lighting_pose(self, light_pose):
73
+ self.light_pose = light_pose.copy()
74
+
75
+ light_poses = self._get_light_poses(self.light_angle, light_pose)
76
+ for light_node, light_pose in zip(self.light_nodes, light_poses):
77
+ self.scene.set_pose(light_node, pose=light_pose)
78
+
79
+ def render_mesh(self, v, f, t_center, rot=np.zeros(3), tex_img=None, tex_uv=None,
80
+ camera_pose=None, light_pose=None):
81
+ # Prepare mesh
82
+ v[:] = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
83
+ if tex_img is not None:
84
+ tex = pyrender.Texture(source=tex_img, source_channels='RGB')
85
+ tex_material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
86
+ from kiui.mesh import Mesh
87
+ import torch
88
+ mesh = Mesh(
89
+ v=torch.from_numpy(v),
90
+ f=torch.from_numpy(f),
91
+ vt=tex_uv['vt'],
92
+ ft=tex_uv['ft']
93
+ )
94
+ with tempfile.NamedTemporaryFile(suffix='.obj') as f:
95
+ mesh.write_obj(f.name)
96
+ tri_mesh = trimesh.load(f.name, process=False)
97
+ return tri_mesh
98
+ # tri_mesh = self._pyrender_mesh_workaround(mesh)
99
+ render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=tex_material)
100
+ else:
101
+ tri_mesh = trimesh.Trimesh(vertices=v, faces=f)
102
+ render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=self.primitive_material, smooth=True)
103
+ mesh_node = self.scene.add(render_mesh, pose=np.eye(4))
104
+
105
+ # Change camera and lighting pose if necessary
106
+ if camera_pose is not None:
107
+ self.set_camera_pose(camera_pose)
108
+ if light_pose is not None:
109
+ self.set_lighting_pose(light_pose)
110
+
111
+ # Render
112
+ flags = pyrender.RenderFlags.SKIP_CULL_FACES
113
+ color, depth = self.renderer.render(self.scene, flags=flags)
114
+
115
+ # Remove mesh
116
+ self.scene.remove_node(mesh_node)
117
+
118
+ return color, depth
119
+
120
+ @staticmethod
121
+ def _get_light_poses(light_angle, light_pose):
122
+ light_poses = []
123
+ init_pos = light_pose[:3, 3].copy()
124
+
125
+ light_poses.append(light_pose.copy())
126
+
127
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([light_angle, 0, 0]))[0].dot(init_pos)
128
+ light_poses.append(light_pose.copy())
129
+
130
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([-light_angle, 0, 0]))[0].dot(init_pos)
131
+ light_poses.append(light_pose.copy())
132
+
133
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([0, -light_angle, 0]))[0].dot(init_pos)
134
+ light_poses.append(light_pose.copy())
135
+
136
+ light_pose[:3, 3] = cv2.Rodrigues(np.array([0, light_angle, 0]))[0].dot(init_pos)
137
+ light_poses.append(light_pose.copy())
138
+
139
+ return light_poses
140
+
141
+ @staticmethod
142
+ def _pyrender_mesh_workaround(mesh):
143
+ # Workaround as pyrender requires number of vertices and uv coordinates to be the same
144
+ with tempfile.NamedTemporaryFile(suffix='.obj') as f:
145
+ mesh.write_obj(f.name)
146
+ tri_mesh = trimesh.load(f.name, process=False)
147
+ return tri_mesh
diffposetalk/utils/rotation_conversions.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
+ # Check PYTORCH3D_LICENCE before use
4
+
5
+ import functools
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ """
13
+ The transformation matrices returned from the functions in this file assume
14
+ the points on which the transformation will be applied are column vectors.
15
+ i.e. the R matrix is structured as
16
+
17
+ R = [
18
+ [Rxx, Rxy, Rxz],
19
+ [Ryx, Ryy, Ryz],
20
+ [Rzx, Rzy, Rzz],
21
+ ] # (3, 3)
22
+
23
+ This matrix can be applied to column vectors by post multiplication
24
+ by the points e.g.
25
+
26
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
27
+ transformed_points = R * points
28
+
29
+ To apply the same matrix to points which are row vectors, the R matrix
30
+ can be transposed and pre multiplied by the points:
31
+
32
+ e.g.
33
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
34
+ transformed_points = points * R.transpose(1, 0)
35
+ """
36
+
37
+
38
+ def quaternion_to_matrix(quaternions):
39
+ """
40
+ Convert rotations given as quaternions to rotation matrices.
41
+
42
+ Args:
43
+ quaternions: quaternions with real part first,
44
+ as tensor of shape (..., 4).
45
+
46
+ Returns:
47
+ Rotation matrices as tensor of shape (..., 3, 3).
48
+ """
49
+ r, i, j, k = torch.unbind(quaternions, -1)
50
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
51
+
52
+ o = torch.stack(
53
+ (
54
+ 1 - two_s * (j * j + k * k),
55
+ two_s * (i * j - k * r),
56
+ two_s * (i * k + j * r),
57
+ two_s * (i * j + k * r),
58
+ 1 - two_s * (i * i + k * k),
59
+ two_s * (j * k - i * r),
60
+ two_s * (i * k - j * r),
61
+ two_s * (j * k + i * r),
62
+ 1 - two_s * (i * i + j * j),
63
+ ),
64
+ -1,
65
+ )
66
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
67
+
68
+
69
+ def _copysign(a, b):
70
+ """
71
+ Return a tensor where each element has the absolute value taken from the,
72
+ corresponding element of a, with sign taken from the corresponding
73
+ element of b. This is like the standard copysign floating-point operation,
74
+ but is not careful about negative 0 and NaN.
75
+
76
+ Args:
77
+ a: source tensor.
78
+ b: tensor whose signs will be used, of the same shape as a.
79
+
80
+ Returns:
81
+ Tensor of the same shape as a with the signs of b.
82
+ """
83
+ signs_differ = (a < 0) != (b < 0)
84
+ return torch.where(signs_differ, -a, a)
85
+
86
+
87
+ def _sqrt_positive_part(x):
88
+ """
89
+ Returns torch.sqrt(torch.max(0, x))
90
+ but with a zero subgradient where x is 0.
91
+ """
92
+ ret = torch.zeros_like(x)
93
+ positive_mask = x > 0
94
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
95
+ return ret
96
+
97
+
98
+ def matrix_to_quaternion(matrix):
99
+ """
100
+ Convert rotations given as rotation matrices to quaternions.
101
+
102
+ Args:
103
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
104
+
105
+ Returns:
106
+ quaternions with real part first, as tensor of shape (..., 4).
107
+ """
108
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
109
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
110
+ m00 = matrix[..., 0, 0]
111
+ m11 = matrix[..., 1, 1]
112
+ m22 = matrix[..., 2, 2]
113
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
114
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
115
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
116
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
117
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
118
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
119
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
120
+ return torch.stack((o0, o1, o2, o3), -1)
121
+
122
+
123
+ def _axis_angle_rotation(axis: str, angle):
124
+ """
125
+ Return the rotation matrices for one of the rotations about an axis
126
+ of which Euler angles describe, for each value of the angle given.
127
+
128
+ Args:
129
+ axis: Axis label "X" or "Y or "Z".
130
+ angle: any shape tensor of Euler angles in radians
131
+
132
+ Returns:
133
+ Rotation matrices as tensor of shape (..., 3, 3).
134
+ """
135
+
136
+ cos = torch.cos(angle)
137
+ sin = torch.sin(angle)
138
+ one = torch.ones_like(angle)
139
+ zero = torch.zeros_like(angle)
140
+
141
+ if axis == "X":
142
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
143
+ if axis == "Y":
144
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
145
+ if axis == "Z":
146
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
147
+
148
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
149
+
150
+
151
+ def euler_angles_to_matrix(euler_angles, convention: str):
152
+ """
153
+ Convert rotations given as Euler angles in radians to rotation matrices.
154
+
155
+ Args:
156
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
157
+ convention: Convention string of three uppercase letters from
158
+ {"X", "Y", and "Z"}.
159
+
160
+ Returns:
161
+ Rotation matrices as tensor of shape (..., 3, 3).
162
+ """
163
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
164
+ raise ValueError("Invalid input euler angles.")
165
+ if len(convention) != 3:
166
+ raise ValueError("Convention must have 3 letters.")
167
+ if convention[1] in (convention[0], convention[2]):
168
+ raise ValueError(f"Invalid convention {convention}.")
169
+ for letter in convention:
170
+ if letter not in ("X", "Y", "Z"):
171
+ raise ValueError(f"Invalid letter {letter} in convention string.")
172
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
173
+ return functools.reduce(torch.matmul, matrices)
174
+
175
+
176
+ def _angle_from_tan(
177
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
178
+ ):
179
+ """
180
+ Extract the first or third Euler angle from the two members of
181
+ the matrix which are positive constant times its sine and cosine.
182
+
183
+ Args:
184
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
185
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
186
+ convention.
187
+ data: Rotation matrices as tensor of shape (..., 3, 3).
188
+ horizontal: Whether we are looking for the angle for the third axis,
189
+ which means the relevant entries are in the same row of the
190
+ rotation matrix. If not, they are in the same column.
191
+ tait_bryan: Whether the first and third axes in the convention differ.
192
+
193
+ Returns:
194
+ Euler Angles in radians for each matrix in dataset as a tensor
195
+ of shape (...).
196
+ """
197
+
198
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
199
+ if horizontal:
200
+ i2, i1 = i1, i2
201
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
202
+ if horizontal == even:
203
+ return torch.atan2(data[..., i1], data[..., i2])
204
+ if tait_bryan:
205
+ return torch.atan2(-data[..., i2], data[..., i1])
206
+ return torch.atan2(data[..., i2], -data[..., i1])
207
+
208
+
209
+ def _index_from_letter(letter: str):
210
+ if letter == "X":
211
+ return 0
212
+ if letter == "Y":
213
+ return 1
214
+ if letter == "Z":
215
+ return 2
216
+
217
+
218
+ def matrix_to_euler_angles(matrix, convention: str):
219
+ """
220
+ Convert rotations given as rotation matrices to Euler angles in radians.
221
+
222
+ Args:
223
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
224
+ convention: Convention string of three uppercase letters.
225
+
226
+ Returns:
227
+ Euler angles in radians as tensor of shape (..., 3).
228
+ """
229
+ if len(convention) != 3:
230
+ raise ValueError("Convention must have 3 letters.")
231
+ if convention[1] in (convention[0], convention[2]):
232
+ raise ValueError(f"Invalid convention {convention}.")
233
+ for letter in convention:
234
+ if letter not in ("X", "Y", "Z"):
235
+ raise ValueError(f"Invalid letter {letter} in convention string.")
236
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
237
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
238
+ i0 = _index_from_letter(convention[0])
239
+ i2 = _index_from_letter(convention[2])
240
+ tait_bryan = i0 != i2
241
+ if tait_bryan:
242
+ central_angle = torch.asin(
243
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
244
+ )
245
+ else:
246
+ central_angle = torch.acos(matrix[..., i0, i0])
247
+
248
+ o = (
249
+ _angle_from_tan(
250
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
251
+ ),
252
+ central_angle,
253
+ _angle_from_tan(
254
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
255
+ ),
256
+ )
257
+ return torch.stack(o, -1)
258
+
259
+
260
+ def random_quaternions(
261
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
262
+ ):
263
+ """
264
+ Generate random quaternions representing rotations,
265
+ i.e. versors with nonnegative real part.
266
+
267
+ Args:
268
+ n: Number of quaternions in a batch to return.
269
+ dtype: Type to return.
270
+ device: Desired device of returned tensor. Default:
271
+ uses the current device for the default tensor type.
272
+ requires_grad: Whether the resulting tensor should have the gradient
273
+ flag set.
274
+
275
+ Returns:
276
+ Quaternions as tensor of shape (N, 4).
277
+ """
278
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
279
+ s = (o * o).sum(1)
280
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
281
+ return o
282
+
283
+
284
+ def random_rotations(
285
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
286
+ ):
287
+ """
288
+ Generate random rotations as 3x3 rotation matrices.
289
+
290
+ Args:
291
+ n: Number of rotation matrices in a batch to return.
292
+ dtype: Type to return.
293
+ device: Device of returned tensor. Default: if None,
294
+ uses the current device for the default tensor type.
295
+ requires_grad: Whether the resulting tensor should have the gradient
296
+ flag set.
297
+
298
+ Returns:
299
+ Rotation matrices as tensor of shape (n, 3, 3).
300
+ """
301
+ quaternions = random_quaternions(
302
+ n, dtype=dtype, device=device, requires_grad=requires_grad
303
+ )
304
+ return quaternion_to_matrix(quaternions)
305
+
306
+
307
+ def random_rotation(
308
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
309
+ ):
310
+ """
311
+ Generate a single random 3x3 rotation matrix.
312
+
313
+ Args:
314
+ dtype: Type to return
315
+ device: Device of returned tensor. Default: if None,
316
+ uses the current device for the default tensor type
317
+ requires_grad: Whether the resulting tensor should have the gradient
318
+ flag set
319
+
320
+ Returns:
321
+ Rotation matrix as tensor of shape (3, 3).
322
+ """
323
+ return random_rotations(1, dtype, device, requires_grad)[0]
324
+
325
+
326
+ def standardize_quaternion(quaternions):
327
+ """
328
+ Convert a unit quaternion to a standard form: one in which the real
329
+ part is non negative.
330
+
331
+ Args:
332
+ quaternions: Quaternions with real part first,
333
+ as tensor of shape (..., 4).
334
+
335
+ Returns:
336
+ Standardized quaternions as tensor of shape (..., 4).
337
+ """
338
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
339
+
340
+
341
+ def quaternion_raw_multiply(a, b):
342
+ """
343
+ Multiply two quaternions.
344
+ Usual torch rules for broadcasting apply.
345
+
346
+ Args:
347
+ a: Quaternions as tensor of shape (..., 4), real part first.
348
+ b: Quaternions as tensor of shape (..., 4), real part first.
349
+
350
+ Returns:
351
+ The product of a and b, a tensor of quaternions shape (..., 4).
352
+ """
353
+ aw, ax, ay, az = torch.unbind(a, -1)
354
+ bw, bx, by, bz = torch.unbind(b, -1)
355
+ ow = aw * bw - ax * bx - ay * by - az * bz
356
+ ox = aw * bx + ax * bw + ay * bz - az * by
357
+ oy = aw * by - ax * bz + ay * bw + az * bx
358
+ oz = aw * bz + ax * by - ay * bx + az * bw
359
+ return torch.stack((ow, ox, oy, oz), -1)
360
+
361
+
362
+ def quaternion_multiply(a, b):
363
+ """
364
+ Multiply two quaternions representing rotations, returning the quaternion
365
+ representing their composition, i.e. the versor with nonnegative real part.
366
+ Usual torch rules for broadcasting apply.
367
+
368
+ Args:
369
+ a: Quaternions as tensor of shape (..., 4), real part first.
370
+ b: Quaternions as tensor of shape (..., 4), real part first.
371
+
372
+ Returns:
373
+ The product of a and b, a tensor of quaternions of shape (..., 4).
374
+ """
375
+ ab = quaternion_raw_multiply(a, b)
376
+ return standardize_quaternion(ab)
377
+
378
+
379
+ def quaternion_invert(quaternion):
380
+ """
381
+ Given a quaternion representing rotation, get the quaternion representing
382
+ its inverse.
383
+
384
+ Args:
385
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
386
+ first, which must be versors (unit quaternions).
387
+
388
+ Returns:
389
+ The inverse, a tensor of quaternions of shape (..., 4).
390
+ """
391
+
392
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
393
+
394
+
395
+ def quaternion_apply(quaternion, point):
396
+ """
397
+ Apply the rotation given by a quaternion to a 3D point.
398
+ Usual torch rules for broadcasting apply.
399
+
400
+ Args:
401
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
402
+ point: Tensor of 3D points of shape (..., 3).
403
+
404
+ Returns:
405
+ Tensor of rotated points of shape (..., 3).
406
+ """
407
+ if point.size(-1) != 3:
408
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
409
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
410
+ point_as_quaternion = torch.cat((real_parts, point), -1)
411
+ out = quaternion_raw_multiply(
412
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
413
+ quaternion_invert(quaternion),
414
+ )
415
+ return out[..., 1:]
416
+
417
+
418
+ def axis_angle_to_matrix(axis_angle):
419
+ """
420
+ Convert rotations given as axis/angle to rotation matrices.
421
+
422
+ Args:
423
+ axis_angle: Rotations given as a vector in axis angle form,
424
+ as a tensor of shape (..., 3), where the magnitude is
425
+ the angle turned anticlockwise in radians around the
426
+ vector's direction.
427
+
428
+ Returns:
429
+ Rotation matrices as tensor of shape (..., 3, 3).
430
+ """
431
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
432
+
433
+
434
+ def matrix_to_axis_angle(matrix):
435
+ """
436
+ Convert rotations given as rotation matrices to axis/angle.
437
+
438
+ Args:
439
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
440
+
441
+ Returns:
442
+ Rotations given as a vector in axis angle form, as a tensor
443
+ of shape (..., 3), where the magnitude is the angle
444
+ turned anticlockwise in radians around the vector's
445
+ direction.
446
+ """
447
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
448
+
449
+
450
+ def axis_angle_to_quaternion(axis_angle):
451
+ """
452
+ Convert rotations given as axis/angle to quaternions.
453
+
454
+ Args:
455
+ axis_angle: Rotations given as a vector in axis angle form,
456
+ as a tensor of shape (..., 3), where the magnitude is
457
+ the angle turned anticlockwise in radians around the
458
+ vector's direction.
459
+
460
+ Returns:
461
+ quaternions with real part first, as tensor of shape (..., 4).
462
+ """
463
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
464
+ half_angles = 0.5 * angles
465
+ eps = 1e-6
466
+ small_angles = angles.abs() < eps
467
+ sin_half_angles_over_angles = torch.empty_like(angles)
468
+ sin_half_angles_over_angles[~small_angles] = (
469
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
470
+ )
471
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
472
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
473
+ sin_half_angles_over_angles[small_angles] = (
474
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
475
+ )
476
+ quaternions = torch.cat(
477
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
478
+ )
479
+ return quaternions
480
+
481
+
482
+ def quaternion_to_axis_angle(quaternions):
483
+ """
484
+ Convert rotations given as quaternions to axis/angle.
485
+
486
+ Args:
487
+ quaternions: quaternions with real part first,
488
+ as tensor of shape (..., 4).
489
+
490
+ Returns:
491
+ Rotations given as a vector in axis angle form, as a tensor
492
+ of shape (..., 3), where the magnitude is the angle
493
+ turned anticlockwise in radians around the vector's
494
+ direction.
495
+ """
496
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
497
+ half_angles = torch.atan2(norms, quaternions[..., :1])
498
+ angles = 2 * half_angles
499
+ eps = 1e-6
500
+ small_angles = angles.abs() < eps
501
+ sin_half_angles_over_angles = torch.empty_like(angles)
502
+ sin_half_angles_over_angles[~small_angles] = (
503
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
504
+ )
505
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
506
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
507
+ sin_half_angles_over_angles[small_angles] = (
508
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
509
+ )
510
+ return quaternions[..., 1:] / sin_half_angles_over_angles
511
+
512
+
513
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
514
+ """
515
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
516
+ using Gram--Schmidt orthogonalisation per Section B of [1].
517
+ Args:
518
+ d6: 6D rotation representation, of size (*, 6)
519
+
520
+ Returns:
521
+ batch of rotation matrices of size (*, 3, 3)
522
+
523
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
524
+ On the Continuity of Rotation Representations in Neural Networks.
525
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
526
+ Retrieved from http://arxiv.org/abs/1812.07035
527
+ """
528
+
529
+ a1, a2 = d6[..., :3], d6[..., 3:]
530
+ b1 = F.normalize(a1, dim=-1)
531
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
532
+ b2 = F.normalize(b2, dim=-1)
533
+ b3 = torch.cross(b1, b2, dim=-1)
534
+ return torch.stack((b1, b2, b3), dim=-2)
535
+
536
+
537
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
538
+ """
539
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
540
+ by dropping the last row. Note that 6D representation is not unique.
541
+ Args:
542
+ matrix: batch of rotation matrices of size (*, 3, 3)
543
+
544
+ Returns:
545
+ 6D rotation representation, of size (*, 6)
546
+
547
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
548
+ On the Continuity of Rotation Representations in Neural Networks.
549
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
550
+ Retrieved from http://arxiv.org/abs/1812.07035
551
+ """
552
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
553
+
554
+
555
+ def axis_angle_to_rotation_6d(axis_angle):
556
+ """
557
+ Convert rotations given as axis/angle to 6D rotation representation by Zhou
558
+ et al. [1].
559
+
560
+ Args:
561
+ axis_angle: Rotations given as a vector in axis angle form,
562
+ as a tensor of shape (..., 3), where the magnitude is
563
+ the angle turned anticlockwise in radians around the
564
+ vector's direction.
565
+
566
+ Returns:
567
+ 6D rotation representation, of size (*, 6)
568
+ """
569
+ return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
diffposetalk/wav2vec2.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from packaging import version
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import transformers
8
+ from transformers import Wav2Vec2Model
9
+ from transformers.modeling_outputs import BaseModelOutput
10
+
11
+ _CONFIG_FOR_DOC = 'Wav2Vec2Config'
12
+
13
+
14
+ # the implementation of Wav2Vec2Model is borrowed from
15
+ # https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
16
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
17
+ def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
18
+ attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
19
+ bsz, all_sz = shape
20
+ mask = np.full((bsz, all_sz), False)
21
+
22
+ all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
23
+ all_num_mask = max(min_masks, all_num_mask)
24
+ mask_idcs = []
25
+ padding_mask = attention_mask.ne(1) if attention_mask is not None else None
26
+ for i in range(bsz):
27
+ if padding_mask is not None:
28
+ sz = all_sz - padding_mask[i].long().sum().item()
29
+ num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
30
+ num_mask = max(min_masks, num_mask)
31
+ else:
32
+ sz = all_sz
33
+ num_mask = all_num_mask
34
+
35
+ lengths = np.full(num_mask, mask_length)
36
+
37
+ if sum(lengths) == 0:
38
+ lengths[0] = min(mask_length, sz - 1)
39
+
40
+ min_len = min(lengths)
41
+ if sz - min_len <= num_mask:
42
+ min_len = sz - num_mask - 1
43
+
44
+ mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
45
+ mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
46
+ mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
47
+
48
+ min_len = min([len(m) for m in mask_idcs])
49
+ for i, mask_idc in enumerate(mask_idcs):
50
+ if len(mask_idc) > min_len:
51
+ mask_idc = np.random.choice(mask_idc, min_len, replace=False)
52
+ mask[i, mask_idc] = True
53
+ return mask
54
+
55
+
56
+ # linear interpolation layer
57
+ def linear_interpolation(features, input_fps, output_fps, output_len=None):
58
+ # features: (N, C, L)
59
+ seq_len = features.shape[2] / float(input_fps)
60
+ if output_len is None:
61
+ output_len = int(seq_len * output_fps)
62
+ output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
63
+ return output_features
64
+
65
+
66
+ class Wav2Vec2Model(Wav2Vec2Model):
67
+ def __init__(self, config):
68
+ super().__init__(config)
69
+ self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
70
+
71
+ def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
72
+ output_hidden_states=None, return_dict=None, frame_num=None):
73
+ self.config.output_attentions = True
74
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
75
+ output_hidden_states = (
76
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
77
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
78
+
79
+ hidden_states = self.feature_extractor(input_values) # (N, C, L)
80
+ # Resample the audio feature @ 50 fps to `output_fps`.
81
+ if frame_num is not None:
82
+ hidden_states_len = round(frame_num * 50 / output_fps)
83
+ hidden_states = hidden_states[:, :, :hidden_states_len]
84
+ hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
85
+ hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
86
+
87
+ if attention_mask is not None:
88
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
89
+ attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
90
+ device=hidden_states.device)
91
+ attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
92
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
93
+
94
+ if self.is_old_version:
95
+ hidden_states = self.feature_projection(hidden_states)
96
+ else:
97
+ hidden_states = self.feature_projection(hidden_states)[0]
98
+
99
+ if self.config.apply_spec_augment and self.training:
100
+ batch_size, sequence_length, hidden_size = hidden_states.size()
101
+ if self.config.mask_time_prob > 0:
102
+ mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
103
+ self.config.mask_time_length, attention_mask=attention_mask,
104
+ min_masks=2, )
105
+ hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
106
+ if self.config.mask_feature_prob > 0:
107
+ mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
108
+ self.config.mask_feature_length, )
109
+ mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
110
+ hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
111
+ encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
112
+ output_attentions=output_attentions, output_hidden_states=output_hidden_states,
113
+ return_dict=return_dict, )
114
+ hidden_states = encoder_outputs[0]
115
+ if not return_dict:
116
+ return (hidden_states,) + encoder_outputs[1:]
117
+
118
+ return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
119
+ attentions=encoder_outputs.attentions, )