Spaces:
Running
on
L40S
Running
on
L40S
Upload 83 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- LICENSE.txt +38 -0
- README.md +194 -14
- assets/.DS_Store +0 -0
- assets/demo.gif +3 -0
- assets/driving_audio/1.wav +3 -0
- assets/driving_audio/2.wav +3 -0
- assets/driving_audio/3.wav +3 -0
- assets/driving_audio/4.wav +3 -0
- assets/driving_audio/5.wav +3 -0
- assets/driving_audio/6.wav +3 -0
- assets/driving_video/.DS_Store +0 -0
- assets/driving_video/1.mp4 +3 -0
- assets/driving_video/2.mp4 +3 -0
- assets/driving_video/3.mp4 +3 -0
- assets/driving_video/4.mp4 +3 -0
- assets/driving_video/5.mp4 +3 -0
- assets/driving_video/6.mp4 +3 -0
- assets/driving_video/7.mp4 +3 -0
- assets/driving_video/8.mp4 +3 -0
- assets/logo.png +0 -0
- assets/ref_images/1.png +3 -0
- assets/ref_images/10.png +3 -0
- assets/ref_images/11.png +3 -0
- assets/ref_images/12.png +3 -0
- assets/ref_images/13.png +3 -0
- assets/ref_images/14.png +3 -0
- assets/ref_images/15.png +3 -0
- assets/ref_images/16.png +3 -0
- assets/ref_images/17.png +3 -0
- assets/ref_images/18.png +3 -0
- assets/ref_images/19.png +3 -0
- assets/ref_images/2.png +3 -0
- assets/ref_images/20.png +0 -0
- assets/ref_images/3.png +3 -0
- assets/ref_images/4.png +3 -0
- assets/ref_images/5.png +3 -0
- assets/ref_images/6.png +3 -0
- assets/ref_images/7.png +3 -0
- assets/ref_images/8.png +3 -0
- diffposetalk/common.py +46 -0
- diffposetalk/diff_talking_head.py +536 -0
- diffposetalk/diffposetalk.py +228 -0
- diffposetalk/hubert.py +51 -0
- diffposetalk/utils/__init__.py +1 -0
- diffposetalk/utils/common.py +378 -0
- diffposetalk/utils/media.py +35 -0
- diffposetalk/utils/renderer.py +147 -0
- diffposetalk/utils/rotation_conversions.py +569 -0
- diffposetalk/wav2vec2.py +119 -0
.gitattributes
CHANGED
@@ -33,3 +33,38 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/demo.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/driving_audio/1.wav filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/driving_audio/2.wav filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/driving_audio/3.wav filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/driving_audio/4.wav filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/driving_audio/5.wav filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/driving_audio/6.wav filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/driving_video/1.mp4 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/driving_video/2.mp4 filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/driving_video/3.mp4 filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/driving_video/4.mp4 filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/driving_video/5.mp4 filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/driving_video/6.mp4 filter=lfs diff=lfs merge=lfs -text
|
49 |
+
assets/driving_video/7.mp4 filter=lfs diff=lfs merge=lfs -text
|
50 |
+
assets/driving_video/8.mp4 filter=lfs diff=lfs merge=lfs -text
|
51 |
+
assets/ref_images/1.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
assets/ref_images/10.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
assets/ref_images/11.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
assets/ref_images/12.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
assets/ref_images/13.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
assets/ref_images/14.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
assets/ref_images/15.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
assets/ref_images/16.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
assets/ref_images/17.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
assets/ref_images/18.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
assets/ref_images/19.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
assets/ref_images/2.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
assets/ref_images/3.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
assets/ref_images/4.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
assets/ref_images/5.png filter=lfs diff=lfs merge=lfs -text
|
66 |
+
assets/ref_images/6.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
assets/ref_images/7.png filter=lfs diff=lfs merge=lfs -text
|
68 |
+
assets/ref_images/8.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
skyreels_a1/src/media_pipe/mp_models/face_landmarker_v2_with_blendshapes.task filter=lfs diff=lfs merge=lfs -text
|
70 |
+
skyreels_a1/src/media_pipe/mp_models/pose_landmarker_heavy.task filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- en
|
4 |
+
- zh
|
5 |
+
license: other
|
6 |
+
tasks:
|
7 |
+
- text-generation
|
8 |
+
|
9 |
+
---
|
10 |
+
|
11 |
+
<!-- markdownlint-disable first-line-h1 -->
|
12 |
+
<!-- markdownlint-disable html -->
|
13 |
+
|
14 |
+
# <span id="Terms">声明与协议/Terms and Conditions</span>
|
15 |
+
|
16 |
+
## 声明
|
17 |
+
|
18 |
+
我们在此声明,不要利用Skywork模型进行任何危害国家社会安全或违法的活动。另外,我们也要求使用者不要将 Skywork 模型用于未经适当安全审查和备案的互联网服务。我们希望所有的使用者都能遵守这个原则,确保科技的发展能在规范和合法的环境下进行。
|
19 |
+
|
20 |
+
我们已经尽我们所能,来确保模型训练过程中使用的数据的合规性。然而,尽管我们已经做出了巨大的努力,但由于模型和数据的复杂性,仍有可能存在一些无法预见的问题。因此,如果由于使用skywork开源模型而导致的任何问题,包括但不限于数据安全问题、公共舆论风险,或模型被误导、滥用、传播或不当利用所带来的任何风险和问题,我们将不承担任何责任。
|
21 |
+
|
22 |
+
We hereby declare that the Skywork model should not be used for any activities that pose a threat to national or societal security or engage in unlawful actions. Additionally, we request users not to deploy the Skywork model for internet services without appropriate security reviews and records. We hope that all users will adhere to this principle to ensure that technological advancements occur in a regulated and lawful environment.
|
23 |
+
|
24 |
+
We have done our utmost to ensure the compliance of the data used during the model's training process. However, despite our extensive efforts, due to the complexity of the model and data, there may still be unpredictable risks and issues. Therefore, if any problems arise as a result of using the Skywork open-source model, including but not limited to data security issues, public opinion risks, or any risks and problems arising from the model being misled, abused, disseminated, or improperly utilized, we will not assume any responsibility.
|
25 |
+
|
26 |
+
## 协议
|
27 |
+
|
28 |
+
社区使用Skywork模型需要遵循[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)。Skywork模型支持商业用途,如果您计划将Skywork模型或其衍生品用于商业目的,无需再次申请, 但请您仔细阅读[《Skywork 模型社区许可协议》](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf)并严格遵守相关条款。
|
29 |
+
|
30 |
+
|
31 |
+
The community usage of Skywork model requires [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf). The Skywork model supports commercial use. If you plan to use the Skywork model or its derivatives for commercial purposes, you must abide by terms and conditions within [Skywork Community License](https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20Community%20License.pdf).
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
[《Skywork 模型社区许可协议》》]:https://github.com/SkyworkAI/Skywork/blob/main/Skywork%20模型社区许可协议.pdf
|
36 |
+
|
37 |
+
|
38 |
+
[[email protected]]: mailto:[email protected]
|
README.md
CHANGED
@@ -1,14 +1,194 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<img src="assets/logo.png" alt="Skyreels Logo" width="50%">
|
3 |
+
</p>
|
4 |
+
|
5 |
+
|
6 |
+
<h1 align="center">SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers</h1>
|
7 |
+
<div align='center'>
|
8 |
+
<a href='https://scholar.google.com/citations?user=6D_nzucAAAAJ&hl=en' target='_blank'>Di Qiu</a> 
|
9 |
+
<a href='https://scholar.google.com/citations?user=_43YnBcAAAAJ&hl=zh-CN' target='_blank'>Zhengcong Fei</a> 
|
10 |
+
<a href='' target='_blank'>Rui Wang</a> 
|
11 |
+
<a href='' target='_blank'>Jialin Bai</a> 
|
12 |
+
<a href='https://scholar.google.com/citations?user=Hv-vj2sAAAAJ&hl=en' target='_blank'>Changqian Yu</a> 
|
13 |
+
</div>
|
14 |
+
|
15 |
+
<div align='center'>
|
16 |
+
<a href='https://scholar.google.com.au/citations?user=ePIeVuUAAAAJ&hl=en' target='_blank'>Mingyuan Fan</a> 
|
17 |
+
<a href='https://scholar.google.com/citations?user=HukWSw4AAAAJ&hl=en' target='_blank'>Guibin Chen</a> 
|
18 |
+
<a href='https://scholar.google.com.tw/citations?user=RvAuMk0AAAAJ&hl=zh-CN' target='_blank'>Xiang Wen</a> 
|
19 |
+
</div>
|
20 |
+
|
21 |
+
<div align='center'>
|
22 |
+
<small><strong>Skywork AI</strong></small>
|
23 |
+
</div>
|
24 |
+
|
25 |
+
<br>
|
26 |
+
|
27 |
+
<div align="center">
|
28 |
+
<!-- <a href='LICENSE'><img src='https://img.shields.io/badge/license-MIT-yellow'></a> -->
|
29 |
+
<a href='https://arxiv.org/abs/2502.10841'><img src='https://img.shields.io/badge/arXiv-SkyReels A1-red'></a>
|
30 |
+
<a href='https://skyworkai.github.io/skyreels-a1.github.io/'><img src='https://img.shields.io/badge/Project-SkyReels A1-green'></a>
|
31 |
+
<a href='https://huggingface.co/Skywork/SkyReels-A1'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue'></a>
|
32 |
+
<a href='https://www.skyreels.ai/home?utm_campaign=github_A1'><img src='https://img.shields.io/badge/Playground-Spaces-yellow'></a>
|
33 |
+
<br>
|
34 |
+
</div>
|
35 |
+
<br>
|
36 |
+
|
37 |
+
|
38 |
+
<p align="center">
|
39 |
+
<img src="./assets/demo.gif" alt="showcase">
|
40 |
+
<br>
|
41 |
+
🔥 For more results, visit our <a href="https://skyworkai.github.io/skyreels-a1.github.io/"><strong>homepage</strong></a> 🔥
|
42 |
+
</p>
|
43 |
+
|
44 |
+
<p align="center">
|
45 |
+
👋 Join our <a href="https://discord.gg/PwM6NYtccQ" target="_blank"><strong>Discord</strong></a>
|
46 |
+
</p>
|
47 |
+
|
48 |
+
|
49 |
+
This repo, named **SkyReels-A1**, contains the official PyTorch implementation of our paper [SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers](https://arxiv.org).
|
50 |
+
|
51 |
+
|
52 |
+
## 🔥🔥🔥 News!!
|
53 |
+
* Mar 4, 2025: 🔥 We release audio-driven portrait image animation pipeline.
|
54 |
+
* Feb 18, 2025: 👋 We release the inference code and model weights of SkyReels-A1. [Download](https://huggingface.co/Skywork/SkyReels-A1)
|
55 |
+
* Feb 18, 2025: 🎉 We have made our technical report available as open source. [Read](https://skyworkai.github.io/skyreels-a1.github.io/report.pdf)
|
56 |
+
* Feb 18, 2025: 🔥 Our online demo of LipSync is available on SkyReels now! Try out [LipSync](https://www.skyreels.ai/home/tools/lip-sync?refer=navbar).
|
57 |
+
* Feb 18, 2025: 🔥 We have open-sourced I2V video generation model [SkyReels-V1](https://github.com/SkyworkAI/SkyReels-V1). This is the first and most advanced open-source human-centric video foundation model.
|
58 |
+
|
59 |
+
## 📑 TODO List
|
60 |
+
- [x] Checkpoints
|
61 |
+
- [x] Inference Code
|
62 |
+
- [x] Web Demo (Gradio)
|
63 |
+
- [x] Audio-driven Portrait Image Animation Pipeline
|
64 |
+
- [ ] Inference Code for Long Videos
|
65 |
+
- [ ] User-Level GPU Inference on RTX4090
|
66 |
+
- [ ] ComfyUI
|
67 |
+
|
68 |
+
|
69 |
+
## Getting Started 🏁
|
70 |
+
|
71 |
+
### 1. Clone the code and prepare the environment 🛠️
|
72 |
+
First git clone the repository with code:
|
73 |
+
```bash
|
74 |
+
git clone https://github.com/SkyworkAI/SkyReels-A1.git
|
75 |
+
cd SkyReels-A1
|
76 |
+
|
77 |
+
# create env using conda
|
78 |
+
conda create -n skyreels-a1 python=3.10
|
79 |
+
conda activate skyreels-a1
|
80 |
+
```
|
81 |
+
Then, install the remaining dependencies:
|
82 |
+
```bash
|
83 |
+
pip install -r requirements.txt
|
84 |
+
```
|
85 |
+
|
86 |
+
|
87 |
+
### 2. Download pretrained weights 📥
|
88 |
+
You can download the pretrained weights is from HuggingFace:
|
89 |
+
```bash
|
90 |
+
# !pip install -U "huggingface_hub[cli]"
|
91 |
+
huggingface-cli download SkyReels-A1 --local-dir local_path --exclude "*.git*" "README.md" "docs"
|
92 |
+
```
|
93 |
+
|
94 |
+
The FLAME, mediapipe, and smirk models are located in the SkyReels-A1/extra_models folder.
|
95 |
+
|
96 |
+
The directory structure of our SkyReels-A1 code is formulated as:
|
97 |
+
```text
|
98 |
+
pretrained_models
|
99 |
+
├── FLAME
|
100 |
+
├── SkyReels-A1-5B
|
101 |
+
│ ├── pose_guider
|
102 |
+
│ ├── scheduler
|
103 |
+
│ ├── tokenizer
|
104 |
+
│ ├── siglip-so400m-patch14-384
|
105 |
+
│ ├── transformer
|
106 |
+
│ ├── vae
|
107 |
+
│ └── text_encoder
|
108 |
+
├── mediapipe
|
109 |
+
└── smirk
|
110 |
+
|
111 |
+
```
|
112 |
+
|
113 |
+
#### Download DiffposeTalk assets and pretrained weights (For Audio-driven)
|
114 |
+
|
115 |
+
- We use [diffposetalk](https://github.com/DiffPoseTalk/DiffPoseTalk/tree/main) to generate flame coefficients from audio, thereby constructing motion signals.
|
116 |
+
|
117 |
+
- Download the diffposetalk code and follow its README to download the weights and related data.
|
118 |
+
|
119 |
+
- Then place them in the specified directory.
|
120 |
+
|
121 |
+
```bash
|
122 |
+
cp -r ${diffposetalk_root}/style pretrained_models/diffposetalk
|
123 |
+
cp ${diffposetalk_root}/experiments/DPT/head-SA-hubert-WM/checkpoints/iter_0110000.pt pretrained_models/diffposetalk
|
124 |
+
cp ${diffposetalk_root}/datasets/HDTF_TFHP/lmdb/stats_train.npz pretrained_models/diffposetalk
|
125 |
+
```
|
126 |
+
|
127 |
+
```text
|
128 |
+
pretrained_models
|
129 |
+
├── FLAME
|
130 |
+
├── SkyReels-A1-5B
|
131 |
+
├── mediapipe
|
132 |
+
├── diffposetalk
|
133 |
+
│ ├── style
|
134 |
+
│ ├── iter_0110000.pt
|
135 |
+
│ ├── states_train.npz
|
136 |
+
└── smirk
|
137 |
+
|
138 |
+
```
|
139 |
+
|
140 |
+
|
141 |
+
### 3. Inference 🚀
|
142 |
+
You can simply run the inference scripts as:
|
143 |
+
```bash
|
144 |
+
python inference.py
|
145 |
+
|
146 |
+
# inference audio to video
|
147 |
+
python inference_audio.py
|
148 |
+
```
|
149 |
+
|
150 |
+
If the script runs successfully, you will get an output mp4 file. This file includes the following results: driving video, input image or video, and generated result.
|
151 |
+
|
152 |
+
|
153 |
+
## Gradio Interface 🤗
|
154 |
+
|
155 |
+
We provide a [Gradio](https://huggingface.co/docs/hub/spaces-sdks-gradio) interface for a better experience, just run by:
|
156 |
+
|
157 |
+
```bash
|
158 |
+
python app.py
|
159 |
+
```
|
160 |
+
|
161 |
+
The graphical interactive interface is shown as below:
|
162 |
+
|
163 |
+

|
164 |
+
|
165 |
+
|
166 |
+
## Metric Evaluation 👓
|
167 |
+
|
168 |
+
We also provide all scripts for automatically calculating the metrics, including SimFace, FID, and L1 distance between expression and motion, reported in the paper.
|
169 |
+
|
170 |
+
All codes can be found in the ```eval``` folder. After setting the video result path, run the following commands in sequence:
|
171 |
+
|
172 |
+
```bash
|
173 |
+
python arc_score.py
|
174 |
+
python expression_score.py
|
175 |
+
python pose_score.py
|
176 |
+
```
|
177 |
+
|
178 |
+
|
179 |
+
## Acknowledgements 💐
|
180 |
+
We would like to thank the contributors of [CogvideoX](https://github.com/THUDM/CogVideo), [finetrainers](https://github.com/a-r-r-o-w/finetrainers) and [DiffPoseTalk](https://github.com/DiffPoseTalk/DiffPoseTalk)repositories, for their open research and contributions.
|
181 |
+
|
182 |
+
## Citation 💖
|
183 |
+
If you find SkyReels-A1 useful for your research, welcome to 🌟 this repo and cite our work using the following BibTeX:
|
184 |
+
```bibtex
|
185 |
+
@article{qiu2025skyreels,
|
186 |
+
title={SkyReels-A1: Expressive Portrait Animation in Video Diffusion Transformers},
|
187 |
+
author={Qiu, Di and Fei, Zhengcong and Wang, Rui and Bai, Jialin and Yu, Changqian and Fan, Mingyuan and Chen, Guibin and Wen, Xiang},
|
188 |
+
journal={arXiv preprint arXiv:2502.10841},
|
189 |
+
year={2025}
|
190 |
+
}
|
191 |
+
```
|
192 |
+
|
193 |
+
|
194 |
+
|
assets/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/demo.gif
ADDED
![]() |
Git LFS Details
|
assets/driving_audio/1.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:38dab65a002455f4c186d4b0bde848c964415441d9636d94a48d5b32f23b0f6f
|
3 |
+
size 575850
|
assets/driving_audio/2.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:15b998fd3fbabf22e9fde210f93df4f7b1647e19fe81b2d33b2b74470fea32b5
|
3 |
+
size 3891278
|
assets/driving_audio/3.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:448576f545e18c4cb33cb934a00c0bc3f331896eba4d1cb6a077c1b9382d0628
|
3 |
+
size 910770
|
assets/driving_audio/4.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a3e0d37fc6e235b5a09eb4b7e3b0b5d5f566e2204c5c815c23ca2215dcbf9c93
|
3 |
+
size 553038
|
assets/driving_audio/5.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95275c9299919e38e52789cb3af17ddc4691b7afea82f26c7edc640addce057d
|
3 |
+
size 856142
|
assets/driving_audio/6.wav
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7558a976b8b64d214f33c65e503c77271d3be0cd116a00ddadcb2b2fc53a6396
|
3 |
+
size 2641742
|
assets/driving_video/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
assets/driving_video/1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b7da4f10cf9e692ba8c75848bacceb3c4d30ee8d3b07719435560c44a8da6544
|
3 |
+
size 306996
|
assets/driving_video/2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e795b7be655c4b5ae8cac0733a32e8d321ccebd13f2cac07cc15dfc8f61a547
|
3 |
+
size 2875843
|
assets/driving_video/3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:02f5ee85c1028c9673c70682b533a4f22e203173eddd40de42bad0cb57f18abb
|
3 |
+
size 1020948
|
assets/driving_video/4.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6f7ddbb17b198a580f658d57f4d83bee7489aa4d8a677f2c45b76b1ec01ae461
|
3 |
+
size 215144
|
assets/driving_video/5.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9637fea5ef83b494a0aa8b7c526ae1efc6ec94d79dfa94381de8d6f38eec238e
|
3 |
+
size 556047
|
assets/driving_video/6.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac7ee3c2419046f11dc230b6db33c2391a98334eba2b1d773e7eb9627992622f
|
3 |
+
size 1064930
|
assets/driving_video/7.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1dc94c1fec7ef7dc831c8a49f0e1788ae568812cb68e62f6875d9070f573d02a
|
3 |
+
size 187263
|
assets/driving_video/8.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3047ba66296d96b8a4584e412e61493d7bc0fa5149c77b130e7feea375e698bd
|
3 |
+
size 232859
|
assets/logo.png
ADDED
![]() |
assets/ref_images/1.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/10.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/11.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/12.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/13.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/14.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/15.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/16.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/17.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/18.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/19.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/2.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/20.png
ADDED
![]() |
assets/ref_images/3.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/4.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/5.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/6.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/7.png
ADDED
![]() |
Git LFS Details
|
assets/ref_images/8.png
ADDED
![]() |
Git LFS Details
|
diffposetalk/common.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class PositionalEncoding(nn.Module):
|
9 |
+
def __init__(self, d_model, dropout=0.1, max_len=600):
|
10 |
+
super().__init__()
|
11 |
+
self.dropout = nn.Dropout(p=dropout)
|
12 |
+
# vanilla sinusoidal encoding
|
13 |
+
pe = torch.zeros(max_len, d_model)
|
14 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
15 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
16 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
pe = pe.unsqueeze(0)
|
19 |
+
self.register_buffer('pe', pe)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = x + self.pe[:, x.shape[1], :]
|
23 |
+
return self.dropout(x)
|
24 |
+
|
25 |
+
|
26 |
+
def enc_dec_mask(T, S, frame_width=2, expansion=0, device='cuda'):
|
27 |
+
mask = torch.ones(T, S)
|
28 |
+
for i in range(T):
|
29 |
+
mask[i, max(0, (i - expansion) * frame_width):(i + expansion + 1) * frame_width] = 0
|
30 |
+
return (mask == 1).to(device=device)
|
31 |
+
|
32 |
+
|
33 |
+
def pad_audio(audio, audio_unit=320, pad_threshold=80):
|
34 |
+
batch_size, audio_len = audio.shape
|
35 |
+
n_units = audio_len // audio_unit
|
36 |
+
side_len = math.ceil((audio_unit * n_units + pad_threshold - audio_len) / 2)
|
37 |
+
if side_len >= 0:
|
38 |
+
reflect_len = side_len // 2
|
39 |
+
replicate_len = side_len % 2
|
40 |
+
if reflect_len > 0:
|
41 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
42 |
+
audio = F.pad(audio, (reflect_len, reflect_len), mode='reflect')
|
43 |
+
if replicate_len > 0:
|
44 |
+
audio = F.pad(audio, (1, 1), mode='replicate')
|
45 |
+
|
46 |
+
return audio
|
diffposetalk/diff_talking_head.py
ADDED
@@ -0,0 +1,536 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .common import PositionalEncoding, enc_dec_mask, pad_audio
|
6 |
+
|
7 |
+
|
8 |
+
class DiffusionSchedule(nn.Module):
|
9 |
+
def __init__(self, num_steps, mode='linear', beta_1=1e-4, beta_T=0.02, s=0.008):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
if mode == 'linear':
|
13 |
+
betas = torch.linspace(beta_1, beta_T, num_steps)
|
14 |
+
elif mode == 'quadratic':
|
15 |
+
betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2
|
16 |
+
elif mode == 'sigmoid':
|
17 |
+
betas = torch.sigmoid(torch.linspace(-5, 5, num_steps)) * (beta_T - beta_1) + beta_1
|
18 |
+
elif mode == 'cosine':
|
19 |
+
steps = num_steps + 1
|
20 |
+
x = torch.linspace(0, num_steps, steps)
|
21 |
+
alpha_bars = torch.cos(((x / num_steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
22 |
+
alpha_bars = alpha_bars / alpha_bars[0]
|
23 |
+
betas = 1 - (alpha_bars[1:] / alpha_bars[:-1])
|
24 |
+
betas = torch.clip(betas, 0.0001, 0.999)
|
25 |
+
else:
|
26 |
+
raise ValueError(f'Unknown diffusion schedule {mode}!')
|
27 |
+
betas = torch.cat([torch.zeros(1), betas], dim=0) # Padding beta_0 = 0
|
28 |
+
|
29 |
+
alphas = 1 - betas
|
30 |
+
log_alphas = torch.log(alphas)
|
31 |
+
for i in range(1, log_alphas.shape[0]): # 1 to T
|
32 |
+
log_alphas[i] += log_alphas[i - 1]
|
33 |
+
alpha_bars = log_alphas.exp()
|
34 |
+
|
35 |
+
sigmas_flex = torch.sqrt(betas)
|
36 |
+
sigmas_inflex = torch.zeros_like(sigmas_flex)
|
37 |
+
for i in range(1, sigmas_flex.shape[0]):
|
38 |
+
sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[i]
|
39 |
+
sigmas_inflex = torch.sqrt(sigmas_inflex)
|
40 |
+
|
41 |
+
self.num_steps = num_steps
|
42 |
+
self.register_buffer('betas', betas)
|
43 |
+
self.register_buffer('alphas', alphas)
|
44 |
+
self.register_buffer('alpha_bars', alpha_bars)
|
45 |
+
self.register_buffer('sigmas_flex', sigmas_flex)
|
46 |
+
self.register_buffer('sigmas_inflex', sigmas_inflex)
|
47 |
+
|
48 |
+
def uniform_sample_t(self, batch_size):
|
49 |
+
ts = torch.randint(1, self.num_steps + 1, (batch_size,))
|
50 |
+
return ts.tolist()
|
51 |
+
|
52 |
+
def get_sigmas(self, t, flexibility=0):
|
53 |
+
assert 0 <= flexibility <= 1
|
54 |
+
sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility)
|
55 |
+
return sigmas
|
56 |
+
|
57 |
+
|
58 |
+
class DiffTalkingHead(nn.Module):
|
59 |
+
def __init__(self, args, device='cuda'):
|
60 |
+
super().__init__()
|
61 |
+
|
62 |
+
# Model parameters
|
63 |
+
self.target = args.target
|
64 |
+
self.architecture = args.architecture
|
65 |
+
self.use_style = args.style_enc_ckpt is not None
|
66 |
+
|
67 |
+
self.motion_feat_dim = 50
|
68 |
+
if args.rot_repr == 'aa':
|
69 |
+
self.motion_feat_dim += 1 if args.no_head_pose else 4
|
70 |
+
else:
|
71 |
+
raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
|
72 |
+
|
73 |
+
self.fps = args.fps
|
74 |
+
self.n_motions = args.n_motions
|
75 |
+
self.n_prev_motions = args.n_prev_motions
|
76 |
+
if self.use_style:
|
77 |
+
self.style_feat_dim = args.d_style
|
78 |
+
|
79 |
+
# Audio encoder
|
80 |
+
self.audio_model = args.audio_model
|
81 |
+
if self.audio_model == 'wav2vec2':
|
82 |
+
from .wav2vec2 import Wav2Vec2Model
|
83 |
+
self.audio_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
|
84 |
+
# wav2vec 2.0 weights initialization
|
85 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
86 |
+
elif self.audio_model == 'hubert':
|
87 |
+
from .hubert import HubertModel
|
88 |
+
self.audio_encoder = HubertModel.from_pretrained('facebook/hubert-base-ls960')
|
89 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
90 |
+
|
91 |
+
frozen_layers = [0, 1]
|
92 |
+
for name, param in self.audio_encoder.named_parameters():
|
93 |
+
if name.startswith("feature_projection"):
|
94 |
+
param.requires_grad = False
|
95 |
+
if name.startswith("encoder.layers"):
|
96 |
+
layer = int(name.split(".")[2])
|
97 |
+
if layer in frozen_layers:
|
98 |
+
param.requires_grad = False
|
99 |
+
else:
|
100 |
+
raise ValueError(f'Unknown audio model {self.audio_model}!')
|
101 |
+
|
102 |
+
if args.architecture == 'decoder':
|
103 |
+
self.audio_feature_map = nn.Linear(768, args.feature_dim)
|
104 |
+
self.start_audio_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, args.feature_dim))
|
105 |
+
else:
|
106 |
+
raise ValueError(f'Unknown architecture {args.architecture}!')
|
107 |
+
|
108 |
+
self.start_motion_feat = nn.Parameter(torch.randn(1, self.n_prev_motions, self.motion_feat_dim))
|
109 |
+
|
110 |
+
# Diffusion model
|
111 |
+
self.denoising_net = DenoisingNetwork(args, device)
|
112 |
+
# diffusion schedule
|
113 |
+
self.diffusion_sched = DiffusionSchedule(args.n_diff_steps, args.diff_schedule)
|
114 |
+
|
115 |
+
# Classifier-free settings
|
116 |
+
self.cfg_mode = args.cfg_mode
|
117 |
+
guiding_conditions = args.guiding_conditions.split(',') if args.guiding_conditions else []
|
118 |
+
self.guiding_conditions = [cond for cond in guiding_conditions if cond in ['style', 'audio']]
|
119 |
+
if 'style' in self.guiding_conditions:
|
120 |
+
if not self.use_style:
|
121 |
+
raise ValueError('Cannot use style guiding without enabling it!')
|
122 |
+
self.null_style_feat = nn.Parameter(torch.randn(1, 1, self.style_feat_dim))
|
123 |
+
if 'audio' in self.guiding_conditions:
|
124 |
+
audio_feat_dim = args.feature_dim
|
125 |
+
self.null_audio_feat = nn.Parameter(torch.randn(1, 1, audio_feat_dim))
|
126 |
+
|
127 |
+
self.to(device)
|
128 |
+
|
129 |
+
@property
|
130 |
+
def device(self):
|
131 |
+
return next(self.parameters()).device
|
132 |
+
|
133 |
+
def forward(self, motion_feat, audio_or_feat, shape_feat, style_feat=None,
|
134 |
+
prev_motion_feat=None, prev_audio_feat=None, time_step=None, indicator=None):
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
motion_feat: (N, L, d_coef) motion coefficients or features
|
138 |
+
audio_or_feat: (N, L_audio) raw audio or audio feature
|
139 |
+
shape_feat: (N, d_shape) or (N, 1, d_shape)
|
140 |
+
style_feat: (N, d_style)
|
141 |
+
prev_motion_feat: (N, n_prev_motions, d_motion) previous motion coefficients or feature
|
142 |
+
prev_audio_feat: (N, n_prev_motions, d_audio) previous audio features
|
143 |
+
time_step: (N,)
|
144 |
+
indicator: (N, L) 0/1 indicator of real (unpadded) motion coefficients
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
motion_feat_noise: (N, L, d_motion)
|
148 |
+
"""
|
149 |
+
if self.use_style:
|
150 |
+
assert style_feat is not None, 'Missing style features!'
|
151 |
+
|
152 |
+
batch_size = motion_feat.shape[0]
|
153 |
+
|
154 |
+
if audio_or_feat.ndim == 2:
|
155 |
+
# Extract audio features
|
156 |
+
assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
|
157 |
+
f'Incorrect audio length {audio_or_feat.shape[1]}'
|
158 |
+
audio_feat_saved = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
|
159 |
+
elif audio_or_feat.ndim == 3:
|
160 |
+
assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
|
161 |
+
audio_feat_saved = audio_or_feat
|
162 |
+
else:
|
163 |
+
raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
|
164 |
+
audio_feat = audio_feat_saved.clone()
|
165 |
+
|
166 |
+
if shape_feat.ndim == 2:
|
167 |
+
shape_feat = shape_feat.unsqueeze(1) # (N, 1, d_shape)
|
168 |
+
if style_feat is not None and style_feat.ndim == 2:
|
169 |
+
style_feat = style_feat.unsqueeze(1) # (N, 1, d_style)
|
170 |
+
|
171 |
+
if prev_motion_feat is None:
|
172 |
+
prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
|
173 |
+
if prev_audio_feat is None:
|
174 |
+
# (N, n_prev_motions, feature_dim)
|
175 |
+
prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
|
176 |
+
|
177 |
+
# Classifier-free guidance
|
178 |
+
if len(self.guiding_conditions) > 0:
|
179 |
+
assert len(self.guiding_conditions) <= 2, 'Only support 1 or 2 CFG conditions!'
|
180 |
+
if len(self.guiding_conditions) == 1 or self.cfg_mode == 'independent':
|
181 |
+
null_cond_prob = 0.5 if len(self.guiding_conditions) >= 2 else 0.1
|
182 |
+
if 'style' in self.guiding_conditions:
|
183 |
+
mask_style = torch.rand(batch_size, device=self.device) < null_cond_prob
|
184 |
+
style_feat = torch.where(mask_style.view(-1, 1, 1),
|
185 |
+
self.null_style_feat.expand(batch_size, -1, -1),
|
186 |
+
style_feat)
|
187 |
+
if 'audio' in self.guiding_conditions:
|
188 |
+
mask_audio = torch.rand(batch_size, device=self.device) < null_cond_prob
|
189 |
+
audio_feat = torch.where(mask_audio.view(-1, 1, 1),
|
190 |
+
self.null_audio_feat.expand(batch_size, self.n_motions, -1),
|
191 |
+
audio_feat)
|
192 |
+
else:
|
193 |
+
# len(self.guiding_conditions) > 1 and self.cfg_mode == 'incremental'
|
194 |
+
# full (0.45), w/o style (0.45), w/o style or audio (0.1)
|
195 |
+
mask_flag = torch.rand(batch_size, device=self.device)
|
196 |
+
if 'style' in self.guiding_conditions:
|
197 |
+
mask_style = mask_flag > 0.55
|
198 |
+
style_feat = torch.where(mask_style.view(-1, 1, 1),
|
199 |
+
self.null_style_feat.expand(batch_size, -1, -1),
|
200 |
+
style_feat)
|
201 |
+
if 'audio' in self.guiding_conditions:
|
202 |
+
mask_audio = mask_flag > 0.9
|
203 |
+
audio_feat = torch.where(mask_audio.view(-1, 1, 1),
|
204 |
+
self.null_audio_feat.expand(batch_size, self.n_motions, -1),
|
205 |
+
audio_feat)
|
206 |
+
|
207 |
+
if style_feat is None:
|
208 |
+
# The model only accepts audio and shape features, i.e., self.use_style = False
|
209 |
+
person_feat = shape_feat
|
210 |
+
else:
|
211 |
+
person_feat = torch.cat([shape_feat, style_feat], dim=-1)
|
212 |
+
|
213 |
+
if time_step is None:
|
214 |
+
# Sample time step
|
215 |
+
time_step = self.diffusion_sched.uniform_sample_t(batch_size) # (N,)
|
216 |
+
|
217 |
+
# The forward diffusion process
|
218 |
+
alpha_bar = self.diffusion_sched.alpha_bars[time_step] # (N,)
|
219 |
+
c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (N, 1, 1)
|
220 |
+
c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (N, 1, 1)
|
221 |
+
|
222 |
+
eps = torch.randn_like(motion_feat) # (N, L, d_motion)
|
223 |
+
motion_feat_noisy = c0 * motion_feat + c1 * eps
|
224 |
+
|
225 |
+
# The reverse diffusion process
|
226 |
+
motion_feat_target = self.denoising_net(motion_feat_noisy, audio_feat, person_feat,
|
227 |
+
prev_motion_feat, prev_audio_feat, time_step, indicator)
|
228 |
+
|
229 |
+
return eps, motion_feat_target, motion_feat.detach(), audio_feat_saved.detach()
|
230 |
+
|
231 |
+
def extract_audio_feature(self, audio, frame_num=None):
|
232 |
+
frame_num = frame_num or self.n_motions
|
233 |
+
|
234 |
+
# # Strategy 1: resample during audio feature extraction
|
235 |
+
# hidden_states = self.audio_encoder(pad_audio(audio), self.fps, frame_num=frame_num).last_hidden_state # (N, L, 768)
|
236 |
+
|
237 |
+
# Strategy 2: resample after audio feature extraction (BackResample)
|
238 |
+
hidden_states = self.audio_encoder(pad_audio(audio), self.fps,
|
239 |
+
frame_num=frame_num * 2).last_hidden_state # (N, 2L, 768)
|
240 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, 768, 2L)
|
241 |
+
hidden_states = F.interpolate(hidden_states, size=frame_num, align_corners=False, mode='linear') # (N, 768, L)
|
242 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, L, 768)
|
243 |
+
|
244 |
+
audio_feat = self.audio_feature_map(hidden_states) # (N, L, feature_dim)
|
245 |
+
return audio_feat
|
246 |
+
|
247 |
+
@torch.no_grad()
|
248 |
+
def sample(self, audio_or_feat, shape_feat, style_feat=None, prev_motion_feat=None, prev_audio_feat=None,
|
249 |
+
motion_at_T=None, indicator=None, cfg_mode=None, cfg_cond=None, cfg_scale=1.15, flexibility=0,
|
250 |
+
dynamic_threshold=None, ret_traj=False):
|
251 |
+
# Check and convert inputs
|
252 |
+
batch_size = audio_or_feat.shape[0]
|
253 |
+
|
254 |
+
# Check CFG conditions
|
255 |
+
if cfg_mode is None: # Use default CFG mode
|
256 |
+
cfg_mode = self.cfg_mode
|
257 |
+
if cfg_cond is None: # Use default CFG conditions
|
258 |
+
cfg_cond = self.guiding_conditions
|
259 |
+
cfg_cond = [c for c in cfg_cond if c in ['audio', 'style']]
|
260 |
+
|
261 |
+
if not isinstance(cfg_scale, list):
|
262 |
+
cfg_scale = [cfg_scale] * len(cfg_cond)
|
263 |
+
|
264 |
+
# sort cfg_cond and cfg_scale
|
265 |
+
if len(cfg_cond) > 0:
|
266 |
+
cfg_cond, cfg_scale = zip(*sorted(zip(cfg_cond, cfg_scale), key=lambda x: ['audio', 'style'].index(x[0])))
|
267 |
+
else:
|
268 |
+
cfg_cond, cfg_scale = [], []
|
269 |
+
|
270 |
+
if 'style' in cfg_cond:
|
271 |
+
assert self.use_style and style_feat is not None
|
272 |
+
|
273 |
+
if self.use_style:
|
274 |
+
if style_feat is None: # use null style feature
|
275 |
+
style_feat = self.null_style_feat.expand(batch_size, -1, -1)
|
276 |
+
else:
|
277 |
+
assert style_feat is None, 'This model does not support style feature input!'
|
278 |
+
|
279 |
+
if audio_or_feat.ndim == 2:
|
280 |
+
# Extract audio features
|
281 |
+
assert audio_or_feat.shape[1] == 16000 * self.n_motions / self.fps, \
|
282 |
+
f'Incorrect audio length {audio_or_feat.shape[1]}'
|
283 |
+
audio_feat = self.extract_audio_feature(audio_or_feat) # (N, L, feature_dim)
|
284 |
+
elif audio_or_feat.ndim == 3:
|
285 |
+
assert audio_or_feat.shape[1] == self.n_motions, f'Incorrect audio feature length {audio_or_feat.shape[1]}'
|
286 |
+
audio_feat = audio_or_feat
|
287 |
+
else:
|
288 |
+
raise ValueError(f'Incorrect audio input shape {audio_or_feat.shape}')
|
289 |
+
|
290 |
+
if shape_feat.ndim == 2:
|
291 |
+
shape_feat = shape_feat.unsqueeze(1) # (N, 1, d_shape)
|
292 |
+
if style_feat is not None and style_feat.ndim == 2:
|
293 |
+
style_feat = style_feat.unsqueeze(1) # (N, 1, d_style)
|
294 |
+
|
295 |
+
if prev_motion_feat is None:
|
296 |
+
prev_motion_feat = self.start_motion_feat.expand(batch_size, -1, -1) # (N, n_prev_motions, d_motion)
|
297 |
+
if prev_audio_feat is None:
|
298 |
+
# (N, n_prev_motions, feature_dim)
|
299 |
+
prev_audio_feat = self.start_audio_feat.expand(batch_size, -1, -1)
|
300 |
+
|
301 |
+
if motion_at_T is None:
|
302 |
+
motion_at_T = torch.randn((batch_size, self.n_motions, self.motion_feat_dim)).to(self.device)
|
303 |
+
|
304 |
+
# Prepare input for the reverse diffusion process (including optional classifier-free guidance)
|
305 |
+
if 'audio' in cfg_cond:
|
306 |
+
audio_feat_null = self.null_audio_feat.expand(batch_size, self.n_motions, -1)
|
307 |
+
else:
|
308 |
+
audio_feat_null = audio_feat
|
309 |
+
|
310 |
+
if 'style' in cfg_cond:
|
311 |
+
person_feat_null = torch.cat([shape_feat, self.null_style_feat.expand(batch_size, -1, -1)], dim=-1)
|
312 |
+
else:
|
313 |
+
if self.use_style:
|
314 |
+
person_feat_null = torch.cat([shape_feat, style_feat], dim=-1)
|
315 |
+
else:
|
316 |
+
person_feat_null = shape_feat
|
317 |
+
|
318 |
+
audio_feat_in = [audio_feat_null]
|
319 |
+
person_feat_in = [person_feat_null]
|
320 |
+
for cond in cfg_cond:
|
321 |
+
if cond == 'audio':
|
322 |
+
audio_feat_in.append(audio_feat)
|
323 |
+
person_feat_in.append(person_feat_null)
|
324 |
+
elif cond == 'style':
|
325 |
+
if cfg_mode == 'independent':
|
326 |
+
audio_feat_in.append(audio_feat_null)
|
327 |
+
elif cfg_mode == 'incremental':
|
328 |
+
audio_feat_in.append(audio_feat)
|
329 |
+
else:
|
330 |
+
raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
|
331 |
+
person_feat_in.append(torch.cat([shape_feat, style_feat], dim=-1))
|
332 |
+
|
333 |
+
n_entries = len(audio_feat_in)
|
334 |
+
audio_feat_in = torch.cat(audio_feat_in, dim=0)
|
335 |
+
person_feat_in = torch.cat(person_feat_in, dim=0)
|
336 |
+
prev_motion_feat_in = torch.cat([prev_motion_feat] * n_entries, dim=0)
|
337 |
+
prev_audio_feat_in = torch.cat([prev_audio_feat] * n_entries, dim=0)
|
338 |
+
indicator_in = torch.cat([indicator] * n_entries, dim=0) if indicator is not None else None
|
339 |
+
|
340 |
+
traj = {self.diffusion_sched.num_steps: motion_at_T}
|
341 |
+
for t in range(self.diffusion_sched.num_steps, 0, -1):
|
342 |
+
if t > 1:
|
343 |
+
z = torch.randn_like(motion_at_T)
|
344 |
+
else:
|
345 |
+
z = torch.zeros_like(motion_at_T)
|
346 |
+
|
347 |
+
alpha = self.diffusion_sched.alphas[t]
|
348 |
+
alpha_bar = self.diffusion_sched.alpha_bars[t]
|
349 |
+
alpha_bar_prev = self.diffusion_sched.alpha_bars[t - 1]
|
350 |
+
sigma = self.diffusion_sched.get_sigmas(t, flexibility)
|
351 |
+
|
352 |
+
motion_at_t = traj[t]
|
353 |
+
motion_in = torch.cat([motion_at_t] * n_entries, dim=0)
|
354 |
+
step_in = torch.tensor([t] * batch_size, device=self.device)
|
355 |
+
step_in = torch.cat([step_in] * n_entries, dim=0)
|
356 |
+
|
357 |
+
results = self.denoising_net(motion_in, audio_feat_in, person_feat_in, prev_motion_feat_in,
|
358 |
+
prev_audio_feat_in, step_in, indicator_in)
|
359 |
+
|
360 |
+
# Apply thresholding if specified
|
361 |
+
if dynamic_threshold:
|
362 |
+
dt_ratio, dt_min, dt_max = dynamic_threshold
|
363 |
+
abs_results = results[:, -self.n_motions:].reshape(batch_size * n_entries, -1).abs()
|
364 |
+
s = torch.quantile(abs_results, dt_ratio, dim=1)
|
365 |
+
s = torch.clamp(s, min=dt_min, max=dt_max)
|
366 |
+
s = s[..., None, None]
|
367 |
+
results = torch.clamp(results, min=-s, max=s)
|
368 |
+
|
369 |
+
results = results.chunk(n_entries)
|
370 |
+
|
371 |
+
# Unconditional target (CFG) or the conditional target (non-CFG)
|
372 |
+
target_theta = results[0][:, -self.n_motions:]
|
373 |
+
# Classifier-free Guidance (optional)
|
374 |
+
for i in range(0, n_entries - 1):
|
375 |
+
if cfg_mode == 'independent':
|
376 |
+
target_theta += cfg_scale[i] * (
|
377 |
+
results[i + 1][:, -self.n_motions:] - results[0][:, -self.n_motions:])
|
378 |
+
elif cfg_mode == 'incremental':
|
379 |
+
target_theta += cfg_scale[i] * (
|
380 |
+
results[i + 1][:, -self.n_motions:] - results[i][:, -self.n_motions:])
|
381 |
+
else:
|
382 |
+
raise NotImplementedError(f'Unknown cfg_mode {cfg_mode}')
|
383 |
+
|
384 |
+
if self.target == 'noise':
|
385 |
+
c0 = 1 / torch.sqrt(alpha)
|
386 |
+
c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar)
|
387 |
+
motion_next = c0 * (motion_at_t - c1 * target_theta) + sigma * z
|
388 |
+
elif self.target == 'sample':
|
389 |
+
c0 = (1 - alpha_bar_prev) * torch.sqrt(alpha) / (1 - alpha_bar)
|
390 |
+
c1 = (1 - alpha) * torch.sqrt(alpha_bar_prev) / (1 - alpha_bar)
|
391 |
+
motion_next = c0 * motion_at_t + c1 * target_theta + sigma * z
|
392 |
+
else:
|
393 |
+
raise ValueError('Unknown target type: {}'.format(self.target))
|
394 |
+
|
395 |
+
traj[t - 1] = motion_next.detach() # Stop gradient and save trajectory.
|
396 |
+
traj[t] = traj[t].cpu() # Move previous output to CPU memory.
|
397 |
+
if not ret_traj:
|
398 |
+
del traj[t]
|
399 |
+
|
400 |
+
if ret_traj:
|
401 |
+
return traj, motion_at_T, audio_feat
|
402 |
+
else:
|
403 |
+
return traj[0], motion_at_T, audio_feat
|
404 |
+
|
405 |
+
|
406 |
+
class DenoisingNetwork(nn.Module):
|
407 |
+
def __init__(self, args, device='cuda'):
|
408 |
+
super().__init__()
|
409 |
+
|
410 |
+
# Model parameters
|
411 |
+
self.use_style = args.style_enc_ckpt is not None
|
412 |
+
self.motion_feat_dim = 50
|
413 |
+
if args.rot_repr == 'aa':
|
414 |
+
self.motion_feat_dim += 1 if args.no_head_pose else 4
|
415 |
+
else:
|
416 |
+
raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
|
417 |
+
self.shape_feat_dim = 100
|
418 |
+
if self.use_style:
|
419 |
+
self.style_feat_dim = args.d_style
|
420 |
+
self.person_feat_dim = self.shape_feat_dim + self.style_feat_dim
|
421 |
+
else:
|
422 |
+
self.person_feat_dim = self.shape_feat_dim
|
423 |
+
self.use_indicator = args.use_indicator
|
424 |
+
|
425 |
+
# Transformer
|
426 |
+
self.architecture = args.architecture
|
427 |
+
self.feature_dim = args.feature_dim
|
428 |
+
self.n_heads = args.n_heads
|
429 |
+
self.n_layers = args.n_layers
|
430 |
+
self.mlp_ratio = args.mlp_ratio
|
431 |
+
self.align_mask_width = args.align_mask_width
|
432 |
+
self.use_learnable_pe = not args.no_use_learnable_pe
|
433 |
+
# sequence length
|
434 |
+
self.n_prev_motions = args.n_prev_motions
|
435 |
+
self.n_motions = args.n_motions
|
436 |
+
|
437 |
+
# Temporal embedding for the diffusion time step
|
438 |
+
self.TE = PositionalEncoding(self.feature_dim, max_len=args.n_diff_steps + 1)
|
439 |
+
self.diff_step_map = nn.Sequential(
|
440 |
+
nn.Linear(self.feature_dim, self.feature_dim),
|
441 |
+
nn.GELU(),
|
442 |
+
nn.Linear(self.feature_dim, self.feature_dim)
|
443 |
+
)
|
444 |
+
|
445 |
+
if self.use_learnable_pe:
|
446 |
+
# Learnable positional encoding
|
447 |
+
self.PE = nn.Parameter(torch.randn(1, 1 + self.n_prev_motions + self.n_motions, self.feature_dim))
|
448 |
+
else:
|
449 |
+
self.PE = PositionalEncoding(self.feature_dim)
|
450 |
+
|
451 |
+
self.person_proj = nn.Linear(self.person_feat_dim, self.feature_dim)
|
452 |
+
|
453 |
+
# Transformer decoder
|
454 |
+
if self.architecture == 'decoder':
|
455 |
+
self.feature_proj = nn.Linear(self.motion_feat_dim + (1 if self.use_indicator else 0),
|
456 |
+
self.feature_dim)
|
457 |
+
decoder_layer = nn.TransformerDecoderLayer(
|
458 |
+
d_model=self.feature_dim, nhead=self.n_heads, dim_feedforward=self.mlp_ratio * self.feature_dim,
|
459 |
+
activation='gelu', batch_first=True
|
460 |
+
)
|
461 |
+
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=self.n_layers)
|
462 |
+
if self.align_mask_width > 0:
|
463 |
+
motion_len = self.n_prev_motions + self.n_motions
|
464 |
+
alignment_mask = enc_dec_mask(motion_len, motion_len, 1, self.align_mask_width - 1)
|
465 |
+
alignment_mask = F.pad(alignment_mask, (0, 0, 1, 0), value=False)
|
466 |
+
self.register_buffer('alignment_mask', alignment_mask)
|
467 |
+
else:
|
468 |
+
self.alignment_mask = None
|
469 |
+
else:
|
470 |
+
raise ValueError(f'Unknown architecture: {self.architecture}')
|
471 |
+
|
472 |
+
# Motion decoder
|
473 |
+
self.motion_dec = nn.Sequential(
|
474 |
+
nn.Linear(self.feature_dim, self.feature_dim // 2),
|
475 |
+
nn.GELU(),
|
476 |
+
nn.Linear(self.feature_dim // 2, self.motion_feat_dim)
|
477 |
+
)
|
478 |
+
|
479 |
+
self.to(device)
|
480 |
+
|
481 |
+
@property
|
482 |
+
def device(self):
|
483 |
+
return next(self.parameters()).device
|
484 |
+
|
485 |
+
def forward(self, motion_feat, audio_feat, person_feat, prev_motion_feat, prev_audio_feat, step, indicator=None):
|
486 |
+
"""
|
487 |
+
Args:
|
488 |
+
motion_feat: (N, L, d_motion). Noisy motion feature
|
489 |
+
audio_feat: (N, L, feature_dim)
|
490 |
+
person_feat: (N, 1, d_person)
|
491 |
+
prev_motion_feat: (N, L_p, d_motion). Padded previous motion coefficients or feature
|
492 |
+
prev_audio_feat: (N, L_p, d_audio). Padded previous motion coefficients or feature
|
493 |
+
step: (N,)
|
494 |
+
indicator: (N, L). 0/1 indicator for the real (unpadded) motion feature
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
motion_feat_target: (N, L_p + L, d_motion)
|
498 |
+
"""
|
499 |
+
# Diffusion time step embedding
|
500 |
+
diff_step_embedding = self.diff_step_map(self.TE.pe[0, step]).unsqueeze(1) # (N, 1, diff_step_dim)
|
501 |
+
|
502 |
+
person_feat = self.person_proj(person_feat) # (N, 1, feature_dim)
|
503 |
+
person_feat = person_feat + diff_step_embedding
|
504 |
+
|
505 |
+
if indicator is not None:
|
506 |
+
indicator = torch.cat([torch.zeros((indicator.shape[0], self.n_prev_motions), device=indicator.device),
|
507 |
+
indicator], dim=1) # (N, L_p + L)
|
508 |
+
indicator = indicator.unsqueeze(-1) # (N, L_p + L, 1)
|
509 |
+
|
510 |
+
# Concat features and embeddings
|
511 |
+
if self.architecture == 'decoder':
|
512 |
+
feats_in = torch.cat([prev_motion_feat, motion_feat], dim=1) # (N, L_p + L, d_motion)
|
513 |
+
else:
|
514 |
+
raise ValueError(f'Unknown architecture: {self.architecture}')
|
515 |
+
if self.use_indicator:
|
516 |
+
feats_in = torch.cat([feats_in, indicator], dim=-1) # (N, L_p + L, d_motion + d_audio + 1)
|
517 |
+
|
518 |
+
feats_in = self.feature_proj(feats_in) # (N, L_p + L, feature_dim)
|
519 |
+
feats_in = torch.cat([person_feat, feats_in], dim=1) # (N, 1 + L_p + L, feature_dim)
|
520 |
+
|
521 |
+
if self.use_learnable_pe:
|
522 |
+
feats_in = feats_in + self.PE
|
523 |
+
else:
|
524 |
+
feats_in = self.PE(feats_in)
|
525 |
+
|
526 |
+
# Transformer
|
527 |
+
if self.architecture == 'decoder':
|
528 |
+
audio_feat_in = torch.cat([prev_audio_feat, audio_feat], dim=1) # (N, L_p + L, d_audio)
|
529 |
+
feat_out = self.transformer(feats_in, audio_feat_in, memory_mask=self.alignment_mask)
|
530 |
+
else:
|
531 |
+
raise ValueError(f'Unknown architecture: {self.architecture}')
|
532 |
+
|
533 |
+
# Decode predicted motion feature noise / sample
|
534 |
+
motion_feat_target = self.motion_dec(feat_out[:, 1:]) # (N, L_p + L, d_motion)
|
535 |
+
|
536 |
+
return motion_feat_target
|
diffposetalk/diffposetalk.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import tempfile
|
3 |
+
import warnings
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import librosa
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from tqdm import tqdm
|
12 |
+
from pydantic import BaseModel
|
13 |
+
|
14 |
+
from .diff_talking_head import DiffTalkingHead
|
15 |
+
from .utils import NullableArgs, coef_dict_to_vertices, get_coef_dict
|
16 |
+
from .utils.media import combine_video_and_audio, convert_video, reencode_audio
|
17 |
+
|
18 |
+
warnings.filterwarnings('ignore', message='PySoundFile failed. Trying audioread instead.')
|
19 |
+
|
20 |
+
class DiffPoseTalkConfig(BaseModel):
|
21 |
+
no_context_audio_feat: bool = False
|
22 |
+
model_path: str = "pretrained_models/diffposetalk/iter_0110000.pt" # DPT/head-SA-hubert-WM
|
23 |
+
coef_stats: str = "pretrained_models/diffposetalk/stats_train.npz"
|
24 |
+
style_path: str = "pretrained_models/diffposetalk/style/L4H4-T0.1-BS32/iter_0034000/normal.npy"
|
25 |
+
dynamic_threshold_ratio: float = 0.99
|
26 |
+
dynamic_threshold_min: float = 1.0
|
27 |
+
dynamic_threshold_max: float = 4.0
|
28 |
+
scale_audio: float = 1.15
|
29 |
+
scale_style: float = 3.0
|
30 |
+
|
31 |
+
class DiffPoseTalk:
|
32 |
+
def __init__(self, config: DiffPoseTalkConfig = DiffPoseTalkConfig(), device="cuda"):
|
33 |
+
self.cfg = config
|
34 |
+
self.device = device
|
35 |
+
|
36 |
+
self.no_context_audio_feat = self.cfg.no_context_audio_feat
|
37 |
+
model_data = torch.load(self.cfg.model_path, map_location=self.device)
|
38 |
+
|
39 |
+
self.model_args = NullableArgs(model_data['args'])
|
40 |
+
self.model = DiffTalkingHead(self.model_args, self.device)
|
41 |
+
model_data['model'].pop('denoising_net.TE.pe')
|
42 |
+
self.model.load_state_dict(model_data['model'], strict=False)
|
43 |
+
self.model.to(self.device)
|
44 |
+
self.model.eval()
|
45 |
+
|
46 |
+
self.use_indicator = self.model_args.use_indicator
|
47 |
+
self.rot_repr = self.model_args.rot_repr
|
48 |
+
self.predict_head_pose = not self.model_args.no_head_pose
|
49 |
+
if self.model.use_style:
|
50 |
+
style_dir = Path(self.model_args.style_enc_ckpt)
|
51 |
+
style_dir = Path(*style_dir.with_suffix('').parts[-3::2])
|
52 |
+
self.style_dir = style_dir
|
53 |
+
|
54 |
+
# sequence
|
55 |
+
self.n_motions = self.model_args.n_motions
|
56 |
+
self.n_prev_motions = self.model_args.n_prev_motions
|
57 |
+
self.fps = self.model_args.fps
|
58 |
+
self.audio_unit = 16000. / self.fps # num of samples per frame
|
59 |
+
self.n_audio_samples = round(self.audio_unit * self.n_motions)
|
60 |
+
self.pad_mode = self.model_args.pad_mode
|
61 |
+
|
62 |
+
self.coef_stats = dict(np.load(self.cfg.coef_stats))
|
63 |
+
self.coef_stats = {k: torch.from_numpy(v).to(self.device) for k, v in self.coef_stats.items()}
|
64 |
+
|
65 |
+
if self.cfg.dynamic_threshold_ratio > 0:
|
66 |
+
self.dynamic_threshold = (self.cfg.dynamic_threshold_ratio, self.cfg.dynamic_threshold_min,
|
67 |
+
self.cfg.dynamic_threshold_max)
|
68 |
+
else:
|
69 |
+
self.dynamic_threshold = None
|
70 |
+
|
71 |
+
|
72 |
+
def infer_from_file(self, audio_path, shape_coef):
|
73 |
+
n_repetitions = 1
|
74 |
+
cfg_mode = None
|
75 |
+
cfg_cond = self.model.guiding_conditions
|
76 |
+
cfg_scale = []
|
77 |
+
for cond in cfg_cond:
|
78 |
+
if cond == 'audio':
|
79 |
+
cfg_scale.append(self.cfg.scale_audio)
|
80 |
+
elif cond == 'style':
|
81 |
+
cfg_scale.append(self.cfg.scale_style)
|
82 |
+
|
83 |
+
coef_dict = self.infer_coeffs(audio_path, shape_coef, self.cfg.style_path, n_repetitions,
|
84 |
+
cfg_mode, cfg_cond, cfg_scale, include_shape=True)
|
85 |
+
return coef_dict
|
86 |
+
|
87 |
+
@torch.no_grad()
|
88 |
+
def infer_coeffs(self, audio, shape_coef, style_feat=None, n_repetitions=1,
|
89 |
+
cfg_mode=None, cfg_cond=None, cfg_scale=1.15, include_shape=False):
|
90 |
+
# Returns dict[str, (n_repetitions, L, *)]
|
91 |
+
# Step 1: Preprocessing
|
92 |
+
# Preprocess audio
|
93 |
+
if isinstance(audio, (str, Path)):
|
94 |
+
audio, _ = librosa.load(audio, sr=16000, mono=True)
|
95 |
+
if isinstance(audio, np.ndarray):
|
96 |
+
audio = torch.from_numpy(audio).to(self.device)
|
97 |
+
assert audio.ndim == 1, 'Audio must be 1D tensor.'
|
98 |
+
audio_mean, audio_std = torch.mean(audio), torch.std(audio)
|
99 |
+
audio = (audio - audio_mean) / (audio_std + 1e-5)
|
100 |
+
|
101 |
+
# Preprocess shape coefficient
|
102 |
+
if isinstance(shape_coef, (str, Path)):
|
103 |
+
shape_coef = np.load(shape_coef)
|
104 |
+
if not isinstance(shape_coef, np.ndarray):
|
105 |
+
shape_coef = shape_coef['shape']
|
106 |
+
if isinstance(shape_coef, np.ndarray):
|
107 |
+
shape_coef = torch.from_numpy(shape_coef).float().to(self.device)
|
108 |
+
assert shape_coef.ndim <= 2, 'Shape coefficient must be 1D or 2D tensor.'
|
109 |
+
if shape_coef.ndim > 1:
|
110 |
+
# use the first frame as the shape coefficient
|
111 |
+
shape_coef = shape_coef[0]
|
112 |
+
original_shape_coef = shape_coef.clone()
|
113 |
+
if self.coef_stats is not None:
|
114 |
+
shape_coef = (shape_coef - self.coef_stats['shape_mean']) / self.coef_stats['shape_std']
|
115 |
+
shape_coef = shape_coef.unsqueeze(0).expand(n_repetitions, -1)
|
116 |
+
|
117 |
+
# Preprocess style feature if given
|
118 |
+
if style_feat is not None:
|
119 |
+
assert self.model.use_style
|
120 |
+
if isinstance(style_feat, (str, Path)):
|
121 |
+
style_feat = Path(style_feat)
|
122 |
+
if not style_feat.exists() and not style_feat.is_absolute():
|
123 |
+
style_feat = style_feat.parent / self.style_dir / style_feat.name
|
124 |
+
style_feat = np.load(style_feat)
|
125 |
+
if not isinstance(style_feat, np.ndarray):
|
126 |
+
style_feat = style_feat['style']
|
127 |
+
if isinstance(style_feat, np.ndarray):
|
128 |
+
style_feat = torch.from_numpy(style_feat).float().to(self.device)
|
129 |
+
assert style_feat.ndim == 1, 'Style feature must be 1D tensor.'
|
130 |
+
style_feat = style_feat.unsqueeze(0).expand(n_repetitions, -1)
|
131 |
+
|
132 |
+
# Step 2: Predict motion coef
|
133 |
+
# divide into synthesize units and do synthesize
|
134 |
+
clip_len = int(len(audio) / 16000 * self.fps)
|
135 |
+
stride = self.n_motions
|
136 |
+
if clip_len <= self.n_motions:
|
137 |
+
n_subdivision = 1
|
138 |
+
else:
|
139 |
+
n_subdivision = math.ceil(clip_len / stride)
|
140 |
+
|
141 |
+
# Prepare audio input
|
142 |
+
n_padding_audio_samples = self.n_audio_samples * n_subdivision - len(audio)
|
143 |
+
n_padding_frames = math.ceil(n_padding_audio_samples / self.audio_unit)
|
144 |
+
if n_padding_audio_samples > 0:
|
145 |
+
if self.pad_mode == 'zero':
|
146 |
+
padding_value = 0
|
147 |
+
elif self.pad_mode == 'replicate':
|
148 |
+
padding_value = audio[-1]
|
149 |
+
else:
|
150 |
+
raise ValueError(f'Unknown pad mode: {self.pad_mode}')
|
151 |
+
audio = F.pad(audio, (0, n_padding_audio_samples), value=padding_value)
|
152 |
+
|
153 |
+
if not self.no_context_audio_feat:
|
154 |
+
audio_feat = self.model.extract_audio_feature(audio.unsqueeze(0), self.n_motions * n_subdivision)
|
155 |
+
|
156 |
+
# Generate `self.n_motions` new frames at one time, and use the last `self.n_prev_motions` frames
|
157 |
+
# from the previous generation as the initial motion condition
|
158 |
+
coef_list = []
|
159 |
+
for i in range(0, n_subdivision):
|
160 |
+
start_idx = i * stride
|
161 |
+
end_idx = start_idx + self.n_motions
|
162 |
+
indicator = torch.ones((n_repetitions, self.n_motions)).to(self.device) if self.use_indicator else None
|
163 |
+
if indicator is not None and i == n_subdivision - 1 and n_padding_frames > 0:
|
164 |
+
indicator[:, -n_padding_frames:] = 0
|
165 |
+
if not self.no_context_audio_feat:
|
166 |
+
audio_in = audio_feat[:, start_idx:end_idx].expand(n_repetitions, -1, -1)
|
167 |
+
else:
|
168 |
+
audio_in = audio[round(start_idx * self.audio_unit):round(end_idx * self.audio_unit)].unsqueeze(0)
|
169 |
+
|
170 |
+
# generate motion coefficients
|
171 |
+
if i == 0:
|
172 |
+
# -> (N, L, d_motion=n_code_per_frame * code_dim)
|
173 |
+
motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
|
174 |
+
indicator=indicator, cfg_mode=cfg_mode,
|
175 |
+
cfg_cond=cfg_cond, cfg_scale=cfg_scale,
|
176 |
+
dynamic_threshold=self.dynamic_threshold)
|
177 |
+
else:
|
178 |
+
motion_feat, noise, prev_audio_feat = self.model.sample(audio_in, shape_coef, style_feat,
|
179 |
+
prev_motion_feat, prev_audio_feat, noise,
|
180 |
+
indicator=indicator, cfg_mode=cfg_mode,
|
181 |
+
cfg_cond=cfg_cond, cfg_scale=cfg_scale,
|
182 |
+
dynamic_threshold=self.dynamic_threshold)
|
183 |
+
prev_motion_feat = motion_feat[:, -self.n_prev_motions:].clone()
|
184 |
+
prev_audio_feat = prev_audio_feat[:, -self.n_prev_motions:]
|
185 |
+
|
186 |
+
motion_coef = motion_feat
|
187 |
+
if i == n_subdivision - 1 and n_padding_frames > 0:
|
188 |
+
motion_coef = motion_coef[:, :-n_padding_frames] # delete padded frames
|
189 |
+
coef_list.append(motion_coef)
|
190 |
+
|
191 |
+
motion_coef = torch.cat(coef_list, dim=1)
|
192 |
+
|
193 |
+
# Step 3: restore to coef dict
|
194 |
+
coef_dict = get_coef_dict(motion_coef, None, self.coef_stats, self.predict_head_pose, self.rot_repr)
|
195 |
+
if include_shape:
|
196 |
+
coef_dict['shape'] = original_shape_coef[None, None].expand(n_repetitions, motion_coef.shape[1], -1)
|
197 |
+
return self.coef_to_a1_format(coef_dict)
|
198 |
+
|
199 |
+
def coef_to_a1_format(self, coef_dict):
|
200 |
+
n_frames = coef_dict['exp'].shape[1]
|
201 |
+
new_coef_dict = []
|
202 |
+
for i in range(n_frames):
|
203 |
+
|
204 |
+
new_coef_dict.append({
|
205 |
+
"expression_params": coef_dict["exp"][0, i:i+1],
|
206 |
+
"jaw_params": coef_dict["pose"][0, i:i+1, 3:],
|
207 |
+
"eye_pose_params": torch.zeros(1, 6).type_as(coef_dict["pose"]),
|
208 |
+
"pose_params": coef_dict["pose"][0, i:i+1, :3],
|
209 |
+
"eyelid_params": None
|
210 |
+
})
|
211 |
+
return new_coef_dict
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def _pad_coef(coef, n_frames, elem_ndim=1):
|
219 |
+
if coef.ndim == elem_ndim:
|
220 |
+
coef = coef[None]
|
221 |
+
elem_shape = coef.shape[1:]
|
222 |
+
if coef.shape[0] >= n_frames:
|
223 |
+
new_coef = coef[:n_frames]
|
224 |
+
else:
|
225 |
+
# repeat the last coef frame
|
226 |
+
new_coef = torch.cat([coef, coef[[-1]].expand(n_frames - coef.shape[0], *elem_shape)], dim=0)
|
227 |
+
return new_coef # (n_frames, *elem_shape)
|
228 |
+
|
diffposetalk/hubert.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import HubertModel
|
2 |
+
from transformers.modeling_outputs import BaseModelOutput
|
3 |
+
|
4 |
+
from .wav2vec2 import linear_interpolation
|
5 |
+
|
6 |
+
_CONFIG_FOR_DOC = 'HubertConfig'
|
7 |
+
|
8 |
+
|
9 |
+
class HubertModel(HubertModel):
|
10 |
+
def __init__(self, config):
|
11 |
+
super().__init__(config)
|
12 |
+
|
13 |
+
def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
|
14 |
+
output_hidden_states=None, return_dict=None, frame_num=None):
|
15 |
+
self.config.output_attentions = True
|
16 |
+
|
17 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
18 |
+
output_hidden_states = (
|
19 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
20 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
21 |
+
|
22 |
+
extract_features = self.feature_extractor(input_values) # (N, C, L)
|
23 |
+
# Resample the audio feature @ 50 fps to `output_fps`.
|
24 |
+
if frame_num is not None:
|
25 |
+
extract_features_len = round(frame_num * 50 / output_fps)
|
26 |
+
extract_features = extract_features[:, :, :extract_features_len]
|
27 |
+
extract_features = linear_interpolation(extract_features, 50, output_fps, output_len=frame_num)
|
28 |
+
extract_features = extract_features.transpose(1, 2) # (N, L, C)
|
29 |
+
|
30 |
+
if attention_mask is not None:
|
31 |
+
# compute reduced attention_mask corresponding to feature vectors
|
32 |
+
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
33 |
+
|
34 |
+
hidden_states = self.feature_projection(extract_features)
|
35 |
+
hidden_states = self._mask_hidden_states(hidden_states)
|
36 |
+
|
37 |
+
encoder_outputs = self.encoder(
|
38 |
+
hidden_states,
|
39 |
+
attention_mask=attention_mask,
|
40 |
+
output_attentions=output_attentions,
|
41 |
+
output_hidden_states=output_hidden_states,
|
42 |
+
return_dict=return_dict,
|
43 |
+
)
|
44 |
+
|
45 |
+
hidden_states = encoder_outputs[0]
|
46 |
+
|
47 |
+
if not return_dict:
|
48 |
+
return (hidden_states,) + encoder_outputs[1:]
|
49 |
+
|
50 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
|
51 |
+
attentions=encoder_outputs.attentions, )
|
diffposetalk/utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .common import *
|
diffposetalk/utils/common.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import reduce
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
class NullableArgs:
|
9 |
+
def __init__(self, namespace):
|
10 |
+
for key, value in namespace.__dict__.items():
|
11 |
+
setattr(self, key, value)
|
12 |
+
|
13 |
+
def __getattr__(self, key):
|
14 |
+
# when an attribute lookup has not found the attribute
|
15 |
+
if key == 'align_mask_width':
|
16 |
+
if 'use_alignment_mask' in self.__dict__:
|
17 |
+
return 1 if self.use_alignment_mask else 0
|
18 |
+
else:
|
19 |
+
return 0
|
20 |
+
if key == 'no_head_pose':
|
21 |
+
return not self.predict_head_pose
|
22 |
+
if key == 'no_use_learnable_pe':
|
23 |
+
return not self.use_learnable_pe
|
24 |
+
|
25 |
+
return None
|
26 |
+
|
27 |
+
|
28 |
+
def count_parameters(model):
|
29 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
30 |
+
|
31 |
+
|
32 |
+
def get_option_text(args, parser):
|
33 |
+
message = ''
|
34 |
+
for k, v in sorted(vars(args).items()):
|
35 |
+
comment = ''
|
36 |
+
default = parser.get_default(k)
|
37 |
+
if v != default:
|
38 |
+
comment = f'\t[default: {str(default)}]'
|
39 |
+
message += f'{str(k):>30}: {str(v):<30}{comment}\n'
|
40 |
+
return message
|
41 |
+
|
42 |
+
|
43 |
+
def get_model_path(exp_name, iteration, model_type='DPT'):
|
44 |
+
exp_root_dir = Path(__file__).parent.parent / 'experiments' / model_type
|
45 |
+
exp_dir = exp_root_dir / exp_name
|
46 |
+
if not exp_dir.exists():
|
47 |
+
exp_dir = next(exp_root_dir.glob(f'{exp_name}*'))
|
48 |
+
model_path = exp_dir / f'checkpoints/iter_{iteration:07}.pt'
|
49 |
+
return model_path, exp_dir.relative_to(exp_root_dir)
|
50 |
+
|
51 |
+
|
52 |
+
def get_pose_input(coef_dict, rot_repr, with_global_pose):
|
53 |
+
if rot_repr == 'aa':
|
54 |
+
pose_input = coef_dict['pose'] if with_global_pose else coef_dict['pose'][..., -3:]
|
55 |
+
# Remove mouth rotation round y, z axis
|
56 |
+
pose_input = pose_input[..., :-2]
|
57 |
+
else:
|
58 |
+
raise ValueError(f'Unknown rotation representation: {rot_repr}')
|
59 |
+
return pose_input
|
60 |
+
|
61 |
+
|
62 |
+
def get_motion_coef(coef_dict, rot_repr, with_global_pose=False, norm_stats=None):
|
63 |
+
if norm_stats is not None:
|
64 |
+
if rot_repr == 'aa':
|
65 |
+
keys = ['exp', 'pose']
|
66 |
+
else:
|
67 |
+
raise ValueError(f'Unknown rotation representation {rot_repr}!')
|
68 |
+
|
69 |
+
coef_dict = {k: (coef_dict[k] - norm_stats[f'{k}_mean']) / norm_stats[f'{k}_std'] for k in keys}
|
70 |
+
pose_coef = get_pose_input(coef_dict, rot_repr, with_global_pose)
|
71 |
+
return torch.cat([coef_dict['exp'], pose_coef], dim=-1)
|
72 |
+
|
73 |
+
|
74 |
+
def get_coef_dict(motion_coef, shape_coef=None, denorm_stats=None, with_global_pose=False, rot_repr='aa'):
|
75 |
+
coef_dict = {
|
76 |
+
'exp': motion_coef[..., :50]
|
77 |
+
}
|
78 |
+
if rot_repr == 'aa':
|
79 |
+
if with_global_pose:
|
80 |
+
coef_dict['pose'] = motion_coef[..., 50:]
|
81 |
+
else:
|
82 |
+
placeholder = torch.zeros_like(motion_coef[..., :3])
|
83 |
+
coef_dict['pose'] = torch.cat([placeholder, motion_coef[..., -1:]], dim=-1)
|
84 |
+
# Add back rotation around y, z axis
|
85 |
+
coef_dict['pose'] = torch.cat([coef_dict['pose'], torch.zeros_like(motion_coef[..., :2])], dim=-1)
|
86 |
+
else:
|
87 |
+
raise ValueError(f'Unknown rotation representation {rot_repr}!')
|
88 |
+
|
89 |
+
if shape_coef is not None:
|
90 |
+
if motion_coef.ndim == 3:
|
91 |
+
if shape_coef.ndim == 2:
|
92 |
+
shape_coef = shape_coef.unsqueeze(1)
|
93 |
+
if shape_coef.shape[1] == 1:
|
94 |
+
shape_coef = shape_coef.expand(-1, motion_coef.shape[1], -1)
|
95 |
+
|
96 |
+
coef_dict['shape'] = shape_coef
|
97 |
+
|
98 |
+
if denorm_stats is not None:
|
99 |
+
coef_dict = {k: coef_dict[k] * denorm_stats[f'{k}_std'] + denorm_stats[f'{k}_mean'] for k in coef_dict}
|
100 |
+
|
101 |
+
if not with_global_pose:
|
102 |
+
if rot_repr == 'aa':
|
103 |
+
coef_dict['pose'][..., :3] = 0
|
104 |
+
else:
|
105 |
+
raise ValueError(f'Unknown rotation representation {rot_repr}!')
|
106 |
+
|
107 |
+
return coef_dict
|
108 |
+
|
109 |
+
|
110 |
+
def coef_dict_to_vertices(coef_dict, flame, rot_repr='aa', ignore_global_rot=False, flame_batch_size=512):
|
111 |
+
shape = coef_dict['exp'].shape[:-1]
|
112 |
+
coef_dict = {k: v.view(-1, v.shape[-1]) for k, v in coef_dict.items()}
|
113 |
+
n_samples = reduce(lambda x, y: x * y, shape, 1)
|
114 |
+
|
115 |
+
# Convert to vertices
|
116 |
+
vert_list = []
|
117 |
+
for i in range(0, n_samples, flame_batch_size):
|
118 |
+
batch_coef_dict = {k: v[i:i + flame_batch_size] for k, v in coef_dict.items()}
|
119 |
+
if rot_repr == 'aa':
|
120 |
+
vert, _, _ = flame(
|
121 |
+
batch_coef_dict['shape'], batch_coef_dict['exp'], batch_coef_dict['pose'],
|
122 |
+
pose2rot=True, ignore_global_rot=ignore_global_rot, return_lm2d=False, return_lm3d=False)
|
123 |
+
else:
|
124 |
+
raise ValueError(f'Unknown rot_repr: {rot_repr}')
|
125 |
+
vert_list.append(vert)
|
126 |
+
|
127 |
+
vert_list = torch.cat(vert_list, dim=0) # (n_samples, 5023, 3)
|
128 |
+
vert_list = vert_list.view(*shape, -1, 3) # (..., 5023, 3)
|
129 |
+
|
130 |
+
return vert_list
|
131 |
+
|
132 |
+
|
133 |
+
def compute_loss(args, is_starting_sample, shape_coef, motion_coef_gt, noise, target, prev_motion_coef, coef_stats,
|
134 |
+
flame, end_idx=None):
|
135 |
+
if args.criterion.lower() == 'l2':
|
136 |
+
criterion_func = F.mse_loss
|
137 |
+
elif args.criterion.lower() == 'l1':
|
138 |
+
criterion_func = F.l1_loss
|
139 |
+
else:
|
140 |
+
raise NotImplementedError(f'Criterion {args.criterion} not implemented.')
|
141 |
+
|
142 |
+
loss_vert = None
|
143 |
+
loss_vel = None
|
144 |
+
loss_smooth = None
|
145 |
+
loss_head_angle = None
|
146 |
+
loss_head_vel = None
|
147 |
+
loss_head_smooth = None
|
148 |
+
loss_head_trans_vel = None
|
149 |
+
loss_head_trans_accel = None
|
150 |
+
loss_head_trans = None
|
151 |
+
if args.target == 'noise':
|
152 |
+
loss_noise = criterion_func(noise, target[:, args.n_prev_motions:], reduction='none')
|
153 |
+
elif args.target == 'sample':
|
154 |
+
if is_starting_sample:
|
155 |
+
target = target[:, args.n_prev_motions:]
|
156 |
+
else:
|
157 |
+
motion_coef_gt = torch.cat([prev_motion_coef, motion_coef_gt], dim=1)
|
158 |
+
if args.no_constrain_prev:
|
159 |
+
target = torch.cat([prev_motion_coef, target[:, args.n_prev_motions:]], dim=1)
|
160 |
+
|
161 |
+
loss_noise = criterion_func(motion_coef_gt, target, reduction='none')
|
162 |
+
|
163 |
+
if args.l_vert > 0 or args.l_vel > 0:
|
164 |
+
coef_gt = get_coef_dict(motion_coef_gt, shape_coef, coef_stats, with_global_pose=False,
|
165 |
+
rot_repr=args.rot_repr)
|
166 |
+
coef_pred = get_coef_dict(target, shape_coef, coef_stats, with_global_pose=False,
|
167 |
+
rot_repr=args.rot_repr)
|
168 |
+
seq_len = target.shape[1]
|
169 |
+
|
170 |
+
if args.rot_repr == 'aa':
|
171 |
+
verts_gt, _, _ = flame(coef_gt['shape'].view(-1, 100), coef_gt['exp'].view(-1, 50),
|
172 |
+
coef_gt['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
|
173 |
+
verts_pred, _, _ = flame(coef_pred['shape'].view(-1, 100), coef_pred['exp'].view(-1, 50),
|
174 |
+
coef_pred['pose'].view(-1, 6), return_lm2d=False, return_lm3d=False)
|
175 |
+
else:
|
176 |
+
raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
|
177 |
+
verts_gt = verts_gt.view(-1, seq_len, 5023, 3)
|
178 |
+
verts_pred = verts_pred.view(-1, seq_len, 5023, 3)
|
179 |
+
|
180 |
+
if args.l_vert > 0:
|
181 |
+
loss_vert = criterion_func(verts_gt, verts_pred, reduction='none')
|
182 |
+
|
183 |
+
if args.l_vel > 0:
|
184 |
+
vel_gt = verts_gt[:, 1:] - verts_gt[:, :-1]
|
185 |
+
vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
|
186 |
+
loss_vel = criterion_func(vel_gt, vel_pred, reduction='none')
|
187 |
+
|
188 |
+
if args.l_smooth > 0:
|
189 |
+
vel_pred = verts_pred[:, 1:] - verts_pred[:, :-1]
|
190 |
+
loss_smooth = criterion_func(vel_pred[:, 1:], vel_pred[:, :-1], reduction='none')
|
191 |
+
|
192 |
+
# head pose
|
193 |
+
if not args.no_head_pose:
|
194 |
+
if args.rot_repr == 'aa':
|
195 |
+
head_pose_gt = motion_coef_gt[:, :, 50:53]
|
196 |
+
head_pose_pred = target[:, :, 50:53]
|
197 |
+
else:
|
198 |
+
raise ValueError(f'Unknown rotation representation {args.rot_repr}!')
|
199 |
+
|
200 |
+
if args.l_head_angle > 0:
|
201 |
+
loss_head_angle = criterion_func(head_pose_gt, head_pose_pred, reduction='none')
|
202 |
+
|
203 |
+
if args.l_head_vel > 0:
|
204 |
+
head_vel_gt = head_pose_gt[:, 1:] - head_pose_gt[:, :-1]
|
205 |
+
head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
|
206 |
+
loss_head_vel = criterion_func(head_vel_gt, head_vel_pred, reduction='none')
|
207 |
+
|
208 |
+
if args.l_head_smooth > 0:
|
209 |
+
head_vel_pred = head_pose_pred[:, 1:] - head_pose_pred[:, :-1]
|
210 |
+
loss_head_smooth = criterion_func(head_vel_pred[:, 1:], head_vel_pred[:, :-1], reduction='none')
|
211 |
+
|
212 |
+
if not is_starting_sample and args.l_head_trans > 0:
|
213 |
+
# # version 1: constrain both the predicted previous and current motions (x_{-3} ~ x_{2})
|
214 |
+
# head_pose_trans = head_pose_pred[:, args.n_prev_motions - 3:args.n_prev_motions + 3]
|
215 |
+
# head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
|
216 |
+
# head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]
|
217 |
+
|
218 |
+
# version 2: constrain only the predicted current motions (x_{0} ~ x_{2})
|
219 |
+
head_pose_trans = torch.cat([head_pose_gt[:, args.n_prev_motions - 3:args.n_prev_motions],
|
220 |
+
head_pose_pred[:, args.n_prev_motions:args.n_prev_motions + 3]], dim=1)
|
221 |
+
head_vel_pred = head_pose_trans[:, 1:] - head_pose_trans[:, :-1]
|
222 |
+
head_accel_pred = head_vel_pred[:, 1:] - head_vel_pred[:, :-1]
|
223 |
+
|
224 |
+
# will constrain x_{-2|0} ~ x_{1}
|
225 |
+
loss_head_trans_vel = criterion_func(head_vel_pred[:, 2:4], head_vel_pred[:, 1:3], reduction='none')
|
226 |
+
# will constrain x_{-3|0} ~ x_{2}
|
227 |
+
loss_head_trans_accel = criterion_func(head_accel_pred[:, 1:], head_accel_pred[:, :-1],
|
228 |
+
reduction='none')
|
229 |
+
else:
|
230 |
+
raise ValueError(f'Unknown diffusion target: {args.target}')
|
231 |
+
|
232 |
+
if end_idx is None:
|
233 |
+
mask = torch.ones((target.shape[0], args.n_motions), dtype=torch.bool, device=target.device)
|
234 |
+
else:
|
235 |
+
mask = torch.arange(args.n_motions, device=target.device).expand(target.shape[0], -1) < end_idx.unsqueeze(1)
|
236 |
+
|
237 |
+
if args.target == 'sample' and not is_starting_sample:
|
238 |
+
if args.no_constrain_prev:
|
239 |
+
# Warning: this option will be deprecated in the future
|
240 |
+
mask = torch.cat([torch.zeros_like(mask[:, :args.n_prev_motions]), mask], dim=1)
|
241 |
+
else:
|
242 |
+
mask = torch.cat([torch.ones_like(mask[:, :args.n_prev_motions]), mask], dim=1)
|
243 |
+
|
244 |
+
loss_noise = loss_noise[mask].mean()
|
245 |
+
if loss_vert is not None:
|
246 |
+
loss_vert = loss_vert[mask].mean()
|
247 |
+
if loss_vel is not None:
|
248 |
+
loss_vel = loss_vel[mask[:, 1:]]
|
249 |
+
loss_vel = loss_vel.mean() if torch.numel(loss_vel) > 0 else None
|
250 |
+
if loss_smooth is not None:
|
251 |
+
loss_smooth = loss_smooth[mask[:, 2:]]
|
252 |
+
loss_smooth = loss_smooth.mean() if torch.numel(loss_smooth) > 0 else None
|
253 |
+
if loss_head_angle is not None:
|
254 |
+
loss_head_angle = loss_head_angle[mask].mean()
|
255 |
+
if loss_head_vel is not None:
|
256 |
+
loss_head_vel = loss_head_vel[mask[:, 1:]]
|
257 |
+
loss_head_vel = loss_head_vel.mean() if torch.numel(loss_head_vel) > 0 else None
|
258 |
+
if loss_head_smooth is not None:
|
259 |
+
loss_head_smooth = loss_head_smooth[mask[:, 2:]]
|
260 |
+
loss_head_smooth = loss_head_smooth.mean() if torch.numel(loss_head_smooth) > 0 else None
|
261 |
+
if loss_head_trans_vel is not None:
|
262 |
+
vel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 2]
|
263 |
+
accel_mask = mask[:, args.n_prev_motions:args.n_prev_motions + 3]
|
264 |
+
loss_head_trans_vel = loss_head_trans_vel[vel_mask].mean()
|
265 |
+
loss_head_trans_accel = loss_head_trans_accel[accel_mask].mean()
|
266 |
+
loss_head_trans = loss_head_trans_vel + loss_head_trans_accel
|
267 |
+
|
268 |
+
return loss_noise, loss_vert, loss_vel, loss_smooth, loss_head_angle, loss_head_vel, loss_head_smooth, \
|
269 |
+
loss_head_trans
|
270 |
+
|
271 |
+
|
272 |
+
def _truncate_audio(audio, end_idx, pad_mode='zero'):
|
273 |
+
batch_size = audio.shape[0]
|
274 |
+
audio_trunc = audio.clone()
|
275 |
+
if pad_mode == 'replicate':
|
276 |
+
for i in range(batch_size):
|
277 |
+
audio_trunc[i, end_idx[i]:] = audio_trunc[i, end_idx[i] - 1]
|
278 |
+
elif pad_mode == 'zero':
|
279 |
+
for i in range(batch_size):
|
280 |
+
audio_trunc[i, end_idx[i]:] = 0
|
281 |
+
else:
|
282 |
+
raise ValueError(f'Unknown pad mode {pad_mode}!')
|
283 |
+
|
284 |
+
return audio_trunc
|
285 |
+
|
286 |
+
|
287 |
+
def _truncate_coef_dict(coef_dict, end_idx, pad_mode='zero'):
|
288 |
+
batch_size = coef_dict['exp'].shape[0]
|
289 |
+
coef_dict_trunc = {k: v.clone() for k, v in coef_dict.items()}
|
290 |
+
if pad_mode == 'replicate':
|
291 |
+
for i in range(batch_size):
|
292 |
+
for k in coef_dict_trunc:
|
293 |
+
coef_dict_trunc[k][i, end_idx[i]:] = coef_dict_trunc[k][i, end_idx[i] - 1]
|
294 |
+
elif pad_mode == 'zero':
|
295 |
+
for i in range(batch_size):
|
296 |
+
for k in coef_dict:
|
297 |
+
coef_dict_trunc[k][i, end_idx[i]:] = 0
|
298 |
+
else:
|
299 |
+
raise ValueError(f'Unknown pad mode: {pad_mode}!')
|
300 |
+
|
301 |
+
return coef_dict_trunc
|
302 |
+
|
303 |
+
|
304 |
+
def truncate_coef_dict_and_audio(audio, coef_dict, n_motions, audio_unit=640, pad_mode='zero'):
|
305 |
+
batch_size = audio.shape[0]
|
306 |
+
end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
|
307 |
+
audio_end_idx = (end_idx * audio_unit).long()
|
308 |
+
# mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)
|
309 |
+
|
310 |
+
# truncate audio
|
311 |
+
audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)
|
312 |
+
|
313 |
+
# truncate coef dict
|
314 |
+
coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
|
315 |
+
|
316 |
+
return audio_trunc, coef_dict_trunc, end_idx
|
317 |
+
|
318 |
+
|
319 |
+
def truncate_motion_coef_and_audio(audio, motion_coef, n_motions, audio_unit=640, pad_mode='zero'):
|
320 |
+
batch_size = audio.shape[0]
|
321 |
+
end_idx = torch.randint(1, n_motions, (batch_size,), device=audio.device)
|
322 |
+
audio_end_idx = (end_idx * audio_unit).long()
|
323 |
+
# mask = torch.arange(n_motions, device=audio.device).expand(batch_size, -1) < end_idx.unsqueeze(1)
|
324 |
+
|
325 |
+
# truncate audio
|
326 |
+
audio_trunc = _truncate_audio(audio, audio_end_idx, pad_mode=pad_mode)
|
327 |
+
|
328 |
+
# prepare coef dict and stats
|
329 |
+
coef_dict = {'exp': motion_coef[..., :50], 'pose_any': motion_coef[..., 50:]}
|
330 |
+
|
331 |
+
# truncate coef dict
|
332 |
+
coef_dict_trunc = _truncate_coef_dict(coef_dict, end_idx, pad_mode=pad_mode)
|
333 |
+
motion_coef_trunc = torch.cat([coef_dict_trunc['exp'], coef_dict_trunc['pose_any']], dim=-1)
|
334 |
+
|
335 |
+
return audio_trunc, motion_coef_trunc, end_idx
|
336 |
+
|
337 |
+
|
338 |
+
def nt_xent_loss(feature_a, feature_b, temperature):
|
339 |
+
"""
|
340 |
+
Normalized temperature-scaled cross entropy loss.
|
341 |
+
|
342 |
+
(Adapted from https://github.com/sthalles/SimCLR/blob/master/simclr.py)
|
343 |
+
|
344 |
+
Args:
|
345 |
+
feature_a (torch.Tensor): shape (batch_size, feature_dim)
|
346 |
+
feature_b (torch.Tensor): shape (batch_size, feature_dim)
|
347 |
+
temperature (float): temperature scaling factor
|
348 |
+
|
349 |
+
Returns:
|
350 |
+
torch.Tensor: scalar
|
351 |
+
"""
|
352 |
+
batch_size = feature_a.shape[0]
|
353 |
+
device = feature_a.device
|
354 |
+
|
355 |
+
features = torch.cat([feature_a, feature_b], dim=0)
|
356 |
+
|
357 |
+
labels = torch.cat([torch.arange(batch_size), torch.arange(batch_size)], dim=0)
|
358 |
+
labels = (labels.unsqueeze(0) == labels.unsqueeze(1))
|
359 |
+
labels = labels.to(device)
|
360 |
+
|
361 |
+
features = F.normalize(features, dim=1)
|
362 |
+
similarity_matrix = torch.matmul(features, features.T)
|
363 |
+
|
364 |
+
# discard the main diagonal from both: labels and similarities matrix
|
365 |
+
mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
|
366 |
+
labels = labels[~mask].view(labels.shape[0], -1)
|
367 |
+
similarity_matrix = similarity_matrix[~mask].view(labels.shape[0], -1)
|
368 |
+
|
369 |
+
# select the positives and negatives
|
370 |
+
positives = similarity_matrix[labels].view(labels.shape[0], -1)
|
371 |
+
negatives = similarity_matrix[~labels].view(labels.shape[0], -1)
|
372 |
+
|
373 |
+
logits = torch.cat([positives, negatives], dim=1)
|
374 |
+
logits = logits / temperature
|
375 |
+
labels = torch.zeros(labels.shape[0], dtype=torch.long).to(device)
|
376 |
+
|
377 |
+
loss = F.cross_entropy(logits, labels)
|
378 |
+
return loss
|
diffposetalk/utils/media.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shlex
|
2 |
+
import subprocess
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
|
6 |
+
def combine_video_and_audio(video_file, audio_file, output, quality=17, copy_audio=True):
|
7 |
+
audio_codec = '-c:a copy' if copy_audio else ''
|
8 |
+
cmd = f'ffmpeg -i {video_file} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
|
9 |
+
f'{audio_codec} -fflags +shortest -y -hide_banner -loglevel error {output}'
|
10 |
+
assert subprocess.run(shlex.split(cmd)).returncode == 0
|
11 |
+
|
12 |
+
|
13 |
+
def combine_frames_and_audio(frame_files, audio_file, fps, output, quality=17):
|
14 |
+
cmd = f'ffmpeg -framerate {fps} -i {frame_files} -i {audio_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
|
15 |
+
f'-c:a copy -fflags +shortest -y -hide_banner -loglevel error {output}'
|
16 |
+
assert subprocess.run(shlex.split(cmd)).returncode == 0
|
17 |
+
|
18 |
+
|
19 |
+
def convert_video(video_file, output, quality=17):
|
20 |
+
cmd = f'ffmpeg -i {video_file} -c:v libx264 -crf {quality} -pix_fmt yuv420p ' \
|
21 |
+
f'-fflags +shortest -y -hide_banner -loglevel error {output}'
|
22 |
+
assert subprocess.run(shlex.split(cmd)).returncode == 0
|
23 |
+
|
24 |
+
|
25 |
+
def reencode_audio(audio_file, output):
|
26 |
+
cmd = f'ffmpeg -i {audio_file} -y -hide_banner -loglevel error {output}'
|
27 |
+
assert subprocess.run(shlex.split(cmd)).returncode == 0
|
28 |
+
|
29 |
+
|
30 |
+
def extract_frames(filename, output_dir, quality=1):
|
31 |
+
output_dir = Path(output_dir)
|
32 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
33 |
+
cmd = f'ffmpeg -i {filename} -qmin 1 -qscale:v {quality} -y -start_number 0 -hide_banner -loglevel error ' \
|
34 |
+
f'{output_dir / "%06d.jpg"}'
|
35 |
+
assert subprocess.run(shlex.split(cmd)).returncode == 0
|
diffposetalk/utils/renderer.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tempfile
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import kiui.mesh
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
# os.environ['PYOPENGL_PLATFORM'] = 'osmesa' # osmesa or egl
|
9 |
+
os.environ['PYOPENGL_PLATFORM'] = 'egl'
|
10 |
+
import pyrender
|
11 |
+
import trimesh
|
12 |
+
# from psbody.mesh import Mesh
|
13 |
+
|
14 |
+
|
15 |
+
class MeshRenderer:
|
16 |
+
def __init__(self, size, fov=16 / 180 * np.pi, camera_pose=None, light_pose=None, black_bg=False):
|
17 |
+
# Camera
|
18 |
+
self.frustum = {'near': 0.01, 'far': 3.0}
|
19 |
+
self.camera = pyrender.PerspectiveCamera(yfov=fov, znear=self.frustum['near'],
|
20 |
+
zfar=self.frustum['far'], aspectRatio=1.0)
|
21 |
+
|
22 |
+
# Material
|
23 |
+
self.primitive_material = pyrender.material.MetallicRoughnessMaterial(
|
24 |
+
alphaMode='BLEND',
|
25 |
+
baseColorFactor=[0.3, 0.3, 0.3, 1.0],
|
26 |
+
metallicFactor=0.8,
|
27 |
+
roughnessFactor=0.8
|
28 |
+
)
|
29 |
+
|
30 |
+
# Lighting
|
31 |
+
light_color = np.array([1., 1., 1.])
|
32 |
+
self.light = pyrender.DirectionalLight(color=light_color, intensity=2)
|
33 |
+
self.light_angle = np.pi / 6.0
|
34 |
+
|
35 |
+
# Scene
|
36 |
+
self.scene = None
|
37 |
+
self._init_scene(black_bg)
|
38 |
+
|
39 |
+
# add camera and lighting
|
40 |
+
self._init_camera(camera_pose)
|
41 |
+
self._init_lighting(light_pose)
|
42 |
+
|
43 |
+
# Renderer
|
44 |
+
self.renderer = pyrender.OffscreenRenderer(*size, point_size=1.0)
|
45 |
+
|
46 |
+
def _init_scene(self, black_bg=False):
|
47 |
+
if black_bg:
|
48 |
+
self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[0, 0, 0])
|
49 |
+
else:
|
50 |
+
self.scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
|
51 |
+
|
52 |
+
def _init_camera(self, camera_pose=None):
|
53 |
+
if camera_pose is None:
|
54 |
+
camera_pose = np.eye(4)
|
55 |
+
camera_pose[:3, 3] = np.array([0, 0, 1])
|
56 |
+
self.camera_pose = camera_pose.copy()
|
57 |
+
self.camera_node = self.scene.add(self.camera, pose=camera_pose)
|
58 |
+
|
59 |
+
def _init_lighting(self, light_pose=None):
|
60 |
+
if light_pose is None:
|
61 |
+
light_pose = np.eye(4)
|
62 |
+
light_pose[:3, 3] = np.array([0, 0, 1])
|
63 |
+
self.light_pose = light_pose.copy()
|
64 |
+
|
65 |
+
light_poses = self._get_light_poses(self.light_angle, light_pose)
|
66 |
+
self.light_nodes = [self.scene.add(self.light, pose=light_pose) for light_pose in light_poses]
|
67 |
+
|
68 |
+
def set_camera_pose(self, camera_pose):
|
69 |
+
self.camera_pose = camera_pose.copy()
|
70 |
+
self.scene.set_pose(self.camera_node, pose=camera_pose)
|
71 |
+
|
72 |
+
def set_lighting_pose(self, light_pose):
|
73 |
+
self.light_pose = light_pose.copy()
|
74 |
+
|
75 |
+
light_poses = self._get_light_poses(self.light_angle, light_pose)
|
76 |
+
for light_node, light_pose in zip(self.light_nodes, light_poses):
|
77 |
+
self.scene.set_pose(light_node, pose=light_pose)
|
78 |
+
|
79 |
+
def render_mesh(self, v, f, t_center, rot=np.zeros(3), tex_img=None, tex_uv=None,
|
80 |
+
camera_pose=None, light_pose=None):
|
81 |
+
# Prepare mesh
|
82 |
+
v[:] = cv2.Rodrigues(rot)[0].dot((v - t_center).T).T + t_center
|
83 |
+
if tex_img is not None:
|
84 |
+
tex = pyrender.Texture(source=tex_img, source_channels='RGB')
|
85 |
+
tex_material = pyrender.material.MetallicRoughnessMaterial(baseColorTexture=tex)
|
86 |
+
from kiui.mesh import Mesh
|
87 |
+
import torch
|
88 |
+
mesh = Mesh(
|
89 |
+
v=torch.from_numpy(v),
|
90 |
+
f=torch.from_numpy(f),
|
91 |
+
vt=tex_uv['vt'],
|
92 |
+
ft=tex_uv['ft']
|
93 |
+
)
|
94 |
+
with tempfile.NamedTemporaryFile(suffix='.obj') as f:
|
95 |
+
mesh.write_obj(f.name)
|
96 |
+
tri_mesh = trimesh.load(f.name, process=False)
|
97 |
+
return tri_mesh
|
98 |
+
# tri_mesh = self._pyrender_mesh_workaround(mesh)
|
99 |
+
render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=tex_material)
|
100 |
+
else:
|
101 |
+
tri_mesh = trimesh.Trimesh(vertices=v, faces=f)
|
102 |
+
render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=self.primitive_material, smooth=True)
|
103 |
+
mesh_node = self.scene.add(render_mesh, pose=np.eye(4))
|
104 |
+
|
105 |
+
# Change camera and lighting pose if necessary
|
106 |
+
if camera_pose is not None:
|
107 |
+
self.set_camera_pose(camera_pose)
|
108 |
+
if light_pose is not None:
|
109 |
+
self.set_lighting_pose(light_pose)
|
110 |
+
|
111 |
+
# Render
|
112 |
+
flags = pyrender.RenderFlags.SKIP_CULL_FACES
|
113 |
+
color, depth = self.renderer.render(self.scene, flags=flags)
|
114 |
+
|
115 |
+
# Remove mesh
|
116 |
+
self.scene.remove_node(mesh_node)
|
117 |
+
|
118 |
+
return color, depth
|
119 |
+
|
120 |
+
@staticmethod
|
121 |
+
def _get_light_poses(light_angle, light_pose):
|
122 |
+
light_poses = []
|
123 |
+
init_pos = light_pose[:3, 3].copy()
|
124 |
+
|
125 |
+
light_poses.append(light_pose.copy())
|
126 |
+
|
127 |
+
light_pose[:3, 3] = cv2.Rodrigues(np.array([light_angle, 0, 0]))[0].dot(init_pos)
|
128 |
+
light_poses.append(light_pose.copy())
|
129 |
+
|
130 |
+
light_pose[:3, 3] = cv2.Rodrigues(np.array([-light_angle, 0, 0]))[0].dot(init_pos)
|
131 |
+
light_poses.append(light_pose.copy())
|
132 |
+
|
133 |
+
light_pose[:3, 3] = cv2.Rodrigues(np.array([0, -light_angle, 0]))[0].dot(init_pos)
|
134 |
+
light_poses.append(light_pose.copy())
|
135 |
+
|
136 |
+
light_pose[:3, 3] = cv2.Rodrigues(np.array([0, light_angle, 0]))[0].dot(init_pos)
|
137 |
+
light_poses.append(light_pose.copy())
|
138 |
+
|
139 |
+
return light_poses
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def _pyrender_mesh_workaround(mesh):
|
143 |
+
# Workaround as pyrender requires number of vertices and uv coordinates to be the same
|
144 |
+
with tempfile.NamedTemporaryFile(suffix='.obj') as f:
|
145 |
+
mesh.write_obj(f.name)
|
146 |
+
tri_mesh = trimesh.load(f.name, process=False)
|
147 |
+
return tri_mesh
|
diffposetalk/utils/rotation_conversions.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code is based on https://github.com/Mathux/ACTOR.git
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
3 |
+
# Check PYTORCH3D_LICENCE before use
|
4 |
+
|
5 |
+
import functools
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
"""
|
13 |
+
The transformation matrices returned from the functions in this file assume
|
14 |
+
the points on which the transformation will be applied are column vectors.
|
15 |
+
i.e. the R matrix is structured as
|
16 |
+
|
17 |
+
R = [
|
18 |
+
[Rxx, Rxy, Rxz],
|
19 |
+
[Ryx, Ryy, Ryz],
|
20 |
+
[Rzx, Rzy, Rzz],
|
21 |
+
] # (3, 3)
|
22 |
+
|
23 |
+
This matrix can be applied to column vectors by post multiplication
|
24 |
+
by the points e.g.
|
25 |
+
|
26 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
27 |
+
transformed_points = R * points
|
28 |
+
|
29 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
30 |
+
can be transposed and pre multiplied by the points:
|
31 |
+
|
32 |
+
e.g.
|
33 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
34 |
+
transformed_points = points * R.transpose(1, 0)
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
def quaternion_to_matrix(quaternions):
|
39 |
+
"""
|
40 |
+
Convert rotations given as quaternions to rotation matrices.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
quaternions: quaternions with real part first,
|
44 |
+
as tensor of shape (..., 4).
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
48 |
+
"""
|
49 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
50 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
51 |
+
|
52 |
+
o = torch.stack(
|
53 |
+
(
|
54 |
+
1 - two_s * (j * j + k * k),
|
55 |
+
two_s * (i * j - k * r),
|
56 |
+
two_s * (i * k + j * r),
|
57 |
+
two_s * (i * j + k * r),
|
58 |
+
1 - two_s * (i * i + k * k),
|
59 |
+
two_s * (j * k - i * r),
|
60 |
+
two_s * (i * k - j * r),
|
61 |
+
two_s * (j * k + i * r),
|
62 |
+
1 - two_s * (i * i + j * j),
|
63 |
+
),
|
64 |
+
-1,
|
65 |
+
)
|
66 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
67 |
+
|
68 |
+
|
69 |
+
def _copysign(a, b):
|
70 |
+
"""
|
71 |
+
Return a tensor where each element has the absolute value taken from the,
|
72 |
+
corresponding element of a, with sign taken from the corresponding
|
73 |
+
element of b. This is like the standard copysign floating-point operation,
|
74 |
+
but is not careful about negative 0 and NaN.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
a: source tensor.
|
78 |
+
b: tensor whose signs will be used, of the same shape as a.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Tensor of the same shape as a with the signs of b.
|
82 |
+
"""
|
83 |
+
signs_differ = (a < 0) != (b < 0)
|
84 |
+
return torch.where(signs_differ, -a, a)
|
85 |
+
|
86 |
+
|
87 |
+
def _sqrt_positive_part(x):
|
88 |
+
"""
|
89 |
+
Returns torch.sqrt(torch.max(0, x))
|
90 |
+
but with a zero subgradient where x is 0.
|
91 |
+
"""
|
92 |
+
ret = torch.zeros_like(x)
|
93 |
+
positive_mask = x > 0
|
94 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
95 |
+
return ret
|
96 |
+
|
97 |
+
|
98 |
+
def matrix_to_quaternion(matrix):
|
99 |
+
"""
|
100 |
+
Convert rotations given as rotation matrices to quaternions.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
107 |
+
"""
|
108 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
109 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
110 |
+
m00 = matrix[..., 0, 0]
|
111 |
+
m11 = matrix[..., 1, 1]
|
112 |
+
m22 = matrix[..., 2, 2]
|
113 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
114 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
115 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
116 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
117 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
118 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
119 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
120 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
121 |
+
|
122 |
+
|
123 |
+
def _axis_angle_rotation(axis: str, angle):
|
124 |
+
"""
|
125 |
+
Return the rotation matrices for one of the rotations about an axis
|
126 |
+
of which Euler angles describe, for each value of the angle given.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
axis: Axis label "X" or "Y or "Z".
|
130 |
+
angle: any shape tensor of Euler angles in radians
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
134 |
+
"""
|
135 |
+
|
136 |
+
cos = torch.cos(angle)
|
137 |
+
sin = torch.sin(angle)
|
138 |
+
one = torch.ones_like(angle)
|
139 |
+
zero = torch.zeros_like(angle)
|
140 |
+
|
141 |
+
if axis == "X":
|
142 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
143 |
+
if axis == "Y":
|
144 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
145 |
+
if axis == "Z":
|
146 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
147 |
+
|
148 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
149 |
+
|
150 |
+
|
151 |
+
def euler_angles_to_matrix(euler_angles, convention: str):
|
152 |
+
"""
|
153 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
157 |
+
convention: Convention string of three uppercase letters from
|
158 |
+
{"X", "Y", and "Z"}.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
162 |
+
"""
|
163 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
164 |
+
raise ValueError("Invalid input euler angles.")
|
165 |
+
if len(convention) != 3:
|
166 |
+
raise ValueError("Convention must have 3 letters.")
|
167 |
+
if convention[1] in (convention[0], convention[2]):
|
168 |
+
raise ValueError(f"Invalid convention {convention}.")
|
169 |
+
for letter in convention:
|
170 |
+
if letter not in ("X", "Y", "Z"):
|
171 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
172 |
+
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
173 |
+
return functools.reduce(torch.matmul, matrices)
|
174 |
+
|
175 |
+
|
176 |
+
def _angle_from_tan(
|
177 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
178 |
+
):
|
179 |
+
"""
|
180 |
+
Extract the first or third Euler angle from the two members of
|
181 |
+
the matrix which are positive constant times its sine and cosine.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
185 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
186 |
+
convention.
|
187 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
188 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
189 |
+
which means the relevant entries are in the same row of the
|
190 |
+
rotation matrix. If not, they are in the same column.
|
191 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Euler Angles in radians for each matrix in dataset as a tensor
|
195 |
+
of shape (...).
|
196 |
+
"""
|
197 |
+
|
198 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
199 |
+
if horizontal:
|
200 |
+
i2, i1 = i1, i2
|
201 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
202 |
+
if horizontal == even:
|
203 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
204 |
+
if tait_bryan:
|
205 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
206 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
207 |
+
|
208 |
+
|
209 |
+
def _index_from_letter(letter: str):
|
210 |
+
if letter == "X":
|
211 |
+
return 0
|
212 |
+
if letter == "Y":
|
213 |
+
return 1
|
214 |
+
if letter == "Z":
|
215 |
+
return 2
|
216 |
+
|
217 |
+
|
218 |
+
def matrix_to_euler_angles(matrix, convention: str):
|
219 |
+
"""
|
220 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
224 |
+
convention: Convention string of three uppercase letters.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
Euler angles in radians as tensor of shape (..., 3).
|
228 |
+
"""
|
229 |
+
if len(convention) != 3:
|
230 |
+
raise ValueError("Convention must have 3 letters.")
|
231 |
+
if convention[1] in (convention[0], convention[2]):
|
232 |
+
raise ValueError(f"Invalid convention {convention}.")
|
233 |
+
for letter in convention:
|
234 |
+
if letter not in ("X", "Y", "Z"):
|
235 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
236 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
237 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
238 |
+
i0 = _index_from_letter(convention[0])
|
239 |
+
i2 = _index_from_letter(convention[2])
|
240 |
+
tait_bryan = i0 != i2
|
241 |
+
if tait_bryan:
|
242 |
+
central_angle = torch.asin(
|
243 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
244 |
+
)
|
245 |
+
else:
|
246 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
247 |
+
|
248 |
+
o = (
|
249 |
+
_angle_from_tan(
|
250 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
251 |
+
),
|
252 |
+
central_angle,
|
253 |
+
_angle_from_tan(
|
254 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
255 |
+
),
|
256 |
+
)
|
257 |
+
return torch.stack(o, -1)
|
258 |
+
|
259 |
+
|
260 |
+
def random_quaternions(
|
261 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
262 |
+
):
|
263 |
+
"""
|
264 |
+
Generate random quaternions representing rotations,
|
265 |
+
i.e. versors with nonnegative real part.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
n: Number of quaternions in a batch to return.
|
269 |
+
dtype: Type to return.
|
270 |
+
device: Desired device of returned tensor. Default:
|
271 |
+
uses the current device for the default tensor type.
|
272 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
273 |
+
flag set.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
Quaternions as tensor of shape (N, 4).
|
277 |
+
"""
|
278 |
+
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
279 |
+
s = (o * o).sum(1)
|
280 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
281 |
+
return o
|
282 |
+
|
283 |
+
|
284 |
+
def random_rotations(
|
285 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
286 |
+
):
|
287 |
+
"""
|
288 |
+
Generate random rotations as 3x3 rotation matrices.
|
289 |
+
|
290 |
+
Args:
|
291 |
+
n: Number of rotation matrices in a batch to return.
|
292 |
+
dtype: Type to return.
|
293 |
+
device: Device of returned tensor. Default: if None,
|
294 |
+
uses the current device for the default tensor type.
|
295 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
296 |
+
flag set.
|
297 |
+
|
298 |
+
Returns:
|
299 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
300 |
+
"""
|
301 |
+
quaternions = random_quaternions(
|
302 |
+
n, dtype=dtype, device=device, requires_grad=requires_grad
|
303 |
+
)
|
304 |
+
return quaternion_to_matrix(quaternions)
|
305 |
+
|
306 |
+
|
307 |
+
def random_rotation(
|
308 |
+
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
309 |
+
):
|
310 |
+
"""
|
311 |
+
Generate a single random 3x3 rotation matrix.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
dtype: Type to return
|
315 |
+
device: Device of returned tensor. Default: if None,
|
316 |
+
uses the current device for the default tensor type
|
317 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
318 |
+
flag set
|
319 |
+
|
320 |
+
Returns:
|
321 |
+
Rotation matrix as tensor of shape (3, 3).
|
322 |
+
"""
|
323 |
+
return random_rotations(1, dtype, device, requires_grad)[0]
|
324 |
+
|
325 |
+
|
326 |
+
def standardize_quaternion(quaternions):
|
327 |
+
"""
|
328 |
+
Convert a unit quaternion to a standard form: one in which the real
|
329 |
+
part is non negative.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
quaternions: Quaternions with real part first,
|
333 |
+
as tensor of shape (..., 4).
|
334 |
+
|
335 |
+
Returns:
|
336 |
+
Standardized quaternions as tensor of shape (..., 4).
|
337 |
+
"""
|
338 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
339 |
+
|
340 |
+
|
341 |
+
def quaternion_raw_multiply(a, b):
|
342 |
+
"""
|
343 |
+
Multiply two quaternions.
|
344 |
+
Usual torch rules for broadcasting apply.
|
345 |
+
|
346 |
+
Args:
|
347 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
348 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
352 |
+
"""
|
353 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
354 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
355 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
356 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
357 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
358 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
359 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
360 |
+
|
361 |
+
|
362 |
+
def quaternion_multiply(a, b):
|
363 |
+
"""
|
364 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
365 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
366 |
+
Usual torch rules for broadcasting apply.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
370 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
374 |
+
"""
|
375 |
+
ab = quaternion_raw_multiply(a, b)
|
376 |
+
return standardize_quaternion(ab)
|
377 |
+
|
378 |
+
|
379 |
+
def quaternion_invert(quaternion):
|
380 |
+
"""
|
381 |
+
Given a quaternion representing rotation, get the quaternion representing
|
382 |
+
its inverse.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
386 |
+
first, which must be versors (unit quaternions).
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
390 |
+
"""
|
391 |
+
|
392 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
393 |
+
|
394 |
+
|
395 |
+
def quaternion_apply(quaternion, point):
|
396 |
+
"""
|
397 |
+
Apply the rotation given by a quaternion to a 3D point.
|
398 |
+
Usual torch rules for broadcasting apply.
|
399 |
+
|
400 |
+
Args:
|
401 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
402 |
+
point: Tensor of 3D points of shape (..., 3).
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
Tensor of rotated points of shape (..., 3).
|
406 |
+
"""
|
407 |
+
if point.size(-1) != 3:
|
408 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
409 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
410 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
411 |
+
out = quaternion_raw_multiply(
|
412 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
413 |
+
quaternion_invert(quaternion),
|
414 |
+
)
|
415 |
+
return out[..., 1:]
|
416 |
+
|
417 |
+
|
418 |
+
def axis_angle_to_matrix(axis_angle):
|
419 |
+
"""
|
420 |
+
Convert rotations given as axis/angle to rotation matrices.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
424 |
+
as a tensor of shape (..., 3), where the magnitude is
|
425 |
+
the angle turned anticlockwise in radians around the
|
426 |
+
vector's direction.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
430 |
+
"""
|
431 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
432 |
+
|
433 |
+
|
434 |
+
def matrix_to_axis_angle(matrix):
|
435 |
+
"""
|
436 |
+
Convert rotations given as rotation matrices to axis/angle.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
440 |
+
|
441 |
+
Returns:
|
442 |
+
Rotations given as a vector in axis angle form, as a tensor
|
443 |
+
of shape (..., 3), where the magnitude is the angle
|
444 |
+
turned anticlockwise in radians around the vector's
|
445 |
+
direction.
|
446 |
+
"""
|
447 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
448 |
+
|
449 |
+
|
450 |
+
def axis_angle_to_quaternion(axis_angle):
|
451 |
+
"""
|
452 |
+
Convert rotations given as axis/angle to quaternions.
|
453 |
+
|
454 |
+
Args:
|
455 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
456 |
+
as a tensor of shape (..., 3), where the magnitude is
|
457 |
+
the angle turned anticlockwise in radians around the
|
458 |
+
vector's direction.
|
459 |
+
|
460 |
+
Returns:
|
461 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
462 |
+
"""
|
463 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
464 |
+
half_angles = 0.5 * angles
|
465 |
+
eps = 1e-6
|
466 |
+
small_angles = angles.abs() < eps
|
467 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
468 |
+
sin_half_angles_over_angles[~small_angles] = (
|
469 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
470 |
+
)
|
471 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
472 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
473 |
+
sin_half_angles_over_angles[small_angles] = (
|
474 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
475 |
+
)
|
476 |
+
quaternions = torch.cat(
|
477 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
478 |
+
)
|
479 |
+
return quaternions
|
480 |
+
|
481 |
+
|
482 |
+
def quaternion_to_axis_angle(quaternions):
|
483 |
+
"""
|
484 |
+
Convert rotations given as quaternions to axis/angle.
|
485 |
+
|
486 |
+
Args:
|
487 |
+
quaternions: quaternions with real part first,
|
488 |
+
as tensor of shape (..., 4).
|
489 |
+
|
490 |
+
Returns:
|
491 |
+
Rotations given as a vector in axis angle form, as a tensor
|
492 |
+
of shape (..., 3), where the magnitude is the angle
|
493 |
+
turned anticlockwise in radians around the vector's
|
494 |
+
direction.
|
495 |
+
"""
|
496 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
497 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
498 |
+
angles = 2 * half_angles
|
499 |
+
eps = 1e-6
|
500 |
+
small_angles = angles.abs() < eps
|
501 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
502 |
+
sin_half_angles_over_angles[~small_angles] = (
|
503 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
504 |
+
)
|
505 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
506 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
507 |
+
sin_half_angles_over_angles[small_angles] = (
|
508 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
509 |
+
)
|
510 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
511 |
+
|
512 |
+
|
513 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
514 |
+
"""
|
515 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
516 |
+
using Gram--Schmidt orthogonalisation per Section B of [1].
|
517 |
+
Args:
|
518 |
+
d6: 6D rotation representation, of size (*, 6)
|
519 |
+
|
520 |
+
Returns:
|
521 |
+
batch of rotation matrices of size (*, 3, 3)
|
522 |
+
|
523 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
524 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
525 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
526 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
527 |
+
"""
|
528 |
+
|
529 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
530 |
+
b1 = F.normalize(a1, dim=-1)
|
531 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
532 |
+
b2 = F.normalize(b2, dim=-1)
|
533 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
534 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
535 |
+
|
536 |
+
|
537 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
538 |
+
"""
|
539 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
540 |
+
by dropping the last row. Note that 6D representation is not unique.
|
541 |
+
Args:
|
542 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
543 |
+
|
544 |
+
Returns:
|
545 |
+
6D rotation representation, of size (*, 6)
|
546 |
+
|
547 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
548 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
549 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
550 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
551 |
+
"""
|
552 |
+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
553 |
+
|
554 |
+
|
555 |
+
def axis_angle_to_rotation_6d(axis_angle):
|
556 |
+
"""
|
557 |
+
Convert rotations given as axis/angle to 6D rotation representation by Zhou
|
558 |
+
et al. [1].
|
559 |
+
|
560 |
+
Args:
|
561 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
562 |
+
as a tensor of shape (..., 3), where the magnitude is
|
563 |
+
the angle turned anticlockwise in radians around the
|
564 |
+
vector's direction.
|
565 |
+
|
566 |
+
Returns:
|
567 |
+
6D rotation representation, of size (*, 6)
|
568 |
+
"""
|
569 |
+
return matrix_to_rotation_6d(axis_angle_to_matrix(axis_angle))
|
diffposetalk/wav2vec2.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from packaging import version
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import transformers
|
8 |
+
from transformers import Wav2Vec2Model
|
9 |
+
from transformers.modeling_outputs import BaseModelOutput
|
10 |
+
|
11 |
+
_CONFIG_FOR_DOC = 'Wav2Vec2Config'
|
12 |
+
|
13 |
+
|
14 |
+
# the implementation of Wav2Vec2Model is borrowed from
|
15 |
+
# https://huggingface.co/transformers/_modules/transformers/models/wav2vec2/modeling_wav2vec2.html#Wav2Vec2Model
|
16 |
+
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
17 |
+
def _compute_mask_indices(shape: Tuple[int, int], mask_prob: float, mask_length: int,
|
18 |
+
attention_mask: Optional[torch.Tensor] = None, min_masks: int = 0, ) -> np.ndarray:
|
19 |
+
bsz, all_sz = shape
|
20 |
+
mask = np.full((bsz, all_sz), False)
|
21 |
+
|
22 |
+
all_num_mask = int(mask_prob * all_sz / float(mask_length) + np.random.rand())
|
23 |
+
all_num_mask = max(min_masks, all_num_mask)
|
24 |
+
mask_idcs = []
|
25 |
+
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
26 |
+
for i in range(bsz):
|
27 |
+
if padding_mask is not None:
|
28 |
+
sz = all_sz - padding_mask[i].long().sum().item()
|
29 |
+
num_mask = int(mask_prob * sz / float(mask_length) + np.random.rand())
|
30 |
+
num_mask = max(min_masks, num_mask)
|
31 |
+
else:
|
32 |
+
sz = all_sz
|
33 |
+
num_mask = all_num_mask
|
34 |
+
|
35 |
+
lengths = np.full(num_mask, mask_length)
|
36 |
+
|
37 |
+
if sum(lengths) == 0:
|
38 |
+
lengths[0] = min(mask_length, sz - 1)
|
39 |
+
|
40 |
+
min_len = min(lengths)
|
41 |
+
if sz - min_len <= num_mask:
|
42 |
+
min_len = sz - num_mask - 1
|
43 |
+
|
44 |
+
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
45 |
+
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
46 |
+
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
47 |
+
|
48 |
+
min_len = min([len(m) for m in mask_idcs])
|
49 |
+
for i, mask_idc in enumerate(mask_idcs):
|
50 |
+
if len(mask_idc) > min_len:
|
51 |
+
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
52 |
+
mask[i, mask_idc] = True
|
53 |
+
return mask
|
54 |
+
|
55 |
+
|
56 |
+
# linear interpolation layer
|
57 |
+
def linear_interpolation(features, input_fps, output_fps, output_len=None):
|
58 |
+
# features: (N, C, L)
|
59 |
+
seq_len = features.shape[2] / float(input_fps)
|
60 |
+
if output_len is None:
|
61 |
+
output_len = int(seq_len * output_fps)
|
62 |
+
output_features = F.interpolate(features, size=output_len, align_corners=False, mode='linear')
|
63 |
+
return output_features
|
64 |
+
|
65 |
+
|
66 |
+
class Wav2Vec2Model(Wav2Vec2Model):
|
67 |
+
def __init__(self, config):
|
68 |
+
super().__init__(config)
|
69 |
+
self.is_old_version = version.parse(transformers.__version__) < version.parse('4.7.0')
|
70 |
+
|
71 |
+
def forward(self, input_values, output_fps=25, attention_mask=None, output_attentions=None,
|
72 |
+
output_hidden_states=None, return_dict=None, frame_num=None):
|
73 |
+
self.config.output_attentions = True
|
74 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
75 |
+
output_hidden_states = (
|
76 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states)
|
77 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
78 |
+
|
79 |
+
hidden_states = self.feature_extractor(input_values) # (N, C, L)
|
80 |
+
# Resample the audio feature @ 50 fps to `output_fps`.
|
81 |
+
if frame_num is not None:
|
82 |
+
hidden_states_len = round(frame_num * 50 / output_fps)
|
83 |
+
hidden_states = hidden_states[:, :, :hidden_states_len]
|
84 |
+
hidden_states = linear_interpolation(hidden_states, 50, output_fps, output_len=frame_num)
|
85 |
+
hidden_states = hidden_states.transpose(1, 2) # (N, L, C)
|
86 |
+
|
87 |
+
if attention_mask is not None:
|
88 |
+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
89 |
+
attention_mask = torch.zeros(hidden_states.shape[:2], dtype=hidden_states.dtype,
|
90 |
+
device=hidden_states.device)
|
91 |
+
attention_mask[(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)] = 1
|
92 |
+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
93 |
+
|
94 |
+
if self.is_old_version:
|
95 |
+
hidden_states = self.feature_projection(hidden_states)
|
96 |
+
else:
|
97 |
+
hidden_states = self.feature_projection(hidden_states)[0]
|
98 |
+
|
99 |
+
if self.config.apply_spec_augment and self.training:
|
100 |
+
batch_size, sequence_length, hidden_size = hidden_states.size()
|
101 |
+
if self.config.mask_time_prob > 0:
|
102 |
+
mask_time_indices = _compute_mask_indices((batch_size, sequence_length), self.config.mask_time_prob,
|
103 |
+
self.config.mask_time_length, attention_mask=attention_mask,
|
104 |
+
min_masks=2, )
|
105 |
+
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
106 |
+
if self.config.mask_feature_prob > 0:
|
107 |
+
mask_feature_indices = _compute_mask_indices((batch_size, hidden_size), self.config.mask_feature_prob,
|
108 |
+
self.config.mask_feature_length, )
|
109 |
+
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
110 |
+
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
111 |
+
encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask,
|
112 |
+
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
|
113 |
+
return_dict=return_dict, )
|
114 |
+
hidden_states = encoder_outputs[0]
|
115 |
+
if not return_dict:
|
116 |
+
return (hidden_states,) + encoder_outputs[1:]
|
117 |
+
|
118 |
+
return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states,
|
119 |
+
attentions=encoder_outputs.attentions, )
|