diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..45ad6d765774670f2a3fad149965a1c23184e251 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+TalkingHead-1KH/data_list/train_video_tubes.txt filter=lfs diff=lfs merge=lfs -text
+TalkingHead-1KH/teaser.gif filter=lfs diff=lfs merge=lfs -text
+data_gen/utils/mp_feature_extractors/face_landmarker.task filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a269d9ca55511e478e976768a1302f2121c17868
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,199 @@
+# big files
+data_util/face_tracking/3DMM/01_MorphableModel.mat
+data_util/face_tracking/3DMM/3DMM_info.npy
+
+!/deep_3drecon/BFM/.gitkeep
+deep_3drecon/BFM/Exp_Pca.bin
+deep_3drecon/BFM/01_MorphableModel.mat
+deep_3drecon/BFM/BFM_model_front.mat
+deep_3drecon/network/FaceReconModel.pb
+deep_3drecon/checkpoints/*
+
+.vscode
+### Project ignore
+/checkpoints/*
+!/checkpoints/.gitkeep
+/data/*
+!/data/.gitkeep
+infer_out
+rsync
+.idea
+.DS_Store
+bak
+tmp
+*.tar.gz
+mos
+nbs
+/configs_usr/*
+!/configs_usr/.gitkeep
+/egs_usr/*
+!/egs_usr/.gitkeep
+/rnnoise
+#/usr/*
+#!/usr/.gitkeep
+scripts_usr
+
+# Created by .ignore support plugin (hsz.mobi)
+### Python template
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+data_util/deepspeech_features/deepspeech-0.9.2-models.pbmm
+deep_3drecon/mesh_renderer/bazel-bin
+deep_3drecon/mesh_renderer/bazel-mesh_renderer
+deep_3drecon/mesh_renderer/bazel-out
+deep_3drecon/mesh_renderer/bazel-testlogs
+
+.nfs*
+infer_outs/*
+
+*.pth
+venv_113/*
+*.pt
+experiments/trials
+flame_3drecon/*
+
+temp/
+/kill.sh
+/datasets
+data_util/imagenet_classes.txt
+process_data_May.sh
+/env_prepare_reproduce.md
+/my_debug.py
+
+utils/metrics/shape_predictor_68_face_landmarks.dat
+*.mp4
+_torchshow/
+*.png
+*.jpg
+
+*.mrc
+
+deep_3drecon/BFM/BFM_exp_idx.mat
+deep_3drecon/BFM/BFM_front_idx.mat
+deep_3drecon/BFM/facemodel_info.mat
+deep_3drecon/BFM/index_mp468_from_mesh35709.npy
+deep_3drecon/BFM/mediapipe_in_bfm53201.npy
+deep_3drecon/BFM/std_exp.txt
+!data/raw/examples/*
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ff625085cf0dfdfd39a0a890cb3654a848349712
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 ZhenhuiYe
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README-zh.md b/README-zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa34b75697cad9603fd42aa6b0090a61040ddabf
--- /dev/null
+++ b/README-zh.md
@@ -0,0 +1,144 @@
+# Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
+[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/Real3DPortrait) | [English Readme](./README.md)
+
+这个仓库是Real3D-Portrait的官方PyTorch实现, 用于实现单参考图(one-shot)、高视频真实度(video reality)的虚拟人视频合成。您可以访问我们的[项目页面](https://real3dportrait.github.io/)以观看Demo视频, 阅读我们的[论文](https://arxiv.org/pdf/2401.08503.pdf)以了解技术细节。
+
+
+
+
+
+
+
+## 您可能同样感兴趣
+- 我们发布了GeneFace++([https://github.com/yerfor/GeneFacePlusPlus](https://github.com/yerfor/GeneFacePlusPlus)), 一个专注于提升单个特定说话人效果的说话人合成系统,它实现了高嘴形对齐、高视频质量和高系统效率。
+
+
+# 快速上手!
+## 安装环境
+请参照[环境配置文档](docs/prepare_env/install_guide-zh.md),配置Conda环境`real3dportrait`
+## 下载预训练与第三方模型
+### 3DMM BFM模型
+下载3DMM BFM模型:[Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) 提取码: m9q5
+
+
+下载完成后,放置全部的文件到`deep_3drecon/BFM`里,文件结构如下:
+```
+deep_3drecon/BFM/
+├── 01_MorphableModel.mat
+├── BFM_exp_idx.mat
+├── BFM_front_idx.mat
+├── BFM_model_front.mat
+├── Exp_Pca.bin
+├── facemodel_info.mat
+├── index_mp468_from_mesh35709.npy
+├── mediapipe_in_bfm53201.npy
+└── std_exp.txt
+```
+
+### 预训练模型
+下载预训练的Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) 或 [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) 提取码: 6x4f
+
+下载完成后,放置全部的文件到`checkpoints`里并解压,文件结构如下:
+```
+checkpoints/
+├── 240210_real3dportrait_orig
+│ ├── audio2secc_vae
+│ │ ├── config.yaml
+│ │ └── model_ckpt_steps_400000.ckpt
+│ └── secc2plane_torso_orig
+│ ├── config.yaml
+│ └── model_ckpt_steps_100000.ckpt
+└── pretrained_ckpts
+ └── mit_b0.pth
+```
+
+## 推理测试
+我们目前提供了**命令行(CLI)**, **Gradio WebUI**与**Google Colab**推理方式。我们同时支持音频驱动(Audio-Driven)与视频驱动(Video-Driven):
+
+- 音频驱动场景下,需要至少提供`source image`与`driving audio`
+- 视频驱动场景下,需要至少提供`source image`与`driving expression video`
+
+### Gradio WebUI推理
+启动Gradio WebUI,按照提示上传素材,点击`Generate`按钮即可推理:
+```bash
+python inference/app_real3dportrait.py
+```
+
+### Google Colab推理
+运行这个[Colab](https://colab.research.google.com/github/yerfor/Real3DPortrait/blob/main/inference/real3dportrait_demo.ipynb)中的所有cell。
+
+### 命令行推理
+首先,切换至项目根目录并启用Conda环境:
+```bash
+cd
+conda activate real3dportrait
+export PYTHON_PATH=./
+```
+音频驱动场景下,需要至少提供source image与driving audio,推理指令:
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+视频驱动场景下,需要至少提供source image与driving expression video(作为drv_aud参数),推理指令:
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+一些可选参数注释:
+- `--drv_pose` 指定时提供了运动pose信息,不指定则为静态运动
+- `--bg_img` 指定时提供了背景信息,不指定则为source image提取的背景
+- `--mouth_amp` 嘴部张幅参数,值越大张幅越大
+- `--map_to_init_pose` 值为`True`时,首帧的pose将被映射到source pose,后续帧也作相同变换
+- `--temperature` 代表audio2motion的采样温度,值越大结果越多样,但同时精确度越低
+- `--out_name` 不指定时,结果将保存在`infer_out/tmp/`中
+- `--out_mode` 值为`final`时,只输出说话人视频;值为`concat_debug`时,同时输出一些可视化的中间结果
+
+指令示例:
+```bash
+python inference/real3d_infer.py \
+--src_img data/raw/examples/Macron.png \
+--drv_aud data/raw/examples/Obama_5s.wav \
+--drv_pose data/raw/examples/May_5s.mp4 \
+--bg_img data/raw/examples/bg.png \
+--out_name output.mp4 \
+--out_mode concat_debug
+```
+
+## ToDo
+- [x] **Release Pre-trained weights of Real3D-Portrait.**
+- [x] **Release Inference Code of Real3D-Portrait.**
+- [x] **Release Gradio Demo of Real3D-Portrait..**
+- [x] **Release Google Colab of Real3D-Portrait..**
+- [ ] **Release Training Code of Real3D-Portrait.**
+
+# 引用我们
+如果这个仓库对你有帮助,请考虑引用我们的工作:
+```
+@article{ye2024real3d,
+ title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
+ author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
+ journal={arXiv preprint arXiv:2401.08503},
+ year={2024}
+}
+@article{ye2023geneface++,
+ title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
+ author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2305.00787},
+ year={2023}
+}
+@article{ye2023geneface,
+ title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
+ author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2301.13430},
+ year={2023}
+}
+```
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..97b38195a44dcb1437ad776b2a5d5bb010e92d8e
--- /dev/null
+++ b/README.md
@@ -0,0 +1,149 @@
+# Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis | ICLR 2024 Spotlight
+[](https://arxiv.org/abs/2401.08503)| [](https://github.com/yerfor/Real3DPortrait) | [中文文档](./README-zh.md)
+
+This is the official repo of Real3D-Portrait with Pytorch implementation, for one-shot and high video reality talking portrait synthesis. You can visit our [Demo Page](https://real3dportrait.github.io/) for watching demo videos, and read our [Paper](https://arxiv.org/pdf/2401.08503.pdf) for technical details.
+
+
+
+
+
+
+
+## 🔥 Update
+- \[2024.07.02\] We release the training code of the whole system, including audio-to-motion model, image-to-plane model, secc2plane model, and the secc2plane_torso model, please refer to `docs/train_models`. We also release the code to preprocess and binarize the dataset, please refer to `docs/process_data`. Thanks for your patience!
+
+## You may also interested in
+- We release the code of GeneFace++, ([https://github.com/yerfor/GeneFacePlusPlus](https://github.com/yerfor/GeneFacePlusPlus)), a NeRF-based person-specific talking face system, which aims at producing high-quality talking face videos with extreme idenetity-similarity of the target person.
+
+# Quick Start!
+## Environment Installation
+Please refer to [Installation Guide](docs/prepare_env/install_guide.md), prepare a Conda environment `real3dportrait`.
+## Download Pre-trained & Third-Party Models
+### 3DMM BFM Model
+Download 3DMM BFM Model from [Google Drive](https://drive.google.com/drive/folders/1o4t5YIw7w4cMUN4bgU9nPf6IyWVG1bEk?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1aqv1z_qZ23Vp2VP4uxxblQ?pwd=m9q5 ) with Password m9q5.
+
+
+Put all the files in `deep_3drecon/BFM`, the file structure will be like this:
+```
+deep_3drecon/BFM/
+├── 01_MorphableModel.mat
+├── BFM_exp_idx.mat
+├── BFM_front_idx.mat
+├── BFM_model_front.mat
+├── Exp_Pca.bin
+├── facemodel_info.mat
+├── index_mp468_from_mesh35709.npy
+├── mediapipe_in_bfm53201.npy
+└── std_exp.txt
+```
+
+### Pre-trained Real3D-Portrait
+Download Pre-trained Real3D-Portrait:[Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing) or [BaiduYun Disk](https://pan.baidu.com/s/1Mjmbn0UtA1Zm9owZ7zWNgQ?pwd=6x4f ) with Password 6x4f
+
+Put the zip files in `checkpoints` and unzip them, the file structure will be like this:
+```
+checkpoints/
+├── 240210_real3dportrait_orig
+│ ├── audio2secc_vae
+│ │ ├── config.yaml
+│ │ └── model_ckpt_steps_400000.ckpt
+│ └── secc2plane_torso_orig
+│ ├── config.yaml
+│ └── model_ckpt_steps_100000.ckpt
+└── pretrained_ckpts
+ └── mit_b0.pth
+```
+
+## Inference
+Currently, we provide **CLI**, **Gradio WebUI** and **Google Colab** for inference. We support both Audio-Driven and Video-Driven methods:
+
+- For audio-driven, at least prepare `source image` and `driving audio`
+- For video-driven, at least prepare `source image` and `driving expression video`
+
+### Gradio WebUI
+Run Gradio WebUI demo, upload resouces in webpage,click `Generate` button to inference:
+```bash
+python inference/app_real3dportrait.py
+```
+
+### Google Colab
+Run all the cells in this [Colab](https://colab.research.google.com/github/yerfor/Real3DPortrait/blob/main/inference/real3dportrait_demo.ipynb).
+
+### CLI Inference
+Firstly, switch to project folder and activate conda environment:
+```bash
+cd
+conda activate real3dportrait
+export PYTHONPATH=./
+```
+For audio-driven, provide source image and driving audio:
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+For video-driven, provide source image and driving expression video(as `--drv_aud` parameter):
+```bash
+python inference/real3d_infer.py \
+--src_img \
+--drv_aud \
+--drv_pose \
+--bg_img \
+--out_name
+```
+Some optional parameters:
+- `--drv_pose` provide motion pose information, default to be static poses
+- `--bg_img` provide background information, default to be image extracted from source
+- `--mouth_amp` mouth amplitude, higher value leads to wider mouth
+- `--map_to_init_pose` when set to `True`, the initial pose will be mapped to source pose, and other poses will be equally transformed
+- `--temperature` stands for the sampling temperature of audio2motion, higher for more diverse results at the expense of lower accuracy
+- `--out_name` When not assigned, the results will be stored at `infer_out/tmp/`.
+- `--out_mode` When `final`, only outputs the final result; when `concat_debug`, also outputs visualization of several intermediate process.
+
+Commandline example:
+```bash
+python inference/real3d_infer.py \
+--src_img data/raw/examples/Macron.png \
+--drv_aud data/raw/examples/Obama_5s.wav \
+--drv_pose data/raw/examples/May_5s.mp4 \
+--bg_img data/raw/examples/bg.png \
+--out_name output.mp4 \
+--out_mode concat_debug
+```
+
+# ToDo
+- [x] **Release Pre-trained weights of Real3D-Portrait.**
+- [x] **Release Inference Code of Real3D-Portrait.**
+- [x] **Release Gradio Demo of Real3D-Portrait..**
+- [x] **Release Google Colab of Real3D-Portrait..**
+- [ ] **Release Training Code of Real3D-Portrait.**
+
+# Disclaimer
+Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's talking video without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
+
+# Citation
+If you found this repo helpful to your work, please consider cite us:
+```
+@article{ye2024real3d,
+ title={Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis},
+ author={Ye, Zhenhui and Zhong, Tianyun and Ren, Yi and Yang, Jiaqi and Li, Weichuang and Huang, Jiawei and Jiang, Ziyue and He, Jinzheng and Huang, Rongjie and Liu, Jinglin and others},
+ journal={arXiv preprint arXiv:2401.08503},
+ year={2024}
+}
+@article{ye2023geneface++,
+ title={GeneFace++: Generalized and Stable Real-Time Audio-Driven 3D Talking Face Generation},
+ author={Ye, Zhenhui and He, Jinzheng and Jiang, Ziyue and Huang, Rongjie and Huang, Jiawei and Liu, Jinglin and Ren, Yi and Yin, Xiang and Ma, Zejun and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2305.00787},
+ year={2023}
+}
+@article{ye2023geneface,
+ title={GeneFace: Generalized and High-Fidelity Audio-Driven 3D Talking Face Synthesis},
+ author={Ye, Zhenhui and Jiang, Ziyue and Ren, Yi and Liu, Jinglin and He, Jinzheng and Zhao, Zhou},
+ journal={arXiv preprint arXiv:2301.13430},
+ year={2023}
+}
+```
diff --git a/TalkingHead-1KH/.gitignore b/TalkingHead-1KH/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..69369055b187881d038bf178f48c74f11cf36a76
--- /dev/null
+++ b/TalkingHead-1KH/.gitignore
@@ -0,0 +1,2 @@
+data/
+.DS_Store
\ No newline at end of file
diff --git a/TalkingHead-1KH/LICENSE.txt b/TalkingHead-1KH/LICENSE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..0afe65a09fb0044be6f053792b62c4596ad7c8ae
--- /dev/null
+++ b/TalkingHead-1KH/LICENSE.txt
@@ -0,0 +1,42 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person obtaining a
+# copy of this software and associated documentation files (the "Software"),
+# to deal in the Software without restriction, including without limitation
+# the rights to use, copy, modify, merge, publish, distribute, sublicense,
+# and/or sell copies of the Software, and to permit persons to whom the
+# Software is furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+# DEALINGS IN THE SOFTWARE.
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+ * DEALINGS IN THE SOFTWARE.
+ */
diff --git a/TalkingHead-1KH/README.md b/TalkingHead-1KH/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..307ee02c5193dcb81d6ffb84c4e2631b9ccd3402
--- /dev/null
+++ b/TalkingHead-1KH/README.md
@@ -0,0 +1,84 @@
+## TalkingHead-1KH Dataset
+
+
+
+
+
+
+
+
+
+TalkingHead-1KH is a talking-head dataset consisting of YouTube videos, originally created as a benchmark for face-vid2vid:
+
+> **One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing**
+> Ting-Chun Wang (NVIDIA), Arun Mallya (NVIDIA), Ming-Yu Liu (NVIDIA)
+> https://nvlabs.github.io/face-vid2vid/
+> https://arxiv.org/abs/2011.15126.pdf
+
+The dataset consists of 500k video clips, of which about 80k are greater than 512x512 resolution. Only videos under permissive licenses are included. Note that the number of videos differ from that in the original paper because a more robust preprocessing script was used to split the videos.
+For business inquiries, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
+
+
+## Download
+### Unzip the video metadata
+First, unzip the metadata and put it under the root directory:
+```bash
+unzip data_list.zip
+```
+
+### Unit test
+This step downloads a small subset of the dataset to verify the scripts are working on your computer. You can also skip this step if you want to directly download the entire dataset.
+```bash
+bash videos_download_and_crop.sh small
+```
+The processed clips should appear in `small/cropped_clips`.
+
+### Download the entire dataset
+Please run
+```bash
+bash videos_download_and_crop.sh train
+```
+The script will automatically download the YouTube videos, split them into short clips, and then crop and trim them to include only the face regions. The final processed clips should appear in `train/cropped_clips`.
+
+
+## Evaluation
+To download the evaluation set which consists of only 1080p videos, please run
+```bash
+bash videos_download_and_crop.sh val
+```
+The processed clips should appear in `val/cropped_clips`.
+
+We also provide the reconstruction results synthesized by our model [here](https://drive.google.com/file/d/1BX9zaNL_zowTDruvRB3KvebaSUi3WHWc/view?usp=sharing).
+For each video, we use only the first frame to reconstruct all the following frames.
+
+Furthermore, for models trained using the VoxCeleb2 dataset, we also provide comparisons using another model trained on the VoxCeleb2 dataset.
+Please find the reconstruction results [here](https://drive.google.com/file/d/1HVCFj7WOy9KHP1J76wn-ZExh-nQnff9g/view?usp=sharing).
+
+
+## Licenses
+The individual videos were published in YouTube by their respective authors under [Creative Commons BY 3.0](https://creativecommons.org/licenses/by/3.0/legalcode) license.
+The metadata file, the download script file, the processing script file, and the documentation file are made available under [MIT](LICENSE.txt) license. You can **use, redistribute, and adapt it**, as long as you (a) give appropriate credit by **citing our paper**, (b) **indicate any changes** that you've made, and (c) distribute any derivative works **under the same license**.
+
+
+## Privacy
+When collecting the data, we were careful to only include videos that – to the best of our knowledge – were intended for free use and redistribution by their respective authors. That said, we are committed to protecting the privacy of individuals who do not wish their videos to be included.
+
+If you would like to remove your video from the dataset, you can either
+
+1. Go to YouTube and change the license of your video, or remove your video entirely.
+2. Contact [researchinquiries@nvidia.com](mailto:researchinquiries@nvidia.com). Please include your YouTube video link in the email.
+
+
+## Acknowledgements
+This webpage borrows heavily from the [FFHQ-dataset](https://github.com/NVlabs/ffhq-dataset) page.
+
+## Citation
+If you use this dataset for your work, please cite
+```
+@inproceedings{wang2021facevid2vid,
+ title={One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing},
+ author={Ting-Chun Wang and Arun Mallya and Ming-Yu Liu},
+ booktitle={CVPR},
+ year={2021}
+}
+```
diff --git a/TalkingHead-1KH/data_list.zip b/TalkingHead-1KH/data_list.zip
new file mode 100644
index 0000000000000000000000000000000000000000..cdc9e00dde3d5274874bf1086799c22cbc0cb25f
--- /dev/null
+++ b/TalkingHead-1KH/data_list.zip
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fb6179ecac9f08c7c2d960cb8a36354e9fa2edde7215d1fc5a20d6a342050f1a
+size 6822806
diff --git a/TalkingHead-1KH/data_list/small_video_ids.txt b/TalkingHead-1KH/data_list/small_video_ids.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bae340182208f66e887bff891588806ba93540fc
--- /dev/null
+++ b/TalkingHead-1KH/data_list/small_video_ids.txt
@@ -0,0 +1,2 @@
+--Y9imYnfBw
+-7TMJtnhiPM
\ No newline at end of file
diff --git a/TalkingHead-1KH/data_list/small_video_tubes.txt b/TalkingHead-1KH/data_list/small_video_tubes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..75e9d521d1da4e9d18d52f1b4a2fa81096dc3c52
--- /dev/null
+++ b/TalkingHead-1KH/data_list/small_video_tubes.txt
@@ -0,0 +1,4 @@
+--Y9imYnfBw_0000, 720, 1280, 0, 271, 504, 63, 792, 351
+--Y9imYnfBw_0000, 720, 1280, 1015, 1107, 488, 23, 824, 359
+-7TMJtnhiPM_0000, 720, 1280, 1202, 1607, 345, 26, 857, 538
+-7TMJtnhiPM_0000, 720, 1280, 1608, 1674, 467, 52, 851, 436
\ No newline at end of file
diff --git a/TalkingHead-1KH/data_list/train_video_ids.txt b/TalkingHead-1KH/data_list/train_video_ids.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8790a101d52d84204164d489ff42061f390f4e7a
--- /dev/null
+++ b/TalkingHead-1KH/data_list/train_video_ids.txt
@@ -0,0 +1,2872 @@
+--Y9imYnfBw
+-5Lm-oibjJQ
+-7TMJtnhiPM
+-8lgFak3RwU
+-9qRhqjD7PY
+-A0jrCGS_TE
+-B1Z9vrjpgg
+-B9PTg7XynE
+-BNR218UFX4
+-BPwyGVD2ec
+-DiIOp80LOo
+-EM2xi5Jnd0
+-FbPk4SmV0M
+-GZFS7r2GJ0
+-GtLyhxJ4V4
+-II0hSGU52I
+-J3Z1m6vrCs
+-J9ECsMrNkY
+-KBiHKx54sg
+-KL3wSSFx10
+-Qi6_EcnYEI
+-RuxUoQQaT0
+-TNWOtcCbOU
+-UrrtxGsmoU
+-ZqVCfsj3yg
+-_A67wxucLo
+-_KOlZjGD9c
+-a7oWdSyePY
+-acp9Dk7Ve4
+-axp0KRQJYc
+-ayP3gFLy9s
+-cNe_z2qsGQ
+-dvajjXM2dg
+-dxGq0Mu0oY
+-flEiFmuh64
+-gnJtZFyzZA
+-hLsqlKm-A4
+-mingboa4sM
+-n-HQiT-mkw
+-o8Ns4hvhr0
+-qsTrNdfd1w
+-r-ckuF3JSg
+-tIO8GnrSJM
+-xlo-qTzC7s
+-yhJf04a1mI
+-zqAjBpG0Jw
+0-FQbhkWYuY
+04-JxYnzcq0
+04BN6UOVKOM
+04WJEEb33CY
+04fidgplUXU
+05e4nROsPHM
+06RuVZUbUc8
+07F4x2eeD1k
+0B5ojfWryqA
+0Fz4oferM-c
+0H-cSZFZq_4
+0ND-w_eEOQw
+0NLQfXOo0dM
+0OaM12UIeVA
+0OzxymzEo-A
+0P-9yxO2df4
+0PffDSC_cLU
+0PtUxdA35Zs
+0RQl9e11aeE
+0S269HFjKx0
+0SKHozldsNM
+0S_QNLx2S7E
+0X6ZN_eRaSI
+0_Jci9-d2VY
+0_udiqKlYuU
+0aHSV7nbEvE
+0aZJJHF4K8U
+0aiKZGqjvL4
+0bA8z8qYmdo
+0beHbJDqlOs
+0cSDEZ6NxM8
+0cfM5xNaaSQ
+0ez7IiBfCGE
+0hhbwkQ06EI
+0hmQbZUYG4o
+0kLlrDNq9RA
+0lOkJkh2Hwg
+0lcaZupVMmo
+0oE11MGQRZM
+0pQvRqU43eU
+0pe-7C_c-fY
+0rZJhc1FeU4
+0t9j7CfXTcM
+0uGzSXRi7qQ
+0ueGFLVFi80
+0zTkQXcIaqc
+1-YOlGoKxeM
+13FSE5dCvvM
+1653Kz-SnxA
+16uK2Gbyk2k
+18Kl4qkvGUU
+18Ts-YxUIts
+19baUOGVR2M
+1ADqe6vt0xE
+1BQNyk_BgDQ
+1BuXth0dz5s
+1C7do4tIrfo
+1DgBEF9zbvc
+1FBVcxb4PQA
+1FmSxmpitBA
+1HIHu4mxDLE
+1IJ3ISINhnU
+1JSocaAw--A
+1JwEvzz-kf0
+1MibRcX7KUM
+1MvtBdLyAKc
+1O3ghiyirvU
+1OUkXFsyyP0
+1PDs1mL7TG0
+1PV7Hy_8fhA
+1R5dPw4sYrE
+1SLkDhvOnNQ
+1UHlG0puRUM
+1UyhO1hMhrU
+1WOs6S0VrlA
+1_rRguH_Vx4
+1aQGw_-I25I
+1aY5n5H0ugE
+1bMYn6Vb-Mk
+1bwUe85Mg4U
+1h3J3fDsupc
+1hd5majTxeQ
+1i8607wMy68
+1iXuJgqZhik
+1izPOehU6zA
+1k3DrJOQ008
+1lPU8Tw928Q
+1mAMa9hn-18
+1oPgMsUmC6Y
+1q7yGJbX1sI
+1rxyrZUtm_g
+1tZaTf21D64
+1wKWC0EzK_s
+1wsOU17LXTU
+1yHWhXaXxFk
+1yP_MElLiKM
+1z-0ranXCyc
+1z2ZTPZEx1E
+1z8sqW5xIrA
+216-fEvtTag
+2250kOXVaGE
+22AUNaHGGz8
+22tF3y_epZs
+23eZ2Hjh3gY
+24paFjOVKoE
+264aaUUvB9U
+2CQgBtUJNBc
+2CRIlwxj6uo
+2ChRXYtCCXs
+2DXm-gBAbDQ
+2EgGqTDAiC4
+2H7qsNMyxvU
+2HPH4qtE95k
+2Ix4h6gaa1w
+2JgN7AZO_6Y
+2KTA78QK5-A
+2KmXllrBNHw
+2L-QZfbDm_s
+2LdAF2dHHFU
+2Li-PLW3Aew
+2NV_nn8apOA
+2O_SN4pg-lk
+2ObBpwabNmg
+2OfVzz2aZ8c
+2PejNMim-YQ
+2PqqZcC31ys
+2R3GNE3ob-4
+2Rtp8zWE6nQ
+2T2f2cENuiM
+2TmfeHf8Rqw
+2VLez1fkVq4
+2VdwqgI2l2o
+2bAuuK1uU1M
+2d9BcdLC4dA
+2dt6XZlI-Lg
+2e6a4eog8jc
+2eZgkkbNq54
+2eaxjEpQPkk
+2ep9Eb36OCY
+2gIkmyxa2ds
+2gMaIjvjCFo
+2k7VhqEg7ws
+2llLx_WkF3Q
+2llMolwKs88
+2mLsOuOve8k
+2nhDjEnUslE
+2nj6tpVKUxQ
+2qDH4bPvfMw
+2qVqaA1hn7A
+2rxU553pKgs
+30QwIdz8Vjc
+30gEiweaAVQ
+32DCEzpcVEY
+35ELLdyGmpc
+37QWwuNFVmU
+39Axc3FIu9A
+3AJdjwSsKoI
+3AoYXlMCioI
+3BkL0UXgNDs
+3C0T230n-mo
+3CGECvfakGM
+3EFiRh-y4L0
+3EJtdIIh43s
+3EYU3VTI3IU
+3EsqpF-W_wQ
+3FKX7kcU4hA
+3GUcVRz_JyA
+3I2jJpnw_xY
+3KNGPFM4d6c
+3KtIBPwaRkk
+3NJN-C1R9gA
+3Qlr9YoHBkg
+3RITVJy7ogI
+3Tyan5xgQ4Y
+3UYOtcC00L4
+3YE8YLldLu8
+3Yepui-bWyw
+3ZLYs6Wj51s
+3ZQL-3ElxpY
+3ZlXiEZb3k0
+3_celBcyBJE
+3fyphrZ-yUc
+3gCA0Z4y7wA
+3hPHW74cBRM
+3jESlqRuLw0
+3kyzHQL4s1c
+3l_puRAIDDU
+3lpnoDrt5Tw
+3nJexBvb5UY
+3oV_cljdDs8
+3pIR5nfFNQM
+3qBW_x5fatI
+3tKwbLJEYZg
+3v1tt9mD6cg
+3viekW3AnRA
+3vudNvflufc
+3ytLt3BDqFU
+3zhCbaVBEjg
+4-1Nci8AkAc
+4-7ABAFvCZE
+41Q97FiM2n8
+43wXnwHo_qU
+465Wt-eX2RY
+47OtSe4dffo
+47slfR-Knq0
+4A9OeOEnFyA
+4BPibf6C35E
+4DdTNSKleK4
+4Hg2Eu-F9mU
+4NXUGnLbl5Y
+4Q6NknbHNiQ
+4QQyjqtHwlY
+4QgmM1dcHMw
+4R4aN98Qrbo
+4RN-c0HdJOI
+4cz4pq_N--g
+4gxMVLXmBiE
+4kR-Fev28po
+4l515yNhqxg
+4lB_x0kLffU
+4lOnYQCxbOM
+4neIRMYUT0U
+4nqMZX542bU
+4oXZCuXwrXo
+4q7lHy-1U6o
+4qGWz9v-UXk
+4rJ_RB3Iwws
+4tbkGu8boyo
+4u7y4RPfV5c
+4udsKgSP7UA
+4ujhmvsSE9c
+4x234xeqtHg
+4xlKjVPFFv4
+4zzVjonyHcQ
+50zJWKotPP4
+51Qp9Z3ZlHs
+58GcsaJhTzM
+58K7XpkuMCM
+5JG5xwguT6I
+5K4dtC2A_bw
+5OXy29bFfPk
+5PGVFUADgGY
+5PUTMuJxqZQ
+5TfPoq3lkso
+5WZmSJYYxPk
+5Yz7ssYCa8Y
+5Z8x06jnjSk
+5ZYrc-3gQVo
+5db6EORa_0U
+5dsxC8M7vCM
+5gx3WNfa56g
+5icAa3G2X8E
+5kas2jBObUY
+5kuqWp4gjV4
+5l2j8k759-4
+5leH4t1V9LY
+5llwjtJDqMo
+5meC4Z61qGg
+5okG5zh9ePY
+5oybklNuCZU
+5qaBD538C9k
+5r2Y6QsFACY
+5s9UdpT0TOo
+5shU5ZVQuEc
+5t4cwPdsVgQ
+5ty6o1t2emk
+5u-Aw6NOIy0
+5vAj_wY8c7s
+5xeVIgWq7s4
+5yZEtG_r-wI
+6-6m2H_aE-s
+6-O4gWLQkgk
+60bkwhHzols
+60dSEN1RIbE
+627Ufg0PVkg
+62OgF2Pw09o
+62aDb2JgFQc
+63gG5qmwREo
+63iL92MjDT4
+63stDIiwraM
+64Yx3Odqo9c
+67zz9OzY07c
+681IyvC_JAg
+685CK3xr0jg
+68pNYnDuoQA
+69pi63RgeJ4
+6CIMXUDzKtQ
+6CS50XMGV0o
+6Dtn8MfMPVE
+6Fd-CsvZQX8
+6IA7AuXVSxQ
+6IdeRRqG1ak
+6JIUlIwX3Pw
+6JmuCmI7Iqo
+6KKDqqV8OUk
+6LljU1cDSbI
+6MhEveLVeO4
+6Nyz_yd_GJw
+6QgRw4lDN10
+6R08SayU3bU
+6SA9lGH3JN4
+6SsIE00duz0
+6T5iJ6TjWj0
+6V3kI3QBWK8
+6VVCVtxeq1g
+6YBDqpgpVck
+6aAa8M4vg2k
+6bm6Y9TaX1w
+6bnxY8W5otQ
+6frz5TaAIto
+6gyKLNQH44I
+6iXGqEq_cpI
+6lMj2CGW6u0
+6lzxVmBJIlk
+6mL4rjxEnbo
+6nDOJdyJDk0
+6nPPyCrkwSE
+6ng7_H8pdBo
+6oHmiJLrEAk
+6oLStjIxffE
+6p6boxrvbZk
+6pCBqmPozE4
+6ptI5B4a-ag
+6rKSZPwHTf8
+6tx3pU5F1x8
+6wyDTrAPV7s
+6xSxXiHwMrg
+6yt3pqCQn6s
+6zXMtU6jgc8
+7-AfS9rehcM
+7-ByRppD-EE
+7-QzoS-dW-c
+702SXH0JdaQ
+73hls2GdB5o
+75cv2PgOmX4
+75mMPO0x4Gs
+78UIF4JCcYc
+798S4UbhNE4
+7AjsnSEhZ-w
+7DJdGbYHll4
+7KcRJyXmuzo
+7MBqGLvoQG0
+7ONrAflL5Oc
+7PQ0QmUGpvw
+7Q2Qe-zid24
+7Qu_ETu3vi0
+7Rgh8v8Qmzg
+7STjD4eWMs8
+7UQMqOkGD8M
+7W9ACrwLn1A
+7_ppXSABYLY
+7bAPKKE_tzA
+7bC_8QTdbHM
+7bXAZfRTZQk
+7c0wYvntu8M
+7c7ccOHMK8o
+7dLu8wqYpJw
+7dmNo3X-Lus
+7eOwm6nBBGA
+7elaTVxAEX4
+7f6b7b_yzQY
+7flG_V2SHc0
+7g7RpRl-Pi8
+7hHOLdgvG-4
+7hfFxTbLFWs
+7jyokhjUCyk
+7l7GryUZGvY
+7mHwsV3Mb-Y
+7o8NcTiXGYI
+7oKjW1OIjuw
+7pg_Dgs0wUU
+7s_Sb4-mwes
+7uwUj4aX2YA
+7w68Up6F9ZI
+7zWswXer8i4
+8-RmR1XmaxE
+8-hAwTTVYFM
+80It12pD4j8
+829gaOsl7Z8
+86OqcZtLtgE
+8BZC5UBidm0
+8C5cn_Qj0G8
+8CTEl-Zhv38
+8DzVNa_BjOI
+8HbsSkuiPmg
+8J6FseCtEXc
+8JtWrFwQ-h0
+8KD2cPzxF3U
+8Kux1TQWdLU
+8LSK5wIjZMk
+8MPjG8RMYJY
+8N-x9QC50m0
+8NJ74aVTZMc
+8P3rViI8Xw8
+8SIWlQ0ZUqY
+8ToIwnP2a-g
+8TunLMoE9Xw
+8Uij8BYDuf8
+8WfV93go5TM
+8XJp2c05iVk
+8XlrNlfd-9M
+8_yb_nW5x6I
+8aafXYh_gHA
+8bQPi0ssTLw
+8dI7AzzZLXw
+8e1BMiU951c
+8gg-oKufUo4
+8gtpnlVb31U
+8lrlXoXGQo0
+8pcELaZV2b0
+8qUQwmwC7Oc
+8quGD9W7B2I
+8rwcfIrAXtA
+8uyPpy2ejA4
+8vzKZVqXmo4
+8xhqhC_PHzE
+8xo4s6tYzzs
+9-EtiWDJbJw
+9-a3sSZeXDE
+9-nsXlNXRDw
+90dka2zrP1Y
+90sJHDwKmSg
+91Zu4JRnfxc
+91z62p7t-AU
+95hYfRw1aHA
+95zvFw1VkqQ
+97C_sZ2821s
+9Dm1Sekkcdw
+9G4s_qIxJYk
+9I8mQpFAJ50
+9IGRsXq9Wis
+9ITi5C8vHpw
+9JJ3ullABD8
+9K4A7e3clMM
+9KPyflyHP6s
+9KisGZnflBc
+9MOleOgz5To
+9O2D9K7l-FU
+9Pj-QPlN2CU
+9QeqkUN0bNU
+9S9O2T1B6xE
+9SflFku0eKo
+9UaAyI-uI30
+9Wb409Nlhlw
+9WkYKiltzJU
+9XYx1vUd-v8
+9Zb8e6nE5QI
+9Zuk9Huqdrg
+9_VLbfRXTss
+9awtQRbMhG4
+9c7zctYTBLA
+9dCHp07it-Q
+9eHXo_KFvJ0
+9f0eB75r-Y8
+9h3m1jzWWEU
+9kUljb9G-MY
+9kgiApzKDMw
+9mAiyn6gMJw
+9mnTCYsbKfw
+9n7hn-M4GpI
+9oMmZIJijgY
+9p8z1A3TsxU
+9r8yZ-68pkY
+9s3BPDNEJek
+9t7ujBSH3WM
+9ujM2nAMK0g
+9uyYxs79EcY
+9ychKZIG8ms
+9ydk7zFmOxU
+9ydusnvBSys
+A-QOg-tFApA
+A0AViKj8EGk
+A2B-b-nfCqk
+A2DdRsUdFeU
+A2lEI0kaf3k
+A3uNIgDmqwI
+A6xjN8BqDjk
+A7ktYbVwr90
+AAprf4PLDM4
+ACsoFzXDE3Q
+ADQSqUBZjvE
+AEKZERIDiUk
+AHKGGtJ15o8
+AIXQngEnJgY
+AJa2DO_woJ4
+AKkYU-fExWc
+AL1pISpcG2Q
+AP0S21vT3Co
+AQe2ANirwW4
+AQuAGO9ceIU
+ARXYGhV5VFg
+ATgNzwHSfjw
+AUtRgfFUCl8
+AVc5fXa0oMs
+AWJcd3F-HPY
+AXfRrHD4Cps
+Ac8Fsu0WVKg
+AccPbM4JhFI
+Ad5je4UNgDw
+Adgif4D3ujk
+AfZNZ3bfJvA
+Afy7H04X9Us
+Ag-zqjX1TV8
+Ag1AKIl_2GM
+AglvA1tduMA
+AkU94JdXbQ0
+AnxrJiS5uKU
+AptPjGKdaDU
+Aqu7R_7vFKM
+AuLoMmjFONE
+AwvReatHB2g
+AygCTeXnJ6o
+AzECoalJ4WU
+AziRcPo6rm0
+B-8ovk81nNM
+B-qxGhkRojc
+B0KlNLkO3qE
+B5uSsp0Rbbc
+B7HzMw9rSMs
+BAj3fHStRGI
+BB4bpEKrOlM
+BBWIPL66Fpc
+BDLtYpLZfbU
+BE4Y5Uc53Nc
+BElqXjOG5Gk
+BErluP3jDjw
+BF7hcsKb1WA
+BFWbo17t-Ig
+BFat39XKT2E
+BFoqBnl5XNw
+BGJuAODr8Ks
+BH7-eZYnJkE
+BIuRA0GGIgk
+BK1VxSDsCu8
+BLcZAhQzQF0
+BM2mqrIXY2w
+BNgmYFwUjjw
+BNyCZJOTRZA
+BPLCiXRSBNk
+BQL5wkJS4y0
+BUouNsjhTTc
+BXT72YlkQrA
+BY2ADlTxdZI
+BYgCS4of0TI
+BYvQek24Kbs
+Bb3XWac-WuM
+Bc7WoDXhcjM
+BcTOSxcv2_o
+BcjJjiE4Ivs
+BfdcdAIJh4g
+BhEio_W1LU4
+Bio8ZpEFlqY
+BjrDmB15S-M
+BkTiMi0Owuw
+BlCoAfks8kk
+BoWXd-LNnm4
+BoXR8KrAfIQ
+Boj9eD0Wug8
+BowyM_Wlsd8
+Bq0vohH2pL4
+BsDYAcOsWqk
+BsgJuviWgJI
+Bv3taVzLZZU
+Bv7wdHRhifM
+BzAzXDgqKzc
+Bzb0FhMqaU8
+C2TLW8MS33E
+C2rVYklWl8I
+C3589eDY5rU
+C4bevJm-MbI
+C5WNr5vzUPs
+C7c1LAUbSho
+CA-0cn2Dbgo
+CBYhVcO4WgI
+CEJTBzMKcuk
+CFUGsvVC9mQ
+CG0OnKUqziA
+CIIsZt9c8nA
+CMAiUAvqIh4
+CNFTLmMYY1c
+CPckick_ioM
+CQUP1LhDLPM
+CWLEHmYHNro
+CZjbKqbYS-A
+CaD5TRQQNsI
+Cd4ZxGNbwTw
+CfA30p4X9g4
+CfKy85LA_bs
+CjnVq4zLT6s
+Ckrrk5oBneA
+ClVtccpCs6g
+Clsy0PuGl2A
+Cptzvn3nM0Q
+Cr0cBSpnn40
+Cr2VggKQrQg
+CsG0Or6-SiI
+CtpXTXjQk4o
+CuZnbR4fb_M
+CuoCrLWcsjI
+CuyAI82HFe0
+Cymmi8L0O1E
+CyyImnREpbg
+D0TgSpsBabY
+D17AYqYPFDk
+D2UT1AmyZFE
+D4T0Ffg1I_Y
+D5ZZUHKPC10
+D5gW_X-Db74
+D7wvFZjtVOw
+D7zgzpc_PVI
+D8TFETlLRdA
+D9aRZXIOX5k
+D9ocoySPGOk
+DA2mx0793uI
+DCUnruZponA
+DChlO5fNMGw
+DFqDbrTGTnY
+DGKY7K-pyqw
+DGoTdntvKfs
+DKnHYcKfz6Q
+DMVduTyjp1k
+DPDF4odJrq8
+DRHNdWF4Kho
+DSoKEFb8R_w
+DStiXbf9Mk0
+D_obsdDVv20
+Ddltu_Cq4E0
+DdtEInQEQ-s
+DeeRqZd7sCE
+DgGKe3kCx74
+DgpbuHAOgf8
+Di9AwmKtblo
+DjHPXW6Crac
+DnOLvKEYIQI
+DrDOMMmwPvI
+Dtfi2BHWwaY
+DtuJ55tmjps
+DuNW0KQ_GZM
+E-hVDqrQq6M
+E-qQNDCVSnk
+E1QbVtkza54
+E1gXzYA0tFA
+E2uUfyT64VY
+E3Q61rKXhrM
+E4zPYb7O2EU
+E7XkvbCu-jU
+E9tUCuAZ-LU
+EAc4WhNQZ30
+EB_d1jK1R44
+ECfLtssUZa0
+EEq8OK4BUyM
+EFLh9Vqr-YU
+EGCqtu8qujE
+EH6q2YGx45M
+EHp_I9ETmtQ
+EHrbAhrbw9Y
+EJMlBT6jptE
+EJZvz09LVa4
+EK3WoT1Gqvs
+ENXP9HEul98
+EQJN03a6M28
+EQLg-kHxwCA
+EQ_VlFQT9hg
+EUS1m5MSt9k
+EUqWrqb9Oug
+EVBOSk6kEwg
+EWSp2QMzKv8
+EYaKFlWd2MY
+E_LYrCtoTIA
+EaHZLUWwxfQ
+EaQOEXTkKYg
+EckpALTiYhE
+Ed4bxiP1RPM
+EeqwFjqFvJA
+EiW4lKrMXQ4
+EjB9J20nulw
+EjBkTt0LHbU
+EnPJk_9Ug7I
+EnU6HVRC4s0
+EoHwvsJcBNg
+Eojazns82hw
+Eotj0EeepoQ
+Eoxazjg1NUA
+Erso0HgtV5A
+Esu34JYC2YQ
+EtctGvH92Ww
+EuVBpqBgmvM
+Eup2Ca9Kiis
+EwXXD0uLj8I
+Eylcb4rbLSo
+F-GzNvvs-lU
+F-nsVjM7FmU
+F07fXd4vVlg
+F0InXG0ln4Y
+F2G-buBtp7w
+F2yK5VkHRaA
+F4xgvj4kSnU
+F5KV-iaMKK0
+F63B6wWXGtA
+F6ShWjU7GaI
+F7oit5SKxdw
+FAf0YtSelug
+FBQDiiEbknE
+FBbfCBOJOt4
+FEc-U45TzKM
+FGgMrNSmMn4
+FICT79cA3U4
+FL9qpSH5eKw
+FLI0WmBWWv4
+FLIippwrXSc
+FMIDAWVPq7c
+FNErh9EogUg
+FTN_93Px-Qc
+FUSU_WYPwx4
+FV4aEpanJ2I
+FV7tKSeGr3Y
+FXbC_3_8tGM
+FYo5E7zT-vM
+FZLwYiceIOU
+FaoVpVXcZsA
+FlCNvBBqIyU
+FloFzFl0jZc
+FnKhFaijBBI
+FoYce_3oUGs
+Fpo6nvSZirI
+Ft1Nw-Hy8Ao
+Fwh0r8YNLIU
+FxBLtp7UpTI
+FxQchGBQpZA
+FxoOE2dTCHE
+FxrCNf8utsE
+FyPFdBhklEw
+FzXWP7ZHs-Q
+G-6OXtSMyNI
+G0n970JRNII
+G2tfebkUbPo
+G4SMtaNDtfk
+G5hOJXZmqPc
+G66ClBEmWdY
+G7GI04txkOM
+G7T6dTs5YKw
+G9CHdvWwzQ0
+GAS_7760FSo
+GBvfnfwGq5Q
+GG0F0uXuIqQ
+GGIZzL-1gZ0
+GHjjXhd6WAc
+GJar81QVmQk
+GKCxRcuLm8o
+GLUdt99wUY8
+GMndqLvTqhA
+GNx4EbTu10w
+GOQWmjLkqOU
+GPEavB9GXHc
+GS1UT0mSks4
+GTsSe03hPxY
+GUZ4T_xFtwQ
+GVFq0_6imAA
+GXldrjxDZqQ
+GZRbKJMEdk0
+GZVxh8CQFkg
+GbUfu1wF02s
+GcNJkyYSmW8
+GcayBgPOr04
+Gd_zypjbv9E
+Gf8lUImdL3g
+GiShqIyw-_0
+GnwqktjWrVM
+GoZ6KwuAdT0
+GpAOMD6Z_Cs
+GsYl_thySnM
+GtzbWxb9nuQ
+GuVNl_oEMuQ
+GwU1VXsDXbc
+GxDeY_UiGBg
+GxRgKs4TpWo
+Gzo5PEGQpe8
+H2LhBAi-Q8I
+H450lZb-Mdg
+H4TYEoft_rM
+H6L04OEm71w
+H6ODJtcqyTg
+H8t_snz8B5A
+H91qCYIfZuQ
+HA8PjarK2mo
+HAFKRtBHFlQ
+HC-zhWIIC5w
+HCMZ4s2A_k8
+HFCmCsxt1xw
+HFTVP9qIMPE
+HLXDpFYmyqo
+HNdh6Valoys
+HOpKzDhCFtE
+HPHNakev2ss
+HQpZQj3TY80
+HSjvtpwKyhU
+HSu21_qc2kA
+HTJyZwYPQOI
+HUoUVlK_bHI
+HVMqm9jlUDM
+HWNwvBrUUGQ
+HXSxHJO6Srg
+HZ0pn4ijwnQ
+HZuhPDbZtcg
+HbKvcTZqNA0
+Hb_SVDUmWzo
+Hc6K7g7wqrs
+Hcy27nbeNWY
+HdzDCZ28cI4
+Hg0qN4cNJfY
+HhFPSCGRFHY
+HhNo_IOPOtU
+HhljUdMUbs0
+HkfLT86wPkM
+HmexkGBB428
+HsGbFpi3xtk
+Hy6QRP3ENl8
+HyzD8pNlpwI
+HztoBDblr8o
+HzzUW9y9FCQ
+I-9uLKZmxOw
+I-HFjHKJJ7E
+I0lqvxqEKhU
+I1fZdwFStnY
+I25TYNMclKk
+I2Z6LNkwijk
+I3RMF_9xW1o
+I68lZ9jptWU
+I8tkl9kVfaI
+I9taZpV2JfU
+IAYJhZS231s
+IB6_L9xsnAo
+IBESpBTIQTQ
+IC3EX6ipxFo
+IGEJo3QHvSI
+IGgpoI_0oPs
+IKMjg2fEGgE
+ILQii0-r-bE
+IM-TQHJKefA
+IM1xpnkmG7o
+IMEIBu0uULg
+IMFI8waM8rs
+IPPVWl-jPqk
+IRBAZJ4lF0U
+ISip9JRbYNs
+ITBbGDndjGM
+ITjHgkTTX_s
+IaHioK3Ljz0
+IahmVXN7xEQ
+Iam-aEiQOeI
+IatKu7sngG8
+IdZPa5vWdtc
+If2Fw0z6uxY
+IfnTz7vZyVg
+Iib7x8rYE7E
+ImCLzPvVKTI
+InL3YA_6P6s
+InP5DEpeVSU
+InSvBuHK4vI
+Io6JjgckHbg
+IsMFdiLsqbg
+Is_C4-xmayE
+IuuRvopzIf8
+IxYu1FAY5qc
+J1OGqF5Eo1k
+J3-ySSl7ceM
+J44SPYSVAAc
+J4vKu-s3OqM
+J71SYNCcRQI
+J78srd3-odQ
+J7BLQbvZyrU
+J7ei1-rYHMU
+J8ifUEgXF-o
+JC6pZ92y-hk
+JFkqEW-sz1c
+JGwMIlpgR2A
+JIEgN3las5E
+JIvGXG4z9X4
+JKg4o6SHbCY
+JLEPOAlZ7LU
+JP5ywOknF-8
+JPP_3u8gD-U
+JPmZtJ8vgAo
+JQTxw8OdBKw
+JTq75Se8vRA
+JUM_s4uQDW4
+JUSVl1JXYl0
+JVSwiQ0tNLM
+JZTlzyHnQeA
+Je69HPxSd_c
+JeZ5gAUnlFk
+JfjLKBO27nw
+JgwRwOWyR0Y
+JjSIRfrDKF4
+JmWnjvHEM38
+JmjxtqnhzHI
+JpL__knumpM
+JqQtOGCWrQM
+JrIwV6YniCg
+JrXU1owDEVs
+Jriq9eOSu7g
+JsFEIvtKCns
+Jtu5eFsIPmw
+Juen1iIJQTE
+Juv1TqsYiV4
+JvpxC406_jA
+JxR5EZ_GY1o
+JxUqtbpjpqg
+JxWjxAqCrCQ
+K-FG2oWl-2k
+K1hPyYEMp3s
+K2POkUf2EUQ
+K2QS3ZvjPMM
+K2bh3ZJOFnI
+K2pUtcVSXEo
+K3c6AMXsam8
+KAsvI2qAzlc
+KBQRetXolA8
+KCcenWMXQQ8
+KD6bh5ZfS2k
+KFDlZxR4yG4
+KIiKNpySv6w
+KK5120o36GM
+KNjgy1o65SA
+KP3ToVHnOZY
+KQQbMdsFsdQ
+KQkK8ThNsOY
+KTxhy419vto
+KZvz22uAVM4
+K_vUbBQzFjw
+KcdTad5fztE
+Kd9RcLW7knw
+Kgs0QbthCEU
+Khnx2cNTiu0
+Ki9AHohJpIc
+KihglmOX7j0
+KjMe0TXyQzM
+KjWVlz6cAyY
+KlC5HJFI40M
+KlcZwSsceL0
+KmmMz3NfoU4
+KoqaUANGvpA
+KqygzQmEuhA
+Kr3OvLakOgk
+Kr4Xe-BghFU
+KrEzYXCjJDQ
+KtHgPTAdfYM
+KtVfCgNJdeg
+KuqcVxRqQHI
+KxbRf-ZlSSk
+KxcNu6WZSCY
+KyBD2AeGXIE
+KyDTB0i_wQ0
+KzVEGlNNXuI
+KzVzRN5XoJs
+Kz_wvavZp6c
+L-bQmqP82oU
+L0LZDwNDqRQ
+L13WvYq8G68
+L37uYJnDxVY
+L5M93LmOmqw
+L6SiuhBZWDk
+LCsx9rVQjwE
+LDXrXC0cPBo
+LFBC0d4i3jE
+LG8TfyiQr7w
+LGumhl8-kiY
+LH0IOQrB-NU
+LH1WrGpM7p8
+LHFY1Vg97AQ
+LHNNYQ57V2c
+LJJgAqdxBdA
+LL-SFPRBBFw
+LObC_A4G-_c
+LR7ZZ5gw984
+LRXrEiBeOXo
+LTZPvLi3Hdc
+LUhyxjNub8I
+LUolzAltwKI
+LVYVlHr2FKU
+LXAcG9mITz0
+Lawz1Mc16Vo
+LbQo2tKGxV4
+LbR6jjbgbis
+Lbx1F4V_-8Q
+LgT5uYkkPE4
+Lgo_bkq9SWU
+LhEzvegA-Dg
+LhF_sfDfBd4
+LiCmCW6EGzA
+LiWszVY2lW8
+Li_m7BVja44
+Liv3vT9dGhU
+LkCrGs-XeYI
+Llq5mmhLy8s
+LosFY2otV8E
+Lq2icL5Y_FM
+Lv3WAxgaZqU
+Lx48775nwWs
+LylMvKFdwJU
+M5axFzT2_u0
+M5zDCmXSejU
+M6ZdYNFo6gM
+M76UHFsQp2U
+M8SkT5nE-0s
+M8ayDH5DuJA
+M9Xsci4JUy4
+MCQ3H2jfCBs
+MDyjY3uiWp0
+MFcNY_CyXk4
+MGGCuJdo8xs
+MJnxzdG5QwM
+MLUJVpk0BM8
+MLay5YHp48w
+MMoBEZ_d2g8
+MMwKA-Ku1mM
+MSr3hE3z2jw
+MT6uJni993A
+MTBfv2io-pQ
+MVI_yEqA2RE
+Makb_p6HcxE
+Me95iJdHO18
+MeM6r8Nj8G0
+MfS4oDLBpp0
+MgLHVw0tUBg
+Mh3Dvs7DwdM
+MimJZypAiy8
+Misd5Qrx_CI
+Ml1I6WEYSAY
+MmfiHdQ4Wfs
+MousuD_jX24
+MpGxOR50sn4
+Mrcn4Q50j5s
+Mro6gxnOfus
+Ms9K0eZLnFQ
+MsA2OJiYApw
+MsBU3uGpUGw
+N15YpxEHjVs
+N1cJakFhjNo
+N4t59MjWdsg
+N5esEarb5MQ
+N8XGCs0js30
+N9JXjNCR1EM
+NDsl1_vHHTc
+NISKpzp_QAM
+NKJfBsc5kHk
+NKKaWmoFdZA
+NL6zaIPoU-Y
+NMWqBL_Uhr4
+NM_SJwSMRT8
+NMpf6HNYIzc
+NNK4pvyOhAU
+NNp4yt9dHns
+NNwM_OMLa10
+NOYFz4DIfh8
+NSZvUu8Q8ZQ
+NUKIA9I6gRA
+N_kIz8R84jU
+Nbn1NJjbqoc
+Nc-HP2vyKoE
+Nchrj-dzVgs
+NdKNfaWpCj0
+Ndnq9Ofs2eA
+NhOwdlKHcAE
+NhfsI4jbWgk
+NjN07qsdh0w
+Nk6N7ieiphs
+Nm2nt6dxVv0
+NmBcsOMtKqM
+NnPyqGW2w2Y
+NoIHCm4wrpk
+NoZ7ujJhb3k
+Nom2-9WmsWU
+NpskXvrCNA0
+NysbSdox6zM
+O0D0E42AA4I
+O0hP4Hrek4s
+O1UBCHWPIqE
+O1d4dHwSZqs
+O1j79d0IuhU
+O2AnvMTbDCw
+O2YHi_g2JuY
+O3iPo__LYZQ
+O3jxAN3j_P4
+O3sFnc87STU
+O42jXkaQtQU
+O7dLxPQvIkI
+O8G_glgmBcA
+O8gZs9BCr4Y
+OBlw3eBxHvA
+OCQd02hORJQ
+ODx9C7kHmWs
+OGhvz1fwacA
+OGzPe8LXHeA
+OHHqegRBDWg
+OI4sCSsyS4s
+OLbdw1imnKY
+OLe9xNH4G6s
+OLfHpvJKNg0
+OMeIMC_s0GQ
+OMw8kl2kcZY
+ON3YD52Df5s
+ONvg9SbauMg
+OOJ7OYp1gjg
+OPzyc9rXx-I
+OR6qP-X5fcs
+ORObgzbi8Fc
+OSrCb8eDWWk
+OT_0JLIALxk
+OUE8rKxYGLY
+OUoQ4c1YyJM
+OWdqgZQdMgw
+OX0OARBqBp0
+OYJh1xJbQys
+OZWiY3xcr1w
+OcPW_rtcbio
+Odb1XiXpRTg
+Odw7NOmXeaU
+OgzTnofR1WY
+OhHBPhb8a1I
+OhsXUPDBX90
+Oj22JVi762E
+OmdRZ6ZM_pc
+OnPdfl9qKRg
+OpaFC283wJE
+Opx0cWUuaBk
+OqTI0KYkeJQ
+OsBgWxoAOf0
+OsLj4GyNvYQ
+Osp20p7mHLw
+OtTAH-lcO_M
+OuP7vIfN3GI
+OufFx-XmsLM
+OyNBayVulb4
+Oysv7K93B-w
+OzIUFdCRm4o
+Ozk5w3I6wlU
+P-9Z6WeojxU
+P2-SOq0SrmU
+P9M__yYbsZ4
+P9S15gsSywg
+PAFybFU4fzI
+PGQTyzsX7V8
+PJXZQrwDPdQ
+PJdbIALsMYQ
+PNJghHAUlLg
+PRgoisHRmUE
+PT0KwTjsMUo
+PTOZrIogdhs
+PTaeZggiMrM
+PUYoRT2EA5Q
+PVLxoTzL31U
+PXNdbOr8f9s
+PXf59S2kFag
+Pa-5lbobZpk
+PaEO7DFyzZY
+PiHMIYoV3OE
+PjlEL4poXaU
+Pnfh3Bxo4mE
+PnxpH92CJOU
+PpE-wfM5NhU
+PtV9uxQHnGY
+Puu_PXPw_H0
+Pw_4_f6PQno
+Pw_JdBkki_I
+PxVZGdKMmvc
+Pxg5mHeIoTA
+PxiRvHqbQoo
+PyNVMlDyNtg
+PzzNuCk-e0Y
+Q1MYDq7GgBc
+Q5iKAqZ9yVU
+Q71FNI--3vk
+Q9W5Lxr-7v4
+QEKRkIbCZEg
+QEz1caz0loM
+QFJZwvkJsGk
+QFbp5scBzys
+QGbL4LlrcIs
+QHeYQalwm8Q
+QIz-y9-jywM
+QJj7WiwcadU
+QKxRCmpAFKE
+QL2Hb-v1r0A
+QLunLzt4r4k
+QMcdgAriqy0
+QO7Jhl-r-BY
+QP2zJO0AtlA
+QPKKQnijnsM
+QTp2snIa-cU
+QWgbM9DoA7A
+QX1d52gmEZ4
+QZhzZtExnfc
+QaqyAdedM0Q
+Qd3DEZud65I
+QgscBSUsuNU
+Qh7rX2S4lFs
+Qi4P0mLkkLQ
+QkmGfY9iY9Y
+QkwYl6HJHdc
+Qlzca3efn6E
+Qmp2Z7wPR4Y
+QowV8kgwHX8
+QpXUjqjewGU
+QqsPM1rd688
+Qru-q3ykC48
+QspDZ-DYs0Y
+R0XyMhCdSkY
+R3QabWdSsxY
+R5XePwAO4m0
+R7ZX64ASSe8
+R9KOp8PKhpY
+R9P4_3GEjS8
+RAMbIz3Y2JA
+RCrTM0fHg-o
+RDeLBp_-3sM
+RDkMkH4drhc
+REUvXBK7ypQ
+REzffEzpiUM
+RFM-LCECtmk
+RPTk00TE3Ak
+RQXCRoVV9Hc
+RSJQ7iFntRA
+RVG1EXFbRb0
+RVQcpNgh6sI
+RVluy0cjHbA
+RWjDH6Od5xU
+RWoOC8KvHEA
+RXEUWvpmJaI
+RY8e2Ivu4Ak
+Raiw3nozIoc
+RaqHo26ohYs
+Rb2fCxGGcHE
+RdGOK7ZAHMc
+RdhwDd8PW0Y
+Rdi5ExhmqHM
+RdlMCh2idHI
+RfP1AjOOtSE
+RfmgkgzNhYU
+Rh5qNYU-_jI
+RhcnVBgKxEg
+Rhn0fatp9PI
+Ri-M_Vo3w5A
+RjzuefWqVY8
+Rm5MIya_48o
+RovZoquZGn4
+RsC_d7GnZtI
+RsX1lwPnPPQ
+RvjlBL4A_8U
+RweXbb_OzBU
+RyiqQnCd7qQ
+RyszHongpf0
+S-imxWoyMD0
+S0MWAAykFuc
+S1MV9j0dPAQ
+S4Fje5FUgfw
+S4nfcw632Oo
+S5OzDdLlsUI
+SBUwUUIVwHM
+SCjA7BJDnEM
+SD-LjOboaE0
+SDCICtm9zXQ
+SEDDSCUJxK8
+SFE7NNxfbM4
+SH80ZuySDW4
+SLaZEauuEU8
+SNBMdDaYhZA
+SNcjWH6ZhPI
+SOUME9xzIxk
+SSGEoCsFoH4
+SSfhcpWRUrM
+SSnj8kkmNDI
+STKRq8VXzjw
+STKb-ai6874
+STWrgFYmkL4
+SVR3ZmdAV-A
+SYriZ4xtdDM
+SaznCPVAiJc
+Sd3QWZ76IZg
+SfaKxqo1NfQ
+SfkSpHzhZb8
+Sg-HIZ3qgtk
+SjNRtrZjkfE
+SjbFjIeSCf0
+SlJh09MsJ7I
+SnwbCVxeEVU
+SoeKlf4DcSE
+SqRfNG6yLEk
+SrLPH5590RU
+SsIDNPoW7q8
+Su8Q8XMQzIM
+SuYeKcei7Zo
+SvWCVOGF6vs
+SzjmpNTVH6U
+Szx43_ah4ys
+T-CAP-ULW_A
+T0qjkPboQXY
+T4kVhUwJZdo
+T6P9TCdWE64
+T9X3YhUWsDc
+TAdw9R0ku2o
+TFPS-iX0L4s
+TGCJQR2BZhc
+TIgRgGQ2azQ
+TJXW9hpOlnI
+TMjlO7UUubU
+TNEWst_1m4s
+TOT4eRF4oCU
+TOzkrCUSlFg
+TP2wD0LX4T8
+TQ2NfO8grLs
+TQ7NqpFMbFs
+TQHxvPooKZc
+TQzJKL_4l44
+TS5BXMeG890
+TUVeg-x9Za4
+TX0v0pIVKUY
+TXjwmCoRmhM
+TckoXQoHj7c
+Tg2yHK2Hnag
+Tm1RbCh9YeA
+Tq-AsJ8M5yw
+Trq6vcUOeQE
+Ts04-23URYA
+Tu22Y0kIzJ0
+TuoGVwBkTEA
+TxC0dIBPzZg
+TzzhAQLRwT8
+U-kFZbOf6Nk
+U1VkpKvSn5w
+U2LvmqmOEZI
+U2eqvs_MZGg
+U4auzU8E2ms
+U5Ze75nn72M
+UA-U6m9O5OI
+UBVzrTJEbS0
+UDvCqeXCI-o
+UFXt6O5cxjw
+UGRFE3vTcRA
+UIV2gt1Jzno
+UIr0uGUWN6g
+UJeKNl461d8
+UJjixKei0ag
+UKBlZt_JL-I
+UKc54igtdXI
+UL42vZliknk
+UOGLLo60dzI
+UPRvrLsqN8U
+UPU_yi9Nv3M
+UQO5Qcl8aAk
+URC125wpMS4
+UT7n0WY_6Ww
+UTYuVdJskY0
+UUT_nhIlR_U
+UUoBfrbl1Cs
+UV1luTyPOiY
+UVTsXdLyIsk
+UVaN02jNqgM
+UXEr-xcC46E
+UXRKQjlYUZU
+UXZm1KDE_qA
+UbyxFZSZZ90
+UcwU4ghl9Gg
+UdpTELIW7bU
+UeMQ_al9lDQ
+UeO44STvnJw
+UjAxgnvxIJE
+UlKZ83REIkA
+Ul_ZfzfHRek
+UlsgLGCVKks
+UmA3FpFowh4
+UoVJllDh6rg
+UpScMViXT1s
+Ux-ExovbpsE
+V1frWTXGVN8
+V2DvDDzrFqI
+V2XB2l3aFvc
+V2x33UdHq4w
+V3nPA6doMBM
+V5FsNyk9rDo
+V5blsv5pn60
+V5uV1vR2M6g
+V6egGoCrbIo
+V7uG1tWVqHM
+VC7lRZTxDng
+VDLUQlOR_nI
+VF0hNn-Yfn0
+VFJsMQnqZZ4
+VFeOyT16oEI
+VFf_xowsMcQ
+VGOslZT-f-I
+VJjiUcZCkzw
+VKfVhH35RL4
+VKnNVP23toY
+VNtYIBI3BFc
+VOFQb8sVnxs
+VOomWcgrHis
+VRu86oG1hVY
+VSRuncwwJyQ
+VTsKS3ccd5M
+VVV3XeAevc8
+VWc3ezusNTg
+VWrYbF797LI
+VYE91m6Rli4
+VYKXayozHFU
+VZ0-gccLaNU
+VZki-LyHI0E
+V_lZ61OB0EI
+VcLWubfmJcM
+VcYUWSBo4i4
+Vcvl5piGlYg
+VePpQBCbKBw
+VfLSoyinXmo
+VfmcO0gQ8B4
+VfxI-6_LL54
+VgrW-fB3EXI
+Vm2mKdh_VUc
+VmPl5QClxYk
+VpKnSOS6NGo
+Vpw09qKQal0
+VsOismDNYjE
+VvXMMtldJU0
+Vw7hC4jpglg
+Vwxft_tqrcg
+Vy6PKQy2OI0
+VyJCLVAj-vE
+VzHAbg24fRs
+W-7t8Zho4AI
+W-_X1HZM7ys
+W0Mw5dfZoyg
+W0caVAMB7LQ
+W1tOTi9L7Hs
+W354uz6KPJE
+W3H5b9yZcOs
+W3XADagE6P8
+W42cARRDRe4
+W4PKfrQ2J5Q
+W8mQ41XaZvo
+W9okU7bAE-U
+WA1L8vXkSKQ
+WA9sMdQzdiA
+WAl7wIw5ReQ
+WBZ3das2pd0
+WDialA8RHEg
+WIiuU7Fd-KA
+WJBgKbe1ZyU
+WJFdyqLo-pM
+WJaW32ZTyKE
+WJlK8D9Vy1k
+WL5IJ4JVjrg
+WLVKRBitiY4
+WMl7MSsJMGI
+WP4jOqfpJx8
+WPIIgluc2vA
+WQ2QwrfZHTk
+WQ4W-UqaaMo
+WQXM3hU-7vk
+WQiGCGxYMm4
+WR4vkITJckY
+WT2aDvMsZF8
+WTxTntSi6Jw
+WUx-OqSLBak
+WUzV_mDcBmk
+WXKNBguF0a0
+WZQi3LOqiNc
+WZw5Vz_imPM
+WcLq2oABDhA
+WeHTlU8efMA
+WgWkOiDD4n4
+WiILKjGKveI
+WiNgIQ1pOlE
+Wj6MkLWm8dQ
+Wmt5FzPTuXE
+WnQCWE22AYs
+Wq5ZIu0UZmE
+Wqb2ZPi6KyI
+Wrb8tfqYPzw
+WsQ7ysVt-0A
+WssviVcyncg
+Wx9v_J34Fyo
+Wxp1QG7GIBs
+WysUzNnT4i0
+Wz6DxjLZSbc
+WzyI1gx_PHI
+X-fZ5fMfAv4
+X-irPLY6oLk
+X2ac5fHYZ3I
+X3bcvqrTTJI
+X3tI7SRuvhw
+X3ztiC2hUXM
+X7Jd6mI9iHM
+X7WI5_UUpys
+XAJMmm6sLQs
+XAhljQgieEs
+XAlAgC7rYug
+XBF-Pd_asag
+XEHxXWpF6qI
+XIXA-Kqb_Os
+XKCWJuWup08
+XL6m6Zl2ejc
+XLSJ7NIldNo
+XPgdPEYr648
+XPwYTyTbCag
+XQy5b5mlVZ4
+XUxgC9KYoPE
+XVtVVuER3W4
+XXKVwqZS1R8
+XXgJZXUPZXM
+XZ70Li8v68o
+XZxl6h-TxQ4
+XaGss5gNQEM
+XabAlfY1TyQ
+XaxSxkApUQc
+Xb32P_VWh7w
+XbDyn-2_xIc
+XbNcsgEX2jc
+XbtdjM8wGv4
+XcZaQG7wUpA
+Xdlo2HW5jwk
+Xf6xJiXV3AM
+XfMoY5WSlvc
+XkNofrvAt2s
+XkX-zJa_9Wo
+XlQzQ55BA0c
+XlYXYqwEeJw
+XoVZxH4CwMQ
+Xq-YIJA61W4
+XrGuIdJ0Ckg
+XvfpyrdeZ0s
+Xw2PEuhphPk
+Xwu5PSZfocc
+XxqpGCkrF8g
+XyvlWUQAkxM
+Y05wSSdslyI
+Y1QnypwCPZo
+Y2zmCgmDVqk
+Y3sbt-6ndq0
+Y65tbaBs88M
+Y6Mn-d9QMxY
+Y9anPwwkuU4
+Y9rkKtK1b44
+YC3Hhxqbof0
+YECIYrmXH0o
+YEgE-m8kzpM
+YGdNSqa_clI
+YGwLb86Uso8
+YIDhGzp8Z2s
+YJkS7QqeYPw
+YMoa5JpjEtM
+YMqKHYZHQD0
+YO-oamNqsA4
+YOD6jIFD5aw
+YP_35HEWcfc
+YUJa7_i3Cn8
+YUxb5mw96eQ
+YWpX7tlXLB0
+YXPEXwjYpGA
+YXfpaJVsVa4
+YYTcoNCWy8c
+YZBLFRe-G0A
+YZLGkfy0_oM
+YZME4lTWBHY
+YZp4fNMECVg
+YcGOV1iUBlE
+YeIn1rFbTuw
+Yex9CPZzezk
+Yf_jdF01azE
+YgCBzXR63l8
+YhoYie6_la0
+YiPEZWaxSXs
+YjTCns1fFmM
+YjwTPflZm70
+YlIl427ZdHc
+YnRgf2UNBXA
+Yo-CEXgHwkk
+YpmKbkU-X_U
+YtiVmh7UKOw
+YuLSfLZ9apM
+Yui2Msy8X1E
+YvRRXziXpMY
+YwqTTdUEJGc
+Z12xymgfH9k
+Z3r9tnbK_iU
+Z5SpCVtoOdw
+Z75PmkL_UaQ
+Z8iLMnTX6OQ
+ZAX4LKUeAVE
+ZAYK9cMiSbE
+ZBcOyv8LZ8s
+ZDfRJiIODMs
+ZDv9njERj0s
+ZEHCzjk0Hrk
+ZFnr_mTTL0E
+ZGzQUBDGd-g
+ZH6357PHRCc
+ZHLsDRxEMGs
+ZPHfpA_4uhY
+ZQPrD6cbqHQ
+ZQQwgYaQh1w
+ZSEV1ivNUWY
+ZSTo8stxfG0
+ZUlkRGgSLfI
+ZXMx5C9oFk8
+ZYSjPZUqLdk
+ZZIjgZYnP6Q
+Z_TNQsWm8SM
+ZcO5PLLXLy0
+Zg4zFxyeLMQ
+Zi2RoikofDs
+Zj7U7R-2fcs
+ZjmynEwugnk
+Zn8iU3-RNL0
+ZnpqfpPzQAE
+Zot-rd9KoYE
+ZoyUo4ZY-70
+ZpCK7G0LQAc
+ZpNQS6CX7FA
+ZpbEtTNDQBM
+ZptziXSxpx0
+ZsMJ8wO1YKs
+ZsY6kAHzdwM
+ZtXr7bckLyc
+ZvAaTKbPzH8
+ZyOqaUFxojc
+ZyUCcfMcmyE
+ZzVhwsetzDk
+_1B60L4m60M
+_4bzRoRn260
+_4r4LIX84TM
+_5O0_4kkKSg
+_5r3yqpaNuM
+_6AK5vp6nDE
+_6ck3EDlssw
+_7RMd0g9FtQ
+_A-pZH-S4jk
+_AvNT3vyzr0
+_BxhdmIGNys
+_CKFycGHzLo
+_F6bq0l18Ng
+_FcVS3nf3Qs
+_H71S7ar21o
+_ITluulB74Q
+_K-JEM9RNeA
+_KIjLRuf4NQ
+_LR1NtZOISg
+_OhqC6w4C28
+_QRahNSxQfc
+_SYvYBxt_Dg
+_UVuXcclWM8
+_WN-6t58HdM
+_Xo-Cy_okCE
+_ZZVgjbePRI
+_ZdfwflsC3U
+_aOcddJCrSQ
+_ak7BooJlnk
+_bBrj6QBPW0
+_cg2nSAfeUY
+_czQyx6JYgc
+_f2hDpHsQlg
+_gJN7I0a9XU
+_gR3p72XEMg
+_gm8Jp_4a5E
+_hE7CQmHPNI
+_hYhPQTHOMg
+_kUyGB9O_Hs
+_nwF3zYpfQ4
+_piu_0WXnOA
+_rWxLXmPdc0
+_sWO1yvsmmo
+_t61OZ5dN9Q
+_uV1LuL6yDQ
+_y9aHEd08LM
+_yKGZu_8gyI
+_zYWKwpARf0
+a3Ah-REb5vc
+a3mNe6zQS7Y
+a5CAVvzL2GA
+a5lUNW9wlpY
+a8-ySFmij_I
+aA1H4iLap44
+aB0Ku9J6jHg
+aDURTM8MwIE
+aGPO0OB2_Ak
+aJAQVletzdY
+aLPYGw9HubM
+aN5gtrkwgb4
+aQM3Hbvospw
+aSDRWUh7nqQ
+aSF8LfYNlSM
+aSp8Pf6zXvg
+aSuaW_0lLYY
+aTigt0mmyiM
+aValNp9ADko
+aVstEGR5jt0
+aVtOmrULE2w
+aZn-H9uoYk8
+abAitsu2-iA
+acWlXpQu83s
+ad7p5nSZbBg
+adXUEwBrrl0
+ae527aiePIg
+afqjRWOtCzM
+ajI8zq2SedY
+aluYFRd96Kk
+aqz-KE-bpKQ
+asFNxO3cXMY
+asyKEIv1pZw
+atd0dgRkl2k
+au5oFT_zkYU
+avITb-TZWOQ
+ays4JisCVAw
+azPa3uccg2I
+b-lZA6gHWg8
+b0dX-hS1mUY
+b3wy7s_QvOY
+b44IhiCuNw4
+b4jlwhPxTpo
+b552H9Gclo0
+b6DszEFrkSQ
+bBpQgQ416dY
+bCbKYU8D2QA
+bCqq2y9Ge-A
+bDjCnFvz11A
+bHeksIj7dpM
+bI0kawdOOXs
+bKqwRavOIVY
+bL2qA6Gt6Ko
+bMtTAQHLc6A
+bPnX7JwHQEg
+bQ5gacQY7y8
+bRh_KJ6yyGU
+bSySiLIcnzs
+bUAg8d3aQZU
+bWh4bNRqaK0
+b_PbouAwYIU
+babTVMWIosM
+bc15oBT5Nec
+be0sJh-pxOg
+bebrJ-o4d2E
+bg6f32E9vN0
+bhMSZQLSp2M
+bi3iqJykwEo
+bj77Ljq7kbA
+bjNvrJwE9Dc
+bk4b1P_IHZ8
+bkUDUnE6aP8
+bkZBYQ7mnrQ
+blCUScrxzdo
+bn5OSGM8oKc
+bqfA6lRrlaU
+bqfAc6X_uUY
+bsNztQbe4Ic
+bugT5lGfl9M
+bvCbqmHHQD8
+bzDZY35Vacw
+c0IrO0Rnlls
+c0XSC8b_9hU
+c1Fjrbt7peU
+c1H__IGGs8E
+c30s6OnCApw
+c41Mm_I_pBc
+c62KFa9jY8E
+c6gVQ2e1zEs
+c7baTdyHv8g
+cBXqfQzrsEY
+cCLOtOW3mIk
+cDJ_DVILBuk
+cFhONVldyE0
+cG_MFoGh_EM
+cH3DOOXFE24
+cIpXNRmZtVc
+cLMUQZr0TOg
+cLVQSb0l2Pw
+cLcmBWC4FBI
+cLofgTNATIQ
+cOtoozzRy8g
+cRiOZTv0dzs
+cSEqa8cLqss
+cSp0k6jEvY4
+cTFi0WEdtvo
+cUkZLttONGM
+cVezZhj7sC0
+cWZD_1CxIJw
+cWdvVQ7hIa8
+cXpb2qzMlJ8
+c_Ok8VutSbw
+caSPzISFhiM
+ca_kO_J5rWg
+cafMvJshI1c
+cbqshch5yAA
+ciV5Vi-3Adw
+cmA7Ze4KPd8
+cngWph0hiVo
+co-tsFhaeRQ
+coswCgQ_Zwg
+cpAUXdNPnQs
+crJ3KpoMxdk
+cvlLKmzsEBc
+cy9rx19dujU
+d0am7-lXaus
+d0npGPi3UPw
+d2b921R6Q7U
+d4gLpU_fqmE
+d9ShdxJRArw
+dBHgd6MNyyE
+dCmfHeROKIE
+dFxMzUUWLeM
+dGwu92AU4TU
+dHx4lFPqPiI
+dIFsZ9uTrpE
+dKfNJ2umK4o
+dMN2CeulM6o
+dMjan7yZ4gU
+dMxd2IrTF7w
+dMxrAgNbTRk
+dMyTRV4x50w
+dP4U1yI1WZ0
+dR4CGlYvT7g
+dSfIxFW_X1c
+dUKtnNKWV44
+dVeMgCpj-4c
+dZCCdLDE8gI
+d_CTrbVqWW0
+d_oCEre1NJ0
+daIt5KbxFNY
+daTulcsI2RU
+dbkgiCgapGc
+dd03Lj8KOjA
+dd9EimkKEZw
+dk2FVz_B5ls
+dnQnF9apUF8
+dnYUpixA8m0
+do2lClbRBmA
+dpHBJkblp2Y
+dt1MS1YmxyY
+dt1XiYVENuk
+dt9To7mMjl0
+dt9tqCG8neM
+dw6X5-D3tmY
+dwd_Z917XGI
+dxWIZ0ODlZE
+e0Q9GzQzBhE
+e1kPOlccTbM
+e3rjGvtVXeo
+e6MyukRaKOc
+e7jMlBiYy98
+e7z4uzyeRzA
+e82tT1ge9Ak
+e89XopekEJI
+eAQBVytibCc
+eCKD8fP2aao
+eDGFEV4CtB0
+eF7TdzaX_Jc
+eFUL4yP0vqo
+eFdjgis1E5I
+eGYkD0e411I
+eILksASDf9w
+eIXNlBbw54g
+eKl16S8gUXk
+eLX1KG3FnBg
+eNtpsXBrhk4
+eO8hhdMv7P4
+eOJ7XVQlnB8
+eObAEb5SBmg
+eQY39NuSXHU
+eQcmzGIKrzg
+eQjsCvqh1_Y
+eRU1XIwXCAE
+eRvOQ8MsyV0
+eSR9swuM61I
+eT5IGtWmQ-M
+eVXJaSgEf74
+eVz41Qm-5aU
+eX3PtJN52g8
+eXeWNtyjgPk
+eY5c3ce48Rs
+e_7T20B65hQ
+edRk4JmnLkM
+efIVRw-eVzU
+efooj84xalQ
+eg0bYuuHzkI
+ehMbGtUiYh4
+ehoNb1i6vY0
+ejrSSx2FMhQ
+eke0xp4rprM
+ekxjCWMq8Z0
+emqCYr2xVGw
+eoHw0W_CzX8
+eoN2Xrs6cxc
+eoWIjARiH6Y
+eowka-cMoxM
+epfVelQrqX8
+eppNMltmm-o
+err7vaMQ4z4
+esVois0s3qg
+esnMNkFAG1U
+estUu_XdB7k
+etnFlNJ-uys
+eub0JuAMSiI
+evORVtwC0zk
+evnEcyr9dIw
+ewQ3FI_u5nY
+exsaT4HrbhA
+exvJz3JPHuI
+eyBEIx7fCfY
+ezQjNJUSraY
+ezxm9w77JKI
+f5kTVt7QnK4
+f5mk7tpuEs4
+f5ud_4pbXPk
+f7HNmm8MHMI
+f9oet4cvYf4
+fAa51WqtOT0
+fAtWwadP6CY
+fDZ8fFOesDQ
+fJrNkspUMno
+fKHcxXH3f6k
+fLQ-pj8lteY
+fLc6CPIre_4
+fO1dl3Kzz40
+fREa-rfs8aU
+fSSzNXPVvYc
+fUQ-OTqjp9c
+fUzh2Hu4JoY
+fV1xRX3XGKk
+fWnDlFMfOA0
+fX_sxgqAKkg
+fXuqsnqsfGE
+fYHqUK0tqH8
+fZ8S5N2akUE
+f_LHuvCZLf0
+fb8d5RZqzA0
+fceagC5I1Yk
+fdANKfoxm6s
+ferxnIAT2Go
+ffVL9wvWfgs
+fgIfH_85nBw
+fgzKbL4HjCs
+fhXPsdrWJiA
+fhrwGiZrvio
+fjW_WL3KLUg
+fk3Cq0mR6_4
+fkJt1i4aO_8
+fkYqJFYFYXI
+fk_RkZortw8
+fnArfARx9wI
+fnW2Qy7uY9Y
+fo65xXZqfoA
+fsUkET-XhzM
+fu1ZVnk2Gqk
+fuRU10ocuds
+fuZ88PJ7Lsg
+fvtbdq3WiyU
+fxIAaegID9M
+fyFOwJRr3I0
+g0UDwv1vPgw
+g0cy2szFR4Y
+g3pERE3VsX8
+g4zf6FeVoxM
+g8DbS9Ochow
+gAjMBetApiA
+gB-ZlM-aUTE
+gB6vLAjfMdc
+gE9soy7r-Ms
+gFZ8NPmkPpg
+gFZj_2voVyA
+gGZq_aR_jyg
+gMZp4NeJq5w
+gOprZJ2CbR0
+gRO6JH35Cdc
+gROnZPk4dnk
+gTDprJn164c
+gTPaMD7qhek
+gTmG636eAYY
+gTzhAelL1go
+gV7XWdt72Vo
+gV8icTgcuyU
+gX0AEmOrKsY
+gX6r9BWwCOI
+gZyG2_WKQvc
+gaUt3gTwwzU
+gb13T2EYzvM
+gcsL5WUXoYU
+gdb3IbS2SGQ
+gksr08xM7dc
+gmKD-PSkATI
+gn_UhhjygOk
+gqxjTyxDC-Q
+gs2r0YmX7rk
+gvYPX1ja9gU
+gwysz5CtZLg
+gxYTI3KmFk0
+gzfKeagBy1k
+h-h1ceF9ToI
+h0Jgoc3tosQ
+h16CeeTSV2Y
+h3sdAzf02O8
+h5ZU5xukqyE
+h7HDtT_cOs4
+hB5cFvlTYWk
+hCc4qd-6h_Y
+hFj2tLk_poo
+hHHYqKqySy8
+hIkJ1Q69eJI
+hK4VPzJYU94
+hMmYNZgItKc
+hN4A9M4qXy8
+hNhlfyoVlLQ
+hO4aTFDeokA
+hOloWe3oYR4
+hPUs3kB39sA
+hPhJlvpNz2Y
+hPt5JG9kGUc
+hQyEUb5DS2o
+hQzC6_3S6xg
+hS-re7JJ8ts
+hS2fdP1bNV0
+hSUDE3-e5_0
+hUKXRN4RzpA
+hXtL9Oe2OTI
+halJEOj2RP4
+hasaLmOLJ8k
+hcl4vYMfvYQ
+hfMHcdPo7fk
+hihIrNRfnpQ
+hnAqGpkvYIM
+hoAGRdA2Zpo
+hpK6FZbxjgs
+hq-1BIaFjGc
+hqbW-EaOcUY
+hr5Px8dD_iw
+hrb-HL0UUpM
+hrhBvOz8ius
+hsfgj1OWX0Q
+htCPtKDBmKw
+htZAga_nSW8
+hv4NNsPaL44
+hw0VaaAhfjM
+hx0rraflBpg
+hxHjQZSANwI
+hxa7Kd8P-oM
+hykLNP-Cpsc
+i-vmCEgFW5g
+i0CP1dmr4ng
+i1-9WDuXJGY
+i17JKd_eCAU
+i1srQA7ocZs
+i2l72jO-otU
+i3DQ9UOs8nY
+i4jX18-l9pI
+i4oLqHRxoZ8
+i5HYNnHtcdY
+i5z06C9Mi00
+i9WbGqPeY8k
+iAWKPOyYdXo
+iCPxwkaCXLk
+iCd5W4gwJsI
+iDJVid5u0Mo
+iDWMIMCMNiQ
+iFJ2pJph-Do
+iIw0Hfm9h7E
+iJJxd-XGQMo
+iKAYJ-makHU
+iN0gsOIWQgw
+iNTSOGM5rCY
+iNhHBJkMz58
+iPDuJS2QUrc
+iSYvDl5iBFQ
+iT5-jVpP_TQ
+iTk4dutvfsE
+iVBD4cmI3_E
+iVbefrog1uI
+iVu7DNjFgWk
+iWRrLD7H98s
+iX1P6XAvtlM
+iXUwLs4kNvc
+iY2XmAG6Rl0
+i_QZJaPorpU
+ib2KHcsUS18
+id1m2sLzwrE
+id8Uq_339xM
+idAzpl13_sc
+ihpS0LTtWjw
+ijPyu-jUJ5Y
+ijlyytCZ_RI
+ikDYpq4uESA
+io3xwOBD4f0
+ioFH4kyF0kY
+ip0yFdCqbkY
+iqm-XEqpayc
+iqmkldsIxkI
+irKLUxa7jkA
+it0EYBBl5LI
+iuqe74uT7fw
+iy-uKpGZmrQ
+j-0L20FH2Zc
+j-nCvuPZoWo
+j2WVRXdVu_4
+j2k6EQOoRzU
+j3BVnXHPjyw
+j6FjjenQyko
+j7LhySzyh7k
+j7XDYxypFb0
+j921XOqiQ7o
+j9Q2J2MDctA
+jA4WY877mTk
+jBox8_bAoTM
+jHVnt6VROM8
+jLLVBH5ZZG0
+jLRyB9Hz7U4
+jNyPJ1LQmiA
+jOj_NSDXMz0
+jPxamG3f5TE
+jTaFavHwbiI
+jXqb7dbUBHs
+jYEIkpTjpwI
+jZa3FK_T_Z4
+j_k1TZLK2mQ
+jbzpyRmU3YA
+jcFgCdlsPjI
+jec9rBFqcwc
+jfcrY85C_-k
+jgT5X37XW3w
+jhUnM7jekNo
+jhaa9bIPuOY
+jpAig_5u60Y
+jv-dZw3ITJs
+jvRleA3lWBw
+jvr7UJI47UM
+jwBi_Q9ffys
+jwS75RtJJBc
+jzbqbu-0bxc
+k-pmfynqbko
+k5QJ8s3qUyA
+k6Hf0WGquNM
+k9Kb4Pwu3ok
+k9zYM3XqNps
+kBLs0CsEWEg
+kEQyvL6El-M
+kHBSmEfwZZs
+kLdPQfLv1g0
+kM8iZqP4GDk
+kPUqlB13Kn4
+kQLhEtSx0NA
+kYVd_HIkGw4
+kYXiegTXsEs
+kaBI3TvA7vA
+kdIG_ZsrdUs
+kdQPXCcV2K0
+kdo0b2eiXVI
+kfssQHnlHjM
+khm7702744s
+kl5eVK9kY_c
+klh8JFoo4Og
+kmTRmjeoKYA
+kp33ZprO0Ck
+kqFzRSwX3Fs
+kqsmCUo3xEQ
+kr5Ot5x8mmc
+kvOzjIRylsM
+kzgohxtf3hg
+l1wHRYkFhZg
+l29q7OxKMmM
+l5bkGI-CGFI
+l5hgKlBq2aE
+l5tn1sJ0fKg
+l6iQSCD9UeI
+l7OIqpr56NY
+l82uiWKCgF0
+lAzL_47Wkbc
+lB98olcaEm0
+lClz0KqwrFU
+lDJEC7OX6AE
+lHakoK0j9qA
+lIRpgkLK7r0
+lJYwN2X7IbI
+lK44daVVh2w
+lM1Zmq0Lcd4
+lOwZ3mwlenw
+lSGfxPD4nR0
+lZZgjJz-SEw
+ldiJbsXZgvg
+le3cBRlWSE8
+lfrBVdmeUE0
+lgWY-gjEYW0
+lh_ZUt-O34Y
+li7HdK2vIlw
+liKbDWr0RzI
+liVQZnSAv28
+llj7LzTULog
+losLXh9sDc8
+lpShGGv6iJI
+lqbSxAvUNpc
+ltPAsp71rmI
+lxBYJOEbU_g
+lxE9lqRrcuo
+lxQjwbUiM9w
+lz1RY_CNWeU
+lza-61vEfcw
+lzniYTLmFkc
+m03Bn3jgP48
+m0EwquC6wBU
+m0svbKxpLhI
+m0ttpozYW14
+m1lhGqNCZlA
+m3ZogsjvmLA
+m3q4bKFCkio
+m3q7itlxq14
+m4yEAVUxsFM
+m5_41I_BLc8
+m9qPsbp-l00
+m9y0Kt9UFYc
+mAUNqn8QAFU
+mCghaYzVDxw
+mD24h-bbdMU
+mEP7QCH2oYc
+mETDq4gwIkU
+mFRkTyvuamE
+mFu85lirRfw
+mLTM_vEz1jU
+mMiiyP2W-qQ
+mQu-g8K7qtc
+mQyL3LgJwXA
+mRzqtElhGsY
+mXBHUYIXQAU
+mXNd6zUVwP0
+m_D83abmA4Q
+mbKfbtQvUgQ
+mcZfJHqf4KY
+meka3_pUVqc
+mexPA0ocnnQ
+mf8vuUJoNaI
+mfuwTh61hMA
+mfzWMIEVrFk
+mg7CQGLjRzE
+mhZJxmbSSt4
+mhxEcl-85pM
+miGwjQa0txo
+mjI_IyzUye0
+mp_aq18zg38
+mqSucLmIFeg
+mtXfzd53wRQ
+mu1XN7ABANM
+mw4GQ5jhL6g
+myDLLgryIs8
+myGaZRw-0oU
+myaOG-5N12M
+n-EpKQ6xIJs
+n-Ptyqo4lE0
+n-mmxFdva4w
+n1WR4pPT1Jo
+n7YJE14tLLQ
+n7Z6KyTeouY
+n9Gl1hBgjIU
+nAyLvwIcnRc
+nHEr8-LtfUo
+nHU621TjCi0
+nISa7ahsQVY
+nMIE9IKTV6Y
+nN4fDhAcGTM
+nNgIi4eJduY
+nNuk96Rw2Ks
+nQeMM3bqM1M
+nQgkkH-o-aE
+nQmB8u7aBZs
+nQpuGwWyFQ0
+nSsQFG8QzcE
+nWo--dSu9bs
+nXcU8x_xK18
+nZQ_jnOVFeU
+n__HtO28X8M
+naFR4znnS_0
+ngdwHwnLCt8
+niKRw8zeYa8
+nihu7deKG9E
+nj-YK3JJCIU
+nkCgvEXj7FQ
+nky_VSE43Hs
+noYQhtNG-RA
+nor0P6jwoeg
+nsqOggafySQ
+nsrEmJ_19v4
+nu3p9JQ4ykM
+nujU66hnemE
+nvGGjCIIrm0
+nwyoo3oYYpE
+nxk8alj07sY
+nyes4M2CbtI
+nyxYzQVqo20
+nyyS0FSztKc
+o0lvhDX6DXc
+o1zo4BYFdtU
+o3z2CwnY_5M
+o4R1-TLkxBs
+o5q1ne_uNE0
+o5ufkJMr24c
+oBRZ8sD-OjE
+oCM304tbwcM
+oKC8ikzEFIA
+oMbvC_siQyc
+oNGKZrLIO0U
+oOdNHtF_s5o
+oRvkMz0FXtw
+oSBD9lej7oA
+oWVfSS5m5LY
+oWqLRywhozk
+oYKlQPSZr9A
+o_EvSqIz1EE
+oarLQY6VJzs
+oe45d8WFc20
+oe4VZ1YDL9Y
+oeFSoHWiLBI
+ofWA7ERRwzs
+ogOqpmvv6eU
+ol3tAxnNccY
+om1FNUWg4yo
+oo5gIrWn9tI
+oonc4u-Adbc
+oun9ZWMYYVQ
+owZYdzNJSUo
+owqMWjBPCW4
+oxQHiGgFQ2A
+oxsRS9fialY
+p-xcYP1ef6M
+p1lHlsF1LeM
+p2tdlkmZ8D4
+p4Q2mwsvlpQ
+p9ijo4bNOJI
+pBoDJqyONC4
+pD8z8Dior4A
+pD9cAOBZDj4
+pDySXCjwXJg
+pFXjD9J-JE0
+pHgSyR3qUas
+pI1feWHeUq4
+pIeCH_unCd4
+pJznoC1fZ_U
+pKu6GJdrKOM
+pL3u5ztU9kA
+pOBuggPY9_c
+pQ4C6LaLV48
+pQMiwYP8mq0
+pQcXM9HJXjs
+pRm187D8-MQ
+pSeYDrDjV48
+pUayltZ4PRk
+pXg1P_wX79Q
+pZ3OMf01xJs
+pb85Bw9sHY0
+pcOdWPkjUzY
+peYcx-xDiwg
+pfC0yuJtMc0
+phnXyHYku8k
+pidRKmq2iGI
+pk5WuA39qzc
+pkSnKgRHjU8
+pnmOhTbhKhc
+po-Dx2CgFGs
+pp71BH0UlvE
+ppPbjVv2XeQ
+pru-95YczT4
+psfAc33Ok94
+puDuigN0_I8
+puZZoKkv2Nk
+pujqknUWycM
+px-H1yOAQUo
+pzIQGk5nWDw
+pzev9CB3vuM
+q-abd-Tjf_g
+q0whherUIeo
+q1NTjsXYpK0
+q1SYOXJWXNg
+q2S5mFVGVec
+q3kU8MyAHhE
+q3zhx8M7mgI
+q5uELNh_Fro
+q6irtmJxniE
+q7Pfap5IRMc
+q7qZd-5PQec
+q8G-VO3gy5o
+q8PoT8bhElc
+qDzY7_qAP9E
+qEHexC4KaLA
+qGl8JMXmwL0
+qH8uMr3qb_M
+qKMd70Jr-ws
+qKnrUrYLlj0
+qLciglFWSBY
+qLxgp5VIx4A
+qLyCi_ARgfM
+qMoypa93-_M
+qMroyKN05zQ
+qOxbcVscVuQ
+qRiBE2DuGA4
+qTPgNaT-IyY
+qUbu1hTOKQQ
+qWd6VN9sseU
+qXBgNj7d2aI
+qYW_VBZoTsM
+qZdAl2dAL7k
+q_EJdzfnPSg
+qatNSkb_O3Y
+qav1y7G15JQ
+qb4Kx1MoxGQ
+qbaocv8MUJI
+qckQShckJI0
+qfKmOf3d0fc
+qj63Fyah8Jw
+qmwgjonregk
+qnQhGZNOIzE
+qnRbz82xzWk
+qnboyP15mi8
+qo_ZzjhbWLk
+qrmmPQC6o4I
+qsOjHdZtUM4
+qwMVYILJ7bc
+qxabzkWQ744
+r0i10UvjWBY
+r28QlaYxRDY
+r2REC5k-2AE
+r2vFxIWtQ-E
+r341ehyHft4
+r40wnEAEozQ
+r4yqeuWlJqM
+r59GRI81jQU
+r6xt8HZy1-k
+r7vzgexzXOk
+rA2g4mWai58
+rA4jmcefefQ
+rAnvXFlvV3M
+rEJBmcJlEy0
+rFEg1iwdHBM
+rFOs4b_WTYY
+rGBU08IrdrI
+rGqAwF16IOk
+rHZlSM-m_Xo
+rHgcNbOwnas
+rI8ccs3k0kE
+rJXB9VhK1eM
+rJc0Pf2B7Mg
+rKCT1mmxfb0
+rKvnvavLhKg
+rLUoN3LY6Fk
+rLnNaewOAbY
+rNOyRIU9WUk
+rNqiDGKZiZE
+rPa_Q-v7iPU
+rRexKBeW-LY
+rRsjDRqRvgA
+rS8oFARouuA
+rTgVxBHEKJY
+rWMEyWjaDug
+rZ5iulZ3rvk
+rc744Z9IjhY
+rcdFlgadpYQ
+rfzHohoFoZc
+rgj6AQ0KiXI
+rhFDM4YtAjM
+rizm_pzU7xo
+rjjFQ6onAJM
+rjlNd5leBDY
+rl4qZtDSLCs
+rl6K4-2q4JI
+rnzWJv_L13Y
+rpC49dRxPkM
+rqqgH8fMsOs
+rrzJQMo6QQ4
+ruAnF4N_8y4
+rw0AUfItOss
+ryJxN-ZRAGk
+rztwo4iUnqo
+s1QCigF6JOg
+s2Ya_pPoKQw
+s322l02OzWM
+s4b9dN-X32Q
+s5eZtpsX2Uw
+s6LlkCsQSq0
+s8i4K8MvQjg
+s9aNVQP9Bi8
+s9zkSyuL8eQ
+sCk_3RNWdZI
+sEnf_XAMuso
+sGvmS07G46I
+sIesHZHrHck
+sJIvQ5imN20
+sK-QPNTRj24
+sKFiaAcqTMY
+sKRcy61BFfM
+sNHDCXNLLlw
+sS7O_qMN_pM
+sSrSorPjUvU
+sTKAL3OfgSQ
+s_A2wZB7ZfA
+s_gSNpE1Z5Y
+sb7TL-252e0
+sboMcBqlShY
+sdeLg0on1Go
+sdnbXzi5yDQ
+sebF4SCrhgY
+sgGPHfZ6O00
+sgRTFPJ1MdE
+shkbyqqkkTA
+si4zS_Jx_uY
+skyd1JiJiYs
+slAdtDXiNgI
+sp4yAT08rqw
+sphx939ru4U
+spi7TCSQqns
+sqseuym5HKI
+ssHyZRRz6ek
+stqdcOSqbnI
+suK34prc56o
+swP3fNDUD0M
+sy-JNEKRe6w
+t-N3In2rLI4
+t1hyNFQq1Qc
+t1ok0e9gTRo
+t3SPY13b64M
+t43zNbKooJs
+t4xtn8b1Nk0
+t7caALtdVBM
+t7f-iQf3PRA
+t9nyOewqrU0
+tAbhaguKARw
+tBvd7OSDGgQ
+tC0vzGbKXWs
+tCskNV4zvIA
+tDNEWvgVQU8
+tE9_u7slTrc
+tFHxf3tCRr4
+tGQgcHMIq1g
+tHGLoklGaE0
+tIjr1ZMPEkc
+tMTZDgGaWqE
+tPX6j_22umA
+tQRaArBTpu0
+tRBf2g6XqCU
+tSUX6HReL6o
+tSf72G2vTU8
+tTyFNcuR8dw
+tUPzZFsH9NM
+tXwwOXxj3Wk
+tYS3l-lh39Q
+tZMO74qrrJk
+tZZyNqO9bQ0
+tZzyiJC7pqs
+teIlbfF2SE0
+tgKPxj4-0QE
+thhRZXvb0fE
+tj46eMPD8jw
+tj79bCk8Y7E
+tlPIi5sLOoI
+tlSl7JSQ35w
+tm50929wbaA
+tmAx6umELmU
+tnX_ibkICWk
+tng9I_Yfhg8
+trGNjNAxMMA
+tssS0XyUIi0
+tsxCu5Fom2A
+tu3g6qTyWiU
+tvsBIX44mQM
+tygH_zAKRvc
+u2gB0EGQxxA
+u5TYUHIGqzA
+u5eti_-WK9g
+u61kdTMLJTQ
+u7bLXr675gw
+u8B3lvaw9tw
+u9QPtZEhH3o
+uAWBQI7PVKA
+uAkCPrihEzQ
+uAsqM8rL8GA
+uB9UEkO77MU
+uCQMZfnM-a4
+uGUee74VHAc
+uHmQzfPWBG0
+uK0u0BHCF3s
+uKbvq15Pjw8
+uLycRSuHSAs
+uM9zA0QAvhA
+uOWT7-wqKS4
+uQ2I4dZh9O4
+uQEUqgEVQvU
+uQz8Y_jYzy0
+uSs1HeOShlY
+uTUY9W-eDZk
+uW6e50NYlWE
+uZEfW4QJot0
+ufnnbfJSZ3M
+ulMrIHCeLQU
+un5l34RRYR4
+uomUPcl17N8
+upw774ewNjg
+uqCecnqjL1g
+uqav8KWIBPM
+uvNun-YAzwk
+uwJtuNmBb90
+uxBzmTylbwg
+uxQbWbWVWug
+uxy82zQnQpo
+v-2yFMzxqwU
+v1DXBLkiqFM
+v1RFw8FpEJQ
+v1wFRiV9_v4
+v2DjHVhnjQI
+v2_u0yV5ua0
+v5vIyfkumj0
+v62wDFbK7iM
+v7-eoymQp3o
+v79U6LcewMc
+v7iWP-4rYg4
+v8R4spPNofA
+vC-uE8RAzbo
+vC9i_jB5h1k
+vKHi1XtBPQE
+vLb2GXb-bsE
+vM39qhXle4g
+vO87PpdQKV4
+vOFoltimf5o
+vQtLms02PFM
+vT4z4TCf4uk
+vTKx76zG23U
+vXPaPV2Ii-Y
+v_bumG-g0Dc
+vblot9RiULI
+vcL75ICzL0w
+vf4CrrsnVWs
+vhZqXOG9Sfo
+vi9kXfl4Nzo
+vjjNnix_rGQ
+vlddfDNoCS8
+vmGW-kulgy8
+vmhYHtGvqCI
+vmtEmebxSwo
+vpAyvSQdqvU
+vqqv18miTHs
+vsrK3zDsT4c
+vt6evh25rpY
+vtBXhiKZjWk
+vuHbp02YfKA
+vuThpe-Rgxs
+vueMpDhosM8
+vviufcserks
+vwuBZEo9eAE
+vyUiZfxuXXk
+vyWMhZ-NiHg
+vyrHk4_-epU
+w09mGrJy_h8
+w0ztlIAYTCU
+w4ZSe355UKY
+w4fzfZkZ2Ng
+w7zqBnrLxiw
+w9V5LKdTPRY
+w9rXlL5iJNk
+wAaHY_SXdgY
+wB1IOUjeiXU
+wB5lKhN4uoY
+wD-jLNmRVfw
+wDT446J2XME
+wF17z_DX7DA
+wFS3088FEUk
+wFkig7Vh8YU
+wHJB9u61btI
+wJKn7xElol8
+wKnbAj-2Hnw
+wLvCU_o362k
+wN-xfvQKBgc
+wNcTvkR4u8s
+wRQGTMig_P8
+wSDR6ymotHk
+wSO5y2Vq8Ss
+wTHo7j8Ruww
+wTYpjp9Bzrs
+wUnPxBv2h1E
+wV7TB-8PyhU
+wVnhS8BuE4U
+w_Ms7re3lYU
+waHHFLAet8A
+wbokl5PhXEg
+wfCZYvqVFKM
+whbyuy2nHBg
+wkYeYfLt8P4
+wm9QvbgryRg
+wmKPIm2HuXM
+wnH9X7Y8-fw
+wonfEPTNA0w
+wrtkQSU_Ekc
+wrtrwY2caSI
+wwBfqXf-1Xg
+wxNELCR-TMk
+wyM-7ra8uY4
+wz3TSWrZVjM
+x-YG6j9ZYsY
+x-ty9x6-P1o
+x1HMmnAI5sc
+x2Sh6TOWjkM
+x47V2e-4NmY
+x7jZG8cbqBI
+x8LIzP4vBrQ
+x8jdx-lf2Dw
+x99gUDpak34
+xCDQAlQfzNs
+xD1mRhx68eE
+xFFs9UgOAlE
+xGoCF5bk56A
+xITRYnvl0oc
+xIyfkI0jJMs
+xJUlD1q4bRw
+xL32w7DMwHM
+xLSzvPQpo5A
+xMhd22uQusw
+xSaGl8fiiYk
+xVuZZS93qtc
+xWrJD9-1o44
+xZ0Xdew--H4
+x_0F7lFtPRY
+xe21-4hVA20
+xfXkzZ62y8g
+xgpgHZ9iTMo
+xjwSr2qv1cY
+xkf9Gk1QuYo
+xmXiws6G8Mc
+xnLoToJVQH4
+xog_GEISGDE
+xokbf-np33Q
+xt7bRz9dX4A
+xuVTCXSW6WM
+xz7Ilex4McU
+y0jwPOra9Jg
+y2tCX54cDx4
+y5yxjABbdVQ
+y6Dh74INR_I
+y6Ge4RZsq4I
+y6zsW14YDBY
+y7bKFMO0d-w
+y9L5dA7pQxk
+yAe2kgW47G0
+yBEfb0XnMec
+yF8dBlsLn6k
+yHaY_d3QQR8
+yLkqI2UiZJU
+yMSj47Xs2_g
+yPqieOw41gc
+yRV2bb-tRMk
+yV6ZXGQDFfo
+yWJDUISGUp4
+yX5g1OHKE-E
+yXRXk5ubwW8
+yXhj9XnXpbI
+yXrJCyJBugg
+yY9QkbpDLaE
+yZ6vSn7PaPI
+y_-a5D-1_mE
+ya8cSGAiEyw
+yc0wPSjL8rc
+yc3l8Xsk4k4
+yeVl67bOTD4
+yew316-hSiM
+yg49p-i94rM
+yjD-_z6P75s
+yn_0j3hDnPI
+yqCnnSfXrRg
+yrqpVREbHfs
+ysCHqf1HfEQ
+ysogiWIjy10
+yu2iuSrpjHo
+yu6gBfUQFNI
+yuGOVGNkjAA
+yuNg5JugqrE
+yumyfeToWyg
+yvGUNb80FTI
+yvWV56LXTh8
+yvcUoZqSvbk
+yw--rqjGfI8
+ywljr9RKExQ
+yxPk1s6kjls
+yxuTqbqzzpM
+yz4H6fx15Uk
+z-KziTO_5so
+z2-YMvHXvkk
+z6kC2r9UrF8
+z83CZnYt72o
+z8X4DB-ecWM
+z8gNQpsYIco
+z93Kr9LmMCo
+zArkP0YCXw0
+zESeeaFDVSw
+zF-pekRLJ3M
+zF9VM-voMFQ
+zJIKjeOLOFw
+zJahR4BmskA
+zJsr4Tie-tk
+zJuuBn8mTP8
+zKkJB8Etq54
+zKwgR9IBovs
+zL_d7b1bsUE
+zLaE0SmZdMo
+zQpGf1gPY7M
+zUDEten_j9o
+zUkN2ILNflA
+zYYWeXjQn8M
+zZb8HVaO4Nc
+z_cAYz0Q5DI
+zb8B-vxmjTg
+zdIcTkGHFac
+zeO8n9byqg0
+zesVY6sbjoM
+zgs82z9t7Mw
+zhESYHHbzsc
+zhI__xPhoW4
+zhzlGFKXV2g
+zjnOeTL7Bz8
+zlHe6erNcQA
+zmn50KVgQak
+zotrHY6f1-4
+zpRuDsk_0vk
+zqFEtAUGItg
+zqVaXf-E8us
+zucFlVCiV-g
+zuk7TtRuR-U
+zwGcVuZCM4Y
+zxN1HYq4iEg
+zygTUJZ93dE
+zzQoGtjkKuU
+zzpYxy7lveg
\ No newline at end of file
diff --git a/TalkingHead-1KH/data_list/train_video_tubes.txt b/TalkingHead-1KH/data_list/train_video_tubes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e5631594b52d1d897848e7b59722cd6989c6c8b8
--- /dev/null
+++ b/TalkingHead-1KH/data_list/train_video_tubes.txt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:86b9a6514dc4abdc9484c5fa93a3ad5db3c5ecb965181e92d16b9fd95ff8222a
+size 29153463
diff --git a/TalkingHead-1KH/data_list/val_video_ids.txt b/TalkingHead-1KH/data_list/val_video_ids.txt
new file mode 100644
index 0000000000000000000000000000000000000000..949f833ec8ece0df9706564e96ed47fe8d819de7
--- /dev/null
+++ b/TalkingHead-1KH/data_list/val_video_ids.txt
@@ -0,0 +1,28 @@
+1lSejjfNHpw
+2Xu56MEC91w
+3y6Vjr45I34
+4hQi42Q9mcY
+5crEV5DbRyc
+85UEFVcmIjI
+A2800grpOzU
+c1DRo3tPDG4
+d_7s4huYOD4
+EGGsK7po68c
+eKFlMKp9Gs0
+EWKJprUrnPE
+gp4fg9PWuhM
+HBlkinewdHM
+jpCrKYWjYD8
+jxi_Cjc8T1w
+kMXhWN71Ar0
+m2ZmZflLryo
+npEcenV-Y08
+nUe3F8jYoZo
+NXpWIephX1o
+PAaWZTFRP9Q
+SmtJ5Cy4jCM
+SU8NSkuBkb0
+VkKnOEQlwl4
+WigprYZPaLc
+YsrzvkG5_KI
+Zel-zag38mQ
\ No newline at end of file
diff --git a/TalkingHead-1KH/data_list/val_video_tubes.txt b/TalkingHead-1KH/data_list/val_video_tubes.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8a13a74f738732dd82e43993a92afb1ba61061e2
--- /dev/null
+++ b/TalkingHead-1KH/data_list/val_video_tubes.txt
@@ -0,0 +1,38 @@
+1lSejjfNHpw_0075, 1080, 1920, 0, 728, 671, 47, 1471, 847
+1lSejjfNHpw_0075, 1080, 1920, 728, 1456, 671, 47, 1471, 847
+2Xu56MEC91w_0046, 1080, 1920, 80, 1105, 586, 86, 1314, 814
+3y6Vjr45I34_0004, 1080, 1920, 287, 1254, 568, 0, 1464, 896
+4hQi42Q9mcY_0002, 1080, 1920, 0, 605, 443, 0, 1515, 992
+4hQi42Q9mcY_0002, 1080, 1920, 605, 1209, 443, 0, 1515, 992
+5crEV5DbRyc_0009, 1080, 1920, 208, 1152, 1058, 102, 1712, 756
+85UEFVcmIjI_0014, 1080, 1920, 92, 627, 558, 134, 1294, 870
+85UEFVcmIjI_0014, 1080, 1920, 627, 1162, 558, 134, 1294, 870
+A2800grpOzU_0002, 1080, 1920, 812, 1407, 227, 7, 1139, 919
+EGGsK7po68c_0007, 1080, 1920, 0, 1024, 786, 50, 1598, 862
+EWKJprUrnPE_0005, 1080, 1920, 0, 1024, 84, 168, 702, 786
+HBlkinewdHM_0000, 1080, 1920, 319, 1344, 807, 149, 1347, 689
+NXpWIephX1o_0031, 1080, 1920, 0, 632, 357, 0, 1493, 1072
+NXpWIephX1o_0031, 1080, 1920, 632, 1264, 357, 0, 1493, 1072
+PAaWZTFRP9Q_0001, 1080, 1920, 0, 672, 624, 42, 1376, 794
+PAaWZTFRP9Q_0001, 1080, 1920, 926, 1425, 696, 101, 1464, 869
+SU8NSkuBkb0_0015, 1080, 1920, 826, 1397, 347, 69, 1099, 821
+SmtJ5Cy4jCM_0006, 1080, 1920, 0, 523, 524, 50, 1388, 914
+SmtJ5Cy4jCM_0006, 1080, 1920, 546, 1134, 477, 42, 1357, 922
+VkKnOEQlwl4_0010, 1080, 1920, 98, 818, 821, 22, 1733, 934
+VkKnOEQlwl4_0010, 1080, 1920, 818, 1537, 821, 22, 1733, 934
+WigprYZPaLc_0002, 1080, 1920, 234, 877, 802, 25, 1490, 713
+WigprYZPaLc_0002, 1080, 1920, 877, 1519, 802, 25, 1490, 713
+YsrzvkG5_KI_0018, 1080, 1920, 36, 1061, 591, 100, 1055, 564
+Zel-zag38mQ_0001, 1080, 1920, 0, 733, 591, 12, 1439, 860
+Zel-zag38mQ_0001, 1080, 1920, 733, 1466, 591, 12, 1439, 860
+c1DRo3tPDG4_0010, 1080, 1920, 0, 865, 432, 33, 1264, 865
+c1DRo3tPDG4_0010, 1080, 1920, 865, 1730, 432, 33, 1264, 865
+eKFlMKp9Gs0_0005, 1080, 1920, 0, 1024, 705, 118, 1249, 662
+gp4fg9PWuhM_0003, 1080, 1920, 0, 858, 526, 0, 1310, 768
+jpCrKYWjYD8_0002, 1080, 1920, 0, 768, 527, 68, 1215, 756
+jpCrKYWjYD8_0002, 1080, 1920, 768, 1535, 527, 68, 1215, 756
+jxi_Cjc8T1w_0061, 1080, 1920, 0, 1024, 660, 102, 1286, 728
+kMXhWN71Ar0_0001, 1080, 1920, 0, 656, 60, 0, 940, 832
+kMXhWN71Ar0_0001, 1080, 1920, 656, 1311, 60, 0, 940, 832
+m2ZmZflLryo_0009, 1080, 1920, 0, 1024, 678, 51, 1390, 763
+npEcenV-Y08_0011, 1080, 1920, 99, 1087, 625, 69, 1425, 869
\ No newline at end of file
diff --git a/TalkingHead-1KH/requirements.txt b/TalkingHead-1KH/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..931a34da59d6bec3b090315e6b057ea944710fdd
--- /dev/null
+++ b/TalkingHead-1KH/requirements.txt
@@ -0,0 +1,4 @@
+ffmpeg-python
+imageio
+git+https://github.com/nficano/pytube
+tqdm
\ No newline at end of file
diff --git a/TalkingHead-1KH/teaser.gif b/TalkingHead-1KH/teaser.gif
new file mode 100644
index 0000000000000000000000000000000000000000..7079bd5992428cea506bb2ba9fa077eb148e26a7
--- /dev/null
+++ b/TalkingHead-1KH/teaser.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b6408b915b161270fedbfc932daa9a6615b49dbc2d2691bb9d794fe91fbdb18d
+size 3926806
diff --git a/TalkingHead-1KH/videos_crop.py b/TalkingHead-1KH/videos_crop.py
new file mode 100644
index 0000000000000000000000000000000000000000..96b36837c3a02d6a36ce0cdec59f598455d593ea
--- /dev/null
+++ b/TalkingHead-1KH/videos_crop.py
@@ -0,0 +1,79 @@
+# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
+#
+# This script is licensed under the MIT License.
+
+import argparse
+import multiprocessing as mp
+import os
+from functools import partial
+from time import time as timer
+
+import ffmpeg
+from tqdm import tqdm
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input_dir', type=str, required=True,
+ help='Dir containing youtube clips.')
+parser.add_argument('--clip_info_file', type=str, required=True,
+ help='File containing clip information.')
+parser.add_argument('--output_dir', type=str, required=True,
+ help='Location to dump outputs.')
+parser.add_argument('--num_workers', type=int, default=8,
+ help='How many multiprocessing workers?')
+args = parser.parse_args()
+
+
+def get_h_w(filepath):
+ probe = ffmpeg.probe(filepath)
+ video_stream = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)
+ height = int(video_stream['height'])
+ width = int(video_stream['width'])
+ return height, width
+
+
+def trim_and_crop(input_dir, output_dir, clip_params):
+ video_name, H, W, S, E, L, T, R, B = clip_params.strip().split(',')
+ H, W, S, E, L, T, R, B = int(H), int(W), int(S), int(E), int(L), int(T), int(R), int(B)
+ output_filename = '{}_S{}_E{}_L{}_T{}_R{}_B{}.mp4'.format(video_name, S, E, L, T, R, B)
+ output_filepath = os.path.join(output_dir, output_filename)
+ if os.path.exists(output_filepath):
+ print('Output file %s exists, skipping' % (output_filepath))
+ return
+
+ input_filepath = os.path.join(input_dir, video_name + '.mp4')
+ if not os.path.exists(input_filepath):
+ print('Input file %s does not exist, skipping' % (input_filepath))
+ return
+
+ h, w = get_h_w(input_filepath)
+ t = int(T / H * h)
+ b = int(B / H * h)
+ l = int(L / W * w)
+ r = int(R / W * w)
+ stream = ffmpeg.input(input_filepath)
+ stream = ffmpeg.trim(stream, start_frame=S, end_frame=E+1)
+ stream = ffmpeg.crop(stream, l, t, r-l, b-t)
+ stream = ffmpeg.output(stream, output_filepath)
+ ffmpeg.run(stream)
+
+
+if __name__ == '__main__':
+ # Read list of videos.
+ clip_info = []
+ with open(args.clip_info_file) as fin:
+ for line in fin:
+ clip_info.append(line.strip())
+
+ # Create output folder.
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Download videos.
+ downloader = partial(trim_and_crop, args.input_dir, args.output_dir)
+
+ start = timer()
+ pool_size = args.num_workers
+ print('Using pool size of %d' % (pool_size))
+ with mp.Pool(processes=pool_size) as p:
+ _ = list(tqdm(p.imap_unordered(downloader, clip_info), total=len(clip_info)))
+ print('Elapsed time: %.2f' % (timer() - start))
diff --git a/TalkingHead-1KH/videos_download.py b/TalkingHead-1KH/videos_download.py
new file mode 100644
index 0000000000000000000000000000000000000000..b557fcceb57f86aff9ecf9e40c1bed6eabd4572b
--- /dev/null
+++ b/TalkingHead-1KH/videos_download.py
@@ -0,0 +1,56 @@
+import argparse
+import multiprocessing as mp
+import os
+from functools import partial
+from time import time as timer
+
+from pytube import YouTube
+from tqdm import tqdm
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--input_list', type=str, required=True,
+ help='List of youtube video ids')
+parser.add_argument('--output_dir', type=str, default='data/youtube_videos',
+ help='Location to download videos')
+parser.add_argument('--num_workers', type=int, default=8,
+ help='How many multiprocessing workers?')
+args = parser.parse_args()
+
+
+def download_video(output_dir, video_id):
+ r"""Download video."""
+ video_path = '%s/%s.mp4' % (output_dir, video_id)
+ if not os.path.isfile(video_path):
+ try:
+ # Download the highest quality mp4 stream.
+ yt = YouTube('https://www.youtube.com/watch?v=%s' % (video_id))
+ stream = yt.streams.filter(subtype='mp4', only_video=True, adaptive=True).first()
+ if stream is None:
+ stream = yt.streams.filter(subtype='mp4').first()
+ stream.download(output_path=output_dir, filename=video_id + '.mp4')
+ except Exception as e:
+ print(e)
+ print('Failed to download %s' % (video_id))
+ else:
+ print('File exists: %s' % (video_id))
+
+
+if __name__ == '__main__':
+ # Read list of videos.
+ video_ids = []
+ with open(args.input_list) as fin:
+ for line in fin:
+ video_ids.append(line.strip())
+
+ # Create output folder.
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Download videos.
+ downloader = partial(download_video, args.output_dir)
+
+ start = timer()
+ pool_size = args.num_workers
+ print('Using pool size of %d' % (pool_size))
+ with mp.Pool(processes=pool_size) as p:
+ _ = list(tqdm(p.imap_unordered(downloader, video_ids), total=len(video_ids)))
+ print('Elapsed time: %.2f' % (timer() - start))
diff --git a/TalkingHead-1KH/videos_download_and_crop.sh b/TalkingHead-1KH/videos_download_and_crop.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a3eb3ca37c320ae341b929293586df8b1b792ff6
--- /dev/null
+++ b/TalkingHead-1KH/videos_download_and_crop.sh
@@ -0,0 +1,10 @@
+dataset=$1
+
+# Download the videos.
+python videos_download.py --input_list data_list/${dataset}_video_ids.txt --output_dir ${dataset}/raw_videos
+
+# Split the videos into 1-min chunks.
+./videos_split.sh ${dataset}/raw_videos ${dataset}/1min_clips
+
+# Extract the talking head clips.
+python videos_crop.py --input_dir ${dataset}/1min_clips/ --output_dir ${dataset}/cropped_clips --clip_info_file data_list/${dataset}_video_tubes.txt
\ No newline at end of file
diff --git a/TalkingHead-1KH/videos_split.sh b/TalkingHead-1KH/videos_split.sh
new file mode 100644
index 0000000000000000000000000000000000000000..de27f26fe0c3f287fbc3d66f67a8870aa18f20d5
--- /dev/null
+++ b/TalkingHead-1KH/videos_split.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+in_dir=$1
+out_dir=$2
+
+mkdir $out_dir;
+for f in $in_dir/*.mp4
+do
+ y=${f##*/};
+ ffmpeg -i $f -c copy -map 0 -segment_time 00:01:00 -f segment $out_dir/${y/.mp4}_%04d.mp4;
+done
\ No newline at end of file
diff --git a/checkpoints/.gitkeep b/checkpoints/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data_gen/eg3d/convert_to_eg3d_convention.py b/data_gen/eg3d/convert_to_eg3d_convention.py
new file mode 100644
index 0000000000000000000000000000000000000000..45d4e4b11dc69aa82ac0194c0df1b30d0ff020a7
--- /dev/null
+++ b/data_gen/eg3d/convert_to_eg3d_convention.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+import copy
+from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+
+
+def _fix_intrinsics(intrinsics):
+ """
+ intrinsics: [3,3], not batch-wise
+ """
+ # unnormalized normalized
+
+ # [[ f_x, s=0, x_0] [[ f_x/size_x, s=0, x_0/size_x=0.5]
+ # [ 0, f_y, y_0] -> [ 0, f_y/size_y, y_0/size_y=0.5]
+ # [ 0, 0, 1 ]] [ 0, 0, 1 ]]
+ intrinsics = np.array(intrinsics).copy()
+ assert intrinsics.shape == (3, 3), intrinsics
+ intrinsics[0,0] = 2985.29/700
+ intrinsics[1,1] = 2985.29/700
+ intrinsics[0,2] = 1/2
+ intrinsics[1,2] = 1/2
+ assert intrinsics[0,1] == 0
+ assert intrinsics[2,2] == 1
+ assert intrinsics[1,0] == 0
+ assert intrinsics[2,0] == 0
+ assert intrinsics[2,1] == 0
+ return intrinsics
+
+# Used in original submission
+def _fix_pose_orig(pose):
+ """
+ pose: [4,4], not batch-wise
+ """
+ pose = np.array(pose).copy()
+ location = pose[:3, 3]
+ radius = np.linalg.norm(location)
+ pose[:3, 3] = pose[:3, 3]/radius * 2.7
+ return pose
+
+
+def get_eg3d_convention_camera_pose_intrinsic(item):
+ """
+ item: a dict during binarize
+
+ """
+ if item['euler'].ndim == 1:
+ angle = convert_to_tensor(copy.copy(item['euler']))
+ trans = copy.deepcopy(item['trans'])
+
+ # handle the difference of euler axis between eg3d and ours
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
+ # angle += torch.tensor([0, 3.1415926535, 3.1415926535], device=angle.device)
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
+ trans[2] += -10
+ c = -np.dot(R, trans)
+ pose = np.eye(4)
+ pose[:3,:3] = R
+ c *= 0.27 # normalize camera radius
+ c[1] += 0.006 # additional offset used in submission
+ c[2] += 0.161 # additional offset used in submission
+ pose[0,3] = c[0]
+ pose[1,3] = c[1]
+ pose[2,3] = c[2]
+
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
+ pp = 512#112
+ w = 1024#224
+ h = 1024#224
+
+ K = np.eye(3)
+ K[0][0] = focal
+ K[1][1] = focal
+ K[0][2] = w/2.0
+ K[1][2] = h/2.0
+ convention_K = _fix_intrinsics(K)
+
+ Rot = np.eye(3)
+ Rot[0, 0] = 1
+ Rot[1, 1] = -1
+ Rot[2, 2] = -1
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot) # permute axes
+ convention_pose = _fix_pose_orig(pose)
+
+ item['c2w'] = pose
+ item['convention_c2w'] = convention_pose
+ item['intrinsics'] = convention_K
+ return item
+ else:
+ num_samples = len(item['euler'])
+ eulers_all = convert_to_tensor(copy.deepcopy(item['euler'])) # [B, 3]
+ trans_all = copy.deepcopy(item['trans']) # [B, 3]
+
+ # handle the difference of euler axis between eg3d and ours
+ # see data_gen/process_ffhq_for_eg3d/transplant_eg3d_ckpt_into_our_convention.ipynb
+ # eulers_all += torch.tensor([0, 3.1415926535, 3.1415926535], device=eulers_all.device).unsqueeze(0).repeat([eulers_all.shape[0],1])
+
+ intrinsics = []
+ poses = []
+ convention_poses = []
+ for i in range(num_samples):
+ angle = eulers_all[i]
+ trans = trans_all[i]
+ R = ParametricFaceModel.compute_rotation(angle.unsqueeze(0))[0].cpu().numpy()
+ trans[2] += -10
+ c = -np.dot(R, trans)
+ pose = np.eye(4)
+ pose[:3,:3] = R
+ c *= 0.27 # normalize camera radius
+ c[1] += 0.006 # additional offset used in submission
+ c[2] += 0.161 # additional offset used in submission
+ pose[0,3] = c[0]
+ pose[1,3] = c[1]
+ pose[2,3] = c[2]
+
+ focal = 2985.29 # = 1015*1024/224*(300/466.285),
+ # todo: 如果修改了fit 3dmm阶段的camera intrinsic,这里也要跟着改
+ pp = 512#112
+ w = 1024#224
+ h = 1024#224
+
+ K = np.eye(3)
+ K[0][0] = focal
+ K[1][1] = focal
+ K[0][2] = w/2.0
+ K[1][2] = h/2.0
+ convention_K = _fix_intrinsics(K)
+ intrinsics.append(convention_K)
+
+ Rot = np.eye(3)
+ Rot[0, 0] = 1
+ Rot[1, 1] = -1
+ Rot[2, 2] = -1
+ pose[:3, :3] = np.dot(pose[:3, :3], Rot)
+ convention_pose = _fix_pose_orig(pose)
+ convention_poses.append(convention_pose)
+ poses.append(pose)
+
+ intrinsics = np.stack(intrinsics) # [B, 3, 3]
+ poses = np.stack(poses) # [B, 4, 4]
+ convention_poses = np.stack(convention_poses) # [B, 4, 4]
+ item['intrinsics'] = intrinsics
+ item['c2w'] = poses
+ item['convention_c2w'] = convention_poses
+ return item
diff --git a/data_gen/runs/binarizer_nerf.py b/data_gen/runs/binarizer_nerf.py
new file mode 100644
index 0000000000000000000000000000000000000000..623cd17f6b52c9a981721a8ca14e24af1edfe202
--- /dev/null
+++ b/data_gen/runs/binarizer_nerf.py
@@ -0,0 +1,335 @@
+import os
+import numpy as np
+import math
+import json
+import imageio
+import torch
+import tqdm
+import cv2
+
+from data_util.face3d_helper import Face3DHelper
+from utils.commons.euler2rot import euler_trans_2_c2w, c2w_to_euler_trans
+from data_gen.utils.process_video.euler2quaterion import euler2quaterion, quaterion2euler
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+
+
+def euler2rot(euler_angle):
+ batch_size = euler_angle.shape[0]
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
+ one = torch.ones(batch_size, 1, 1).to(euler_angle.device)
+ zero = torch.zeros(batch_size, 1, 1).to(euler_angle.device)
+ rot_x = torch.cat((
+ torch.cat((one, zero, zero), 1),
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
+ ), 2)
+ rot_y = torch.cat((
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
+ torch.cat((zero, one, zero), 1),
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
+ ), 2)
+ rot_z = torch.cat((
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
+ torch.cat((zero, zero, one), 1)
+ ), 2)
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
+
+
+def rot2euler(rot_mat):
+ batch_size = len(rot_mat)
+ # we assert that y in in [-0.5pi, 0.5pi]
+ cos_y = torch.sqrt(rot_mat[:, 1, 2] * rot_mat[:, 1, 2] + rot_mat[:, 2, 2] * rot_mat[:, 2, 2])
+ theta_x = torch.atan2(-rot_mat[:, 1, 2], rot_mat[:, 2, 2])
+ theta_y = torch.atan2(rot_mat[:, 2, 0], cos_y)
+ theta_z = torch.atan2(rot_mat[:, 0, 1], rot_mat[:, 0, 0])
+ euler_angles = torch.zeros([batch_size, 3])
+ euler_angles[:, 0] = theta_x
+ euler_angles[:, 1] = theta_y
+ euler_angles[:, 2] = theta_z
+ return euler_angles
+
+index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
+
+def plot_lm2d(lm2d):
+ WH = 512
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
+
+ for i in range(len(lm2d)):
+ x, y = lm2d[i]
+ color = (255,0,0)
+ img = cv2.circle(img, center=(int(x),int(y)), radius=3, color=color, thickness=-1)
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ for i in range(len(lm2d)):
+ x, y = lm2d[i]
+ img = cv2.putText(img, f"{i}", org=(int(x),int(y)), fontFace=font, fontScale=0.3, color=(255,0,0))
+ return img
+
+def get_face_rect(lms, h, w):
+ """
+ lms: [68, 2]
+ h, w: int
+ return: [4,]
+ """
+ assert len(lms) == 68
+ # min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
+ cx = int((min_x+max_x)/2.0)
+ cy = int(lms[27, 1])
+ h_w = int((max_x-cx)*1.5)
+ h_h = int((lms[8, 1]-cy)*1.15)
+ rect_x = cx - h_w
+ rect_y = cy - h_h
+ if rect_x < 0:
+ rect_x = 0
+ if rect_y < 0:
+ rect_y = 0
+ rect_w = min(w-1-rect_x, 2*h_w)
+ rect_h = min(h-1-rect_y, 2*h_h)
+ # rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
+ # rect = [rect_x, rect_y, rect_w, rect_h]
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
+ return rect # this x is width, y is height
+
+def get_lip_rect(lms, h, w):
+ """
+ lms: [68, 2]
+ h, w: int
+ return: [4,]
+ """
+ # this x is width, y is height
+ # for lms, lms[:, 0] is width, lms[:, 1] is height
+ assert len(lms) == 68
+ lips = slice(48, 60)
+ lms = lms[lips]
+ min_x, max_x = np.min(lms[:, 0]), np.max(lms[:, 0])
+ min_y, max_y = np.min(lms[:, 1]), np.max(lms[:, 1])
+ cx = int((min_x+max_x)/2.0)
+ cy = int((min_y+max_y)/2.0)
+ h_w = int((max_x-cx)*1.2)
+ h_h = int((max_y-cy)*1.2)
+
+ h_w = max(h_w, h_h)
+ h_h = h_w
+
+ rect_x = cx - h_w
+ rect_y = cy - h_h
+ rect_w = 2*h_w
+ rect_h = 2*h_h
+ if rect_x < 0:
+ rect_x = 0
+ if rect_y < 0:
+ rect_y = 0
+
+ if rect_x + rect_w > w:
+ rect_x = w - rect_w
+ if rect_y + rect_h > h:
+ rect_y = h - rect_h
+
+ rect = [rect_x, rect_x + rect_w, rect_y, rect_y + rect_h] # min_j, max_j, min_i, max_i
+ return rect # this x is width, y is height
+
+
+# def get_lip_rect(lms, h, w):
+# """
+# lms: [68, 2]
+# h, w: int
+# return: [4,]
+# """
+# assert len(lms) == 68
+# lips = slice(48, 60)
+# # this x is width, y is height
+# xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max())
+# ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max())
+# # padding to H == W
+# cx = (xmin + xmax) // 2
+# cy = (ymin + ymax) // 2
+# l = max(xmax - xmin, ymax - ymin) // 2
+# xmin = max(0, cx - l)
+# xmax = min(h, cx + l)
+# ymin = max(0, cy - l)
+# ymax = min(w, cy + l)
+# lip_rect = [xmin, xmax, ymin, ymax]
+# return lip_rect
+
+def get_win_conds(conds, idx, smo_win_size=8, pad_option='zero'):
+ """
+ conds: [b, t=16, h=29]
+ idx: long, time index of the selected frame
+ """
+ idx = max(0, idx)
+ idx = min(idx, conds.shape[0]-1)
+ smo_half_win_size = smo_win_size//2
+ left_i = idx - smo_half_win_size
+ right_i = idx + (smo_win_size - smo_half_win_size)
+ pad_left, pad_right = 0, 0
+ if left_i < 0:
+ pad_left = -left_i
+ left_i = 0
+ if right_i > conds.shape[0]:
+ pad_right = right_i - conds.shape[0]
+ right_i = conds.shape[0]
+ conds_win = conds[left_i:right_i]
+ if pad_left > 0:
+ if pad_option == 'zero':
+ conds_win = np.concatenate([np.zeros_like(conds_win)[:pad_left], conds_win], axis=0)
+ elif pad_option == 'edge':
+ edge_value = conds[0][np.newaxis, ...]
+ conds_win = np.concatenate([edge_value] * pad_left + [conds_win], axis=0)
+ else:
+ raise NotImplementedError
+ if pad_right > 0:
+ if pad_option == 'zero':
+ conds_win = np.concatenate([conds_win, np.zeros_like(conds_win)[:pad_right]], axis=0)
+ elif pad_option == 'edge':
+ edge_value = conds[-1][np.newaxis, ...]
+ conds_win = np.concatenate([conds_win] + [edge_value] * pad_right , axis=0)
+ else:
+ raise NotImplementedError
+ assert conds_win.shape[0] == smo_win_size
+ return conds_win
+
+
+def load_processed_data(processed_dir):
+ # load necessary files
+ background_img_name = os.path.join(processed_dir, "bg.jpg")
+ assert os.path.exists(background_img_name)
+ head_img_dir = os.path.join(processed_dir, "head_imgs")
+ torso_img_dir = os.path.join(processed_dir, "inpaint_torso_imgs")
+ gt_img_dir = os.path.join(processed_dir, "gt_imgs")
+
+ hubert_npy_name = os.path.join(processed_dir, "aud_hubert.npy")
+ mel_f0_npy_name = os.path.join(processed_dir, "aud_mel_f0.npy")
+ coeff_npy_name = os.path.join(processed_dir, "coeff_fit_mp.npy")
+ lm2d_npy_name = os.path.join(processed_dir, "lms_2d.npy")
+
+ ret_dict = {}
+
+ ret_dict['bg_img'] = imageio.imread(background_img_name)
+ ret_dict['H'], ret_dict['W'] = ret_dict['bg_img'].shape[:2]
+ ret_dict['focal'], ret_dict['cx'], ret_dict['cy'] = face_model.focal, face_model.center, face_model.center
+
+ print("loading lm2d coeff ...")
+ lm2d_arr = np.load(lm2d_npy_name)
+ face_rect_lst = []
+ lip_rect_lst = []
+ for lm2d in lm2d_arr:
+ if len(lm2d) in [468, 478]:
+ lm2d = lm2d[index_lm68_from_lm468]
+ face_rect = get_face_rect(lm2d, ret_dict['H'], ret_dict['W'])
+ lip_rect = get_lip_rect(lm2d, ret_dict['H'], ret_dict['W'])
+ face_rect_lst.append(face_rect)
+ lip_rect_lst.append(lip_rect)
+ face_rects = np.stack(face_rect_lst, axis=0) # [T, 4]
+
+ print("loading fitted 3dmm coeff ...")
+ coeff_dict = np.load(coeff_npy_name, allow_pickle=True).tolist()
+ identity_arr = coeff_dict['id']
+ exp_arr = coeff_dict['exp']
+ ret_dict['id'] = identity_arr
+ ret_dict['exp'] = exp_arr
+ euler_arr = ret_dict['euler'] = coeff_dict['euler']
+ trans_arr = ret_dict['trans'] = coeff_dict['trans']
+ print("calculating lm3d ...")
+ idexp_lm3d_arr = face3d_helper.reconstruct_idexp_lm3d(torch.from_numpy(identity_arr), torch.from_numpy(exp_arr)).cpu().numpy().reshape([-1, 68*3])
+ len_motion = len(idexp_lm3d_arr)
+ video_idexp_lm3d_mean = idexp_lm3d_arr.mean(axis=0)
+ video_idexp_lm3d_std = idexp_lm3d_arr.std(axis=0)
+ ret_dict['idexp_lm3d'] = idexp_lm3d_arr
+ ret_dict['idexp_lm3d_mean'] = video_idexp_lm3d_mean
+ ret_dict['idexp_lm3d_std'] = video_idexp_lm3d_std
+
+ # now we convert the euler_trans from deep3d convention to adnerf convention
+ eulers = torch.FloatTensor(euler_arr)
+ trans = torch.FloatTensor(trans_arr)
+ rots = face_model.compute_rotation(eulers) # rotation matrix is a better intermediate for convention-transplan than euler
+
+ # handle the camera pose to geneface's convention
+ trans[:, 2] = 10 - trans[:, 2] # 抵消fit阶段的to_camera操作,即trans[...,2] = 10 - trans[...,2]
+ rots = rots.permute(0, 2, 1)
+ trans[:, 2] = - trans[:,2] # 因为intrinsic proj不同
+ # below is the NeRF camera preprocessing strategy, see `save_transforms` in data_util/process.py
+ trans = trans / 10.0
+ rots_inv = rots.permute(0, 2, 1)
+ trans_inv = - torch.bmm(rots_inv, trans.unsqueeze(2))
+
+ pose = torch.eye(4, dtype=torch.float32).unsqueeze(0).repeat([len_motion, 1, 1]) # [T, 4, 4]
+ pose[:, :3, :3] = rots_inv
+ pose[:, :3, 3] = trans_inv[:, :, 0]
+ c2w_transform_matrices = pose.numpy()
+
+ # process the audio features used for postnet training
+ print("loading hubert ...")
+ hubert_features = np.load(hubert_npy_name)
+ print("loading Mel and F0 ...")
+ mel_f0_features = np.load(mel_f0_npy_name, allow_pickle=True).tolist()
+
+ ret_dict['hubert'] = hubert_features
+ ret_dict['mel'] = mel_f0_features['mel']
+ ret_dict['f0'] = mel_f0_features['f0']
+
+ # obtaining train samples
+ frame_indices = list(range(len_motion))
+ num_train = len_motion // 11 * 10
+ train_indices = frame_indices[:num_train]
+ val_indices = frame_indices[num_train:]
+
+ for split in ['train', 'val']:
+ if split == 'train':
+ indices = train_indices
+ samples = []
+ ret_dict['train_samples'] = samples
+ elif split == 'val':
+ indices = val_indices
+ samples = []
+ ret_dict['val_samples'] = samples
+
+ for idx in indices:
+ sample = {}
+ sample['idx'] = idx
+ sample['head_img_fname'] = os.path.join(head_img_dir,f"{idx:08d}.png")
+ sample['torso_img_fname'] = os.path.join(torso_img_dir,f"{idx:08d}.png")
+ sample['gt_img_fname'] = os.path.join(gt_img_dir,f"{idx:08d}.jpg")
+ # assert os.path.exists(sample['head_img_fname']) and os.path.exists(sample['torso_img_fname']) and os.path.exists(sample['gt_img_fname'])
+ sample['face_rect'] = face_rects[idx]
+ sample['lip_rect'] = lip_rect_lst[idx]
+ sample['c2w'] = c2w_transform_matrices[idx]
+ samples.append(sample)
+ return ret_dict
+
+
+class Binarizer:
+ def __init__(self):
+ self.data_dir = 'data/'
+
+ def parse(self, video_id):
+ processed_dir = os.path.join(self.data_dir, 'processed/videos', video_id)
+ binary_dir = os.path.join(self.data_dir, 'binary/videos', video_id)
+ out_fname = os.path.join(binary_dir, "trainval_dataset.npy")
+ os.makedirs(binary_dir, exist_ok=True)
+ ret = load_processed_data(processed_dir)
+ mel_name = os.path.join(processed_dir, 'aud_mel_f0.npy')
+ mel_f0_dict = np.load(mel_name, allow_pickle=True).tolist()
+ ret.update(mel_f0_dict)
+ np.save(out_fname, ret, allow_pickle=True)
+
+
+
+if __name__ == '__main__':
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument('--video_id', type=str, default='May', help='')
+ args = parser.parse_args()
+ ### Process Single Long Audio for NeRF dataset
+ video_id = args.video_id
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015)
+ face_model.to("cpu")
+ face3d_helper = Face3DHelper()
+
+ binarizer = Binarizer()
+ binarizer.parse(video_id)
+ print(f"Binarization for {video_id} Done!")
diff --git a/data_gen/runs/binarizer_th1kh.py b/data_gen/runs/binarizer_th1kh.py
new file mode 100644
index 0000000000000000000000000000000000000000..63dfde574a3a5dcc084ea638678174bfb97cd3c8
--- /dev/null
+++ b/data_gen/runs/binarizer_th1kh.py
@@ -0,0 +1,100 @@
+import os
+import numpy as np
+from scipy.misc import face
+import torch
+from tqdm import trange
+import pickle
+from copy import deepcopy
+
+from data_util.face3d_helper import Face3DHelper
+from utils.commons.indexed_datasets import IndexedDataset, IndexedDatasetBuilder
+
+
+def load_video_npy(fn):
+ assert fn.endswith("_coeff_fit_mp.npy")
+ ret_dict = np.load(fn,allow_pickle=True).item()
+ video_dict = {
+ 'euler': ret_dict['euler'], # [T, 3]
+ 'trans': ret_dict['trans'], # [T, 3]
+ 'id': ret_dict['id'], # [T, 80]
+ 'exp': ret_dict['exp'], # [T, 64]
+ }
+ return video_dict
+
+def cal_lm3d_in_video_dict(video_dict, face3d_helper):
+ identity = video_dict['id']
+ exp = video_dict['exp']
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d(identity, exp).cpu().numpy()
+ video_dict['idexp_lm3d'] = idexp_lm3d
+
+
+def load_audio_npy(fn):
+ assert fn.endswith(".npy")
+ ret_dict = np.load(fn,allow_pickle=True).item()
+ audio_dict = {
+ "mel": ret_dict['mel'], # [T, 80]
+ "f0": ret_dict['f0'], # [T,1]
+ }
+ return audio_dict
+
+
+if __name__ == '__main__':
+ face3d_helper = Face3DHelper(use_gpu=False)
+
+ import glob,tqdm
+ prefixs = ['val', 'train']
+ binarized_ds_path = "data/binary/th1kh"
+ os.makedirs(binarized_ds_path, exist_ok=True)
+ for prefix in prefixs:
+ databuilder = IndexedDatasetBuilder(os.path.join(binarized_ds_path, prefix), gzip=False, default_idx_size=1024*1024*1024*2)
+ raw_base_dir = '/mnt/bn/ailabrenyi/entries/yezhenhui/datasets/raw/TH1KH_512/video'
+ mp4_names = glob.glob(os.path.join(raw_base_dir, '*.mp4'))
+ mp4_names = mp4_names[:1000]
+ cnt = 0
+ scnt = 0
+ pbar = tqdm.tqdm(enumerate(mp4_names), total=len(mp4_names))
+ for i, mp4_name in pbar:
+ cnt += 1
+ if prefix == 'train':
+ if i % 100 == 0:
+ continue
+ else:
+ if i % 100 != 0:
+ continue
+ hubert_npy_name = mp4_name.replace("/video/", "/hubert/").replace(".mp4", "_hubert.npy")
+ audio_npy_name = mp4_name.replace("/video/", "/mel_f0/").replace(".mp4", "_mel_f0.npy")
+ video_npy_name = mp4_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
+ if not os.path.exists(audio_npy_name):
+ print(f"Skip item for audio npy not found.")
+ continue
+ if not os.path.exists(video_npy_name):
+ print(f"Skip item for video npy not found.")
+ continue
+ if (not os.path.exists(hubert_npy_name)):
+ print(f"Skip item for hubert_npy not found.")
+ continue
+ audio_dict = load_audio_npy(audio_npy_name)
+ hubert = np.load(hubert_npy_name)
+ video_dict = load_video_npy(video_npy_name)
+ com_img_dir = mp4_name.replace("/video/", "/com_imgs/").replace(".mp4", "")
+ num_com_imgs = len(glob.glob(os.path.join(com_img_dir, '*')))
+ num_frames = len(video_dict['exp'])
+ if num_com_imgs != num_frames:
+ print(f"Skip item for length mismatch.")
+ continue
+ mel = audio_dict['mel']
+ if mel.shape[0] < 32: # the video is shorter than 0.6s
+ print(f"Skip item for too short.")
+ continue
+
+ audio_dict.update(video_dict)
+ audio_dict['item_id'] = os.path.basename(mp4_name)[:-4]
+ audio_dict['hubert'] = hubert # [T_x, hid=1024]
+ audio_dict['img_dir'] = com_img_dir
+
+
+ databuilder.add_item(audio_dict)
+ scnt += 1
+ pbar.set_postfix({'success': scnt, 'success rate': scnt / cnt})
+ databuilder.finalize()
+ print(f"{prefix} set has {cnt} samples!")
\ No newline at end of file
diff --git a/data_gen/runs/nerf/process_guide.md b/data_gen/runs/nerf/process_guide.md
new file mode 100644
index 0000000000000000000000000000000000000000..2312d416fcd50cee8656803fe2fdba141e62e86f
--- /dev/null
+++ b/data_gen/runs/nerf/process_guide.md
@@ -0,0 +1,49 @@
+# 温馨提示:第一次执行可以先一步步跑完下面的命令行,把环境跑通后,之后可以直接运行同目录的run.sh,一键完成下面的所有步骤。
+
+# Step0. 将视频Crop到512x512分辨率,25FPS,确保每一帧都有目标人脸
+```
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 data/raw/videos/${VIDEO_ID}_512.mp4
+mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
+mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
+```
+# step1: 提取音频特征, 如mel, f0, hubuert, esperanto
+```
+export CUDA_VISIBLE_DEVICES=0
+export VIDEO_ID=May
+mkdir -p data/processed/videos/${VIDEO_ID}
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 data/processed/videos/${VIDEO_ID}/aud.wav
+python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
+python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
+```
+
+# Step2. 提取图片
+```
+export VIDEO_ID=May
+export CUDA_VISIBLE_DEVICES=0
+mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
+python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
+```
+
+# Step3. 提取lm2d_mediapipe
+### 提取2D landmark用于之后Fit 3DMM
+### num_workers是本机上的CPU worker数量;total_process是使用的机器数;process_id是本机的编号
+
+```
+export VIDEO_ID=May
+python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
+```
+
+# Step3. fit 3dmm
+```
+export VIDEO_ID=May
+export CUDA_VISIBLE_DEVICES=0
+python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
+```
+
+# Step4. Binarize
+```
+export VIDEO_ID=May
+python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
+```
+可以看到在`data/binary/videos/Mayssss`目录下得到了数据集。
\ No newline at end of file
diff --git a/data_gen/runs/nerf/run.sh b/data_gen/runs/nerf/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f028ad9e061c925e51946ff83c27e99c35cbb15c
--- /dev/null
+++ b/data_gen/runs/nerf/run.sh
@@ -0,0 +1,51 @@
+# usage: CUDA_VISIBLE_DEVICES=0 bash data_gen/runs/nerf/run.sh
+# please place video to data/raw/videos/${VIDEO_ID}.mp4
+VIDEO_ID=$1
+echo Processing $VIDEO_ID
+
+echo Resizing the video to 512x512
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -y data/raw/videos/${VIDEO_ID}_512.mp4
+mv data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
+mv data/raw/videos/${VIDEO_ID}_512.mp4 data/raw/videos/${VIDEO_ID}.mp4
+echo Done
+echo The old video is moved to data/raw/videos/${VIDEO_ID}.mp4 data/raw/videos/${VIDEO_ID}_to_rm.mp4
+
+echo mkdir -p data/processed/videos/${VIDEO_ID}
+mkdir -p data/processed/videos/${VIDEO_ID}
+echo Done
+
+# extract audio file from the training video
+echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -f wav -ar 16000 -v quiet -y data/processed/videos/${VIDEO_ID}/aud.wav
+echo Done
+
+# extract hubert_mel_f0 from audio
+echo python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
+python data_gen/utils/process_audio/extract_hubert.py --video_id=${VIDEO_ID}
+echo python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
+python data_gen/utils/process_audio/extract_mel_f0.py --video_id=${VIDEO_ID}
+echo Done
+
+# extract segment images
+echo mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
+mkdir -p data/processed/videos/${VIDEO_ID}/gt_imgs
+echo ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
+ffmpeg -i data/raw/videos/${VIDEO_ID}.mp4 -vf fps=25,scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet data/processed/videos/${VIDEO_ID}/gt_imgs/%08d.jpg
+echo Done
+
+echo python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
+python data_gen/utils/process_video/extract_segment_imgs.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 # extract image, segmap, and background
+echo Done
+
+echo python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
+python data_gen/utils/process_video/extract_lm2d.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4
+echo Done
+
+pkill -f void*
+echo python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
+python data_gen/utils/process_video/fit_3dmm_landmark.py --ds_name=nerf --vid_dir=data/raw/videos/${VIDEO_ID}.mp4 --reset --debug --id_mode=global
+echo Done
+
+echo python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
+python data_gen/runs/binarizer_nerf.py --video_id=${VIDEO_ID}
+echo Done
\ No newline at end of file
diff --git a/data_gen/utils/mp_feature_extractors/face_landmarker.py b/data_gen/utils/mp_feature_extractors/face_landmarker.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b5904a46809352ef08fd1b3d6948ec4fbc6b7fd
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/face_landmarker.py
@@ -0,0 +1,130 @@
+import mediapipe as mp
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+import numpy as np
+import cv2
+import os
+import copy
+
+# simplified mediapipe ldm at https://github.com/k-m-irfan/simplified_mediapipe_face_landmarks
+index_lm141_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [468,469,470,471,472] + [473,474,475,476,477] + [64,4,294]
+# lm141 without iris
+index_lm131_from_lm478 = [70,63,105,66,107,55,65,52,53,46] + [300,293,334,296,336,285,295,282,283,276] + [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249] + [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95] + [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146] + [10,338,297,332,284,251,389,356,454,323,361,288,397,365,379,378,400,377,152,148,176,149,150,136,172,58,132,93,234,127,162,21,54,103,67,109] + [64,4,294]
+
+# face alignment lm68
+index_lm68_from_lm478 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
+# used for weights for key parts
+unmatch_mask_from_lm478 = [ 93, 127, 132, 234, 323, 356, 361, 454]
+index_eye_from_lm478 = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
+index_innerlip_from_lm478 = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
+index_outerlip_from_lm478 = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
+index_withinmouth_from_lm478 = [76, 62] + [184, 183, 74, 72, 73, 41, 72, 38, 11, 12, 302, 268, 303, 271, 304, 272, 408, 407] + [292, 306] + [325, 307, 319, 320, 403, 404, 316, 315, 15, 16, 86, 85, 179, 180, 89, 90, 96, 77]
+index_mouth_from_lm478 = index_innerlip_from_lm478 + index_outerlip_from_lm478 + index_withinmouth_from_lm478
+
+index_yaw_from_lm68 = list(range(0, 17))
+index_brow_from_lm68 = list(range(17, 27))
+index_nose_from_lm68 = list(range(27, 36))
+index_eye_from_lm68 = list(range(36, 48))
+index_mouth_from_lm68 = list(range(48, 68))
+
+
+def read_video_to_frames(video_name):
+ frames = []
+ cap = cv2.VideoCapture(video_name)
+ while cap.isOpened():
+ ret, frame_bgr = cap.read()
+ if frame_bgr is None:
+ break
+ frames.append(frame_bgr)
+ frames = np.stack(frames)
+ frames = np.flip(frames, -1) # BGR ==> RGB
+ return frames
+
+class MediapipeLandmarker:
+ def __init__(self):
+ model_path = 'data_gen/utils/mp_feature_extractors/face_landmarker.task'
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+ print("downloading face_landmarker model from mediapipe...")
+ model_url = 'https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task'
+ os.system(f"wget {model_url}")
+ os.system(f"mv face_landmarker.task {model_path}")
+ print("download success")
+ base_options = python.BaseOptions(model_asset_path=model_path)
+ self.image_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
+ running_mode=vision.RunningMode.IMAGE, # IMAGE, VIDEO, LIVE_STREAM
+ num_faces=1)
+ self.video_mode_options = vision.FaceLandmarkerOptions(base_options=base_options,
+ running_mode=vision.RunningMode.VIDEO, # IMAGE, VIDEO, LIVE_STREAM
+ num_faces=1)
+
+ def extract_lm478_from_img_name(self, img_name):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_lm478 = self.extract_lm478_from_img(img)
+ return img_lm478
+
+ def extract_lm478_from_img(self, img):
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=img.astype(np.uint8))
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
+ H, W, _ = img.shape
+ img_lm478 = np.array(img_face_landmarks)[:, :2] * np.array([W, H]).reshape([1,2]) # [478, 2]
+ return img_lm478
+
+ def extract_lm478_from_video_name(self, video_name, fps=25, anti_smooth_factor=2):
+ frames = read_video_to_frames(video_name)
+ img_lm478, vid_lm478 = self.extract_lm478_from_frames(frames, fps, anti_smooth_factor)
+ return img_lm478, vid_lm478
+
+ def extract_lm478_from_frames(self, frames, fps=25, anti_smooth_factor=20):
+ """
+ frames: RGB, uint8
+ anti_smooth_factor: float, 对video模式的interval进行修改, 1代表无修改, 越大越接近image mode
+ """
+ img_mpldms = []
+ vid_mpldms = []
+ img_landmarker = vision.FaceLandmarker.create_from_options(self.image_mode_options)
+ vid_landmarker = vision.FaceLandmarker.create_from_options(self.video_mode_options)
+
+ for i in range(len(frames)):
+ frame = mp.Image(image_format=mp.ImageFormat.SRGB, data=frames[i].astype(np.uint8))
+ img_face_landmarker_result = img_landmarker.detect(image=frame)
+ vid_face_landmarker_result = vid_landmarker.detect_for_video(image=frame, timestamp_ms=int((1000/fps)*anti_smooth_factor*i))
+ try:
+ img_ldm_i = img_face_landmarker_result.face_landmarks[0]
+ vid_ldm_i = vid_face_landmarker_result.face_landmarks[0]
+ except:
+ print(f"Warning: failed detect ldm in idx={i}, use previous frame results.")
+ img_face_landmarks = np.array([[l.x, l.y, l.z] for l in img_ldm_i])
+ vid_face_landmarks = np.array([[l.x, l.y, l.z] for l in vid_ldm_i])
+ img_mpldms.append(img_face_landmarks)
+ vid_mpldms.append(vid_face_landmarks)
+ img_lm478 = np.stack(img_mpldms)[..., :2]
+ vid_lm478 = np.stack(vid_mpldms)[..., :2]
+ bs, H, W, _ = frames.shape
+ img_lm478 = np.array(img_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
+ vid_lm478 = np.array(vid_lm478)[..., :2] * np.array([W, H]).reshape([1,1,2]) # [T, 478, 2]
+ return img_lm478, vid_lm478
+
+ def combine_vid_img_lm478_to_lm68(self, img_lm478, vid_lm478):
+ img_lm68 = img_lm478[:, index_lm68_from_lm478]
+ vid_lm68 = vid_lm478[:, index_lm68_from_lm478]
+ combined_lm68 = copy.deepcopy(img_lm68)
+ combined_lm68[:, index_yaw_from_lm68] = vid_lm68[:, index_yaw_from_lm68]
+ combined_lm68[:, index_brow_from_lm68] = vid_lm68[:, index_brow_from_lm68]
+ combined_lm68[:, index_nose_from_lm68] = vid_lm68[:, index_nose_from_lm68]
+ return combined_lm68
+
+ def combine_vid_img_lm478_to_lm478(self, img_lm478, vid_lm478):
+ combined_lm478 = copy.deepcopy(vid_lm478)
+ combined_lm478[:, index_mouth_from_lm478] = img_lm478[:, index_mouth_from_lm478]
+ combined_lm478[:, index_eye_from_lm478] = img_lm478[:, index_eye_from_lm478]
+ return combined_lm478
+
+if __name__ == '__main__':
+ landmarker = MediapipeLandmarker()
+ ret = landmarker.extract_lm478_from_video_name("00000.mp4")
diff --git a/data_gen/utils/mp_feature_extractors/face_landmarker.task b/data_gen/utils/mp_feature_extractors/face_landmarker.task
new file mode 100644
index 0000000000000000000000000000000000000000..fedb14de6d2b6708a56c04ae259783e23404c1aa
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/face_landmarker.task
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:64184e229b263107bc2b804c6625db1341ff2bb731874b0bcc2fe6544e0bc9ff
+size 3758596
diff --git a/data_gen/utils/mp_feature_extractors/mp_segmenter.py b/data_gen/utils/mp_feature_extractors/mp_segmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a9da3509e042d0a02699f92dcd71ce1342c6652
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/mp_segmenter.py
@@ -0,0 +1,303 @@
+import os
+import copy
+import numpy as np
+import tqdm
+import mediapipe as mp
+import torch
+from mediapipe.tasks import python
+from mediapipe.tasks.python import vision
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm, multiprocess_run
+from utils.commons.tensor_utils import convert_to_np
+from sklearn.neighbors import NearestNeighbors
+
+def scatter_np(condition_img, classSeg=5):
+# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
+ batch, c, height, width = condition_img.shape
+ # if height != label_size[0] or width != label_size[1]:
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
+ input_label = np.zeros([batch, classSeg, condition_img.shape[2], condition_img.shape[3]]).astype(np.int_)
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
+ np.put_along_axis(input_label, condition_img, 1, 1)
+ return input_label
+
+def scatter(condition_img, classSeg=19):
+# def scatter(condition_img, classSeg=19, label_size=(512, 512)):
+ batch, c, height, width = condition_img.size()
+ # if height != label_size[0] or width != label_size[1]:
+ # condition_img= F.interpolate(condition_img, size=label_size, mode='nearest')
+ input_label = torch.zeros(batch, classSeg, condition_img.shape[2], condition_img.shape[3], device=condition_img.device)
+ # input_label = torch.zeros(batch, classSeg, *label_size, device=condition_img.device)
+ return input_label.scatter_(1, condition_img.long(), 1)
+
+def encode_segmap_mask_to_image(segmap):
+ # rgb
+ _,h,w = segmap.shape
+ encoded_img = np.ones([h,w,3],dtype=np.uint8) * 255
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
+ for i, color in enumerate(colors):
+ mask = segmap[i].astype(int)
+ index = np.where(mask != 0)
+ encoded_img[index[0], index[1], :] = np.array(color)
+ return encoded_img.astype(np.uint8)
+
+def decode_segmap_mask_from_image(encoded_img):
+ # rgb
+ colors = [(255,255,255),(255,255,0),(255,0,255),(0,255,255),(255,0,0),(0,255,0)]
+ bg = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
+ hair = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
+ body_skin = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 255)
+ face_skin = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 255)
+ clothes = (encoded_img[..., 0] == 255) & (encoded_img[..., 1] == 0) & (encoded_img[..., 2] == 0)
+ others = (encoded_img[..., 0] == 0) & (encoded_img[..., 1] == 255) & (encoded_img[..., 2] == 0)
+ segmap = np.stack([bg, hair, body_skin, face_skin, clothes, others], axis=0)
+ return segmap.astype(np.uint8)
+
+def read_video_frame(video_name, frame_id):
+ # https://blog.csdn.net/bby1987/article/details/108923361
+ # frame_num = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) # ==> 总帧数
+ # fps = video_capture.get(cv2.CAP_PROP_FPS) # ==> 帧率
+ # width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) # ==> 视频宽度
+ # height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) # ==> 视频高度
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 句柄位置
+ # video_capture.set(cv2.CAP_PROP_POS_FRAMES, 1000) # ==> 设置句柄位置
+ # pos = video_capture.get(cv2.CAP_PROP_POS_FRAMES) # ==> 此时 pos = 1000.0
+ # video_capture.release()
+ vr = cv2.VideoCapture(video_name)
+ vr.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
+ _, frame = vr.read()
+ return frame
+
+def decode_segmap_mask_from_segmap_video_frame(video_frame):
+ # video_frame: 0~255 BGR, obtained by read_video_frame
+ def assign_values(array):
+ remainder = array % 40 # 计算数组中每个值与40的余数
+ assigned_values = np.where(remainder <= 20, array - remainder, array + (40 - remainder))
+ return assigned_values
+ segmap = video_frame.mean(-1)
+ segmap = assign_values(segmap) // 40 # [H, W] with value 0~5
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ return segmap.astype(np.uint8)
+
+def extract_background(img_lst, segmap_lst=None):
+ """
+ img_lst: list of rgb ndarray
+ """
+ # only use 1/20 images
+ num_frames = len(img_lst)
+ img_lst = img_lst[::20] if num_frames > 20 else img_lst[0:1]
+
+ if segmap_lst is not None:
+ segmap_lst = segmap_lst[::20] if num_frames > 20 else segmap_lst[0:1]
+ assert len(img_lst) == len(segmap_lst)
+ # get H/W
+ h, w = img_lst[0].shape[:2]
+
+ # nearest neighbors
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
+ distss = []
+ for idx, img in enumerate(img_lst):
+ if segmap_lst is not None:
+ segmap = segmap_lst[idx]
+ else:
+ segmap = seg_model._cal_seg_map(img)
+ bg = (segmap[0]).astype(bool)
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ dists, _ = nbrs.kneighbors(all_xys)
+ distss.append(dists)
+
+ distss = np.stack(distss)
+ max_dist = np.max(distss, 0)
+ max_id = np.argmax(distss, 0)
+
+ bc_pixs = max_dist > 10 # 5
+ bc_pixs_id = np.nonzero(bc_pixs)
+ bc_ids = max_id[bc_pixs]
+
+ num_pixs = distss.shape[1]
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
+
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
+ bg_img = bg_img.reshape(h, w, 3)
+
+ max_dist = max_dist.reshape(h, w)
+ bc_pixs = max_dist > 10 # 5
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ distances, indices = nbrs.kneighbors(bg_xys)
+ bg_fg_xys = fg_xys[indices[:, 0]]
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ return bg_img
+
+
+global_segmenter = None
+def job_cal_seg_map_for_image(img, segmenter_options=None, segmenter=None):
+ """
+ 被 MediapipeSegmenter.multiprocess_cal_seg_map_for_a_video所使用, 专门用来处理单个长视频.
+ """
+ global global_segmenter
+ if segmenter is not None:
+ segmenter_actual = segmenter
+ else:
+ global_segmenter = vision.ImageSegmenter.create_from_options(segmenter_options) if global_segmenter is None else global_segmenter
+ segmenter_actual = global_segmenter
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
+ out = segmenter_actual.segment(mp_image)
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
+
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
+ segmap_image = (segmap_image * 40).astype(np.uint8)
+
+ return segmap_mask, segmap_image
+
+class MediapipeSegmenter:
+ def __init__(self):
+ model_path = 'data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite'
+ if not os.path.exists(model_path):
+ os.makedirs(os.path.dirname(model_path), exist_ok=True)
+ print("downloading segmenter model from mediapipe...")
+ os.system(f"wget https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite")
+ os.system(f"mv selfie_multiclass_256x256.tflite {model_path}")
+ print("download success")
+ base_options = python.BaseOptions(model_asset_path=model_path)
+ self.options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.IMAGE, output_category_mask=True)
+ self.video_options = vision.ImageSegmenterOptions(base_options=base_options,running_mode=vision.RunningMode.VIDEO, output_category_mask=True)
+
+ def multiprocess_cal_seg_map_for_a_video(self, imgs, num_workers=4):
+ """
+ 并行处理单个长视频
+ imgs: list of rgb array in 0~255
+ """
+ segmap_masks = []
+ segmap_images = []
+ img_lst = [(self.options, imgs[i]) for i in range(len(imgs))]
+ for (i, res) in multiprocess_run_tqdm(job_cal_seg_map_for_image, args=img_lst, num_workers=num_workers, desc='extracting from a video in multi-process'):
+ segmap_mask, segmap_image = res
+ segmap_masks.append(segmap_mask)
+ segmap_images.append(segmap_image)
+ return segmap_masks, segmap_images
+
+ def _cal_seg_map_for_video(self, imgs, segmenter=None, return_onehot_mask=True, return_segmap_image=True):
+ segmenter = vision.ImageSegmenter.create_from_options(self.video_options) if segmenter is None else segmenter
+ assert return_onehot_mask or return_segmap_image # you should at least return one
+ segmap_masks = []
+ segmap_images = []
+ for i in tqdm.trange(len(imgs), desc="extracting segmaps from a video..."):
+ img = imgs[i]
+ mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
+ out = segmenter.segment_for_video(mp_image, 40 * i)
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
+
+ if return_onehot_mask:
+ segmap_mask = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ segmap_masks.append(segmap_mask)
+ if return_segmap_image:
+ segmap_image = segmap[:, :, None].repeat(3, 2).astype(float)
+ segmap_image = (segmap_image * 40).astype(np.uint8)
+ segmap_images.append(segmap_image)
+
+ if return_onehot_mask and return_segmap_image:
+ return segmap_masks, segmap_images
+ elif return_onehot_mask:
+ return segmap_masks
+ elif return_segmap_image:
+ return segmap_images
+
+ def _cal_seg_map(self, img, segmenter=None, return_onehot_mask=True):
+ """
+ segmenter: vision.ImageSegmenter.create_from_options(options)
+ img: numpy, [H, W, 3], 0~255
+ segmap: [C, H, W]
+ 0 - background
+ 1 - hair
+ 2 - body-skin
+ 3 - face-skin
+ 4 - clothes
+ 5 - others (accessories)
+ """
+ assert img.ndim == 3
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=img)
+ out = segmenter.segment(image)
+ segmap = out.category_mask.numpy_view().copy() # [H, W]
+ if return_onehot_mask:
+ segmap = scatter_np(segmap[None, None, ...], classSeg=6)[0] # [6, H, W]
+ return segmap
+
+ def _seg_out_img_with_segmap(self, img, segmap, mode='head'):
+ """
+ img: [h,w,c], img is in 0~255, np
+ """
+ #
+ img = copy.deepcopy(img)
+ if mode == 'head':
+ selected_mask = segmap[[1,3,5] , :, :].sum(axis=0)[None,:] > 0.5 # glasses 也属于others
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ # selected_mask = segmap[[1,3] , :, :].sum(dim=0, keepdim=True) > 0.5
+ elif mode == 'person':
+ selected_mask = segmap[[1,2,3,4,5], :, :].sum(axis=0)[None,:] > 0.5
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'torso':
+ selected_mask = segmap[[2,4], :, :].sum(axis=0)[None,:] > 0.5
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'torso_with_bg':
+ selected_mask = segmap[[0, 2,4], :, :].sum(axis=0)[None,:] > 0.5
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'bg':
+ selected_mask = segmap[[0], :, :].sum(axis=0)[None,:] > 0.5 # only seg out 0, which means background
+ img[~selected_mask.repeat(3,axis=0).transpose(1,2,0)] = 0 # (-1,-1,-1) denotes black in our [-1,1] convention
+ elif mode == 'full':
+ pass
+ else:
+ raise NotImplementedError()
+ return img, selected_mask
+
+ def _seg_out_img(self, img, segmenter=None, mode='head'):
+ """
+ imgs [H, W, 3] 0-255
+ return : person_img [B, 3, H, W]
+ """
+ segmenter = vision.ImageSegmenter.create_from_options(self.options) if segmenter is None else segmenter
+ segmap = self._cal_seg_map(img, segmenter=segmenter, return_onehot_mask=True) # [B, 19, H, W]
+ return self._seg_out_img_with_segmap(img, segmap, mode=mode)
+
+ def seg_out_imgs(self, img, mode='head'):
+ """
+ api for pytorch img, -1~1
+ img: [B, 3, H, W], -1~1
+ """
+ device = img.device
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
+ img = ((img + 1) * 127.5).astype(np.uint8)
+ img_lst = [copy.deepcopy(img[i]) for i in range(len(img))]
+ out_lst = []
+ for im in img_lst:
+ out = self._seg_out_img(im, mode=mode)
+ out_lst.append(out)
+ seg_imgs = np.stack(out_lst) # [B, H, W, 3]
+ seg_imgs = (seg_imgs - 127.5) / 127.5
+ seg_imgs = torch.from_numpy(seg_imgs).permute(0, 3, 1, 2).to(device)
+ return seg_imgs
+
+if __name__ == '__main__':
+ import imageio, cv2, tqdm
+ import torchshow as ts
+ img = imageio.imread("1.png")
+ img = cv2.resize(img, (512,512))
+
+ seg_model = MediapipeSegmenter()
+ img = torch.tensor(img).unsqueeze(0).repeat([1, 1, 1, 1]).permute(0, 3,1,2)
+ img = (img-127.5)/127.5
+ out = seg_model.seg_out_imgs(img, 'torso')
+ ts.save(out,"torso.png")
+ out = seg_model.seg_out_imgs(img, 'head')
+ ts.save(out,"head.png")
+ out = seg_model.seg_out_imgs(img, 'bg')
+ ts.save(out,"bg.png")
+ img = convert_to_np(img.permute(0, 2, 3, 1)) # [B, H, W, 3]
+ img = ((img + 1) * 127.5).astype(np.uint8)
+ bg = extract_background(img)
+ ts.save(bg,"bg2.png")
diff --git a/data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite b/data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite
new file mode 100644
index 0000000000000000000000000000000000000000..9ebdec318f4426502f8d825b8f0332c3e20e29b7
--- /dev/null
+++ b/data_gen/utils/mp_feature_extractors/selfie_multiclass_256x256.tflite
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
+size 16371837
diff --git a/data_gen/utils/path_converter.py b/data_gen/utils/path_converter.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6e862fb1810da7c6771d358a39a4043f93c9795
--- /dev/null
+++ b/data_gen/utils/path_converter.py
@@ -0,0 +1,24 @@
+import os
+
+
+class PathConverter():
+ def __init__(self):
+ self.prefixs = {
+ "vid": "/video/",
+ "gt": "/gt_imgs/",
+ "head": "/head_imgs/",
+ "torso": "/torso_imgs/",
+ "person": "/person_imgs/",
+ "torso_with_bg": "/torso_with_bg_imgs/",
+ "single_bg": "/bg_img/",
+ "bg": "/bg_imgs/",
+ "segmaps": "/segmaps/",
+ "inpaint_torso": "/inpaint_torso_imgs/",
+ "com": "/com_imgs/",
+ "inpaint_torso_with_com_bg": "/inpaint_torso_with_com_bg_imgs/",
+ }
+
+ def to(self, path: str, old_pattern: str, new_pattern: str):
+ return path.replace(self.prefixs[old_pattern], self.prefixs[new_pattern], 1)
+
+pc = PathConverter()
\ No newline at end of file
diff --git a/data_gen/utils/process_audio/extract_hubert.py b/data_gen/utils/process_audio/extract_hubert.py
new file mode 100644
index 0000000000000000000000000000000000000000..85af486a983b7706f05ea2861565bc7b32d480dd
--- /dev/null
+++ b/data_gen/utils/process_audio/extract_hubert.py
@@ -0,0 +1,95 @@
+from transformers import Wav2Vec2Processor, HubertModel
+import soundfile as sf
+import numpy as np
+import torch
+import os
+from utils.commons.hparams import set_hparams, hparams
+
+
+wav2vec2_processor = None
+hubert_model = None
+
+
+def get_hubert_from_16k_wav(wav_16k_name):
+ speech_16k, _ = sf.read(wav_16k_name)
+ hubert = get_hubert_from_16k_speech(speech_16k)
+ return hubert
+
+@torch.no_grad()
+def get_hubert_from_16k_speech(speech, device="cuda:0"):
+ global hubert_model, wav2vec2_processor
+ local_path = '/home/tiger/.cache/huggingface/hub/models--facebook--hubert-large-ls960-ft/snapshots/ece5fabbf034c1073acae96d5401b25be96709d8'
+ if hubert_model is None:
+ print("Loading the HuBERT Model...")
+ if os.path.exists(local_path):
+ hubert_model = HubertModel.from_pretrained(local_path)
+ else:
+ hubert_model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
+ hubert_model = hubert_model.to(device)
+ if wav2vec2_processor is None:
+ print("Loading the Wav2Vec2 Processor...")
+ if os.path.exists(local_path):
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained(local_path)
+ else:
+ wav2vec2_processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
+
+ if speech.ndim ==2:
+ speech = speech[:, 0] # [T, 2] ==> [T,]
+
+ input_values_all = wav2vec2_processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
+ input_values_all = input_values_all.to(device)
+ # For long audio sequence, due to the memory limitation, we cannot process them in one run
+ # HuBERT process the wav with a CNN of stride [5,2,2,2,2,2], making a stride of 320
+ # Besides, the kernel is [10,3,3,3,3,2,2], making 400 a fundamental unit to get 1 time step.
+ # So the CNN is euqal to a big Conv1D with kernel k=400 and stride s=320
+ # We have the equation to calculate out time step: T = floor((t-k)/s)
+ # To prevent overlap, we set each clip length of (K+S*(N-1)), where N is the expected length T of this clip
+ # The start point of next clip should roll back with a length of (kernel-stride) so it is stride * N
+ kernel = 400
+ stride = 320
+ clip_length = stride * 1000
+ num_iter = input_values_all.shape[1] // clip_length
+ expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
+ res_lst = []
+ for i in range(num_iter):
+ if i == 0:
+ start_idx = 0
+ end_idx = clip_length - stride + kernel
+ else:
+ start_idx = clip_length * i
+ end_idx = start_idx + (clip_length - stride + kernel)
+ input_values = input_values_all[:, start_idx: end_idx]
+ hidden_states = hubert_model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
+ res_lst.append(hidden_states[0])
+ if num_iter > 0:
+ input_values = input_values_all[:, clip_length * num_iter:]
+ else:
+ input_values = input_values_all
+
+ if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
+ hidden_states = hubert_model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
+ res_lst.append(hidden_states[0])
+ ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
+
+ assert abs(ret.shape[0] - expected_T) <= 1
+ if ret.shape[0] < expected_T: # if skipping the last short
+ ret = torch.cat([ret, ret[:, -1:, :].repeat([1,expected_T-ret.shape[0],1])], dim=1)
+ else:
+ ret = ret[:expected_T]
+
+ return ret
+
+
+if __name__ == '__main__':
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument('--video_id', type=str, default='May', help='')
+ args = parser.parse_args()
+ ### Process Single Long Audio for NeRF dataset
+ person_id = args.video_id
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
+ hubert_npy_name = f"data/processed/videos/{person_id}/aud_hubert.npy"
+ speech_16k, _ = sf.read(wav_16k_name)
+ hubert_hidden = get_hubert_from_16k_speech(speech_16k)
+ np.save(hubert_npy_name, hubert_hidden.detach().numpy())
+ print(f"Saved at {hubert_npy_name}")
diff --git a/data_gen/utils/process_audio/extract_mel_f0.py b/data_gen/utils/process_audio/extract_mel_f0.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7d29fe8515f61448431af70c5d3169856b4cef9
--- /dev/null
+++ b/data_gen/utils/process_audio/extract_mel_f0.py
@@ -0,0 +1,148 @@
+import numpy as np
+import torch
+import glob
+import os
+import tqdm
+import librosa
+import parselmouth
+from utils.commons.pitch_utils import f0_to_coarse
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from utils.commons.os_utils import multiprocess_glob
+from utils.audio.io import save_wav
+
+from moviepy.editor import VideoFileClip
+from utils.commons.hparams import hparams, set_hparams
+
+def resample_wav(wav_name, out_name, sr=16000):
+ wav_raw, sr = librosa.core.load(wav_name, sr=sr)
+ save_wav(wav_raw, out_name, sr)
+
+def split_wav(mp4_name, wav_name=None):
+ if wav_name is None:
+ wav_name = mp4_name.replace(".mp4", ".wav").replace("/video/", "/audio/")
+ if os.path.exists(wav_name):
+ return wav_name
+ os.makedirs(os.path.dirname(wav_name), exist_ok=True)
+
+ video = VideoFileClip(mp4_name,verbose=False)
+ dur = video.duration
+ audio = video.audio
+ assert audio is not None
+ audio.write_audiofile(wav_name,fps=16000,verbose=False,logger=None)
+ return wav_name
+
+def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
+ '''compute right padding (final frame) or both sides padding (first and final frames)
+ '''
+ assert pad_sides in (1, 2)
+ # return int(fsize // 2)
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+ if pad_sides == 1:
+ return 0, pad
+ else:
+ return pad // 2, pad // 2 + pad % 2
+
+def extract_mel_from_fname(wav_path,
+ fft_size=512,
+ hop_size=320,
+ win_length=512,
+ window="hann",
+ num_mels=80,
+ fmin=80,
+ fmax=7600,
+ eps=1e-6,
+ sample_rate=16000,
+ min_level_db=-100):
+ if isinstance(wav_path, str):
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
+ else:
+ wav = wav_path
+
+ # get amplitude spectrogram
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, window=window, center=False)
+ spc = np.abs(x_stft) # (n_bins, T)
+
+ # get mel basis
+ fmin = 0 if fmin == -1 else fmin
+ fmax = sample_rate / 2 if fmax == -1 else fmax
+ mel_basis = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
+ mel = mel_basis @ spc
+
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
+ mel = mel.T
+
+ l_pad, r_pad = librosa_pad_lr(wav, fft_size, hop_size, 1)
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
+
+ return wav.T, mel
+
+def extract_f0_from_wav_and_mel(wav, mel,
+ hop_size=320,
+ audio_sample_rate=16000,
+ ):
+ time_step = hop_size / audio_sample_rate * 1000
+ f0_min = 80
+ f0_max = 750
+ f0 = parselmouth.Sound(wav, audio_sample_rate).to_pitch_ac(
+ time_step=time_step / 1000, voicing_threshold=0.6,
+ pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
+
+ delta_l = len(mel) - len(f0)
+ assert np.abs(delta_l) <= 8
+ if delta_l > 0:
+ f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
+ f0 = f0[:len(mel)]
+ pitch_coarse = f0_to_coarse(f0)
+ return f0, pitch_coarse
+
+
+def extract_mel_f0_from_fname(wav_name=None, out_name=None):
+ try:
+ out_name = wav_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+
+ wav, mel = extract_mel_from_fname(wav_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ out_dict = {
+ "mel": mel, # [T, 80]
+ "f0": f0,
+ }
+ np.save(out_name, out_dict)
+ except Exception as e:
+ print(e)
+
+def extract_mel_f0_from_video_name(mp4_name, wav_name=None, out_name=None):
+ if mp4_name.endswith(".mp4"):
+ wav_name = split_wav(mp4_name, wav_name)
+ if out_name is None:
+ out_name = mp4_name.replace(".mp4", "_mel_f0.npy").replace("/video/", "/mel_f0/")
+ elif mp4_name.endswith(".wav"):
+ wav_name = mp4_name
+ if out_name is None:
+ out_name = mp4_name.replace(".wav", "_mel_f0.npy").replace("/audio/", "/mel_f0/")
+
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+
+ wav, mel = extract_mel_from_fname(wav_name)
+
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ out_dict = {
+ "mel": mel, # [T, 80]
+ "f0": f0,
+ }
+ np.save(out_name, out_dict)
+
+
+if __name__ == '__main__':
+ from argparse import ArgumentParser
+ parser = ArgumentParser()
+ parser.add_argument('--video_id', type=str, default='May', help='')
+ args = parser.parse_args()
+ ### Process Single Long Audio for NeRF dataset
+ person_id = args.video_id
+
+ wav_16k_name = f"data/processed/videos/{person_id}/aud.wav"
+ out_name = f"data/processed/videos/{person_id}/aud_mel_f0.npy"
+ extract_mel_f0_from_video_name(wav_16k_name, out_name)
+ print(f"Saved at {out_name}")
\ No newline at end of file
diff --git a/data_gen/utils/process_audio/resample_audio_to_16k.py b/data_gen/utils/process_audio/resample_audio_to_16k.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc353b9385dc22c30256eb7dedbfb610cd33036
--- /dev/null
+++ b/data_gen/utils/process_audio/resample_audio_to_16k.py
@@ -0,0 +1,49 @@
+import os, glob
+from utils.commons.os_utils import multiprocess_glob
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+
+
+def extract_wav16k_job(audio_name:str):
+ out_path = audio_name.replace("/audio_raw/","/audio/",1)
+ assert out_path != audio_name # prevent inplace
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+
+ cmd = f'{ffmpeg_path} -i {audio_name} -ar 16000 -v quiet -y {out_path}'
+ os.system(cmd)
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--aud_dir", default='/home/tiger/datasets/raw/CMLR/audio_raw/')
+ parser.add_argument("--ds_name", default='CMLR')
+ parser.add_argument("--num_workers", default=64, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ args = parser.parse_args()
+ print(f"args {args}")
+
+ aud_dir = args.aud_dir
+ ds_name = args.ds_name
+ if ds_name in ['CMLR']:
+ aud_name_pattern = os.path.join(aud_dir, "*/*/*.wav")
+ aud_names = multiprocess_glob(aud_name_pattern)
+ else:
+ raise NotImplementedError()
+ aud_names = sorted(aud_names)
+ print(f"total audio number : {len(aud_names)}")
+ print(f"first {aud_names[0]} last {aud_names[-1]}")
+ # exit()
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(aud_names) // total_process
+ if process_id == total_process:
+ aud_names = aud_names[process_id * num_samples_per_process : ]
+ else:
+ aud_names = aud_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ for i, res in multiprocess_run_tqdm(extract_wav16k_job, aud_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+
diff --git a/data_gen/utils/process_image/extract_lm2d.py b/data_gen/utils/process_image/extract_lm2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ee0ecc02dc94a04b69682a05a7b089d9cd4c8d6
--- /dev/null
+++ b/data_gen/utils/process_image/extract_lm2d.py
@@ -0,0 +1,197 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+import sys
+
+import glob
+import cv2
+import tqdm
+import numpy as np
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+import warnings
+warnings.filterwarnings('ignore')
+
+import random
+random.seed(42)
+
+import pickle
+import json
+import gzip
+from typing import Any
+
+def load_file(filename, is_gzip: bool = False, is_json: bool = False) -> Any:
+ if is_json:
+ if is_gzip:
+ with gzip.open(filename, "r", encoding="utf-8") as f:
+ loaded_object = json.load(f)
+ return loaded_object
+ else:
+ with open(filename, "r", encoding="utf-8") as f:
+ loaded_object = json.load(f)
+ return loaded_object
+ else:
+ if is_gzip:
+ with gzip.open(filename, "rb") as f:
+ loaded_object = pickle.load(f)
+ return loaded_object
+ else:
+ with open(filename, "rb") as f:
+ loaded_object = pickle.load(f)
+ return loaded_object
+
+def save_file(filename, content, is_gzip: bool = False, is_json: bool = False) -> None:
+ if is_json:
+ if is_gzip:
+ with gzip.open(filename, "w", encoding="utf-8") as f:
+ json.dump(content, f)
+ else:
+ with open(filename, "w", encoding="utf-8") as f:
+ json.dump(content, f)
+ else:
+ if is_gzip:
+ with gzip.open(filename, "wb") as f:
+ pickle.dump(content, f)
+ else:
+ with open(filename, "wb") as f:
+ pickle.dump(content, f)
+
+face_landmarker = None
+
+def extract_lms_mediapipe_job(img):
+ if img is None:
+ return None
+ global face_landmarker
+ if face_landmarker is None:
+ face_landmarker = MediapipeLandmarker()
+ lm478 = face_landmarker.extract_lm478_from_img(img)
+ return lm478
+
+def extract_landmark_job(img_name):
+ try:
+ # if img_name == 'datasets/PanoHeadGen/raw/images/multi_view/chunk_0/seed0000002.png':
+ # print(1)
+ # input()
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
+ if os.path.exists(out_name):
+ print("out exists, skip...")
+ return
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+ img = cv2.imread(img_name)[:,:,::-1]
+
+ if img is not None:
+ lm468 = extract_lms_mediapipe_job(img)
+ if lm468 is not None:
+ np.save(out_name, lm468)
+ # print("Hahaha, solve one item!!!")
+ except Exception as e:
+ print(e)
+ pass
+
+def out_exist_job(img_name):
+ out_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png","_lms.npy")
+ if os.path.exists(out_name):
+ return None
+ else:
+ return img_name
+
+# def get_todo_img_names(img_names):
+# todo_img_names = []
+# for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
+# if res is not None:
+# todo_img_names.append(res)
+# return todo_img_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512/')
+ parser.add_argument("--ds_name", default='FFHQ')
+ parser.add_argument("--num_workers", default=64, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--img_names_file", default="img_names.pkl", type=str)
+ parser.add_argument("--load_img_names", action="store_true")
+
+ args = parser.parse_args()
+ print(f"args {args}")
+ img_dir = args.img_dir
+ img_names_file = os.path.join(img_dir, args.img_names_file)
+ if args.load_img_names:
+ img_names = load_file(img_names_file)
+ print(f"load image names from {img_names_file}")
+ else:
+ if args.ds_name == 'FFHQ_MV':
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
+ img_names1 = glob.glob(img_name_pattern1)
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
+ img_names2 = glob.glob(img_name_pattern2)
+ img_names = img_names1 + img_names2
+ img_names = sorted(img_names)
+ elif args.ds_name == 'FFHQ':
+ img_name_pattern = os.path.join(img_dir, "*.png")
+ img_names = glob.glob(img_name_pattern)
+ img_names = sorted(img_names)
+ elif args.ds_name == "PanoHeadGen":
+ # img_name_patterns = ["ref/*/*.png", "multi_view/*/*.png", "reverse/*/*.png"]
+ img_name_patterns = ["ref/*/*.png"]
+ img_names = []
+ for img_name_pattern in img_name_patterns:
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
+ img_names_part = glob.glob(img_name_pattern_full)
+ img_names.extend(img_names_part)
+ img_names = sorted(img_names)
+
+ # save image names
+ if not args.load_img_names:
+ save_file(img_names_file, img_names)
+ print(f"save image names in {img_names_file}")
+
+ print(f"total images number: {len(img_names)}")
+
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(img_names) // total_process
+ if process_id == total_process:
+ img_names = img_names[process_id * num_samples_per_process : ]
+ else:
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ # if not args.reset:
+ # img_names = get_todo_img_names(img_names)
+
+
+ print(f"todo_image {img_names[:10]}")
+ print(f"processing images number in this process: {len(img_names)}")
+ # print(f"todo images number: {len(img_names)}")
+ # input()
+ # exit()
+
+ if args.num_workers == 1:
+ index = 0
+ for img_name in tqdm.tqdm(img_names, desc=f"Root process {args.process_id}: extracting MP-based landmark2d"):
+ try:
+ extract_landmark_job(img_name)
+ except Exception as e:
+ print(e)
+ pass
+ if index % max(1, int(len(img_names) * 0.003)) == 0:
+ print(f"processed {index} / {len(img_names)}")
+ sys.stdout.flush()
+ index += 1
+ else:
+ for i, res in multiprocess_run_tqdm(
+ extract_landmark_job, img_names,
+ num_workers=args.num_workers,
+ desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
+ # if index % max(1, int(len(img_names) * 0.003)) == 0:
+ print(f"processed {i+1} / {len(img_names)}")
+ sys.stdout.flush()
+ print(f"Root {args.process_id}: Finished extracting.")
\ No newline at end of file
diff --git a/data_gen/utils/process_image/extract_segment_imgs.py b/data_gen/utils/process_image/extract_segment_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..408a6d1b6229e9bd7e2aa1c7c7cdeb067cc0ae7f
--- /dev/null
+++ b/data_gen/utils/process_image/extract_segment_imgs.py
@@ -0,0 +1,114 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+
+import glob
+import cv2
+import tqdm
+import numpy as np
+import PIL
+from utils.commons.tensor_utils import convert_to_np
+import torch
+import mediapipe as mp
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background, save_rgb_image_to_path
+seg_model = MediapipeSegmenter()
+
+
+def extract_segment_job(img_name):
+ try:
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ segmap = seg_model._cal_seg_map(img)
+ bg_img = extract_background([img], [segmap])
+ out_img_name = img_name.replace("/images_512/",f"/bg_img/").replace(".mp4", ".jpg")
+ save_rgb_image_to_path(bg_img, out_img_name)
+
+ com_img = img.copy()
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
+ com_img[bg_part] = bg_img[bg_part]
+ out_img_name = img_name.replace("/images_512/",f"/com_imgs/")
+ save_rgb_image_to_path(com_img, out_img_name)
+
+ for mode in ['head', 'torso', 'person', 'torso_with_bg', 'bg']:
+ out_img, _ = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
+ out_img_name = img_name.replace("/images_512/",f"/{mode}_imgs/")
+ out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
+ except: pass
+ cv2.imwrite(out_img_name, out_img)
+
+ inpaint_torso_img, inpaint_torso_with_bg_img, _, _ = inpaint_torso_job(img, segmap)
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_imgs/")
+ save_rgb_image_to_path(inpaint_torso_img, out_img_name)
+ inpaint_torso_with_bg_img[bg_part] = bg_img[bg_part]
+ out_img_name = img_name.replace("/images_512/",f"/inpaint_torso_with_com_bg_imgs/")
+ save_rgb_image_to_path(inpaint_torso_with_bg_img, out_img_name)
+ return 0
+ except Exception as e:
+ print(e)
+ return 1
+
+def out_exist_job(img_name):
+ out_name1 = img_name.replace("/images_512/", "/head_imgs/")
+ out_name2 = img_name.replace("/images_512/", "/com_imgs/")
+ out_name3 = img_name.replace("/images_512/", "/inpaint_torso_with_com_bg_imgs/")
+
+ if os.path.exists(out_name1) and os.path.exists(out_name2) and os.path.exists(out_name3):
+ return None
+ else:
+ return img_name
+
+def get_todo_img_names(img_names):
+ todo_img_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=64):
+ if res is not None:
+ todo_img_names.append(res)
+ return todo_img_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default='./images_512')
+ # parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
+ parser.add_argument("--ds_name", default='FFHQ')
+ parser.add_argument("--num_workers", default=1, type=int)
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+
+ args = parser.parse_args()
+ img_dir = args.img_dir
+ if args.ds_name == 'FFHQ_MV':
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
+ img_names1 = glob.glob(img_name_pattern1)
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
+ img_names2 = glob.glob(img_name_pattern2)
+ img_names = img_names1 + img_names2
+ elif args.ds_name == 'FFHQ':
+ img_name_pattern = os.path.join(img_dir, "*.png")
+ img_names = glob.glob(img_name_pattern)
+
+ img_names = sorted(img_names)
+ random.seed(args.seed)
+ random.shuffle(img_names)
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(img_names) // total_process
+ if process_id == total_process:
+ img_names = img_names[process_id * num_samples_per_process : ]
+ else:
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ img_names = get_todo_img_names(img_names)
+ print(f"todo images number: {len(img_names)}")
+
+ for vid_name in multiprocess_run_tqdm(extract_segment_job ,img_names, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
+ pass
\ No newline at end of file
diff --git a/data_gen/utils/process_image/fit_3dmm_landmark.py b/data_gen/utils/process_image/fit_3dmm_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fde7d94d919ab2b582fe7ac2e1a11fbe8129fad
--- /dev/null
+++ b/data_gen/utils/process_image/fit_3dmm_landmark.py
@@ -0,0 +1,369 @@
+from numpy.core.numeric import require
+from numpy.lib.function_base import quantile
+import torch
+import torch.nn.functional as F
+import copy
+import numpy as np
+
+import os
+import sys
+import cv2
+import argparse
+import tqdm
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
+
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+import pickle
+
+face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
+face_model.to("cuda")
+
+
+index_lm68_from_lm468 = [127,234,93,132,58,136,150,176,152,400,379,365,288,361,323,454,356,70,63,105,66,107,336,296,334,293,300,168,197,5,4,75,97,2,326,305,
+ 33,160,158,133,153,144,362,385,387,263,373,380,61,40,37,0,267,270,291,321,314,17,84,91,78,81,13,311,308,402,14,178]
+
+dir_path = os.path.dirname(os.path.realpath(__file__))
+
+LAMBDA_REG_ID = 0.3
+LAMBDA_REG_EXP = 0.05
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+def cal_lan_loss_mp(proj_lan, gt_lan):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan).pow(2)
+ # loss = (proj_lan - gt_lan).abs()
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
+ weights = torch.ones_like(loss)
+ weights[:, eye] = 5
+ weights[:, inner_lip] = 2
+ weights[:, outer_lip] = 2
+ weights[:, unmatch_mask] = 0
+ loss = loss * weights
+ return torch.mean(loss)
+
+def cal_lan_loss(proj_lan, gt_lan):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan)** 2
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
+ weights = torch.zeros_like(loss)
+ weights = torch.ones_like(loss)
+ weights[:, 36:48, :] = 3 # eye 12 points
+ weights[:, -8:, :] = 3 # inner lip 8 points
+ weights[:, 28:31, :] = 3 # nose 3 points
+ loss = loss * weights
+ return torch.mean(loss)
+
+def set_requires_grad(tensor_list):
+ for tensor in tensor_list:
+ tensor.requires_grad = True
+
+def read_video_to_frames(img_name):
+ frames = []
+ cap = cv2.VideoCapture(img_name)
+ while cap.isOpened():
+ ret, frame_bgr = cap.read()
+ if frame_bgr is None:
+ break
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ frames.append(frame_rgb)
+ return np.stack(frames)
+
+@torch.enable_grad()
+def fit_3dmm_for_a_image(img_name, debug=False, keypoint_mode='mediapipe', device="cuda:0", save=True):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_h, img_w = img.shape[0], img.shape[0]
+ assert img_h == img_w
+ num_frames = 1
+
+ lm_name = img_name.replace("/images_512/", "/lms_2d/").replace(".png", "_lms.npy")
+ if lm_name.endswith('_lms.npy') and os.path.exists(lm_name):
+ lms = np.load(lm_name)
+ else:
+ # print("lms_2d file not found, try to extract it from image...")
+ try:
+ landmarker = MediapipeLandmarker()
+ lms = landmarker.extract_lm478_from_img_name(img_name)
+ # lms = landmarker.extract_lm478_from_img(img)
+ except Exception as e:
+ print(e)
+ return
+ if lms is None:
+ print("get None lms_2d, please check whether each frame has one head, exiting...")
+ return
+ lms = lms[:468].reshape([468,2])
+ lms = torch.FloatTensor(lms).to(device=device)
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+
+ if keypoint_mode == 'mediapipe':
+ cal_lan_loss_fn = cal_lan_loss_mp
+ out_name = img_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png", "_coeff_fit_mp.npy")
+ else:
+ cal_lan_loss_fn = cal_lan_loss
+ out_name = img_name.replace("/images_512/", "/coeff_fit_lm68/").replace(".png", "_coeff_fit_lm68.npy")
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+
+ id_dim, exp_dim = 80, 64
+ sel_ids = np.arange(0, num_frames, 40)
+ sel_num = sel_ids.shape[0]
+ arg_focal = face_model.focal
+
+ h = w = face_model.center * 2
+ img_scale_factor = img_h / h
+ lms /= img_scale_factor
+ cxy = torch.tensor((w / 2.0, h / 2.0), dtype=torch.float).to(device=device)
+
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True) # lms.new_zeros((1, id_dim), requires_grad=True)
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
+
+ focal_length = lms.new_zeros(1, requires_grad=True)
+ focal_length.data += arg_focal
+
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
+
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
+
+ # 其他参数初始化,先训练euler和trans
+ for _ in range(200):
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans)
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
+ loss = loss_lan
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_frame.step()
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ for param_group in optimizer_frame.param_groups:
+ param_group['lr'] = 0.1
+
+ # "jointly roughly training id exp euler trans"
+ for _ in range(200):
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms.detach())
+ loss_regid = torch.mean(id_para*id_para) # 正则化
+ loss_regexp = torch.mean(exp_para * exp_para)
+
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
+ optimizer_idexp.zero_grad()
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_idexp.step()
+ optimizer_frame.step()
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ # start fine training, intialize from the roughly trained results
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
+ id_para_.data = id_para.data.clone()
+ id_para = id_para_
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ exp_para_.data = exp_para.data.clone()
+ exp_para = exp_para_
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ euler_angle_.data = euler_angle.data.clone()
+ euler_angle = euler_angle_
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans_.data = trans.data.clone()
+ trans = trans_
+
+ batch_size = 1
+
+ # "fine fitting the 3DMM in batches"
+ for i in range(int((num_frames-1)/batch_size+1)):
+ if (i+1)*batch_size > num_frames:
+ start_n = num_frames-batch_size
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
+ else:
+ start_n = i*batch_size
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
+ sel_lms = lms[sel_ids]
+
+ sel_id_para = id_para.new_zeros(
+ (batch_size, id_dim), requires_grad=True)
+ sel_id_para.data = id_para[sel_ids].clone()
+ sel_exp_para = exp_para.new_zeros(
+ (batch_size, exp_dim), requires_grad=True)
+ sel_exp_para.data = exp_para[sel_ids].clone()
+ sel_euler_angle = euler_angle.new_zeros(
+ (batch_size, 3), requires_grad=True)
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
+ sel_trans.data = trans[sel_ids].clone()
+
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
+ optimizer_cur_batch = torch.optim.Adam(
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
+
+ for j in range(50):
+ proj_geo = face_model.compute_for_landmark_fit(
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms.unsqueeze(0).detach())
+
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP
+ optimizer_cur_batch.zero_grad()
+ loss.backward()
+ optimizer_cur_batch.step()
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f}")
+ id_para[sel_ids].data = sel_id_para.data.clone()
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
+ trans[sel_ids].data = sel_trans.data.clone()
+
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
+ if save:
+ np.save(out_name, coeff_dict, allow_pickle=True)
+
+ if debug:
+ import imageio
+ debug_name = img_name.replace("/images_512/", "/coeff_fit_mp_debug/").replace(".png", "_debug.png").replace(".jpg", "_debug.jpg")
+ try: os.makedirs(os.path.dirname(debug_name), exist_ok=True)
+ except: pass
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
+ lm68s = lm68s * img_scale_factor
+ lms = lms * img_scale_factor
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+ lm68s = lm68s.astype(int)
+ lm68s = lm68s.reshape([-1,2])
+ lms = lms.cpu().numpy().astype(int).reshape([-1,2])
+ for lm in lm68s:
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1)
+ for gt_lm in lms:
+ img = cv2.circle(img, gt_lm, 2, (255, 0, 0), thickness=1)
+ imageio.imwrite(debug_name, img)
+ print(f"debug img saved at {debug_name}")
+ return coeff_dict
+
+def out_exist_job(vid_name):
+ out_name = vid_name.replace("/images_512/", "/coeff_fit_mp/").replace(".png","_coeff_fit_mp.npy")
+ # if os.path.exists(out_name) or not os.path.exists(lms_name):
+ if os.path.exists(out_name):
+ return None
+ else:
+ return vid_name
+
+def get_todo_img_names(img_names):
+ todo_img_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, img_names, num_workers=16):
+ if res is not None:
+ todo_img_names.append(res)
+ return todo_img_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--img_dir", default='/home/tiger/datasets/raw/FFHQ/images_512')
+ parser.add_argument("--ds_name", default='FFHQ')
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
+ parser.add_argument("--debug", action='store_true')
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--device", default="cuda:0", type=str)
+ parser.add_argument("--output_log", action='store_true')
+ parser.add_argument("--load_names", action="store_true")
+
+ args = parser.parse_args()
+ img_dir = args.img_dir
+ load_names = args.load_names
+
+ print(f"args {args}")
+
+ if args.ds_name == 'single_img':
+ img_names = [img_dir]
+ else:
+ img_names_path = os.path.join(img_dir, "img_dir.pkl")
+ if os.path.exists(img_names_path) and load_names:
+ print(f"loading vid names from {img_names_path}")
+ img_names = load_file(img_names_path)
+ else:
+ if args.ds_name == 'FFHQ_MV':
+ img_name_pattern1 = os.path.join(img_dir, "ref_imgs/*.png")
+ img_names1 = glob.glob(img_name_pattern1)
+ img_name_pattern2 = os.path.join(img_dir, "mv_imgs/*.png")
+ img_names2 = glob.glob(img_name_pattern2)
+ img_names = img_names1 + img_names2
+ img_names = sorted(img_names)
+ elif args.ds_name == 'FFHQ':
+ img_name_pattern = os.path.join(img_dir, "*.png")
+ img_names = glob.glob(img_name_pattern)
+ img_names = sorted(img_names)
+ elif args.ds_name == "PanoHeadGen":
+ img_name_patterns = ["ref/*/*.png"]
+ img_names = []
+ for img_name_pattern in img_name_patterns:
+ img_name_pattern_full = os.path.join(img_dir, img_name_pattern)
+ img_names_part = glob.glob(img_name_pattern_full)
+ img_names.extend(img_names_part)
+ img_names = sorted(img_names)
+ print(f"saving image names to {img_names_path}")
+ save_file(img_names_path, img_names)
+
+ # import random
+ # random.seed(args.seed)
+ # random.shuffle(img_names)
+
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
+ face_model.to(torch.device(args.device))
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1 and process_id >= 0
+ num_samples_per_process = len(img_names) // total_process
+ if process_id == total_process:
+ img_names = img_names[process_id * num_samples_per_process : ]
+ else:
+ img_names = img_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+ print(f"image names number (before fileter): {len(img_names)}")
+
+
+ if not args.reset:
+ img_names = get_todo_img_names(img_names)
+
+ print(f"image names number (after fileter): {len(img_names)}")
+ for i in tqdm.trange(len(img_names), desc=f"process {process_id}: fitting 3dmm ..."):
+ img_name = img_names[i]
+ try:
+ fit_3dmm_for_a_image(img_name, args.debug, device=args.device)
+ except Exception as e:
+ print(img_name, e)
+ if args.output_log and i % max(int(len(img_names) * 0.003), 1) == 0:
+ print(f"process {process_id}: {i + 1} / {len(img_names)} done")
+ sys.stdout.flush()
+ sys.stderr.flush()
+
+ print(f"process {process_id}: fitting 3dmm all done")
+
diff --git a/data_gen/utils/process_video/euler2quaterion.py b/data_gen/utils/process_video/euler2quaterion.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3fd35af0e26285dafac2931fad5904e9d30321a
--- /dev/null
+++ b/data_gen/utils/process_video/euler2quaterion.py
@@ -0,0 +1,35 @@
+import numpy as np
+import torch
+import math
+import numba
+from scipy.spatial.transform import Rotation as R
+
+def euler2quaterion(euler, use_radian=True):
+ """
+ euler: np.array, [batch, 3]
+ return: the quaterion, np.array, [batch, 4]
+ """
+ r = R.from_euler('xyz',euler, degrees=not use_radian)
+ return r.as_quat()
+
+def quaterion2euler(quat, use_radian=True):
+ """
+ quat: np.array, [batch, 4]
+ return: the euler, np.array, [batch, 3]
+ """
+ r = R.from_quat(quat)
+ return r.as_euler('xyz', degrees=not use_radian)
+
+def rot2quaterion(rot):
+ r = R.from_matrix(rot)
+ return r.as_quat()
+
+def quaterion2rot(quat):
+ r = R.from_quat(quat)
+ return r.as_matrix()
+
+if __name__ == '__main__':
+ euler = np.array([89.999,89.999,89.999] * 100).reshape([100,3])
+ q = euler2quaterion(euler, use_radian=False)
+ e = quaterion2euler(q, use_radian=False)
+ print(" ")
diff --git a/data_gen/utils/process_video/extract_blink.py b/data_gen/utils/process_video/extract_blink.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6d27bb077d401a9c8e9b5b19b121c538db9e037
--- /dev/null
+++ b/data_gen/utils/process_video/extract_blink.py
@@ -0,0 +1,50 @@
+import numpy as np
+from data_util.face3d_helper import Face3DHelper
+from utils.commons.tensor_utils import convert_to_tensor
+
+def polygon_area(x, y):
+ """
+ x: [T, K=6]
+ y: [T, K=6]
+ return: [T,]
+ """
+ x_ = x - x.mean(axis=-1, keepdims=True)
+ y_ = y - y.mean(axis=-1, keepdims=True)
+ correction = x_[:,-1] * y_[:,0] - y_[:,-1]* x_[:,0]
+ main_area = (x_[:,:-1] * y_[:,1:]).sum(axis=-1) - (y_[:,:-1] * x_[:,1:]).sum(axis=-1)
+ return 0.5 * np.abs(main_area + correction)
+
+def get_eye_area_percent(id, exp, face3d_helper):
+ id = convert_to_tensor(id)
+ exp = convert_to_tensor(exp)
+ cano_lm3d = face3d_helper.reconstruct_cano_lm3d(id, exp)
+ cano_lm2d = (cano_lm3d[..., :2] + 1) / 2
+ lms = cano_lm2d.cpu().numpy()
+ eyes_left = slice(36, 42)
+ eyes_right = slice(42, 48)
+ area_left = polygon_area(lms[:, eyes_left, 0], lms[:, eyes_left, 1])
+ area_right = polygon_area(lms[:, eyes_right, 0], lms[:, eyes_right, 1])
+ # area percentage of two eyes of the whole image...
+ area_percent = (area_left + area_right) / 1 * 100 # recommend threshold is 0.25%
+ return area_percent # [T,]
+
+
+if __name__ == '__main__':
+ import numpy as np
+ import imageio
+ import cv2
+ import torch
+ from data_gen.utils.process_video.extract_lm2d import extract_lms_mediapipe_job, read_video_to_frames, index_lm68_from_lm468
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+ from data_util.face3d_helper import Face3DHelper
+
+ face3d_helper = Face3DHelper()
+ video_name = 'data/raw/videos/May_10s.mp4'
+ frames = read_video_to_frames(video_name)
+ coeff = fit_3dmm_for_a_video(video_name, save=False)
+ area_percent = get_eye_area_percent(torch.tensor(coeff['id']), torch.tensor(coeff['exp']), face3d_helper)
+ writer = imageio.get_writer("1.mp4", fps=25)
+ for idx, frame in enumerate(frames):
+ frame = cv2.putText(frame, f"{area_percent[idx]:.2f}", org=(128,128), fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=1, color=(255,0,0), thickness=1)
+ writer.append_data(frame)
+ writer.close()
\ No newline at end of file
diff --git a/data_gen/utils/process_video/extract_lm2d.py b/data_gen/utils/process_video/extract_lm2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9ae0b13408c1d837af4c40912ffc58e0043469b
--- /dev/null
+++ b/data_gen/utils/process_video/extract_lm2d.py
@@ -0,0 +1,164 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+import sys
+import glob
+import cv2
+import pickle
+import tqdm
+import numpy as np
+import mediapipe as mp
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from utils.commons.os_utils import multiprocess_glob
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker
+import warnings
+import traceback
+
+warnings.filterwarnings('ignore')
+
+"""
+基于Face_aligment的lm68已被弃用,因为其:
+1. 对眼睛部位的预测精度极低
+2. 无法在大偏转角度时准确预测被遮挡的下颚线, 导致大角度时3dmm的GT label就是有问题的, 从而影响性能
+我们目前转而使用基于mediapipe的lm68
+"""
+# def extract_landmarks(ori_imgs_dir):
+
+# print(f'[INFO] ===== extract face landmarks from {ori_imgs_dir} =====')
+
+# fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
+# image_paths = glob.glob(os.path.join(ori_imgs_dir, '*.png'))
+# for image_path in tqdm.tqdm(image_paths):
+# out_name = image_path.replace("/images_512/", "/lms_2d/").replace(".png",".lms")
+# if os.path.exists(out_name):
+# continue
+# input = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
+# input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
+# preds = fa.get_landmarks(input)
+# if preds is None:
+# print(f"Skip {image_path} for no face detected")
+# continue
+# if len(preds) > 0:
+# lands = preds[0].reshape(-1, 2)[:,:2]
+# os.makedirs(os.path.dirname(out_name), exist_ok=True)
+# np.savetxt(out_name, lands, '%f')
+# del fa
+# print(f'[INFO] ===== extracted face landmarks =====')
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+
+face_landmarker = None
+
+def extract_landmark_job(video_name, nerf=False):
+ try:
+ if nerf:
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
+ else:
+ out_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
+ if os.path.exists(out_name):
+ # print("out exists, skip...")
+ return
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+ global face_landmarker
+ if face_landmarker is None:
+ face_landmarker = MediapipeLandmarker()
+ img_lm478, vid_lm478 = face_landmarker.extract_lm478_from_video_name(video_name)
+ lm478 = face_landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
+ np.save(out_name, lm478)
+ return True
+ # print("Hahaha, solve one item!!!")
+ except Exception as e:
+ traceback.print_exc()
+ return False
+
+def out_exist_job(vid_name):
+ out_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
+ if os.path.exists(out_name):
+ return None
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names):
+ if len(vid_names) == 1: # nerf
+ return vid_names
+ todo_vid_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=128):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='nerf')
+ parser.add_argument("--ds_name", default='data/raw/videos/May.mp4')
+ parser.add_argument("--num_workers", default=2, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action="store_true")
+ parser.add_argument("--load_names", action="store_true")
+
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ load_names = args.load_names
+
+ if ds_name.lower() == 'nerf': # 处理单个视频
+ vid_names = [vid_dir]
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy") for video_name in vid_names]
+ else: # 处理整个数据集
+ if ds_name in ['lrs3_trainval']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ else:
+ raise NotImplementedError()
+
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
+ if os.path.exists(vid_names_path) and load_names:
+ print(f"loading vid names from {vid_names_path}")
+ vid_names = load_file(vid_names_path)
+ else:
+ vid_names = multiprocess_glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+ if not load_names:
+ print(f"saving vid names to {vid_names_path}")
+ save_file(vid_names_path, vid_names)
+ out_names = [video_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy") for video_name in vid_names]
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ vid_names = get_todo_vid_names(vid_names)
+ print(f"todo videos number: {len(vid_names)}")
+
+ fail_cnt = 0
+ job_args = [(vid_name, ds_name=='nerf') for vid_name in vid_names]
+ for (i, res) in multiprocess_run_tqdm(extract_landmark_job, job_args, num_workers=args.num_workers, desc=f"Root {args.process_id}: extracing MP-based landmark2d"):
+ if res is False:
+ fail_cnt += 1
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {fail_cnt} / {i + 1} = {fail_cnt / (i + 1):.4f}")
+ sys.stdout.flush()
+ pass
\ No newline at end of file
diff --git a/data_gen/utils/process_video/extract_segment_imgs.py b/data_gen/utils/process_video/extract_segment_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..92be7778140a8cac1815323b69297487ab5ef142
--- /dev/null
+++ b/data_gen/utils/process_video/extract_segment_imgs.py
@@ -0,0 +1,494 @@
+import os
+os.environ["OMP_NUM_THREADS"] = "1"
+import random
+import glob
+import cv2
+import tqdm
+import numpy as np
+from typing import Union
+from utils.commons.tensor_utils import convert_to_np
+from utils.commons.os_utils import multiprocess_glob
+import pickle
+import traceback
+import multiprocessing
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from scipy.ndimage import binary_erosion, binary_dilation
+from sklearn.neighbors import NearestNeighbors
+from mediapipe.tasks.python import vision
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter, encode_segmap_mask_to_image, decode_segmap_mask_from_image, job_cal_seg_map_for_image
+
+seg_model = None
+segmenter = None
+mat_model = None
+lama_model = None
+lama_config = None
+
+from data_gen.utils.process_video.split_video_to_imgs import extract_img_job
+
+BG_NAME_MAP = {
+ "knn": "",
+}
+FRAME_SELECT_INTERVAL = 5
+SIM_METHOD = "mse"
+SIM_THRESHOLD = 3
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+def save_rgb_alpha_image_to_path(img, alpha, img_path):
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
+ except: pass
+ cv2.imwrite(img_path, np.concatenate([cv2.cvtColor(img, cv2.COLOR_RGB2BGR), alpha], axis=-1))
+
+def save_rgb_image_to_path(img, img_path):
+ try: os.makedirs(os.path.dirname(img_path), exist_ok=True)
+ except: pass
+ cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
+def load_rgb_image_to_path(img_path):
+ return cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
+
+def image_similarity(x: np.ndarray, y: np.ndarray, method="mse"):
+ if method == "mse":
+ return np.mean((x - y) ** 2)
+ else:
+ raise NotImplementedError
+
+def extract_background(img_lst, segmap_mask_lst=None, method="knn", device='cpu', mix_bg=True):
+ """
+ img_lst: list of rgb ndarray
+ method: "knn"
+ """
+ global segmenter
+ global seg_model
+ global mat_model
+ global lama_model
+ global lama_config
+
+ assert len(img_lst) > 0
+ if segmap_mask_lst is not None:
+ assert len(segmap_mask_lst) == len(img_lst)
+ else:
+ del segmenter
+ del seg_model
+ seg_model = MediapipeSegmenter()
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.video_options)
+
+ def get_segmap_mask(img_lst, segmap_mask_lst, index):
+ if segmap_mask_lst is not None:
+ segmap = refresh_segment_mask(segmap_mask_lst[index])
+ else:
+ segmap = seg_model._cal_seg_map(refresh_image(img_lst[index]), segmenter=segmenter)
+ return segmap
+
+ if method == "knn":
+ num_frames = len(img_lst)
+ if num_frames < 100:
+ FRAME_SELECT_INTERVAL = 5
+ elif num_frames < 10000:
+ FRAME_SELECT_INTERVAL = 20
+ else:
+ FRAME_SELECT_INTERVAL = num_frames // 500
+
+ img_lst = img_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else img_lst[0:1]
+
+ if segmap_mask_lst is not None:
+ segmap_mask_lst = segmap_mask_lst[::FRAME_SELECT_INTERVAL] if num_frames > FRAME_SELECT_INTERVAL else segmap_mask_lst[0:1]
+ assert len(img_lst) == len(segmap_mask_lst)
+ # get H/W
+ h, w = refresh_image(img_lst[0]).shape[:2]
+
+ # nearest neighbors
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() # [512*512, 2] coordinate grid
+ distss = []
+ for idx, img in tqdm.tqdm(enumerate(img_lst), desc='combining backgrounds...', total=len(img_lst)):
+ segmap = get_segmap_mask(img_lst=img_lst, segmap_mask_lst=segmap_mask_lst, index=idx)
+ bg = (segmap[0]).astype(bool) # [h,w] bool mask
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) # [N_nonbg,2] coordinate of non-bg pixels
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ dists, _ = nbrs.kneighbors(all_xys) # [512*512, 1] distance to nearest non-bg pixel
+ distss.append(dists)
+
+ distss = np.stack(distss) # [B, 512*512, 1]
+ max_dist = np.max(distss, 0) # [512*512, 1]
+ max_id = np.argmax(distss, 0) # id of frame
+
+ bc_pixs = max_dist > 10 # 在各个frame有一个出现过是bg的pixel,bg标准是离最近的non-bg pixel距离大于10
+ bc_pixs_id = np.nonzero(bc_pixs)
+ bc_ids = max_id[bc_pixs]
+
+ # TODO: maybe we should reimplement here to avoid memory costs?
+ # though there is upper limits of images here
+ num_pixs = distss.shape[1]
+ bg_img = np.zeros((h*w, 3), dtype=np.uint8)
+ img_lst = [refresh_image(img) for img in img_lst]
+ imgs = np.stack(img_lst).reshape(-1, num_pixs, 3)
+ bg_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] # 对那些铁bg的pixel,直接去对应的image里面采样
+ bg_img = bg_img.reshape(h, w, 3)
+
+ max_dist = max_dist.reshape(h, w)
+ bc_pixs = max_dist > 10 # 5
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ distances, indices = nbrs.kneighbors(bg_xys) # 对non-bg img,用KNN找最近的bg pixel
+ bg_fg_xys = fg_xys[indices[:, 0]]
+ bg_img[bg_xys[:, 0], bg_xys[:, 1], :] = bg_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ else:
+ raise NotImplementedError # deperated
+
+ return bg_img
+
+def inpaint_torso_job(gt_img, segmap):
+ bg_part = (segmap[0]).astype(bool)
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(bool)
+ neck_part = (segmap[2]).astype(bool)
+ torso_part = (segmap[4]).astype(bool)
+ img = gt_img.copy()
+ img[head_part] = 0
+
+ # torso part "vertical" in-painting...
+ L = 8 + 1
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
+ torso_coords = torso_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
+ top_torso_coords = torso_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
+ mask = head_part[tuple(top_torso_coords_up.T)]
+ if mask.any():
+ top_torso_coords = top_torso_coords[mask]
+ # get the color
+ top_torso_colors = gt_img[tuple(top_torso_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_torso_coords += inpaint_offsets
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ img[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
+ inpaint_torso_mask = np.zeros_like(img[..., 0]).astype(bool)
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
+ else:
+ inpaint_torso_mask = None
+
+ # neck part "vertical" in-painting...
+ push_down = 4
+ L = 48 + push_down + 1
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
+ neck_coords = neck_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
+ top_neck_coords = neck_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
+ mask = head_part[tuple(top_neck_coords_up.T)]
+ top_neck_coords = top_neck_coords[mask]
+ # push these top down for 4 pixels to make the neck inpainting more natural...
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
+ # get the color
+ top_neck_colors = gt_img[tuple(top_neck_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_neck_coords += inpaint_offsets
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ img[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
+ inpaint_mask = np.zeros_like(img[..., 0]).astype(bool)
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
+
+ blur_img = img.copy()
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
+ img[inpaint_mask] = blur_img[inpaint_mask]
+
+ # set mask
+ torso_img_mask = (neck_part | torso_part | inpaint_mask)
+ torso_with_bg_img_mask = (bg_part | neck_part | torso_part | inpaint_mask)
+ if inpaint_torso_mask is not None:
+ torso_img_mask = torso_img_mask | inpaint_torso_mask
+ torso_with_bg_img_mask = torso_with_bg_img_mask | inpaint_torso_mask
+
+ torso_img = img.copy()
+ torso_img[~torso_img_mask] = 0
+ torso_with_bg_img = img.copy()
+ torso_img[~torso_with_bg_img_mask] = 0
+
+ return torso_img, torso_img_mask, torso_with_bg_img, torso_with_bg_img_mask
+
+def load_segment_mask_from_file(filename: str):
+ encoded_segmap = load_rgb_image_to_path(filename)
+ segmap_mask = decode_segmap_mask_from_image(encoded_segmap)
+ return segmap_mask
+
+# load segment mask to memory if not loaded yet
+def refresh_segment_mask(segmap_mask: Union[str, np.ndarray]):
+ if isinstance(segmap_mask, str):
+ segmap_mask = load_segment_mask_from_file(segmap_mask)
+ return segmap_mask
+
+# load segment mask to memory if not loaded yet
+def refresh_image(image: Union[str, np.ndarray]):
+ if isinstance(image, str):
+ image = load_rgb_image_to_path(image)
+ return image
+
+def generate_segment_imgs_job(img_name, segmap, img):
+ out_img_name = segmap_name = img_name.replace("/gt_imgs/", "/segmaps/").replace(".jpg", ".png") # 存成jpg的话,pixel value会有误差
+ try: os.makedirs(os.path.dirname(out_img_name), exist_ok=True)
+ except: pass
+ encoded_segmap = encode_segmap_mask_to_image(segmap)
+ save_rgb_image_to_path(encoded_segmap, out_img_name)
+
+ for mode in ['head', 'torso', 'person', 'bg']:
+ out_img, mask = seg_model._seg_out_img_with_segmap(img, segmap, mode=mode)
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
+ mask = mask[0][..., None]
+ img_alpha[~mask] = 0
+ out_img_name = img_name.replace("/gt_imgs/", f"/{mode}_imgs/").replace(".jpg", ".png")
+ save_rgb_alpha_image_to_path(out_img, img_alpha, out_img_name)
+
+ inpaint_torso_img, inpaint_torso_img_mask, inpaint_torso_with_bg_img, inpaint_torso_with_bg_img_mask = inpaint_torso_job(img, segmap)
+ img_alpha = 255 * np.ones((img.shape[0], img.shape[1], 1), dtype=np.uint8) # alpha
+ img_alpha[~inpaint_torso_img_mask[..., None]] = 0
+ out_img_name = img_name.replace("/gt_imgs/", f"/inpaint_torso_imgs/").replace(".jpg", ".png")
+ save_rgb_alpha_image_to_path(inpaint_torso_img, img_alpha, out_img_name)
+ return segmap_name
+
+def segment_and_generate_for_image_job(img_name, img, segmenter_options=None, segmenter=None, store_in_memory=False):
+ img = refresh_image(img)
+ segmap_mask, segmap_image = job_cal_seg_map_for_image(img, segmenter_options=segmenter_options, segmenter=segmenter)
+ segmap_name = generate_segment_imgs_job(img_name=img_name, segmap=segmap_mask, img=img)
+ if store_in_memory:
+ return segmap_mask
+ else:
+ return segmap_name
+
+def extract_segment_job(
+ video_name,
+ nerf=False,
+ background_method='knn',
+ device="cpu",
+ total_gpus=0,
+ mix_bg=True,
+ store_in_memory=False, # set to True to speed up a bit of preprocess, but leads to HUGE memory costs (100GB for 5-min video)
+ force_single_process=False, # turn this on if you find multi-process does not work on your environment
+):
+ global segmenter
+ global seg_model
+ del segmenter
+ del seg_model
+ seg_model = MediapipeSegmenter()
+ segmenter = vision.ImageSegmenter.create_from_options(seg_model.options)
+ # nerf means that we extract only one video, so can enable multi-process acceleration
+ multiprocess_enable = nerf and not force_single_process
+ try:
+ if "cuda" in device:
+ # determine which cuda index from subprocess id
+ pname = multiprocessing.current_process().name
+ pid = int(pname.rsplit("-", 1)[-1]) - 1
+ cuda_id = pid % total_gpus
+ device = f"cuda:{cuda_id}"
+
+ if nerf: # single video
+ raw_img_dir = video_name.replace(".mp4", "/gt_imgs/").replace("/raw/","/processed/")
+ else: # whole dataset
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/", "/gt_imgs/")
+ if not os.path.exists(raw_img_dir):
+ extract_img_job(video_name, raw_img_dir) # use ffmpeg to split video into imgs
+
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
+
+ img_lst = []
+
+ for img_name in img_names:
+ if store_in_memory:
+ img = load_rgb_image_to_path(img_name)
+ else:
+ img = img_name
+ img_lst.append(img)
+
+ print("| Extracting Segmaps && Saving...")
+ args = []
+ segmap_mask_lst = []
+ # preparing parameters for segment
+ for i in range(len(img_lst)):
+ img_name = img_names[i]
+ img = img_lst[i]
+ if multiprocess_enable: # create seg_model in subprocesses here
+ options = seg_model.options
+ segmenter_arg = None
+ else: # use seg_model of this process
+ options = None
+ segmenter_arg = segmenter
+ arg = (img_name, img, options, segmenter_arg, store_in_memory)
+ args.append(arg)
+
+ if multiprocess_enable:
+ for (_, res) in multiprocess_run_tqdm(segment_and_generate_for_image_job, args=args, num_workers=16, desc='generating segment images in multi-processes...'):
+ segmap_mask = res
+ segmap_mask_lst.append(segmap_mask)
+ else:
+ for index in tqdm.tqdm(range(len(img_lst)), desc="generating segment images in single-process..."):
+ segmap_mask = segment_and_generate_for_image_job(*args[index])
+ segmap_mask_lst.append(segmap_mask)
+ print("| Extracted Segmaps Done.")
+
+ print("| Extracting background...")
+ bg_prefix_name = f"bg{BG_NAME_MAP[background_method]}"
+ bg_img = extract_background(img_lst, segmap_mask_lst, method=background_method, device=device, mix_bg=mix_bg)
+ if nerf:
+ out_img_name = video_name.replace("/raw/", "/processed/").replace(".mp4", f"/{bg_prefix_name}.jpg")
+ else:
+ out_img_name = video_name.replace("/video/", f"/{bg_prefix_name}_img/").replace(".mp4", ".jpg")
+ save_rgb_image_to_path(bg_img, out_img_name)
+ print("| Extracted background done.")
+
+ print("| Extracting com_imgs...")
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
+ for i in tqdm.trange(len(img_names), desc='extracting com_imgs'):
+ img_name = img_names[i]
+ com_img = refresh_image(img_lst[i]).copy()
+ segmap = refresh_segment_mask(segmap_mask_lst[i])
+ bg_part = segmap[0].astype(bool)[..., None].repeat(3,axis=-1)
+ com_img[bg_part] = bg_img[bg_part]
+ out_img_name = img_name.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
+ save_rgb_image_to_path(com_img, out_img_name)
+ print("| Extracted com_imgs done.")
+
+ return 0
+ except Exception as e:
+ print(str(type(e)), e)
+ traceback.print_exc(e)
+ return 1
+
+def out_exist_job(vid_name, background_method='knn'):
+ com_prefix_name = f"com{BG_NAME_MAP[background_method]}"
+ img_dir = vid_name.replace("/video/", "/gt_imgs/").replace(".mp4", "")
+ out_dir1 = img_dir.replace("/gt_imgs/", "/head_imgs/")
+ out_dir2 = img_dir.replace("/gt_imgs/", f"/{com_prefix_name}_imgs/")
+
+ if os.path.exists(img_dir) and os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) :
+ num_frames = len(os.listdir(img_dir))
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames:
+ return None
+ else:
+ return vid_name
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names, background_method='knn'):
+ if len(vid_names) == 1: # nerf
+ return vid_names
+ todo_vid_names = []
+ fn_args = [(vid_name, background_method) for vid_name in vid_names]
+ for i, res in multiprocess_run_tqdm(out_exist_job, fn_args, num_workers=16, desc="checking todo videos..."):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/TH1KH_512/video')
+ parser.add_argument("--ds_name", default='TH1KH_512')
+ parser.add_argument("--num_workers", default=48, type=int)
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--load_names", action="store_true")
+ parser.add_argument("--background_method", choices=['knn', 'mat', 'ddnm', 'lama'], type=str, default='knn')
+ parser.add_argument("--total_gpus", default=0, type=int) # zero gpus means utilizing cpu
+ parser.add_argument("--no_mix_bg", action="store_true")
+ parser.add_argument("--store_in_memory", action="store_true") # set to True to speed up preprocess, but leads to high memory costs
+ parser.add_argument("--force_single_process", action="store_true") # turn this on if you find multi-process does not work on your environment
+
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ load_names = args.load_names
+ background_method = args.background_method
+ total_gpus = args.total_gpus
+ mix_bg = not args.no_mix_bg
+ store_in_memory = args.store_in_memory
+ force_single_process = args.force_single_process
+
+ devices = os.environ.get('CUDA_VISIBLE_DEVICES', '').split(",")
+ for d in devices[:total_gpus]:
+ os.system(f'pkill -f "voidgpu{d}"')
+
+ if ds_name.lower() == 'nerf': # 处理单个视频
+ vid_names = [vid_dir]
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_lms.npy") for video_name in vid_names]
+ else: # 处理整个数据集
+ if ds_name in ['lrs3_trainval']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ else:
+ raise NotImplementedError()
+
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
+ if os.path.exists(vid_names_path) and load_names:
+ print(f"loading vid names from {vid_names_path}")
+ vid_names = load_file(vid_names_path)
+ else:
+ vid_names = multiprocess_glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+ print(f"saving vid names to {vid_names_path}")
+ save_file(vid_names_path, vid_names)
+
+ vid_names = sorted(vid_names)
+ random.seed(args.seed)
+ random.shuffle(vid_names)
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ vid_names = get_todo_vid_names(vid_names, background_method)
+ print(f"todo videos number: {len(vid_names)}")
+
+ device = "cuda" if total_gpus > 0 else "cpu"
+ extract_job = extract_segment_job
+ fn_args = [(vid_name, ds_name=='nerf', background_method, device, total_gpus, mix_bg, store_in_memory, force_single_process) for i, vid_name in enumerate(vid_names)]
+
+ if ds_name == 'nerf': # 处理单个视频
+ extract_job(*fn_args[0])
+ else:
+ for vid_name in multiprocess_run_tqdm(extract_job, fn_args, desc=f"Root process {args.process_id}: segment images", num_workers=args.num_workers):
+ pass
\ No newline at end of file
diff --git a/data_gen/utils/process_video/fit_3dmm_landmark.py b/data_gen/utils/process_video/fit_3dmm_landmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..2622860f66989758fafc7f244b89b4f6da5f43f1
--- /dev/null
+++ b/data_gen/utils/process_video/fit_3dmm_landmark.py
@@ -0,0 +1,565 @@
+# This is a script for efficienct 3DMM coefficient extraction.
+# It could reconstruct accurate 3D face in real-time.
+# It is built upon BFM 2009 model and mediapipe landmark extractor.
+# It is authored by ZhenhuiYe (zhenhuiye@zju.edu.cn), free to contact him for any suggestion on improvement!
+
+from numpy.core.numeric import require
+from numpy.lib.function_base import quantile
+import torch
+import torch.nn.functional as F
+import copy
+import numpy as np
+
+import random
+import pickle
+import os
+import sys
+import cv2
+import argparse
+import tqdm
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from data_gen.utils.mp_feature_extractors.face_landmarker import MediapipeLandmarker, read_video_to_frames
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from deep_3drecon.secc_renderer import SECC_Renderer
+from utils.commons.os_utils import multiprocess_glob
+
+
+face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode='mediapipe')
+face_model.to(torch.device("cuda:0"))
+
+dir_path = os.path.dirname(os.path.realpath(__file__))
+
+
+def draw_axes(img, pitch, yaw, roll, tx, ty, size=50):
+ # yaw = -yaw
+ pitch = - pitch
+ roll = - roll
+ rotation_matrix = cv2.Rodrigues(np.array([pitch, yaw, roll]))[0].astype(np.float64)
+ axes_points = np.array([
+ [1, 0, 0, 0],
+ [0, 1, 0, 0],
+ [0, 0, 1, 0]
+ ], dtype=np.float64)
+ axes_points = rotation_matrix @ axes_points
+ axes_points = (axes_points[:2, :] * size).astype(int)
+ axes_points[0, :] = axes_points[0, :] + tx
+ axes_points[1, :] = axes_points[1, :] + ty
+
+ new_img = img.copy()
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 0].ravel()), (255, 0, 0), 3)
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 1].ravel()), (0, 255, 0), 3)
+ cv2.line(new_img, tuple(axes_points[:, 3].ravel()), tuple(axes_points[:, 2].ravel()), (0, 0, 255), 3)
+ return new_img
+
+def save_file(name, content):
+ with open(name, "wb") as f:
+ pickle.dump(content, f)
+
+def load_file(name):
+ with open(name, "rb") as f:
+ content = pickle.load(f)
+ return content
+
+def cal_lap_loss(in_tensor):
+ # [T, 68, 2]
+ t = in_tensor.shape[0]
+ in_tensor = in_tensor.reshape([t, -1]).permute(1,0).unsqueeze(1) # [c, 1, t]
+ in_tensor = torch.cat([in_tensor[:, :, 0:1], in_tensor, in_tensor[:, :, -1:]], dim=-1)
+ lap_kernel = torch.Tensor((-0.5, 1.0, -0.5)).reshape([1,1,3]).float().to(in_tensor.device) # [1, 1, kw]
+ loss_lap = 0
+
+ out_tensor = F.conv1d(in_tensor, lap_kernel)
+ loss_lap += torch.mean(out_tensor**2)
+ return loss_lap
+
+def cal_vel_loss(ldm):
+ # [B, 68, 2]
+ vel = ldm[1:] - ldm[:-1]
+ return torch.mean(torch.abs(vel))
+
+def cal_lan_loss(proj_lan, gt_lan):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan)** 2
+ # use the ldm weights from deep3drecon, see deep_3drecon/deep_3drecon_models/losses.py
+ weights = torch.zeros_like(loss)
+ weights = torch.ones_like(loss)
+ weights[:, 36:48, :] = 3 # eye 12 points
+ weights[:, -8:, :] = 3 # inner lip 8 points
+ weights[:, 28:31, :] = 3 # nose 3 points
+ loss = loss * weights
+ return torch.mean(loss)
+
+def cal_lan_loss_mp(proj_lan, gt_lan, mean:bool=True):
+ # [B, 68, 2]
+ loss = (proj_lan - gt_lan).pow(2)
+ # loss = (proj_lan - gt_lan).abs()
+ unmatch_mask = [ 93, 127, 132, 234, 323, 356, 361, 454]
+ upper_eye = [161,160,159,158,157] + [388,387,386,385,384]
+ eye = [33,246,161,160,159,158,157,173,133,155,154,153,145,144,163,7] + [263,466,388,387,386,385,384,398,362,382,381,380,374,373,390,249]
+ inner_lip = [78,191,80,81,82,13,312,311,310,415,308,324,318,402,317,14,87,178,88,95]
+ outer_lip = [61,185,40,39,37,0,267,269,270,409,291,375,321,405,314,17,84,181,91,146]
+ weights = torch.ones_like(loss)
+ weights[:, eye] = 3
+ weights[:, upper_eye] = 20
+ weights[:, inner_lip] = 5
+ weights[:, outer_lip] = 5
+ weights[:, unmatch_mask] = 0
+ loss = loss * weights
+ if mean:
+ loss = torch.mean(loss)
+ return loss
+
+def cal_acceleration_loss(trans):
+ vel = trans[1:] - trans[:-1]
+ acc = vel[1:] - vel[:-1]
+ return torch.mean(torch.abs(acc))
+
+def cal_acceleration_ldm_loss(ldm):
+ # [B, 68, 2]
+ vel = ldm[1:] - ldm[:-1]
+ acc = vel[1:] - vel[:-1]
+ lip_weight = 0.25 # we dont want smooth the lip too much
+ acc[48:68] *= lip_weight
+ return torch.mean(torch.abs(acc))
+
+def set_requires_grad(tensor_list):
+ for tensor in tensor_list:
+ tensor.requires_grad = True
+
+@torch.enable_grad()
+def fit_3dmm_for_a_video(
+ video_name,
+ nerf=False, # use the file name convention for GeneFace++
+ id_mode='global',
+ debug=False,
+ keypoint_mode='mediapipe',
+ large_yaw_threshold=9999999.9,
+ save=True
+) -> bool: # True: good, False: bad
+ assert video_name.endswith(".mp4"), "this function only support video as input"
+ if id_mode == 'global':
+ LAMBDA_REG_ID = 0.2
+ LAMBDA_REG_EXP = 0.6
+ LAMBDA_REG_LAP = 1.0
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
+ else:
+ LAMBDA_REG_ID = 0.3
+ LAMBDA_REG_EXP = 0.05
+ LAMBDA_REG_LAP = 1.0
+ LAMBDA_REG_VEL_ID = 0.0 # laplcaian is all you need for temporal consistency
+ LAMBDA_REG_VEL_EXP = 0.0 # laplcaian is all you need for temporal consistency
+
+ frames = read_video_to_frames(video_name) # [T, H, W, 3]
+ img_h, img_w = frames.shape[1], frames.shape[2]
+ assert img_h == img_w
+ num_frames = len(frames)
+
+ if nerf: # single video
+ lm_name = video_name.replace("/raw/", "/processed/").replace(".mp4","/lms_2d.npy")
+ else:
+ lm_name = video_name.replace("/video/", "/lms_2d/").replace(".mp4", "_lms.npy")
+
+ if os.path.exists(lm_name):
+ lms = np.load(lm_name)
+ else:
+ print(f"lms_2d file not found, try to extract it from video... {lm_name}")
+ try:
+ landmarker = MediapipeLandmarker()
+ img_lm478, vid_lm478 = landmarker.extract_lm478_from_frames(frames, anti_smooth_factor=20)
+ lms = landmarker.combine_vid_img_lm478_to_lm478(img_lm478, vid_lm478)
+ except Exception as e:
+ print(e)
+ return False
+ if lms is None:
+ print(f"get None lms_2d, please check whether each frame has one head, exiting... {lm_name}")
+ return False
+ lms = lms[:, :468, :]
+ lms = torch.FloatTensor(lms).cuda()
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+
+ if keypoint_mode == 'mediapipe':
+ # default
+ cal_lan_loss_fn = cal_lan_loss_mp
+ if nerf: # single video
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/coeff_fit_mp.npy")
+ else:
+ out_name = video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4", "_coeff_fit_mp.npy")
+ else:
+ # lm68 is less accurate than mp
+ cal_lan_loss_fn = cal_lan_loss
+ if nerf: # single video
+ out_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "_coeff_fit_lm68.npy")
+ else:
+ out_name = video_name.replace("/video/", "/coeff_fit_lm68/").replace(".mp4", "_coeff_fit_lm68.npy")
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except:
+ pass
+
+ id_dim, exp_dim = 80, 64
+ sel_ids = np.arange(0, num_frames, 40)
+
+ h = w = face_model.center * 2
+ img_scale_factor = img_h / h
+ lms /= img_scale_factor # rescale lms into [0,224]
+
+ if id_mode == 'global':
+ # default choice by GeneFace++ and later works
+ id_para = lms.new_zeros((1, id_dim), requires_grad=True)
+ elif id_mode == 'finegrained':
+ # legacy choice by GeneFace1 (ICLR 2023)
+ id_para = lms.new_zeros((num_frames, id_dim), requires_grad=True)
+ else: raise NotImplementedError(f"id mode {id_mode} not supported! we only support global or finegrained.")
+ exp_para = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ euler_angle = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans = lms.new_zeros((num_frames, 3), requires_grad=True)
+
+ set_requires_grad([id_para, exp_para, euler_angle, trans])
+
+ optimizer_idexp = torch.optim.Adam([id_para, exp_para], lr=.1)
+ optimizer_frame = torch.optim.Adam([euler_angle, trans], lr=.1)
+
+ # 其他参数初始化,先训练euler和trans
+ for _ in range(200):
+ if id_mode == 'global':
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans)
+ else:
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans)
+ loss_lan = cal_lan_loss_fn(proj_geo[:, :, :2], lms.detach())
+ loss = loss_lan
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_frame.step()
+
+ # print(f"loss_lan: {loss_lan.item():.2f}, euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ for param_group in optimizer_frame.param_groups:
+ param_group['lr'] = 0.1
+
+ # "jointly roughly training id exp euler trans"
+ for _ in range(200):
+ ret = {}
+ if id_mode == 'global':
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para.expand((num_frames, id_dim)), exp_para, euler_angle, trans, ret)
+ else:
+ proj_geo = face_model.compute_for_landmark_fit(
+ id_para, exp_para, euler_angle, trans, ret)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms.detach())
+ # loss_lap = cal_lap_loss(proj_geo)
+ # laplacian对euler影响不大,但是对trans的提升很大
+ loss_lap = cal_lap_loss(id_para) + cal_lap_loss(exp_para) + cal_lap_loss(euler_angle) * 0.3 + cal_lap_loss(trans) * 0.3
+
+ loss_regid = torch.mean(id_para*id_para) # 正则化
+ loss_regexp = torch.mean(exp_para * exp_para)
+
+ loss_vel_id = cal_vel_loss(id_para)
+ loss_vel_exp = cal_vel_loss(exp_para)
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP + loss_lap * LAMBDA_REG_LAP
+ optimizer_idexp.zero_grad()
+ optimizer_frame.zero_grad()
+ loss.backward()
+ optimizer_idexp.step()
+ optimizer_frame.step()
+
+ # print(f"loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},")
+ # print(f"euler_abs_mean: {euler_angle.abs().mean().item():.4f}, euler_std: {euler_angle.std().item():.4f}, euler_min: {euler_angle.min().item():.4f}, euler_max: {euler_angle.max().item():.4f}")
+ # print(f"trans_z_mean: {trans[...,2].mean().item():.4f}, trans_z_std: {trans[...,2].std().item():.4f}, trans_min: {trans[...,2].min().item():.4f}, trans_max: {trans[...,2].max().item():.4f}")
+
+ # start fine training, intialize from the roughly trained results
+ if id_mode == 'global':
+ id_para_ = lms.new_zeros((1, id_dim), requires_grad=False)
+ else:
+ id_para_ = lms.new_zeros((num_frames, id_dim), requires_grad=True)
+ id_para_.data = id_para.data.clone()
+ id_para = id_para_
+ exp_para_ = lms.new_zeros((num_frames, exp_dim), requires_grad=True)
+ exp_para_.data = exp_para.data.clone()
+ exp_para = exp_para_
+ euler_angle_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ euler_angle_.data = euler_angle.data.clone()
+ euler_angle = euler_angle_
+ trans_ = lms.new_zeros((num_frames, 3), requires_grad=True)
+ trans_.data = trans.data.clone()
+ trans = trans_
+
+ batch_size = 50
+ # "fine fitting the 3DMM in batches"
+ for i in range(int((num_frames-1)/batch_size+1)):
+ if (i+1)*batch_size > num_frames:
+ start_n = num_frames-batch_size
+ sel_ids = np.arange(max(num_frames-batch_size,0), num_frames)
+ else:
+ start_n = i*batch_size
+ sel_ids = np.arange(i*batch_size, i*batch_size+batch_size)
+ sel_lms = lms[sel_ids]
+
+ if id_mode == 'global':
+ sel_id_para = id_para.expand((sel_ids.shape[0], id_dim))
+ else:
+ sel_id_para = id_para.new_zeros((batch_size, id_dim), requires_grad=True)
+ sel_id_para.data = id_para[sel_ids].clone()
+ sel_exp_para = exp_para.new_zeros(
+ (batch_size, exp_dim), requires_grad=True)
+ sel_exp_para.data = exp_para[sel_ids].clone()
+ sel_euler_angle = euler_angle.new_zeros(
+ (batch_size, 3), requires_grad=True)
+ sel_euler_angle.data = euler_angle[sel_ids].clone()
+ sel_trans = trans.new_zeros((batch_size, 3), requires_grad=True)
+ sel_trans.data = trans[sel_ids].clone()
+
+ if id_mode == 'global':
+ set_requires_grad([sel_exp_para, sel_euler_angle, sel_trans])
+ optimizer_cur_batch = torch.optim.Adam(
+ [sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
+ else:
+ set_requires_grad([sel_id_para, sel_exp_para, sel_euler_angle, sel_trans])
+ optimizer_cur_batch = torch.optim.Adam(
+ [sel_id_para, sel_exp_para, sel_euler_angle, sel_trans], lr=0.005)
+
+ for j in range(50):
+ ret = {}
+ proj_geo = face_model.compute_for_landmark_fit(
+ sel_id_para, sel_exp_para, sel_euler_angle, sel_trans, ret)
+ loss_lan = cal_lan_loss_fn(
+ proj_geo[:, :, :2], lms[sel_ids].detach())
+
+ # loss_lap = cal_lap_loss(proj_geo)
+ loss_lap = cal_lap_loss(sel_id_para) + cal_lap_loss(sel_exp_para) + cal_lap_loss(sel_euler_angle) * 0.3 + cal_lap_loss(sel_trans) * 0.3
+ loss_vel_id = cal_vel_loss(sel_id_para)
+ loss_vel_exp = cal_vel_loss(sel_exp_para)
+ log_dict = {
+ 'loss_vel_id': loss_vel_id,
+ 'loss_vel_exp': loss_vel_exp,
+ 'loss_vel_euler': cal_vel_loss(sel_euler_angle),
+ 'loss_vel_trans': cal_vel_loss(sel_trans),
+ }
+ loss_regid = torch.mean(sel_id_para*sel_id_para) # 正则化
+ loss_regexp = torch.mean(sel_exp_para*sel_exp_para)
+ loss = loss_lan + loss_regid * LAMBDA_REG_ID + loss_regexp * LAMBDA_REG_EXP + loss_lap * LAMBDA_REG_LAP + loss_vel_id * LAMBDA_REG_VEL_ID + loss_vel_exp * LAMBDA_REG_VEL_EXP
+
+ optimizer_cur_batch.zero_grad()
+ loss.backward()
+ optimizer_cur_batch.step()
+
+ if debug:
+ print(f"batch {i} | loss_lan: {loss_lan.item():.2f}, loss_reg_id: {loss_regid.item():.2f},loss_reg_exp: {loss_regexp.item():.2f},loss_lap_ldm:{loss_lap.item():.4f}")
+ print("|--------" + ', '.join([f"{k}: {v:.4f}" for k,v in log_dict.items()]))
+ if id_mode != 'global':
+ id_para[sel_ids].data = sel_id_para.data.clone()
+ exp_para[sel_ids].data = sel_exp_para.data.clone()
+ euler_angle[sel_ids].data = sel_euler_angle.data.clone()
+ trans[sel_ids].data = sel_trans.data.clone()
+
+ coeff_dict = {'id': id_para.detach().cpu().numpy(), 'exp': exp_para.detach().cpu().numpy(),
+ 'euler': euler_angle.detach().cpu().numpy(), 'trans': trans.detach().cpu().numpy()}
+
+ # filter data by side-view pose
+ # bad_yaw = False
+ # yaws = [] # not so accurate
+ # for index in range(coeff_dict["trans"].shape[0]):
+ # yaw = coeff_dict["euler"][index][1]
+ # yaw = np.abs(yaw)
+ # yaws.append(yaw)
+ # if yaw > large_yaw_threshold:
+ # bad_yaw = True
+
+ if debug:
+ import imageio
+ from utils.visualization.vis_cam3d.camera_pose_visualizer import CameraPoseVisualizer
+ from data_util.face3d_helper import Face3DHelper
+ from data_gen.utils.process_video.extract_blink import get_eye_area_percent
+ face3d_helper = Face3DHelper('deep_3drecon/BFM', keypoint_mode='mediapipe')
+
+ t = coeff_dict['exp'].shape[0]
+ if len(coeff_dict['id']) == 1:
+ coeff_dict['id'] = np.repeat(coeff_dict['id'], t, axis=0)
+ idexp_lm3d = face3d_helper.reconstruct_idexp_lm3d_np(coeff_dict['id'], coeff_dict['exp']).reshape([t, -1])
+ cano_lm3d = idexp_lm3d / 10 + face3d_helper.key_mean_shape.squeeze().reshape([1, -1]).cpu().numpy()
+ cano_lm3d = cano_lm3d.reshape([t, -1, 3])
+ WH = 512
+ cano_lm3d = (cano_lm3d * WH/2 + WH/2).astype(int)
+
+ with torch.no_grad():
+ rot = ParametricFaceModel.compute_rotation(euler_angle)
+ extrinsic = torch.zeros([rot.shape[0], 4, 4]).to(rot.device)
+ extrinsic[:, :3,:3] = rot
+ extrinsic[:, :3, 3] = trans # / 10
+ extrinsic[:, 3, 3] = 1
+ extrinsic = extrinsic.cpu().numpy()
+
+ xy_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xy')
+ xz_camera_visualizer = CameraPoseVisualizer(xlim=[extrinsic[:,0,3].min().item()-0.5,extrinsic[:,0,3].max().item()+0.5],ylim=[extrinsic[:,1,3].min().item()-0.5,extrinsic[:,1,3].max().item()+0.5], zlim=[extrinsic[:,2,3].min().item()-0.5,extrinsic[:,2,3].max().item()+0.5], view_mode='xz')
+
+ if nerf:
+ debug_name = video_name.replace("/raw/", "/processed/").replace(".mp4", "/debug_fit_3dmm.mp4")
+ else:
+ debug_name = video_name.replace("/video/", "/coeff_fit_debug/").replace(".mp4", "_debug.mp4")
+ try:
+ os.makedirs(os.path.dirname(debug_name), exist_ok=True)
+ except: pass
+ writer = imageio.get_writer(debug_name, fps=25)
+ if id_mode == 'global':
+ id_para = id_para.repeat([exp_para.shape[0], 1])
+ proj_geo = face_model.compute_for_landmark_fit(id_para, exp_para, euler_angle, trans)
+ lm68s = proj_geo[:,:,:2].detach().cpu().numpy() # [T, 68,2]
+ lm68s = lm68s * img_scale_factor
+ lms = lms * img_scale_factor
+ lm68s[..., 1] = img_h - lm68s[..., 1] # flip the height axis
+ lms[..., 1] = img_h - lms[..., 1] # flip the height axis
+ lm68s = lm68s.astype(int)
+ for i in tqdm.trange(min(250, len(frames)), desc=f'rendering debug video to {debug_name}..'):
+ xy_cam3d_img = xy_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
+ xy_cam3d_img = cv2.resize(xy_cam3d_img, (512,512))
+ xz_cam3d_img = xz_camera_visualizer.extrinsic2pyramid(extrinsic[i], focal_len_scaled=0.25)
+ xz_cam3d_img = cv2.resize(xz_cam3d_img, (512,512))
+
+ img = copy.deepcopy(frames[i])
+ img2 = copy.deepcopy(frames[i])
+
+ img = draw_axes(img, euler_angle[i,0].item(), euler_angle[i,1].item(), euler_angle[i,2].item(), lm68s[i][4][0].item(), lm68s[i, 4][1].item(), size=50)
+
+ gt_lm_color = (255, 0, 0)
+
+ for lm in lm68s[i]:
+ img = cv2.circle(img, lm, 1, (0, 0, 255), thickness=-1) # blue
+ for gt_lm in lms[i]:
+ img2 = cv2.circle(img2, gt_lm.cpu().numpy().astype(int), 2, gt_lm_color, thickness=1)
+
+ cano_lm3d_img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
+ for j in range(len(cano_lm3d[i])):
+ x, y, _ = cano_lm3d[i, j]
+ color = (255,0,0)
+ cano_lm3d_img = cv2.circle(cano_lm3d_img, center=(x,y), radius=3, color=color, thickness=-1)
+ cano_lm3d_img = cv2.flip(cano_lm3d_img, 0)
+
+ _, secc_img = secc_renderer(id_para[0:1], exp_para[i:i+1], euler_angle[i:i+1]*0, trans[i:i+1]*0)
+ secc_img = (secc_img +1)*127.5
+ secc_img = F.interpolate(secc_img, size=(img_h, img_w))
+ secc_img = secc_img.permute(0, 2,3,1).int().cpu().numpy()[0]
+ out_img1 = np.concatenate([img, img2, secc_img], axis=1).astype(np.uint8)
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ out_img2 = np.concatenate([xy_cam3d_img, xz_cam3d_img, cano_lm3d_img], axis=1).astype(np.uint8)
+ out_img = np.concatenate([out_img1, out_img2], axis=0)
+ writer.append_data(out_img)
+ writer.close()
+
+ # if bad_yaw:
+ # print(f"Skip {video_name} due to TOO LARGE YAW")
+ # return False
+
+ if save:
+ np.save(out_name, coeff_dict, allow_pickle=True)
+ return coeff_dict
+
+def out_exist_job(vid_name):
+ out_name = vid_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy")
+ lms_name = vid_name.replace("/video/", "/lms_2d/").replace(".mp4","_lms.npy")
+ if os.path.exists(out_name) or not os.path.exists(lms_name):
+ return None
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names):
+ if len(vid_names) == 1: # single video, nerf
+ return vid_names
+ todo_vid_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ # parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--vid_dir", default='data/raw/videos/May_10s.mp4')
+ parser.add_argument("--ds_name", default='nerf') # 'nerf' | 'CelebV-HQ' | 'TH1KH_512' | etc
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--id_mode", default='global', type=str) # global | finegrained
+ parser.add_argument("--keypoint_mode", default='mediapipe', type=str)
+ parser.add_argument("--large_yaw_threshold", default=9999999.9, type=float) # could be 0.7
+ parser.add_argument("--debug", action='store_true')
+ parser.add_argument("--reset", action='store_true')
+ parser.add_argument("--load_names", action="store_true")
+
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ load_names = args.load_names
+
+ print(f"args {args}")
+
+ if ds_name.lower() == 'nerf': # 处理单个视频
+ vid_names = [vid_dir]
+ out_names = [video_name.replace("/raw/", "/processed/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
+ else: # 处理整个数据集
+ if ds_name in ['lrs3_trainval']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*.mp4")
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ else:
+ raise NotImplementedError()
+
+ vid_names_path = os.path.join(vid_dir, "vid_names.pkl")
+ if os.path.exists(vid_names_path) and load_names:
+ print(f"loading vid names from {vid_names_path}")
+ vid_names = load_file(vid_names_path)
+ else:
+ vid_names = multiprocess_glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+ print(f"saving vid names to {vid_names_path}")
+ save_file(vid_names_path, vid_names)
+ out_names = [video_name.replace("/video/", "/coeff_fit_mp/").replace(".mp4","_coeff_fit_mp.npy") for video_name in vid_names]
+
+ print(vid_names[:10])
+ random.seed(args.seed)
+ random.shuffle(vid_names)
+
+ face_model = ParametricFaceModel(bfm_folder='deep_3drecon/BFM',
+ camera_distance=10, focal=1015, keypoint_mode=args.keypoint_mode)
+ face_model.to(torch.device("cuda:0"))
+ secc_renderer = SECC_Renderer(512)
+ secc_renderer.to("cuda:0")
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if not args.reset:
+ vid_names = get_todo_vid_names(vid_names)
+
+ failed_img_names = []
+ for i in tqdm.trange(len(vid_names), desc=f"process {process_id}: fitting 3dmm ..."):
+ img_name = vid_names[i]
+ try:
+ is_person_specific_data = ds_name=='nerf'
+ success = fit_3dmm_for_a_video(img_name, is_person_specific_data, args.id_mode, args.debug, large_yaw_threshold=args.large_yaw_threshold)
+ if not success:
+ failed_img_names.append(img_name)
+ except Exception as e:
+ print(img_name, e)
+ failed_img_names.append(img_name)
+ print(f"finished {i + 1} / {len(vid_names)} = {(i + 1) / len(vid_names):.4f}, failed {len(failed_img_names)} / {i + 1} = {len(failed_img_names) / (i + 1):.4f}")
+ sys.stdout.flush()
+ print(f"all failed image names: {failed_img_names}")
+ print(f"All finished!")
\ No newline at end of file
diff --git a/data_gen/utils/process_video/inpaint_torso_imgs.py b/data_gen/utils/process_video/inpaint_torso_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..c938a6f79e7b796cc6f321e332eb7840244b4cf9
--- /dev/null
+++ b/data_gen/utils/process_video/inpaint_torso_imgs.py
@@ -0,0 +1,193 @@
+import cv2
+import os
+import numpy as np
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+from scipy.ndimage import binary_erosion, binary_dilation
+
+from tasks.eg3ds.loss_utils.segment_loss.mp_segmenter import MediapipeSegmenter
+seg_model = MediapipeSegmenter()
+
+def inpaint_torso_job(video_name, idx=None, total=None):
+ raw_img_dir = video_name.replace(".mp4", "").replace("/video/","/gt_imgs/")
+ img_names = glob.glob(os.path.join(raw_img_dir, "*.jpg"))
+
+ for image_path in tqdm.tqdm(img_names):
+ # read ori image
+ ori_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ segmap = seg_model._cal_seg_map(cv2.cvtColor(ori_image, cv2.COLOR_BGR2RGB))
+ head_part = (segmap[1] + segmap[3] + segmap[5]).astype(np.bool)
+ torso_part = (segmap[4]).astype(np.bool)
+ neck_part = (segmap[2]).astype(np.bool)
+ bg_part = segmap[0].astype(np.bool)
+ head_image = cv2.imread(image_path.replace("/gt_imgs/", "/head_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ torso_image = cv2.imread(image_path.replace("/gt_imgs/", "/torso_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
+ bg_image = cv2.imread(image_path.replace("/gt_imgs/", "/bg_imgs/"), cv2.IMREAD_UNCHANGED) # [H, W, 3]
+
+ # head_part = (head_image[...,0] != 0) & (head_image[...,1] != 0) & (head_image[...,2] != 0)
+ # torso_part = (torso_image[...,0] != 0) & (torso_image[...,1] != 0) & (torso_image[...,2] != 0)
+ # bg_part = (bg_image[...,0] != 0) & (bg_image[...,1] != 0) & (bg_image[...,2] != 0)
+
+ # get gt image
+ gt_image = ori_image.copy()
+ gt_image[bg_part] = bg_image[bg_part]
+ cv2.imwrite(image_path.replace('ori_imgs', 'gt_imgs'), gt_image)
+
+ # get torso image
+ torso_image = gt_image.copy() # rgb
+ torso_image[head_part] = 0
+ torso_alpha = 255 * np.ones((gt_image.shape[0], gt_image.shape[1], 1), dtype=np.uint8) # alpha
+
+ # torso part "vertical" in-painting...
+ L = 8 + 1
+ torso_coords = np.stack(np.nonzero(torso_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((torso_coords[:, 0], torso_coords[:, 1]))
+ torso_coords = torso_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(torso_coords[:, 1], return_index=True, return_counts=True)
+ top_torso_coords = torso_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_torso_coords_up = top_torso_coords.copy() - np.array([1, 0]) # [N, 2]
+ mask = head_part[tuple(top_torso_coords_up.T)]
+ if mask.any():
+ top_torso_coords = top_torso_coords[mask]
+ # get the color
+ top_torso_colors = gt_image[tuple(top_torso_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_torso_coords = top_torso_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_torso_coords += inpaint_offsets
+ inpaint_torso_coords = inpaint_torso_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_torso_colors = top_torso_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_torso_colors = (inpaint_torso_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ torso_image[tuple(inpaint_torso_coords.T)] = inpaint_torso_colors
+
+ inpaint_torso_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
+ inpaint_torso_mask[tuple(inpaint_torso_coords.T)] = True
+ else:
+ inpaint_torso_mask = None
+
+ # neck part "vertical" in-painting...
+ push_down = 4
+ L = 48 + push_down + 1
+
+ neck_part = binary_dilation(neck_part, structure=np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=bool), iterations=3)
+
+ neck_coords = np.stack(np.nonzero(neck_part), axis=-1) # [M, 2]
+ # lexsort: sort 2D coords first by y then by x,
+ # ref: https://stackoverflow.com/questions/2706605/sorting-a-2d-numpy-array-by-multiple-axes
+ inds = np.lexsort((neck_coords[:, 0], neck_coords[:, 1]))
+ neck_coords = neck_coords[inds]
+ # choose the top pixel for each column
+ u, uid, ucnt = np.unique(neck_coords[:, 1], return_index=True, return_counts=True)
+ top_neck_coords = neck_coords[uid] # [m, 2]
+ # only keep top-is-head pixels
+ top_neck_coords_up = top_neck_coords.copy() - np.array([1, 0])
+ mask = head_part[tuple(top_neck_coords_up.T)]
+
+ top_neck_coords = top_neck_coords[mask]
+ # push these top down for 4 pixels to make the neck inpainting more natural...
+ offset_down = np.minimum(ucnt[mask] - 1, push_down)
+ top_neck_coords += np.stack([offset_down, np.zeros_like(offset_down)], axis=-1)
+ # get the color
+ top_neck_colors = gt_image[tuple(top_neck_coords.T)] # [m, 3]
+ # construct inpaint coords (vertically up, or minus in x)
+ inpaint_neck_coords = top_neck_coords[None].repeat(L, 0) # [L, m, 2]
+ inpaint_offsets = np.stack([-np.arange(L), np.zeros(L, dtype=np.int32)], axis=-1)[:, None] # [L, 1, 2]
+ inpaint_neck_coords += inpaint_offsets
+ inpaint_neck_coords = inpaint_neck_coords.reshape(-1, 2) # [Lm, 2]
+ inpaint_neck_colors = top_neck_colors[None].repeat(L, 0) # [L, m, 3]
+ darken_scaler = 0.98 ** np.arange(L).reshape(L, 1, 1) # [L, 1, 1]
+ inpaint_neck_colors = (inpaint_neck_colors * darken_scaler).reshape(-1, 3) # [Lm, 3]
+ # set color
+ torso_image[tuple(inpaint_neck_coords.T)] = inpaint_neck_colors
+
+ # apply blurring to the inpaint area to avoid vertical-line artifects...
+ inpaint_mask = np.zeros_like(torso_image[..., 0]).astype(bool)
+ inpaint_mask[tuple(inpaint_neck_coords.T)] = True
+
+ blur_img = torso_image.copy()
+ blur_img = cv2.GaussianBlur(blur_img, (5, 5), cv2.BORDER_DEFAULT)
+
+ torso_image[inpaint_mask] = blur_img[inpaint_mask]
+
+ # set mask
+ mask = (neck_part | torso_part | inpaint_mask)
+ if inpaint_torso_mask is not None:
+ mask = mask | inpaint_torso_mask
+ torso_image[~mask] = 0
+ torso_alpha[~mask] = 0
+
+ cv2.imwrite("0.png", np.concatenate([torso_image, torso_alpha], axis=-1))
+
+ print(f'[INFO] ===== extracted torso and gt images =====')
+
+
+def out_exist_job(vid_name):
+ out_dir1 = vid_name.replace("/video/", "/inpaint_torso_imgs/").replace(".mp4","")
+ out_dir2 = vid_name.replace("/video/", "/inpaint_torso_with_bg_imgs/").replace(".mp4","")
+ out_dir3 = vid_name.replace("/video/", "/torso_imgs/").replace(".mp4","")
+ out_dir4 = vid_name.replace("/video/", "/torso_with_bg_imgs/").replace(".mp4","")
+
+ if os.path.exists(out_dir1) and os.path.exists(out_dir1) and os.path.exists(out_dir2) and os.path.exists(out_dir3) and os.path.exists(out_dir4):
+ num_frames = len(os.listdir(out_dir1))
+ if len(os.listdir(out_dir1)) == num_frames and len(os.listdir(out_dir2)) == num_frames and len(os.listdir(out_dir3)) == num_frames and len(os.listdir(out_dir4)) == num_frames:
+ return None
+ else:
+ return vid_name
+ else:
+ return vid_name
+
+def get_todo_vid_names(vid_names):
+ todo_vid_names = []
+ for i, res in multiprocess_run_tqdm(out_exist_job, vid_names, num_workers=16):
+ if res is not None:
+ todo_vid_names.append(res)
+ return todo_vid_names
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=48, type=int)
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ parser.add_argument("--reset", action='store_true')
+
+ inpaint_torso_job('/home/tiger/datasets/raw/CelebV-HQ/video/dgdEr-mXQT4_8.mp4')
+ # args = parser.parse_args()
+ # vid_dir = args.vid_dir
+ # ds_name = args.ds_name
+ # if ds_name in ['lrs3_trainval']:
+ # mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ # if ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ # vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
+ # elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
+ # vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ # vid_names = glob.glob(vid_name_pattern)
+ # vid_names = sorted(vid_names)
+ # random.seed(args.seed)
+ # random.shuffle(vid_names)
+
+ # process_id = args.process_id
+ # total_process = args.total_process
+ # if total_process > 1:
+ # assert process_id <= total_process -1
+ # num_samples_per_process = len(vid_names) // total_process
+ # if process_id == total_process:
+ # vid_names = vid_names[process_id * num_samples_per_process : ]
+ # else:
+ # vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ # if not args.reset:
+ # vid_names = get_todo_vid_names(vid_names)
+ # print(f"todo videos number: {len(vid_names)}")
+
+ # fn_args = [(vid_name,i,len(vid_names)) for i, vid_name in enumerate(vid_names)]
+ # for vid_name in multiprocess_run_tqdm(inpaint_torso_job ,fn_args, desc=f"Root process {args.process_id}: extracting segment images", num_workers=args.num_workers):
+ # pass
\ No newline at end of file
diff --git a/data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py b/data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py
new file mode 100644
index 0000000000000000000000000000000000000000..f01c1681a8e39046645cfdb3e5d79b4b82cf9b46
--- /dev/null
+++ b/data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py
@@ -0,0 +1,87 @@
+import os, glob
+import cv2
+from utils.commons.os_utils import multiprocess_glob
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+
+def get_video_infos(video_path):
+ vid_cap = cv2.VideoCapture(video_path)
+ height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
+ width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
+ fps = vid_cap.get(cv2.CAP_PROP_FPS)
+ total_frames = int(vid_cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ return {'height': height, 'width': width, 'fps': fps, 'total_frames':total_frames}
+
+def extract_img_job(video_name:str):
+ out_path = video_name.replace("/video_raw/","/video/",1)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ vid_info = get_video_infos(video_name)
+ assert vid_info['width'] == vid_info['height']
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
+ os.system(cmd)
+
+def extract_img_job_crop(video_name:str):
+ out_path = video_name.replace("/video_raw/","/video/",1)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ vid_info = get_video_infos(video_name)
+ wh = min(vid_info['width'], vid_info['height'])
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop={wh}:{wh},scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
+ os.system(cmd)
+
+def extract_img_job_crop_ravdess(video_name:str):
+ out_path = video_name.replace("/video_raw/","/video/",1)
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},crop=720:720,scale=w=512:h=512 -q:v 1 -c:v libx264 -pix_fmt yuv420p -b:v 2000k -v quiet -y {out_path}'
+ os.system(cmd)
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video_raw/')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=32, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ args = parser.parse_args()
+ print(f"args {args}")
+
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ if ds_name in ['lrs3_trainval']:
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_names = multiprocess_glob(os.path.join(vid_dir, "*.mp4"))
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2', 'CMLR']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ vid_names = multiprocess_glob(vid_name_pattern)
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ vid_names = multiprocess_glob(vid_name_pattern)
+ else:
+ raise NotImplementedError()
+ vid_names = sorted(vid_names)
+ print(f"total video number : {len(vid_names)}")
+ print(f"first {vid_names[0]} last {vid_names[-1]}")
+ # exit()
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ if ds_name == "RAVDESS":
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop_ravdess, vid_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+ elif ds_name == "CMLR":
+ for i, res in multiprocess_run_tqdm(extract_img_job_crop, vid_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+ else:
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="resampling videos"):
+ pass
+
diff --git a/data_gen/utils/process_video/split_video_to_imgs.py b/data_gen/utils/process_video/split_video_to_imgs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1c16c3415fb953c965cf56b3161a59460375079
--- /dev/null
+++ b/data_gen/utils/process_video/split_video_to_imgs.py
@@ -0,0 +1,53 @@
+import os, glob
+from utils.commons.multiprocess_utils import multiprocess_run_tqdm
+
+from data_gen.utils.path_converter import PathConverter, pc
+
+# mp4_names = glob.glob("/home/tiger/datasets/raw/CelebV-HQ/video/*.mp4")
+
+def extract_img_job(video_name, raw_img_dir=None):
+ if raw_img_dir is not None:
+ out_path = raw_img_dir
+ else:
+ out_path = pc.to(video_name.replace(".mp4", ""), "vid", "gt")
+ os.makedirs(out_path, exist_ok=True)
+ ffmpeg_path = "/usr/bin/ffmpeg"
+ cmd = f'{ffmpeg_path} -i {video_name} -vf fps={25},scale=w=512:h=512 -qmin 1 -q:v 1 -start_number 0 -v quiet {os.path.join(out_path, "%8d.jpg")}'
+ os.system(cmd)
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm, random
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--vid_dir", default='/home/tiger/datasets/raw/CelebV-HQ/video')
+ parser.add_argument("--ds_name", default='CelebV-HQ')
+ parser.add_argument("--num_workers", default=64, type=int)
+ parser.add_argument("--process_id", default=0, type=int)
+ parser.add_argument("--total_process", default=1, type=int)
+ args = parser.parse_args()
+ vid_dir = args.vid_dir
+ ds_name = args.ds_name
+ if ds_name in ['lrs3_trainval']:
+ mp4_name_pattern = os.path.join(vid_dir, "*/*.mp4")
+ elif ds_name in ['TH1KH_512', 'CelebV-HQ']:
+ vid_names = glob.glob(os.path.join(vid_dir, "*.mp4"))
+ elif ds_name in ['lrs2', 'lrs3', 'voxceleb2']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*.mp4")
+ vid_names = glob.glob(vid_name_pattern)
+ elif ds_name in ["RAVDESS", 'VFHQ']:
+ vid_name_pattern = os.path.join(vid_dir, "*/*/*/*.mp4")
+ vid_names = glob.glob(vid_name_pattern)
+ vid_names = sorted(vid_names)
+
+ process_id = args.process_id
+ total_process = args.total_process
+ if total_process > 1:
+ assert process_id <= total_process -1
+ num_samples_per_process = len(vid_names) // total_process
+ if process_id == total_process:
+ vid_names = vid_names[process_id * num_samples_per_process : ]
+ else:
+ vid_names = vid_names[process_id * num_samples_per_process : (process_id+1) * num_samples_per_process]
+
+ for i, res in multiprocess_run_tqdm(extract_img_job, vid_names, num_workers=args.num_workers, desc="extracting images"):
+ pass
+
diff --git a/data_util/face3d_helper.py b/data_util/face3d_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4d260b1a9c320639ad035bb1d804ed21b076092
--- /dev/null
+++ b/data_util/face3d_helper.py
@@ -0,0 +1,309 @@
+import os
+import numpy as np
+import torch
+import torch.nn as nn
+from scipy.io import loadmat
+
+from deep_3drecon.deep_3drecon_models.bfm import perspective_projection
+
+
+class Face3DHelper(nn.Module):
+ def __init__(self, bfm_dir='deep_3drecon/BFM', keypoint_mode='lm68', use_gpu=True):
+ super().__init__()
+ self.keypoint_mode = keypoint_mode # lm68 | mediapipe
+ self.bfm_dir = bfm_dir
+ self.load_3dmm()
+ if use_gpu: self.to("cuda")
+
+ def load_3dmm(self):
+ model = loadmat(os.path.join(self.bfm_dir, "BFM_model_front.mat"))
+ self.register_buffer('mean_shape',torch.from_numpy(model['meanshape'].transpose()).float()) # mean face shape. [3*N, 1], N=35709, xyz=3, ==> 3*N=107127
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ # re-center
+ mean_shape = mean_shape - torch.mean(mean_shape, dim=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+ self.register_buffer('id_base',torch.from_numpy(model['idBase']).float()) # identity basis. [3*N,80], we have 80 eigen faces for identity
+ self.register_buffer('exp_base',torch.from_numpy(model['exBase']).float()) # expression basis. [3*N,64], we have 64 eigen faces for expression
+
+ self.register_buffer('mean_texure',torch.from_numpy(model['meantex'].transpose()).float()) # mean face texture. [3*N,1] (0-255)
+ self.register_buffer('tex_base',torch.from_numpy(model['texBase']).float()) # texture basis. [3*N,80], rgb=3
+
+ self.register_buffer('point_buf',torch.from_numpy(model['point_buf']).float()) # triangle indices for each vertex that lies in. starts from 1. [N,8] (1-F)
+ self.register_buffer('face_buf',torch.from_numpy(model['tri']).float()) # vertex indices in each triangle. starts from 1. [F,3] (1-N)
+ if self.keypoint_mode == 'mediapipe':
+ self.register_buffer('key_points', torch.from_numpy(np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)))
+ unmatch_mask = self.key_points < 0
+ self.key_points[unmatch_mask] = 0
+ else:
+ self.register_buffer('key_points',torch.from_numpy(model['keypoints'].squeeze().astype(np.int_)).long()) # vertex indices of 68 facial landmarks. starts from 1. [68,1]
+
+
+ self.register_buffer('key_mean_shape',self.mean_shape.reshape([-1,3])[self.key_points,:])
+ self.register_buffer('key_id_base', self.id_base.reshape([-1,3,80])[self.key_points, :, :].reshape([-1,80]))
+ self.register_buffer('key_exp_base', self.exp_base.reshape([-1,3,64])[self.key_points, :, :].reshape([-1,64]))
+ self.key_id_base_np = self.key_id_base.cpu().numpy()
+ self.key_exp_base_np = self.key_exp_base.cpu().numpy()
+
+ self.register_buffer('persc_proj', torch.tensor(perspective_projection(focal=1015, center=112)))
+ def split_coeff(self, coeff):
+ """
+ coeff: Tensor[B, T, c=257] or [T, c=257]
+ """
+ ret_dict = {
+ 'identity': coeff[..., :80], # identity, [b, t, c=80]
+ 'expression': coeff[..., 80:144], # expression, [b, t, c=80]
+ 'texture': coeff[..., 144:224], # texture, [b, t, c=80]
+ 'euler': coeff[..., 224:227], # euler euler for pose, [b, t, c=3]
+ 'translation': coeff[..., 254:257], # translation, [b, t, c=3]
+ 'gamma': coeff[..., 227:254] # lighting, [b, t, c=27]
+ }
+ return ret_dict
+
+ def reconstruct_face_mesh(self, id_coeff, exp_coeff):
+ """
+ Generate a pose-independent 3D face mesh!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.mean_shape.squeeze().reshape([1, -1]) # [3N, 1] ==> [1, 3N]
+ id_base, exp_base = self.id_base, self.exp_base # [3*N, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3N] ==> [t,3N]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ # mean_xyz = self.mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
+ # face_mesh = face - mean_xyz.unsqueeze(0) # [t,N,3]
+ return face
+
+ def reconstruct_cano_lm3d(self, id_coeff, exp_coeff):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ # mean_xyz = self.key_mean_shape.squeeze().reshape([-1,3]).mean(dim=0) # [1, 3]
+ # lm3d = face - mean_xyz.unsqueeze(0) # [t,N,3]
+ return face
+
+ def reconstruct_lm3d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ rot = self.compute_rotation(euler)
+ # transform
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
+ # to camera
+ if to_camera:
+ lm3d[...,-1] = 10 - lm3d[...,-1]
+ return lm3d
+
+ def reconstruct_lm2d_nerf(self, id_coeff, exp_coeff, euler, trans):
+ lm2d = self.reconstruct_lm2d(id_coeff, exp_coeff, euler, trans, to_camera=False)
+ lm2d[..., 0] = 1 - lm2d[..., 0]
+ lm2d[..., 1] = 1 - lm2d[..., 1]
+ return lm2d
+
+ def reconstruct_lm2d(self, id_coeff, exp_coeff, euler, trans, to_camera=True):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ is_btc_flag = True if id_coeff.ndim == 3 else False
+ if is_btc_flag:
+ b,t,_ = id_coeff.shape
+ id_coeff = id_coeff.reshape([b*t,-1])
+ exp_coeff = exp_coeff.reshape([b*t,-1])
+ euler = euler.reshape([b*t,-1])
+ trans = trans.reshape([b*t,-1])
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ mean_face = self.key_mean_shape.squeeze().reshape([1, -1]) # [3*68, 1] ==> [1, 3*68]
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = mean_face + identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ # re-centering the face with mean_xyz, so the face will be in [-1, 1]
+ rot = self.compute_rotation(euler)
+ # transform
+ lm3d = face @ rot + trans.unsqueeze(1) # [t, N, 3]
+ # to camera
+ if to_camera:
+ lm3d[...,-1] = 10 - lm3d[...,-1]
+ # to image_plane
+ lm3d = lm3d @ self.persc_proj
+ lm2d = lm3d[..., :2] / lm3d[..., 2:]
+ # flip
+ lm2d[..., 1] = 224 - lm2d[..., 1]
+ lm2d /= 224
+ if is_btc_flag:
+ return lm2d.reshape([b,t,-1,2])
+ return lm2d
+
+ def compute_rotation(self, euler):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ euler -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = euler.shape[0]
+ euler = euler.to(self.key_id_base.device)
+ ones = torch.ones([batch_size, 1]).to(self.key_id_base.device)
+ zeros = torch.zeros([batch_size, 1]).to(self.key_id_base.device)
+ x, y, z = euler[:, :1], euler[:, 1:2], euler[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+ def reconstruct_idexp_lm3d(self, id_coeff, exp_coeff):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_coeff = id_coeff.to(self.key_id_base.device)
+ exp_coeff = exp_coeff.to(self.key_id_base.device)
+ id_base, exp_base = self.key_id_base, self.key_exp_base # [3*68, C]
+ identity_diff_face = torch.matmul(id_coeff, id_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = torch.matmul(exp_coeff, exp_base.transpose(0,1)) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ lm3d = face * 10
+ return lm3d
+
+ def reconstruct_idexp_lm3d_np(self, id_coeff, exp_coeff):
+ """
+ Generate 3D landmark with keypoint base!
+ id_coeff: Tensor[T, c=80]
+ exp_coeff: Tensor[T, c=64]
+ """
+ id_base, exp_base = self.key_id_base_np, self.key_exp_base_np # [3*68, C]
+ identity_diff_face = np.dot(id_coeff, id_base.T) # [t,c],[c,3*68] ==> [t,3*68]
+ expression_diff_face = np.dot(exp_coeff, exp_base.T) # [t,c],[c,3*68] ==> [t,3*68]
+
+ face = identity_diff_face + expression_diff_face # [t,3N]
+ face = face.reshape([face.shape[0], -1, 3]) # [t,N,3]
+ lm3d = face * 10
+ return lm3d
+
+ def get_eye_mouth_lm_from_lm3d(self, lm3d):
+ eye_lm = lm3d[:, 17:48] # [T, 31, 3]
+ mouth_lm = lm3d[:, 48:68] # [T, 20, 3]
+ return eye_lm, mouth_lm
+
+ def get_eye_mouth_lm_from_lm3d_batch(self, lm3d):
+ eye_lm = lm3d[:, :, 17:48] # [T, 31, 3]
+ mouth_lm = lm3d[:, :, 48:68] # [T, 20, 3]
+ return eye_lm, mouth_lm
+
+ def close_mouth_for_idexp_lm3d(self, idexp_lm3d, freeze_as_first_frame=True):
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
+ num_frames = idexp_lm3d.shape[0]
+ eps = 0.0
+ # [n_landmarks=68,xyz=3], x 代表左右,y代表上下,z代表深度
+ idexp_lm3d[:,49:54, 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 + eps * 2
+ idexp_lm3d[:,range(59,54,-1), 1] = (idexp_lm3d[:,49:54, 1] + idexp_lm3d[:,range(59,54,-1), 1])/2 - eps * 2
+
+ idexp_lm3d[:,61:64, 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 + eps
+ idexp_lm3d[:,range(67,64,-1), 1] = (idexp_lm3d[:,61:64, 1] + idexp_lm3d[:,range(67,64,-1), 1])/2 - eps
+
+ idexp_lm3d[:,49:54, 1] += (0.03 - idexp_lm3d[:,49:54, 1].mean(dim=1) + idexp_lm3d[:,61:64, 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
+ idexp_lm3d[:,range(59,54,-1), 1] += (-0.03 - idexp_lm3d[:,range(59,54,-1), 1].mean(dim=1) + idexp_lm3d[:,range(67,64,-1), 1].mean(dim=1)).unsqueeze(1).repeat([1,5])
+
+ if freeze_as_first_frame:
+ idexp_lm3d[:, 48:68,] = idexp_lm3d[0, 48:68].unsqueeze(0).clone().repeat([num_frames, 1,1])*0
+ return idexp_lm3d.cpu()
+
+ def close_eyes_for_idexp_lm3d(self, idexp_lm3d):
+ idexp_lm3d = idexp_lm3d.reshape([-1, 68,3])
+ eps = 0.003
+ idexp_lm3d[:,37:39, 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 + eps
+ idexp_lm3d[:,range(41,39,-1), 1] = (idexp_lm3d[:,37:39, 1] + idexp_lm3d[:,range(41,39,-1), 1])/2 - eps
+
+ idexp_lm3d[:,43:45, 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 + eps
+ idexp_lm3d[:,range(47,45,-1), 1] = (idexp_lm3d[:,43:45, 1] + idexp_lm3d[:,range(47,45,-1), 1])/2 - eps
+
+ return idexp_lm3d
+
+if __name__ == '__main__':
+ import cv2
+
+ font = cv2.FONT_HERSHEY_SIMPLEX
+
+ face_mesh_helper = Face3DHelper('deep_3drecon/BFM')
+ coeff_npy = 'data/coeff_fit_mp/crop_nana_003_coeff_fit_mp.npy'
+ coeff_dict = np.load(coeff_npy, allow_pickle=True).tolist()
+ lm3d = face_mesh_helper.reconstruct_lm2d(torch.tensor(coeff_dict['id']).cuda(), torch.tensor(coeff_dict['exp']).cuda(), torch.tensor(coeff_dict['euler']).cuda(), torch.tensor(coeff_dict['trans']).cuda() )
+
+ WH = 512
+ lm3d = (lm3d * WH).cpu().int().numpy()
+ eye_idx = list(range(36,48))
+ mouth_idx = list(range(48,68))
+ import imageio
+ debug_name = 'debug_lm3d.mp4'
+ writer = imageio.get_writer(debug_name, fps=25)
+ for i_img in range(len(lm3d)):
+ lm2d = lm3d[i_img ,:, :2] # [68, 2]
+ img = np.ones([WH, WH, 3], dtype=np.uint8) * 255
+ for i in range(len(lm2d)):
+ x, y = lm2d[i]
+ if i in eye_idx:
+ color = (0,0,255)
+ elif i in mouth_idx:
+ color = (0,255,0)
+ else:
+ color = (255,0,0)
+ img = cv2.circle(img, center=(x,y), radius=3, color=color, thickness=-1)
+ img = cv2.putText(img, f"{i}", org=(x,y), fontFace=font, fontScale=0.3, color=(255,0,0))
+ writer.append_data(img)
+ writer.close()
diff --git a/deep_3drecon/BFM/.gitkeep b/deep_3drecon/BFM/.gitkeep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deep_3drecon/BFM/basel_53201.txt b/deep_3drecon/BFM/basel_53201.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d53cf20802230fe5310b674b15645421ca86643f
--- /dev/null
+++ b/deep_3drecon/BFM/basel_53201.txt
@@ -0,0 +1,53201 @@
+1
+2
+3
+4
+5
+6
+7
+8
+9
+10
+11
+12
+13
+14
+15
+16
+17
+18
+19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
+590
+591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
+623
+624
+625
+626
+627
+628
+629
+630
+631
+632
+633
+634
+635
+636
+637
+638
+639
+640
+641
+642
+643
+644
+645
+646
+647
+648
+649
+650
+651
+652
+653
+654
+655
+656
+657
+658
+659
+660
+661
+662
+663
+664
+665
+666
+667
+668
+669
+670
+671
+672
+673
+674
+675
+676
+677
+678
+679
+680
+681
+682
+683
+684
+685
+686
+687
+688
+689
+690
+691
+692
+693
+694
+695
+696
+697
+698
+699
+700
+701
+702
+703
+704
+705
+706
+707
+708
+709
+710
+711
+712
+713
+714
+715
+716
+717
+718
+719
+720
+721
+722
+723
+724
+725
+726
+727
+728
+729
+730
+731
+732
+733
+734
+735
+736
+737
+738
+739
+740
+741
+742
+743
+744
+745
+746
+747
+748
+749
+750
+751
+752
+753
+754
+755
+756
+757
+758
+759
+760
+761
+762
+763
+764
+765
+766
+767
+768
+769
+770
+771
+772
+773
+774
+775
+776
+777
+778
+779
+780
+781
+782
+783
+784
+785
+786
+787
+788
+789
+790
+791
+792
+793
+794
+795
+796
+797
+798
+799
+800
+801
+802
+803
+804
+805
+806
+807
+808
+809
+810
+811
+812
+813
+814
+815
+816
+817
+818
+819
+820
+821
+822
+823
+824
+825
+826
+827
+828
+829
+830
+831
+832
+833
+834
+835
+836
+837
+838
+839
+840
+841
+842
+843
+844
+845
+846
+847
+848
+849
+850
+851
+852
+853
+854
+855
+856
+857
+858
+859
+860
+861
+862
+863
+864
+865
+866
+867
+868
+869
+870
+871
+872
+873
+874
+875
+876
+877
+878
+879
+880
+881
+882
+883
+884
+885
+886
+887
+888
+889
+890
+891
+892
+893
+894
+895
+896
+897
+898
+899
+900
+901
+902
+903
+904
+905
+906
+907
+908
+909
+910
+911
+912
+913
+914
+915
+916
+917
+918
+919
+920
+921
+922
+923
+924
+925
+926
+927
+928
+929
+930
+931
+932
+933
+934
+935
+936
+937
+938
+939
+940
+941
+942
+943
+944
+945
+946
+947
+948
+949
+950
+951
+952
+953
+954
+955
+956
+957
+958
+959
+960
+961
+962
+963
+964
+965
+966
+967
+968
+969
+970
+971
+972
+973
+974
+975
+976
+977
+978
+979
+980
+981
+982
+983
+984
+985
+986
+987
+988
+989
+990
+991
+992
+993
+994
+995
+996
+997
+998
+999
+1000
+1001
+1002
+1003
+1004
+1005
+1006
+1007
+1008
+1009
+1010
+1011
+1012
+1013
+1014
+1015
+1016
+1017
+1018
+1019
+1020
+1021
+1022
+1023
+1024
+1025
+1026
+1027
+1028
+1029
+1030
+1031
+1032
+1033
+1034
+1035
+1036
+1037
+1038
+1039
+1040
+1041
+1042
+1043
+1044
+1045
+1046
+1047
+1048
+1049
+1050
+1051
+1052
+1053
+1054
+1055
+1056
+1057
+1058
+1059
+1060
+1061
+1062
+1063
+1064
+1065
+1066
+1067
+1068
+1069
+1070
+1071
+1072
+1073
+1074
+1075
+1076
+1077
+1078
+1079
+1080
+1081
+1082
+1083
+1084
+1085
+1086
+1087
+1088
+1089
+1090
+1091
+1092
+1093
+1094
+1095
+1096
+1097
+1098
+1099
+1100
+1101
+1102
+1103
+1104
+1105
+1106
+1107
+1108
+1109
+1110
+1111
+1112
+1113
+1114
+1115
+1116
+1117
+1118
+1119
+1120
+1121
+1122
+1123
+1124
+1125
+1126
+1127
+1128
+1129
+1130
+1131
+1132
+1133
+1134
+1135
+1136
+1137
+1138
+1139
+1140
+1141
+1142
+1143
+1144
+1145
+1146
+1147
+1148
+1149
+1150
+1151
+1152
+1153
+1154
+1155
+1156
+1157
+1158
+1159
+1160
+1161
+1162
+1163
+1164
+1165
+1166
+1167
+1168
+1169
+1170
+1171
+1172
+1173
+1174
+1175
+1176
+1177
+1178
+1179
+1180
+1181
+1182
+1183
+1184
+1185
+1186
+1187
+1188
+1189
+1190
+1191
+1192
+1193
+1194
+1195
+1196
+1197
+1198
+1199
+1200
+1201
+1202
+1203
+1204
+1205
+1206
+1207
+1208
+1209
+1210
+1211
+1212
+1213
+1214
+1215
+1216
+1217
+1218
+1219
+1220
+1221
+1222
+1223
+1224
+1225
+1226
+1227
+1228
+1229
+1230
+1231
+1232
+1233
+1234
+1235
+1236
+1237
+1238
+1239
+1240
+1241
+1242
+1243
+1244
+1245
+1246
+1247
+1248
+1249
+1250
+1251
+1252
+1253
+1254
+1255
+1256
+1257
+1258
+1259
+1260
+1261
+1262
+1263
+1264
+1265
+1266
+1267
+1268
+1269
+1270
+1271
+1272
+1273
+1274
+1275
+1276
+1277
+1278
+1279
+1280
+1281
+1282
+1283
+1284
+1285
+1286
+1287
+1288
+1289
+1290
+1291
+1292
+1293
+1294
+1295
+1296
+1297
+1298
+1299
+1300
+1301
+1302
+1303
+1304
+1305
+1306
+1307
+1308
+1309
+1310
+1311
+1312
+1313
+1314
+1315
+1316
+1317
+1318
+1319
+1320
+1321
+1322
+1323
+1324
+1325
+1326
+1327
+1328
+1329
+1330
+1331
+1332
+1333
+1334
+1335
+1336
+1337
+1338
+1339
+1340
+1341
+1342
+1343
+1344
+1345
+1346
+1347
+1348
+1349
+1350
+1351
+1352
+1353
+1354
+1355
+1356
+1357
+1358
+1359
+1360
+1361
+1362
+1363
+1364
+1365
+1366
+1367
+1368
+1369
+1370
+1371
+1372
+1373
+1374
+1375
+1376
+1377
+1378
+1379
+1380
+1381
+1382
+1383
+1384
+1385
+1386
+1387
+1388
+1389
+1390
+1391
+1392
+1393
+1394
+1395
+1396
+1397
+1398
+1399
+1400
+1401
+1402
+1403
+1404
+1405
+1406
+1407
+1408
+1409
+1410
+1411
+1412
+1413
+1414
+1415
+1416
+1417
+1418
+1419
+1420
+1421
+1422
+1423
+1424
+1425
+1426
+1427
+1428
+1429
+1430
+1431
+1432
+1433
+1434
+1435
+1436
+1437
+1438
+1439
+1440
+1441
+1442
+1443
+1444
+1445
+1446
+1447
+1448
+1449
+1450
+1451
+1452
+1453
+1454
+1455
+1456
+1457
+1458
+1459
+1460
+1461
+1462
+1463
+1464
+1465
+1466
+1467
+1468
+1469
+1470
+1471
+1472
+1473
+1474
+1475
+1476
+1477
+1478
+1479
+1480
+1481
+1482
+1483
+1484
+1485
+1486
+1487
+1488
+1489
+1490
+1491
+1492
+1493
+1494
+1495
+1496
+1497
+1498
+1499
+1500
+1501
+1502
+1503
+1504
+1505
+1506
+1507
+1508
+1509
+1510
+1511
+1512
+1513
+1514
+1515
+1516
+1517
+1518
+1519
+1520
+1521
+1522
+1523
+1524
+1525
+1526
+1527
+1528
+1529
+1530
+1531
+1532
+1533
+1534
+1535
+1536
+1537
+1538
+1539
+1540
+1541
+1542
+1543
+1544
+1545
+1546
+1547
+1548
+1549
+1550
+1551
+1552
+1553
+1554
+1555
+1556
+1557
+1558
+1559
+1560
+1561
+1562
+1563
+1564
+1565
+1566
+1567
+1568
+1569
+1570
+1571
+1572
+1573
+1574
+1575
+1576
+1577
+1578
+1579
+1580
+1581
+1582
+1583
+1584
+1585
+1586
+1587
+1588
+1589
+1590
+1591
+1592
+1593
+1594
+1595
+1596
+1597
+1598
+1599
+1600
+1601
+1602
+1603
+1604
+1605
+1606
+1607
+1608
+1609
+1610
+1611
+1612
+1613
+1614
+1615
+1616
+1617
+1618
+1619
+1620
+1621
+1622
+1623
+1624
+1625
+1626
+1627
+1628
+1629
+1630
+1631
+1632
+1633
+1634
+1635
+1636
+1637
+1638
+1639
+1640
+1641
+1642
+1643
+1644
+1645
+1646
+1647
+1648
+1649
+1650
+1651
+1652
+1653
+1654
+1655
+1656
+1657
+1658
+1659
+1660
+1661
+1662
+1663
+1664
+1665
+1666
+1667
+1668
+1669
+1670
+1671
+1672
+1673
+1674
+1675
+1676
+1677
+1678
+1679
+1680
+1681
+1682
+1683
+1684
+1685
+1686
+1687
+1688
+1689
+1690
+1691
+1692
+1693
+1694
+1695
+1696
+1697
+1698
+1699
+1700
+1701
+1702
+1703
+1704
+1705
+1706
+1707
+1708
+1709
+1710
+1711
+1712
+1713
+1714
+1715
+1716
+1717
+1718
+1719
+1720
+1721
+1722
+1723
+1724
+1725
+1726
+1727
+1728
+1729
+1730
+1731
+1732
+1733
+1734
+1735
+1736
+1737
+1738
+1739
+1740
+1741
+1742
+1743
+1744
+1745
+1746
+1747
+1748
+1749
+1750
+1751
+1752
+1753
+1754
+1755
+1756
+1757
+1758
+1759
+1760
+1761
+1762
+1763
+1764
+1765
+1766
+1767
+1768
+1769
+1770
+1771
+1772
+1773
+1774
+1775
+1776
+1777
+1778
+1779
+1780
+1781
+1782
+1783
+1784
+1785
+1786
+1787
+1788
+1789
+1790
+1791
+1792
+1793
+1794
+1795
+1796
+1797
+1798
+1799
+1800
+1801
+1802
+1803
+1804
+1805
+1806
+1807
+1808
+1809
+1810
+1811
+1812
+1813
+1814
+1815
+1816
+1817
+1818
+1819
+1820
+1821
+1822
+1823
+1824
+1825
+1826
+1827
+1828
+1829
+1830
+1831
+1832
+1833
+1834
+1835
+1836
+1837
+1838
+1839
+1840
+1841
+1842
+1843
+1844
+1845
+1846
+1847
+1848
+1849
+1850
+1851
+1852
+1853
+1854
+1855
+1856
+1857
+1858
+1859
+1860
+1861
+1862
+1863
+1864
+1865
+1866
+1867
+1868
+1869
+1870
+1871
+1872
+1873
+1874
+1875
+1876
+1877
+1878
+1879
+1880
+1881
+1882
+1883
+1884
+1885
+1886
+1887
+1888
+1889
+1890
+1891
+1892
+1893
+1894
+1895
+1896
+1897
+1898
+1899
+1900
+1901
+1902
+1903
+1904
+1905
+1906
+1907
+1908
+1909
+1910
+1911
+1912
+1913
+1914
+1915
+1916
+1917
+1918
+1919
+1920
+1921
+1922
+1923
+1924
+1925
+1926
+1927
+1928
+1929
+1930
+1931
+1932
+1933
+1934
+1935
+1936
+1937
+1938
+1939
+1940
+1941
+1942
+1943
+1944
+1945
+1946
+1947
+1948
+1949
+1950
+1951
+1952
+1953
+1954
+1955
+1956
+1957
+1958
+1959
+1960
+1961
+1962
+1963
+1964
+1965
+1966
+1967
+1968
+1969
+1970
+1971
+1972
+1973
+1974
+1975
+1976
+1977
+1978
+1979
+1980
+1981
+1982
+1983
+1984
+1985
+1986
+1987
+1988
+1989
+1990
+1991
+1992
+1993
+1994
+1995
+1996
+1997
+1998
+1999
+2000
+2001
+2002
+2003
+2004
+2005
+2006
+2007
+2008
+2009
+2010
+2011
+2012
+2013
+2014
+2015
+2016
+2017
+2018
+2019
+2020
+2021
+2022
+2023
+2024
+2025
+2026
+2027
+2028
+2029
+2030
+2031
+2032
+2033
+2034
+2035
+2036
+2037
+2038
+2039
+2040
+2041
+2042
+2043
+2044
+2045
+2046
+2047
+2048
+2049
+2050
+2051
+2052
+2053
+2054
+2055
+2056
+2057
+2058
+2059
+2060
+2061
+2062
+2063
+2064
+2065
+2066
+2067
+2068
+2069
+2070
+2071
+2072
+2073
+2074
+2075
+2076
+2077
+2078
+2079
+2080
+2081
+2082
+2083
+2084
+2085
+2086
+2087
+2088
+2089
+2090
+2091
+2092
+2093
+2094
+2095
+2096
+2097
+2098
+2099
+2100
+2101
+2102
+2103
+2104
+2105
+2106
+2107
+2108
+2109
+2110
+2111
+2112
+2113
+2114
+2115
+2116
+2117
+2118
+2119
+2120
+2121
+2122
+2123
+2124
+2125
+2126
+2127
+2128
+2129
+2130
+2131
+2132
+2133
+2134
+2135
+2136
+2137
+2138
+2139
+2140
+2141
+2142
+2143
+2144
+2145
+2146
+2147
+2148
+2149
+2150
+2151
+2152
+2153
+2154
+2155
+2156
+2157
+2158
+2159
+2160
+2161
+2162
+2163
+2164
+2165
+2166
+2167
+2168
+2169
+2170
+2171
+2172
+2173
+2174
+2175
+2176
+2177
+2178
+2179
+2180
+2181
+2182
+2183
+2184
+2185
+2186
+2187
+2188
+2189
+2190
+2191
+2192
+2193
+2194
+2195
+2196
+2197
+2198
+2199
+2200
+2201
+2202
+2203
+2204
+2205
+2206
+2207
+2208
+2209
+2210
+2211
+2212
+2213
+2214
+2215
+2216
+2217
+2218
+2219
+2220
+2221
+2222
+2223
+2224
+2225
+2226
+2227
+2228
+2229
+2230
+2231
+2232
+2233
+2234
+2235
+2236
+2237
+2238
+2239
+2240
+2241
+2242
+2243
+2244
+2245
+2246
+2247
+2248
+2249
+2250
+2251
+2252
+2253
+2254
+2255
+2256
+2257
+2258
+2259
+2260
+2261
+2262
+2263
+2264
+2265
+2266
+2267
+2268
+2269
+2270
+2271
+2272
+2273
+2274
+2275
+2276
+2277
+2278
+2279
+2280
+2281
+2282
+2283
+2284
+2285
+2286
+2287
+2288
+2289
+2290
+2291
+2292
+2293
+2294
+2295
+2296
+2297
+2298
+2299
+2300
+2301
+2302
+2303
+2304
+2305
+2306
+2307
+2308
+2309
+2310
+2311
+2312
+2313
+2314
+2315
+2316
+2317
+2318
+2319
+2320
+2321
+2322
+2323
+2324
+2325
+2326
+2327
+2328
+2329
+2330
+2331
+2332
+2333
+2334
+2335
+2336
+2337
+2338
+2339
+2340
+2341
+2342
+2343
+2344
+2345
+2346
+2347
+2348
+2349
+2350
+2351
+2352
+2353
+2354
+2355
+2356
+2357
+2358
+2359
+2360
+2361
+2362
+2363
+2364
+2365
+2366
+2367
+2368
+2369
+2370
+2371
+2372
+2373
+2374
+2375
+2376
+2377
+2378
+2379
+2380
+2381
+2382
+2383
+2384
+2385
+2386
+2387
+2388
+2389
+2390
+2391
+2392
+2393
+2394
+2395
+2396
+2397
+2398
+2399
+2400
+2401
+2402
+2403
+2404
+2405
+2406
+2407
+2408
+2409
+2410
+2411
+2412
+2413
+2414
+2415
+2416
+2417
+2418
+2419
+2420
+2421
+2422
+2423
+2424
+2425
+2426
+2427
+2428
+2429
+2430
+2431
+2432
+2433
+2434
+2435
+2436
+2437
+2438
+2439
+2440
+2441
+2442
+2443
+2444
+2445
+2446
+2447
+2448
+2449
+2450
+2451
+2452
+2453
+2454
+2455
+2456
+2457
+2458
+2459
+2460
+2461
+2462
+2463
+2464
+2465
+2466
+2467
+2468
+2469
+2470
+2471
+2472
+2473
+2474
+2475
+2476
+2477
+2478
+2479
+2480
+2481
+2482
+2483
+2484
+2485
+2486
+2487
+2488
+2489
+2490
+2491
+2492
+2493
+2494
+2495
+2496
+2497
+2498
+2499
+2500
+2501
+2502
+2503
+2504
+2505
+2506
+2507
+2508
+2509
+2510
+2511
+2512
+2513
+2514
+2515
+2516
+2517
+2518
+2519
+2520
+2521
+2522
+2523
+2524
+2525
+2526
+2527
+2528
+2529
+2530
+2531
+2532
+2533
+2534
+2535
+2536
+2537
+2538
+2539
+2540
+2541
+2542
+2543
+2544
+2545
+2546
+2547
+2548
+2549
+2550
+2551
+2552
+2553
+2554
+2555
+2556
+2557
+2558
+2559
+2560
+2561
+2562
+2563
+2564
+2565
+2566
+2567
+2568
+2569
+2570
+2571
+2572
+2573
+2574
+2575
+2576
+2577
+2578
+2579
+2580
+2581
+2582
+2583
+2584
+2585
+2586
+2587
+2588
+2589
+2590
+2591
+2592
+2593
+2594
+2595
+2596
+2597
+2598
+2599
+2600
+2601
+2602
+2603
+2604
+2605
+2606
+2607
+2608
+2609
+2610
+2611
+2612
+2613
+2614
+2615
+2616
+2617
+2618
+2619
+2620
+2621
+2622
+2623
+2624
+2625
+2626
+2627
+2628
+2629
+2630
+2631
+2632
+2633
+2634
+2635
+2636
+2637
+2638
+2639
+2640
+2641
+2642
+2643
+2644
+2645
+2646
+2647
+2648
+2649
+2650
+2651
+2652
+2653
+2654
+2655
+2656
+2657
+2658
+2659
+2660
+2661
+2662
+2663
+2664
+2665
+2666
+2667
+2668
+2669
+2670
+2671
+2672
+2673
+2674
+2675
+2676
+2677
+2678
+2679
+2680
+2681
+2682
+2683
+2684
+2685
+2686
+2687
+2688
+2689
+2690
+2691
+2692
+2693
+2694
+2695
+2696
+2697
+2698
+2699
+2700
+2701
+2702
+2703
+2704
+2705
+2706
+2707
+2708
+2709
+2710
+2711
+2712
+2713
+2714
+2715
+2716
+2717
+2718
+2719
+2720
+2721
+2722
+2723
+2724
+2725
+2726
+2727
+2728
+2729
+2730
+2731
+2732
+2733
+2734
+2735
+2736
+2737
+2738
+2739
+2740
+2741
+2742
+2743
+2744
+2745
+2746
+2747
+2748
+2749
+2750
+2751
+2752
+2753
+2754
+2755
+2756
+2757
+2758
+2759
+2760
+2761
+2762
+2763
+2764
+2765
+2766
+2767
+2768
+2769
+2770
+2771
+2772
+2773
+2774
+2775
+2776
+2777
+2778
+2779
+2780
+2781
+2782
+2783
+2784
+2785
+2786
+2787
+2788
+2789
+2790
+2791
+2792
+2793
+2794
+2795
+2796
+2797
+2798
+2799
+2800
+2801
+2802
+2803
+2804
+2805
+2806
+2807
+2808
+2809
+2810
+2811
+2812
+2813
+2814
+2815
+2816
+2817
+2818
+2819
+2820
+2821
+2822
+2823
+2824
+2825
+2826
+2827
+2828
+2829
+2830
+2831
+2832
+2833
+2834
+2835
+2836
+2837
+2838
+2839
+2840
+2841
+2842
+2843
+2844
+2845
+2846
+2847
+2848
+2849
+2850
+2851
+2852
+2853
+2854
+2855
+2856
+2857
+2858
+2859
+2860
+2861
+2862
+2863
+2864
+2865
+2866
+2867
+2868
+2869
+2870
+2871
+2872
+2873
+2874
+2875
+2876
+2877
+2878
+2879
+2880
+2881
+2882
+2883
+2884
+2885
+2886
+2887
+2888
+2889
+2890
+2891
+2892
+2893
+2894
+2895
+2896
+2897
+2898
+2899
+2900
+2901
+2902
+2903
+2904
+2905
+2906
+2907
+2908
+2909
+2910
+2911
+2912
+2913
+2914
+2915
+2916
+2917
+2918
+2919
+2920
+2921
+2922
+2923
+2924
+2925
+2926
+2927
+2928
+2929
+2930
+2931
+2932
+2933
+2934
+2935
+2936
+2937
+2938
+2939
+2940
+2941
+2942
+2943
+2944
+2945
+2946
+2947
+2948
+2949
+2950
+2951
+2952
+2953
+2954
+2955
+2956
+2957
+2958
+2959
+2960
+2961
+2962
+2963
+2964
+2965
+2966
+2967
+2968
+2969
+2970
+2971
+2972
+2973
+2974
+2975
+2976
+2977
+2978
+2979
+2980
+2981
+2982
+2983
+2984
+2985
+2986
+2987
+2988
+2989
+2990
+2991
+2992
+2993
+2994
+2995
+2996
+2997
+2998
+2999
+3000
+3001
+3002
+3003
+3004
+3005
+3006
+3007
+3008
+3009
+3010
+3011
+3012
+3013
+3014
+3015
+3016
+3017
+3018
+3019
+3020
+3021
+3022
+3023
+3024
+3025
+3026
+3027
+3028
+3029
+3030
+3031
+3032
+3033
+3034
+3035
+3036
+3037
+3038
+3039
+3040
+3041
+3042
+3043
+3044
+3045
+3046
+3047
+3048
+3049
+3050
+3051
+3052
+3053
+3054
+3055
+3056
+3057
+3058
+3059
+3060
+3061
+3062
+3063
+3064
+3065
+3066
+3067
+3068
+3069
+3070
+3071
+3072
+3073
+3074
+3075
+3076
+3077
+3078
+3079
+3080
+3081
+3082
+3083
+3084
+3085
+3086
+3087
+3088
+3089
+3090
+3091
+3092
+3093
+3094
+3095
+3096
+3097
+3098
+3099
+3100
+3101
+3102
+3103
+3104
+3105
+3106
+3107
+3108
+3109
+3110
+3111
+3112
+3113
+3114
+3115
+3116
+3117
+3118
+3119
+3120
+3121
+3122
+3123
+3124
+3125
+3126
+3127
+3128
+3129
+3130
+3131
+3132
+3133
+3134
+3135
+3136
+3137
+3138
+3139
+3140
+3141
+3142
+3143
+3144
+3145
+3146
+3147
+3148
+3149
+3150
+3151
+3152
+3153
+3154
+3155
+3156
+3157
+3158
+3159
+3160
+3161
+3162
+3163
+3164
+3165
+3166
+3167
+3168
+3169
+3170
+3171
+3172
+3173
+3174
+3175
+3176
+3177
+3178
+3179
+3180
+3181
+3182
+3183
+3184
+3185
+3186
+3187
+3188
+3189
+3190
+3191
+3192
+3193
+3194
+3195
+3196
+3197
+3198
+3199
+3200
+3201
+3202
+3203
+3204
+3205
+3206
+3207
+3208
+3209
+3210
+3211
+3212
+3213
+3214
+3215
+3216
+3217
+3218
+3219
+3220
+3221
+3222
+3223
+3224
+3225
+3226
+3227
+3228
+3229
+3230
+3231
+3232
+3233
+3234
+3235
+3236
+3237
+3238
+3239
+3240
+3241
+3242
+3243
+3244
+3245
+3246
+3247
+3248
+3249
+3250
+3251
+3252
+3253
+3254
+3255
+3256
+3257
+3258
+3259
+3260
+3261
+3262
+3263
+3264
+3265
+3266
+3267
+3268
+3269
+3270
+3271
+3272
+3273
+3274
+3275
+3276
+3277
+3278
+3279
+3280
+3281
+3282
+3283
+3284
+3285
+3286
+3287
+3288
+3289
+3290
+3291
+3292
+3293
+3294
+3295
+3296
+3297
+3298
+3299
+3300
+3301
+3302
+3303
+3304
+3305
+3306
+3307
+3308
+3309
+3310
+3311
+3312
+3313
+3314
+3315
+3316
+3317
+3318
+3319
+3320
+3321
+3322
+3323
+3324
+3325
+3326
+3327
+3328
+3329
+3330
+3331
+3332
+3333
+3334
+3335
+3336
+3337
+3338
+3339
+3340
+3341
+3342
+3343
+3344
+3345
+3346
+3347
+3348
+3349
+3350
+3351
+3352
+3353
+3354
+3355
+3356
+3357
+3358
+3359
+3360
+3361
+3362
+3363
+3364
+3365
+3366
+3367
+3368
+3369
+3370
+3371
+3372
+3373
+3374
+3375
+3376
+3377
+3378
+3379
+3380
+3381
+3382
+3383
+3384
+3385
+3386
+3387
+3388
+3389
+3390
+3391
+3392
+3393
+3394
+3395
+3396
+3397
+3398
+3399
+3400
+3401
+3402
+3403
+3404
+3405
+3406
+3407
+3408
+3409
+3410
+3411
+3412
+3413
+3414
+3415
+3416
+3417
+3418
+3419
+3420
+3421
+3422
+3423
+3424
+3425
+3426
+3427
+3428
+3429
+3430
+3431
+3432
+3433
+3434
+3435
+3436
+3437
+3438
+3439
+3440
+3441
+3442
+3443
+3444
+3445
+3446
+3447
+3448
+3449
+3450
+3451
+3452
+3453
+3454
+3455
+3456
+3457
+3458
+3459
+3460
+3461
+3462
+3463
+3464
+3465
+3466
+3467
+3468
+3469
+3470
+3471
+3472
+3473
+3474
+3475
+3476
+3477
+3478
+3479
+3480
+3481
+3482
+3483
+3484
+3485
+3486
+3487
+3488
+3489
+3490
+3491
+3492
+3493
+3494
+3495
+3496
+3497
+3498
+3499
+3500
+3501
+3502
+3503
+3504
+3505
+3506
+3507
+3508
+3509
+3510
+3511
+3512
+3513
+3514
+3515
+3516
+3517
+3518
+3519
+3520
+3521
+3522
+3523
+3524
+3525
+3526
+3527
+3528
+3529
+3530
+3531
+3532
+3533
+3534
+3535
+3536
+3537
+3538
+3539
+3540
+3541
+3542
+3543
+3544
+3545
+3546
+3547
+3548
+3549
+3550
+3551
+3552
+3553
+3554
+3555
+3556
+3557
+3558
+3559
+3560
+3561
+3562
+3563
+3564
+3565
+3566
+3567
+3568
+3569
+3570
+3571
+3572
+3573
+3574
+3575
+3576
+3577
+3578
+3579
+3580
+3581
+3582
+3583
+3584
+3585
+3586
+3587
+3588
+3589
+3590
+3591
+3592
+3593
+3594
+3595
+3596
+3597
+3598
+3599
+3600
+3601
+3602
+3603
+3604
+3605
+3606
+3607
+3608
+3609
+3610
+3611
+3612
+3613
+3614
+3615
+3616
+3617
+3618
+3619
+3620
+3621
+3622
+3623
+3624
+3625
+3626
+3627
+3628
+3629
+3630
+3631
+3632
+3633
+3634
+3635
+3636
+3637
+3638
+3639
+3640
+3641
+3642
+3643
+3644
+3645
+3646
+3647
+3648
+3649
+3650
+3651
+3652
+3653
+3654
+3655
+3656
+3657
+3658
+3659
+3660
+3661
+3662
+3663
+3664
+3665
+3666
+3667
+3668
+3669
+3670
+3671
+3672
+3673
+3674
+3675
+3676
+3677
+3678
+3679
+3680
+3681
+3682
+3683
+3684
+3685
+3686
+3687
+3688
+3689
+3690
+3691
+3692
+3693
+3694
+3695
+3696
+3697
+3698
+3699
+3700
+3701
+3702
+3703
+3704
+3705
+3706
+3707
+3708
+3709
+3710
+3711
+3712
+3713
+3714
+3715
+3716
+3717
+3718
+3719
+3720
+3721
+3722
+3723
+3724
+3725
+3726
+3727
+3728
+3729
+3730
+3731
+3732
+3733
+3734
+3735
+3736
+3737
+3738
+3739
+3740
+3741
+3742
+3743
+3744
+3745
+3746
+3747
+3748
+3749
+3750
+3751
+3752
+3753
+3754
+3755
+3756
+3757
+3758
+3759
+3760
+3761
+3762
+3763
+3764
+3765
+3766
+3767
+3768
+3769
+3770
+3771
+3772
+3773
+3774
+3775
+3776
+3777
+3778
+3779
+3780
+3781
+3782
+3783
+3784
+3785
+3786
+3787
+3788
+3789
+3790
+3791
+3792
+3793
+3794
+3795
+3796
+3797
+3798
+3799
+3800
+3801
+3802
+3803
+3804
+3805
+3806
+3807
+3808
+3809
+3810
+3811
+3812
+3813
+3814
+3815
+3816
+3817
+3818
+3819
+3820
+3821
+3822
+3823
+3824
+3825
+3826
+3827
+3828
+3829
+3830
+3831
+3832
+3833
+3834
+3835
+3836
+3837
+3838
+3839
+3840
+3841
+3842
+3843
+3844
+3845
+3846
+3847
+3848
+3849
+3850
+3851
+3852
+3853
+3854
+3855
+3856
+3857
+3858
+3859
+3860
+3861
+3862
+3863
+3864
+3865
+3866
+3867
+3868
+3869
+3870
+3871
+3872
+3873
+3874
+3875
+3876
+3877
+3878
+3879
+3880
+3881
+3882
+3883
+3884
+3885
+3886
+3887
+3888
+3889
+3890
+3891
+3892
+3893
+3894
+3895
+3896
+3897
+3898
+3899
+3900
+3901
+3902
+3903
+3904
+3905
+3906
+3907
+3908
+3909
+3910
+3911
+3912
+3913
+3914
+3915
+3916
+3917
+3918
+3919
+3920
+3921
+3922
+3923
+3924
+3925
+3926
+3927
+3928
+3929
+3930
+3931
+3932
+3933
+3934
+3935
+3936
+3937
+3938
+3939
+3940
+3941
+3942
+3943
+3944
+3945
+3946
+3947
+3948
+3949
+3950
+3951
+3952
+3953
+3954
+3955
+3956
+3957
+3958
+3959
+3960
+3961
+3962
+3963
+3964
+3965
+3966
+3967
+3968
+3969
+3970
+3971
+3972
+3973
+3974
+3975
+3976
+3977
+3978
+3979
+3980
+3981
+3982
+3983
+3984
+3985
+3986
+3987
+3988
+3989
+3990
+3991
+3992
+3993
+3994
+3995
+3996
+3997
+3998
+3999
+4000
+4001
+4002
+4003
+4004
+4005
+4006
+4007
+4008
+4009
+4010
+4011
+4012
+4013
+4014
+4015
+4016
+4017
+4018
+4019
+4020
+4021
+4022
+4023
+4024
+4025
+4026
+4027
+4028
+4029
+4030
+4031
+4032
+4033
+4034
+4035
+4036
+4037
+4038
+4039
+4040
+4041
+4042
+4043
+4044
+4045
+4046
+4047
+4048
+4049
+4050
+4051
+4052
+4053
+4054
+4055
+4056
+4057
+4058
+4059
+4060
+4061
+4062
+4063
+4064
+4065
+4066
+4067
+4068
+4069
+4070
+4071
+4072
+4073
+4074
+4075
+4076
+4077
+4078
+4079
+4080
+4081
+4082
+4083
+4084
+4085
+4086
+4087
+4088
+4089
+4090
+4091
+4092
+4093
+4094
+4095
+4096
+4097
+4098
+4099
+4100
+4101
+4102
+4103
+4104
+4105
+4106
+4107
+4108
+4109
+4110
+4111
+4112
+4113
+4114
+4115
+4116
+4117
+4118
+4119
+4120
+4121
+4122
+4123
+4124
+4125
+4126
+4127
+4128
+4129
+4130
+4131
+4132
+4133
+4134
+4135
+4136
+4137
+4138
+4139
+4140
+4141
+4142
+4143
+4144
+4145
+4146
+4147
+4148
+4149
+4150
+4151
+4152
+4153
+4154
+4155
+4156
+4157
+4158
+4159
+4160
+4161
+4162
+4163
+4164
+4165
+4166
+4167
+4168
+4169
+4170
+4171
+4172
+4173
+4174
+4175
+4176
+4177
+4178
+4179
+4180
+4181
+4182
+4183
+4184
+4185
+4186
+4187
+4188
+4189
+4190
+4191
+4192
+4193
+4194
+4195
+4196
+4197
+4198
+4199
+4200
+4201
+4202
+4203
+4204
+4205
+4206
+4207
+4208
+4209
+4210
+4211
+4212
+4213
+4214
+4215
+4216
+4217
+4218
+4219
+4220
+4221
+4222
+4223
+4224
+4225
+4226
+4227
+4228
+4229
+4230
+4231
+4232
+4233
+4234
+4235
+4236
+4237
+4238
+4239
+4240
+4241
+4242
+4243
+4244
+4245
+4246
+4247
+4248
+4249
+4250
+4251
+4252
+4253
+4254
+4255
+4256
+4257
+4258
+4259
+4260
+4261
+4262
+4263
+4264
+4265
+4266
+4267
+4268
+4269
+4270
+4271
+4272
+4273
+4274
+4275
+4276
+4277
+4278
+4279
+4280
+4281
+4282
+4283
+4284
+4285
+4286
+4287
+4288
+4289
+4290
+4291
+4292
+4293
+4294
+4295
+4296
+4297
+4298
+4299
+4300
+4301
+4302
+4303
+4304
+4305
+4306
+4307
+4308
+4309
+4310
+4311
+4312
+4313
+4314
+4315
+4316
+4317
+4318
+4319
+4320
+4321
+4322
+4323
+4324
+4325
+4326
+4327
+4328
+4329
+4330
+4331
+4332
+4333
+4334
+4335
+4336
+4337
+4338
+4339
+4340
+4341
+4342
+4343
+4344
+4345
+4346
+4347
+4348
+4349
+4350
+4351
+4352
+4353
+4354
+4355
+4356
+4357
+4358
+4359
+4360
+4361
+4362
+4363
+4364
+4365
+4366
+4367
+4368
+4369
+4370
+4371
+4372
+4373
+4374
+4375
+4376
+4377
+4378
+4379
+4380
+4381
+4382
+4383
+4384
+4385
+4386
+4387
+4388
+4389
+4390
+4391
+4392
+4393
+4394
+4395
+4396
+4397
+4398
+4399
+4400
+4401
+4402
+4403
+4404
+4405
+4406
+4407
+4408
+4409
+4410
+4411
+4412
+4413
+4414
+4415
+4416
+4417
+4418
+4419
+4420
+4421
+4422
+4423
+4424
+4425
+4426
+4427
+4428
+4429
+4430
+4431
+4432
+4433
+4434
+4435
+4436
+4437
+4438
+4439
+4440
+4441
+4442
+4443
+4444
+4445
+4446
+4447
+4448
+4449
+4450
+4451
+4452
+4453
+4454
+4455
+4456
+4457
+4458
+4459
+4460
+4461
+4462
+4463
+4464
+4465
+4466
+4467
+4468
+4469
+4470
+4471
+4472
+4473
+4474
+4475
+4476
+4477
+4478
+4479
+4480
+4481
+4482
+4483
+4484
+4485
+4486
+4487
+4488
+4489
+4490
+4491
+4492
+4493
+4494
+4495
+4496
+4497
+4498
+4499
+4500
+4501
+4502
+4503
+4504
+4505
+4506
+4507
+4508
+4509
+4510
+4511
+4512
+4513
+4514
+4515
+4516
+4517
+4518
+4519
+4520
+4521
+4522
+4523
+4524
+4525
+4526
+4527
+4528
+4529
+4530
+4531
+4532
+4533
+4534
+4535
+4536
+4537
+4538
+4539
+4540
+4541
+4542
+4543
+4544
+4545
+4546
+4547
+4548
+4549
+4550
+4551
+4552
+4553
+4554
+4555
+4556
+4557
+4558
+4559
+4560
+4561
+4562
+4563
+4564
+4565
+4566
+4567
+4568
+4569
+4570
+4571
+4572
+4573
+4574
+4575
+4576
+4577
+4578
+4579
+4580
+4581
+4582
+4583
+4584
+4585
+4586
+4587
+4588
+4589
+4590
+4591
+4592
+4593
+4594
+4595
+4596
+4597
+4598
+4599
+4600
+4601
+4602
+4603
+4604
+4605
+4606
+4607
+4608
+4609
+4610
+4611
+4612
+4613
+4614
+4615
+4616
+4617
+4618
+4619
+4620
+4621
+4622
+4623
+4624
+4625
+4626
+4627
+4628
+4629
+4630
+4631
+4632
+4633
+4634
+4635
+4636
+4637
+4638
+4639
+4640
+4641
+4642
+4643
+4644
+4645
+4646
+4647
+4648
+4649
+4650
+4651
+4652
+4653
+4654
+4655
+4656
+4657
+4658
+4659
+4660
+4661
+4662
+4663
+4664
+4665
+4666
+4667
+4668
+4669
+4670
+4671
+4672
+4673
+4674
+4675
+4676
+4677
+4678
+4679
+4680
+4681
+4682
+4683
+4684
+4685
+4686
+4687
+4688
+4689
+4690
+4691
+4692
+4693
+4694
+4695
+4696
+4697
+4698
+4699
+4700
+4701
+4702
+4703
+4704
+4705
+4706
+4707
+4708
+4709
+4710
+4711
+4712
+4713
+4714
+4715
+4716
+4717
+4718
+4719
+4720
+4721
+4722
+4723
+4724
+4725
+4726
+4727
+4728
+4729
+4730
+4731
+4732
+4733
+4734
+4735
+4736
+4737
+4738
+4739
+4740
+4741
+4742
+4743
+4744
+4745
+4746
+4747
+4748
+4749
+4750
+4751
+4752
+4753
+4754
+4755
+4756
+4757
+4758
+4759
+4760
+4761
+4762
+4763
+4764
+4765
+4766
+4767
+4768
+4769
+4770
+4771
+4772
+4773
+4774
+4775
+4776
+4777
+4778
+4779
+4780
+4781
+4782
+4783
+4784
+4785
+4786
+4787
+4788
+4789
+4790
+4791
+4792
+4793
+4794
+4795
+4796
+4797
+4798
+4799
+4800
+4801
+4802
+4803
+4804
+4805
+4806
+4807
+4808
+4809
+4810
+4811
+4812
+4813
+4814
+4815
+4816
+4817
+4818
+4819
+4820
+4821
+4822
+4823
+4824
+4825
+4826
+4827
+4828
+4829
+4830
+4831
+4832
+4833
+4834
+4835
+4836
+4837
+4838
+4839
+4840
+4841
+4842
+4843
+4844
+4845
+4846
+4847
+4848
+4849
+4850
+4851
+4852
+4853
+4854
+4855
+4856
+4857
+4858
+4859
+4860
+4861
+4862
+4863
+4864
+4865
+4866
+4867
+4868
+4869
+4870
+4871
+4872
+4873
+4874
+4875
+4876
+4877
+4878
+4879
+4880
+4881
+4882
+4883
+4884
+4885
+4886
+4887
+4888
+4889
+4890
+4891
+4892
+4893
+4894
+4895
+4896
+4897
+4898
+4899
+4900
+4901
+4902
+4903
+4904
+4905
+4906
+4907
+4908
+4909
+4910
+4911
+4912
+4913
+4914
+4915
+4916
+4917
+4918
+4919
+4920
+4921
+4922
+4923
+4924
+4925
+4926
+4927
+4928
+4929
+4930
+4931
+4932
+4933
+4934
+4935
+4936
+4937
+4938
+4939
+4940
+4941
+4942
+4943
+4944
+4945
+4946
+4947
+4948
+4949
+4950
+4951
+4952
+4953
+4954
+4955
+4956
+4957
+4958
+4959
+4960
+4961
+4962
+4963
+4964
+4965
+4966
+4967
+4968
+4969
+4970
+4971
+4972
+4973
+4974
+4975
+4976
+4977
+4978
+4979
+4980
+4981
+4982
+4983
+4984
+4985
+4986
+4987
+4988
+4989
+4990
+4991
+4992
+4993
+4994
+4995
+4996
+4997
+4998
+4999
+5000
+5001
+5002
+5003
+5004
+5005
+5006
+5007
+5008
+5009
+5010
+5011
+5012
+5013
+5014
+5015
+5016
+5017
+5018
+5019
+5020
+5021
+5022
+5023
+5024
+5025
+5026
+5027
+5028
+5029
+5030
+5031
+5032
+5033
+5034
+5035
+5036
+5037
+5038
+5039
+5040
+5041
+5042
+5043
+5044
+5045
+5046
+5047
+5048
+5049
+5050
+5051
+5052
+5053
+5054
+5055
+5056
+5057
+5058
+5059
+5060
+5061
+5062
+5063
+5064
+5065
+5066
+5067
+5068
+5069
+5070
+5071
+5072
+5073
+5074
+5075
+5076
+5077
+5078
+5079
+5080
+5081
+5082
+5083
+5084
+5085
+5086
+5087
+5088
+5089
+5090
+5091
+5092
+5093
+5094
+5095
+5096
+5097
+5098
+5099
+5100
+5101
+5102
+5103
+5104
+5105
+5106
+5107
+5108
+5109
+5110
+5111
+5112
+5113
+5114
+5115
+5116
+5117
+5118
+5119
+5120
+5121
+5122
+5123
+5124
+5125
+5126
+5127
+5128
+5129
+5130
+5131
+5132
+5133
+5134
+5135
+5136
+5137
+5138
+5139
+5140
+5141
+5142
+5143
+5144
+5145
+5146
+5147
+5148
+5149
+5150
+5151
+5152
+5153
+5154
+5155
+5156
+5157
+5158
+5159
+5160
+5161
+5162
+5163
+5164
+5165
+5166
+5167
+5168
+5169
+5170
+5171
+5172
+5173
+5174
+5175
+5176
+5177
+5178
+5179
+5180
+5181
+5182
+5183
+5184
+5185
+5186
+5187
+5188
+5189
+5190
+5191
+5192
+5193
+5194
+5195
+5196
+5197
+5198
+5199
+5200
+5201
+5202
+5203
+5204
+5205
+5206
+5207
+5208
+5209
+5210
+5211
+5212
+5213
+5214
+5215
+5216
+5217
+5218
+5219
+5220
+5221
+5222
+5223
+5224
+5225
+5226
+5227
+5228
+5229
+5230
+5231
+5232
+5233
+5234
+5235
+5236
+5237
+5238
+5239
+5240
+5241
+5242
+5243
+5244
+5245
+5246
+5247
+5248
+5249
+5250
+5251
+5252
+5253
+5254
+5255
+5256
+5257
+5258
+5259
+5260
+5261
+5262
+5263
+5264
+5265
+5266
+5267
+5268
+5269
+5270
+5271
+5272
+5273
+5274
+5275
+5276
+5277
+5278
+5279
+5280
+5281
+5282
+5283
+5284
+5285
+5286
+5287
+5288
+5289
+5290
+5291
+5292
+5293
+5294
+5295
+5296
+5297
+5298
+5299
+5300
+5301
+5302
+5303
+5304
+5305
+5306
+5307
+5308
+5309
+5310
+5311
+5312
+5313
+5314
+5315
+5316
+5317
+5318
+5319
+5320
+5321
+5322
+5323
+5324
+5325
+5326
+5327
+5328
+5329
+5330
+5331
+5332
+5333
+5334
+5335
+5336
+5337
+5338
+5339
+5340
+5341
+5342
+5343
+5344
+5345
+5346
+5347
+5348
+5349
+5350
+5351
+5352
+5353
+5354
+5355
+5356
+5357
+5358
+5359
+5360
+5361
+5362
+5363
+5364
+5365
+5366
+5367
+5368
+5369
+5370
+5371
+5372
+5373
+5374
+5375
+5376
+5377
+5378
+5379
+5380
+5381
+5382
+5383
+5384
+5385
+5386
+5387
+5388
+5389
+5390
+5391
+5392
+5393
+5394
+5395
+5396
+5397
+5398
+5399
+5400
+5401
+5402
+5403
+5404
+5405
+5406
+5407
+5408
+5409
+5410
+5411
+5412
+5413
+5414
+5415
+5416
+5417
+5418
+5419
+5420
+5421
+5422
+5423
+5424
+5425
+5426
+5427
+5428
+5429
+5430
+5431
+5432
+5433
+5434
+5435
+5436
+5437
+5438
+5439
+5440
+5441
+5442
+5443
+5444
+5445
+5446
+5447
+5448
+5449
+5450
+5451
+5452
+5453
+5454
+5455
+5456
+5457
+5458
+5459
+5460
+5461
+5462
+5463
+5464
+5465
+5466
+5467
+5468
+5469
+5470
+5471
+5472
+5473
+5474
+5475
+5476
+5477
+5478
+5479
+5480
+5481
+5482
+5483
+5484
+5485
+5486
+5487
+5488
+5489
+5490
+5491
+5492
+5493
+5494
+5495
+5496
+5497
+5498
+5499
+5500
+5501
+5502
+5503
+5504
+5505
+5506
+5507
+5508
+5509
+5510
+5511
+5512
+5513
+5514
+5515
+5516
+5517
+5518
+5519
+5520
+5521
+5522
+5523
+5524
+5525
+5526
+5527
+5528
+5529
+5530
+5531
+5532
+5533
+5534
+5535
+5536
+5537
+5538
+5539
+5540
+5541
+5542
+5543
+5544
+5545
+5546
+5547
+5548
+5549
+5550
+5551
+5552
+5553
+5554
+5555
+5556
+5557
+5558
+5559
+5560
+5561
+5562
+5563
+5564
+5565
+5566
+5567
+5568
+5569
+5570
+5571
+5572
+5573
+5574
+5575
+5576
+5577
+5578
+5579
+5580
+5581
+5582
+5583
+5584
+5585
+5586
+5587
+5588
+5589
+5590
+5591
+5592
+5593
+5594
+5595
+5596
+5597
+5598
+5599
+5600
+5601
+5602
+5603
+5604
+5605
+5606
+5607
+5608
+5609
+5610
+5611
+5612
+5613
+5614
+5615
+5616
+5617
+5618
+5619
+5620
+5621
+5622
+5623
+5624
+5625
+5626
+5627
+5628
+5629
+5630
+5631
+5632
+5633
+5634
+5635
+5636
+5637
+5638
+5639
+5640
+5641
+5642
+5643
+5644
+5645
+5646
+5647
+5648
+5649
+5650
+5651
+5652
+5653
+5654
+5655
+5656
+5657
+5658
+5659
+5660
+5661
+5662
+5663
+5664
+5665
+5666
+5667
+5668
+5669
+5670
+5671
+5672
+5673
+5674
+5675
+5676
+5677
+5678
+5679
+5680
+5681
+5682
+5683
+5684
+5685
+5686
+5687
+5688
+5689
+5690
+5691
+5692
+5693
+5694
+5695
+5696
+5697
+5698
+5699
+5700
+5701
+5702
+5703
+5704
+5705
+5706
+5707
+5708
+5709
+5710
+5711
+5712
+5713
+5714
+5715
+5716
+5717
+5718
+5719
+5720
+5721
+5722
+5723
+5724
+5725
+5726
+5727
+5728
+5729
+5730
+5731
+5732
+5733
+5734
+5735
+5736
+5737
+5738
+5739
+5740
+5741
+5742
+5743
+5744
+5745
+5746
+5747
+5748
+5749
+5750
+5751
+5752
+5753
+5754
+5755
+5756
+5757
+5758
+5759
+5760
+5761
+5762
+5763
+5764
+5765
+5766
+5767
+5768
+5769
+5770
+5771
+5772
+5773
+5774
+5775
+5776
+5777
+5778
+5779
+5780
+5781
+5782
+5783
+5784
+5785
+5786
+5787
+5788
+5789
+5790
+5791
+5792
+5793
+5794
+5795
+5796
+5797
+5798
+5799
+5800
+5801
+5802
+5803
+5804
+5805
+5806
+5807
+5808
+5809
+5810
+5811
+5812
+5813
+5814
+5815
+5816
+5817
+5818
+5819
+5820
+5821
+5822
+5823
+5824
+5825
+5826
+5827
+5828
+5829
+5830
+5831
+5832
+5833
+5834
+5835
+5836
+5837
+5838
+5839
+5840
+5841
+5842
+5843
+5844
+5845
+5846
+5847
+5848
+5849
+5850
+5851
+5852
+5853
+5854
+5855
+5856
+5857
+5858
+5859
+5860
+5861
+5862
+5863
+5864
+5865
+5866
+5867
+5868
+5869
+5870
+5871
+5872
+5873
+5874
+5875
+5876
+5877
+5878
+5879
+5880
+5881
+5882
+5883
+5884
+5885
+5886
+5887
+5888
+5889
+5890
+5891
+5892
+5893
+5894
+5895
+5896
+5897
+5898
+5899
+5900
+5901
+5902
+5903
+5904
+5905
+5906
+5907
+5908
+5909
+5910
+5911
+5912
+5913
+5914
+5915
+5916
+5917
+5918
+5919
+5920
+5921
+5922
+5923
+5924
+5925
+5926
+5927
+5928
+5929
+5930
+5931
+5932
+5933
+5934
+5935
+5936
+5937
+5938
+5939
+5940
+5941
+5942
+5943
+5944
+5945
+5946
+5947
+5948
+5949
+5950
+5951
+5952
+5953
+5954
+5955
+5956
+5957
+5958
+5959
+5960
+5961
+5962
+5963
+5964
+5965
+5966
+5967
+5968
+5969
+5970
+5971
+5972
+5973
+5974
+5975
+5976
+5977
+5978
+5979
+5980
+5981
+5982
+5983
+5984
+5985
+5986
+5987
+5988
+5989
+5990
+5991
+5992
+5993
+5994
+5995
+5996
+5997
+5998
+5999
+6000
+6001
+6002
+6003
+6004
+6005
+6006
+6007
+6008
+6009
+6010
+6011
+6012
+6013
+6014
+6015
+6016
+6017
+6018
+6019
+6020
+6021
+6022
+6023
+6024
+6025
+6026
+6027
+6028
+6029
+6030
+6031
+6032
+6033
+6034
+6038
+6039
+6040
+6041
+6042
+6043
+6044
+6045
+6046
+6047
+6048
+6049
+6050
+6051
+6052
+6053
+6054
+6055
+6056
+6057
+6058
+6059
+6060
+6061
+6062
+6063
+6064
+6065
+6066
+6067
+6068
+6069
+6070
+6071
+6072
+6073
+6074
+6075
+6076
+6077
+6078
+6079
+6080
+6081
+6082
+6083
+6084
+6085
+6086
+6087
+6088
+6089
+6090
+6091
+6092
+6093
+6094
+6095
+6096
+6097
+6098
+6099
+6100
+6101
+6102
+6103
+6104
+6105
+6106
+6107
+6108
+6109
+6110
+6111
+6112
+6113
+6114
+6115
+6116
+6117
+6118
+6119
+6120
+6121
+6122
+6123
+6124
+6125
+6126
+6127
+6128
+6129
+6130
+6131
+6132
+6133
+6134
+6135
+6136
+6137
+6138
+6139
+6140
+6141
+6142
+6143
+6144
+6145
+6146
+6147
+6148
+6149
+6150
+6151
+6152
+6153
+6154
+6155
+6156
+6157
+6158
+6159
+6160
+6161
+6162
+6163
+6164
+6166
+6168
+6169
+6170
+6171
+6172
+6173
+6174
+6175
+6176
+6177
+6178
+6179
+6180
+6181
+6182
+6183
+6184
+6185
+6186
+6187
+6188
+6189
+6190
+6191
+6192
+6193
+6194
+6195
+6196
+6197
+6198
+6199
+6200
+6201
+6202
+6203
+6204
+6205
+6206
+6207
+6208
+6209
+6210
+6211
+6212
+6213
+6214
+6215
+6216
+6217
+6218
+6219
+6220
+6221
+6222
+6223
+6224
+6225
+6226
+6227
+6228
+6229
+6230
+6231
+6232
+6233
+6234
+6235
+6236
+6237
+6238
+6239
+6240
+6241
+6242
+6243
+6244
+6245
+6246
+6247
+6248
+6249
+6250
+6251
+6252
+6253
+6254
+6255
+6256
+6257
+6258
+6259
+6260
+6261
+6262
+6263
+6264
+6265
+6266
+6267
+6268
+6269
+6270
+6271
+6272
+6273
+6274
+6275
+6276
+6277
+6278
+6279
+6280
+6281
+6282
+6283
+6284
+6285
+6286
+6287
+6288
+6289
+6290
+6291
+6294
+6295
+6296
+6297
+6298
+6299
+6300
+6301
+6302
+6303
+6304
+6305
+6306
+6307
+6308
+6309
+6310
+6311
+6312
+6313
+6314
+6315
+6316
+6317
+6318
+6319
+6320
+6321
+6322
+6323
+6324
+6325
+6326
+6327
+6328
+6329
+6330
+6331
+6332
+6333
+6334
+6335
+6336
+6337
+6338
+6339
+6340
+6341
+6342
+6343
+6344
+6345
+6346
+6347
+6348
+6349
+6350
+6351
+6352
+6353
+6354
+6355
+6356
+6357
+6358
+6359
+6360
+6361
+6362
+6363
+6364
+6365
+6366
+6367
+6368
+6369
+6370
+6371
+6372
+6373
+6374
+6375
+6376
+6377
+6378
+6379
+6380
+6381
+6382
+6383
+6384
+6385
+6386
+6387
+6388
+6389
+6390
+6391
+6392
+6393
+6394
+6395
+6396
+6397
+6398
+6399
+6400
+6401
+6402
+6403
+6404
+6405
+6406
+6407
+6408
+6409
+6410
+6411
+6412
+6413
+6414
+6415
+6416
+6417
+6418
+6419
+6420
+6421
+6422
+6423
+6424
+6425
+6426
+6427
+6428
+6429
+6430
+6431
+6432
+6433
+6434
+6435
+6436
+6437
+6438
+6439
+6440
+6441
+6442
+6443
+6444
+6445
+6446
+6447
+6448
+6449
+6450
+6451
+6452
+6453
+6454
+6455
+6456
+6457
+6458
+6459
+6460
+6461
+6462
+6463
+6464
+6465
+6466
+6467
+6468
+6469
+6470
+6471
+6472
+6473
+6474
+6475
+6476
+6477
+6478
+6479
+6480
+6481
+6482
+6483
+6484
+6485
+6486
+6487
+6488
+6489
+6490
+6491
+6492
+6493
+6494
+6495
+6496
+6497
+6498
+6499
+6500
+6501
+6502
+6503
+6504
+6505
+6506
+6507
+6508
+6509
+6510
+6511
+6512
+6513
+6514
+6515
+6516
+6517
+6518
+6519
+6520
+6521
+6522
+6523
+6524
+6525
+6526
+6527
+6528
+6529
+6530
+6531
+6532
+6533
+6534
+6535
+6536
+6537
+6538
+6539
+6540
+6541
+6542
+6543
+6544
+6545
+6546
+6547
+6548
+6549
+6550
+6551
+6552
+6553
+6554
+6555
+6556
+6557
+6558
+6559
+6560
+6561
+6562
+6563
+6564
+6565
+6566
+6567
+6568
+6569
+6570
+6571
+6572
+6573
+6574
+6575
+6576
+6577
+6578
+6579
+6580
+6581
+6582
+6583
+6584
+6585
+6586
+6587
+6588
+6589
+6590
+6591
+6592
+6593
+6594
+6595
+6596
+6597
+6598
+6599
+6600
+6601
+6602
+6603
+6604
+6605
+6606
+6607
+6608
+6609
+6610
+6611
+6612
+6613
+6614
+6615
+6616
+6617
+6618
+6619
+6620
+6621
+6622
+6623
+6624
+6625
+6626
+6627
+6628
+6629
+6630
+6631
+6632
+6633
+6634
+6635
+6636
+6637
+6638
+6639
+6640
+6641
+6642
+6643
+6644
+6645
+6646
+6647
+6648
+6649
+6650
+6651
+6652
+6653
+6654
+6655
+6656
+6657
+6658
+6659
+6660
+6661
+6662
+6663
+6666
+6667
+6668
+6669
+6670
+6671
+6672
+6673
+6674
+6675
+6676
+6677
+6678
+6679
+6680
+6681
+6682
+6683
+6684
+6685
+6686
+6687
+6688
+6689
+6690
+6691
+6692
+6693
+6694
+6695
+6696
+6697
+6698
+6699
+6700
+6701
+6702
+6703
+6704
+6705
+6706
+6707
+6708
+6709
+6710
+6711
+6712
+6713
+6714
+6715
+6716
+6717
+6718
+6719
+6720
+6721
+6722
+6723
+6724
+6725
+6726
+6727
+6728
+6729
+6730
+6731
+6732
+6733
+6734
+6735
+6736
+6737
+6738
+6739
+6740
+6741
+6742
+6743
+6744
+6745
+6746
+6747
+6748
+6749
+6750
+6751
+6752
+6753
+6754
+6755
+6756
+6757
+6758
+6759
+6760
+6761
+6762
+6763
+6764
+6765
+6766
+6767
+6768
+6769
+6770
+6771
+6772
+6773
+6774
+6775
+6776
+6777
+6778
+6779
+6780
+6781
+6782
+6783
+6784
+6785
+6786
+6787
+6789
+6790
+6791
+6792
+6793
+6794
+6795
+6796
+6797
+6798
+6799
+6800
+6801
+6802
+6803
+6804
+6805
+6806
+6807
+6808
+6809
+6810
+6811
+6812
+6813
+6814
+6815
+6816
+6817
+6818
+6819
+6820
+6821
+6822
+6823
+6824
+6825
+6826
+6827
+6828
+6829
+6830
+6831
+6832
+6833
+6834
+6835
+6836
+6837
+6838
+6839
+6840
+6841
+6842
+6843
+6844
+6845
+6846
+6847
+6848
+6849
+6850
+6851
+6852
+6853
+6854
+6855
+6856
+6857
+6858
+6859
+6860
+6861
+6862
+6863
+6864
+6865
+6866
+6867
+6868
+6869
+6870
+6871
+6872
+6873
+6874
+6875
+6876
+6877
+6878
+6879
+6880
+6881
+6882
+6883
+6884
+6885
+6886
+6887
+6888
+6889
+6890
+6891
+6892
+6893
+6894
+6895
+6896
+6897
+6898
+6899
+6900
+6901
+6902
+6903
+6904
+6905
+6906
+6907
+6908
+6910
+6911
+6912
+6913
+6914
+6915
+6916
+6917
+6918
+6919
+6920
+6921
+6922
+6923
+6924
+6925
+6926
+6927
+6928
+6929
+6930
+6931
+6932
+6933
+6934
+6935
+6936
+6937
+6938
+6939
+6940
+6941
+6942
+6943
+6944
+6945
+6946
+6947
+6948
+6949
+6950
+6951
+6952
+6953
+6954
+6955
+6956
+6957
+6958
+6959
+6960
+6961
+6962
+6963
+6964
+6965
+6966
+6967
+6968
+6969
+6970
+6971
+6972
+6973
+6974
+6975
+6976
+6977
+6978
+6979
+6980
+6981
+6982
+6983
+6984
+6985
+6986
+6987
+6988
+6989
+6990
+6991
+6992
+6993
+6994
+6995
+6996
+6997
+6998
+6999
+7000
+7001
+7002
+7003
+7004
+7005
+7006
+7007
+7008
+7009
+7010
+7011
+7012
+7013
+7014
+7015
+7016
+7017
+7018
+7019
+7020
+7021
+7022
+7023
+7024
+7025
+7026
+7027
+7028
+7029
+7030
+7031
+7032
+7033
+7034
+7035
+7036
+7037
+7038
+7039
+7040
+7041
+7042
+7043
+7044
+7045
+7046
+7047
+7048
+7049
+7050
+7051
+7052
+7053
+7054
+7055
+7056
+7057
+7058
+7059
+7060
+7061
+7062
+7063
+7064
+7065
+7066
+7067
+7068
+7069
+7070
+7071
+7072
+7073
+7074
+7075
+7076
+7077
+7078
+7079
+7080
+7081
+7082
+7083
+7084
+7085
+7086
+7087
+7088
+7089
+7090
+7091
+7092
+7093
+7094
+7095
+7096
+7097
+7098
+7099
+7100
+7101
+7102
+7103
+7104
+7105
+7106
+7107
+7108
+7109
+7110
+7111
+7112
+7113
+7114
+7115
+7116
+7117
+7118
+7119
+7120
+7121
+7122
+7123
+7124
+7125
+7126
+7127
+7128
+7129
+7130
+7131
+7132
+7133
+7134
+7135
+7136
+7137
+7138
+7139
+7140
+7141
+7142
+7143
+7144
+7145
+7146
+7147
+7148
+7149
+7150
+7151
+7152
+7153
+7154
+7155
+7156
+7157
+7158
+7159
+7160
+7161
+7162
+7163
+7164
+7165
+7166
+7167
+7168
+7169
+7170
+7171
+7172
+7173
+7174
+7175
+7176
+7177
+7178
+7179
+7180
+7181
+7182
+7183
+7184
+7185
+7186
+7187
+7188
+7189
+7190
+7191
+7192
+7193
+7194
+7195
+7196
+7197
+7198
+7199
+7200
+7201
+7202
+7203
+7204
+7205
+7206
+7207
+7208
+7209
+7210
+7211
+7212
+7213
+7214
+7215
+7216
+7217
+7218
+7219
+7220
+7221
+7222
+7223
+7224
+7225
+7226
+7227
+7228
+7229
+7230
+7231
+7232
+7233
+7234
+7235
+7236
+7237
+7238
+7239
+7240
+7241
+7242
+7243
+7244
+7245
+7246
+7247
+7248
+7249
+7250
+7251
+7252
+7253
+7254
+7255
+7256
+7257
+7258
+7259
+7260
+7261
+7262
+7263
+7264
+7265
+7266
+7267
+7268
+7269
+7270
+7271
+7272
+7273
+7274
+7275
+7276
+7277
+7278
+7279
+7280
+7281
+7282
+7283
+7284
+7285
+7286
+7287
+7288
+7289
+7290
+7291
+7292
+7293
+7294
+7295
+7296
+7297
+7298
+7299
+7300
+7301
+7302
+7303
+7304
+7305
+7306
+7307
+7308
+7309
+7310
+7311
+7312
+7313
+7314
+7315
+7316
+7317
+7318
+7319
+7320
+7321
+7322
+7323
+7324
+7325
+7326
+7327
+7328
+7329
+7330
+7331
+7332
+7333
+7334
+7335
+7336
+7337
+7338
+7339
+7340
+7341
+7342
+7343
+7344
+7345
+7346
+7347
+7348
+7349
+7350
+7351
+7352
+7353
+7354
+7355
+7356
+7357
+7358
+7359
+7360
+7361
+7362
+7363
+7364
+7365
+7366
+7367
+7368
+7369
+7370
+7371
+7372
+7373
+7374
+7375
+7376
+7377
+7378
+7379
+7380
+7381
+7382
+7383
+7384
+7385
+7386
+7387
+7388
+7389
+7390
+7391
+7392
+7393
+7394
+7395
+7396
+7397
+7398
+7399
+7400
+7401
+7402
+7403
+7404
+7405
+7406
+7407
+7408
+7409
+7410
+7411
+7412
+7413
+7414
+7415
+7416
+7417
+7418
+7419
+7420
+7421
+7422
+7423
+7424
+7425
+7426
+7427
+7428
+7429
+7430
+7431
+7432
+7433
+7434
+7435
+7436
+7437
+7438
+7439
+7440
+7441
+7442
+7443
+7444
+7445
+7446
+7447
+7448
+7449
+7450
+7451
+7452
+7453
+7454
+7455
+7456
+7457
+7458
+7459
+7460
+7461
+7462
+7463
+7464
+7465
+7466
+7467
+7468
+7469
+7470
+7471
+7472
+7473
+7474
+7475
+7476
+7477
+7478
+7479
+7480
+7481
+7482
+7483
+7484
+7485
+7486
+7487
+7488
+7489
+7490
+7491
+7492
+7493
+7494
+7495
+7496
+7497
+7498
+7499
+7500
+7501
+7502
+7503
+7504
+7505
+7506
+7507
+7508
+7509
+7510
+7511
+7512
+7513
+7514
+7515
+7516
+7517
+7518
+7519
+7520
+7521
+7522
+7523
+7524
+7525
+7526
+7527
+7528
+7529
+7530
+7531
+7532
+7533
+7534
+7535
+7536
+7537
+7538
+7539
+7540
+7541
+7542
+7543
+7544
+7545
+7546
+7547
+7548
+7549
+7550
+7551
+7552
+7553
+7554
+7555
+7556
+7557
+7558
+7559
+7560
+7561
+7562
+7563
+7564
+7565
+7566
+7567
+7568
+7569
+7570
+7571
+7572
+7573
+7574
+7575
+7576
+7577
+7578
+7579
+7580
+7581
+7582
+7583
+7584
+7585
+7586
+7587
+7588
+7589
+7590
+7591
+7592
+7593
+7594
+7595
+7596
+7597
+7598
+7599
+7600
+7601
+7602
+7603
+7604
+7605
+7606
+7607
+7608
+7609
+7610
+7611
+7612
+7613
+7614
+7615
+7616
+7617
+7618
+7619
+7620
+7621
+7622
+7623
+7624
+7625
+7626
+7627
+7628
+7629
+7630
+7631
+7632
+7633
+7634
+7635
+7636
+7637
+7638
+7639
+7640
+7641
+7642
+7643
+7644
+7645
+7646
+7647
+7648
+7649
+7650
+7651
+7652
+7653
+7654
+7655
+7656
+7657
+7658
+7659
+7660
+7661
+7662
+7663
+7664
+7665
+7666
+7667
+7668
+7669
+7670
+7671
+7672
+7673
+7674
+7675
+7676
+7677
+7678
+7679
+7680
+7681
+7682
+7683
+7684
+7685
+7686
+7687
+7688
+7689
+7690
+7691
+7692
+7693
+7694
+7695
+7696
+7697
+7698
+7699
+7700
+7701
+7702
+7703
+7704
+7705
+7706
+7707
+7708
+7709
+7710
+7711
+7712
+7713
+7714
+7715
+7716
+7717
+7718
+7719
+7720
+7721
+7722
+7723
+7724
+7725
+7726
+7727
+7728
+7729
+7730
+7731
+7732
+7733
+7734
+7735
+7736
+7737
+7738
+7739
+7740
+7741
+7742
+7743
+7744
+7745
+7746
+7747
+7748
+7749
+7750
+7751
+7752
+7753
+7754
+7755
+7756
+7757
+7758
+7759
+7760
+7761
+7762
+7763
+7764
+7765
+7766
+7767
+7768
+7769
+7770
+7771
+7772
+7773
+7774
+7775
+7776
+7777
+7778
+7779
+7780
+7781
+7782
+7783
+7784
+7785
+7786
+7787
+7788
+7789
+7790
+7791
+7792
+7793
+7794
+7795
+7796
+7797
+7798
+7799
+7800
+7801
+7802
+7803
+7804
+7805
+7806
+7807
+7808
+7809
+7810
+7811
+7812
+7813
+7814
+7815
+7816
+7817
+7818
+7819
+7820
+7821
+7822
+7823
+7824
+7825
+7826
+7827
+7828
+7829
+7830
+7831
+7832
+7833
+7834
+7835
+7836
+7837
+7838
+7839
+7840
+7841
+7842
+7843
+7844
+7845
+7846
+7847
+7848
+7849
+7850
+7851
+7852
+7853
+7854
+7855
+7856
+7857
+7858
+7859
+7860
+7861
+7862
+7863
+7864
+7865
+7866
+7867
+7868
+7869
+7870
+7871
+7872
+7873
+7874
+7875
+7876
+7877
+7878
+7879
+7880
+7881
+7882
+7883
+7884
+7885
+7886
+7887
+7888
+7889
+7890
+7891
+7892
+7893
+7894
+7895
+7896
+7897
+7898
+7899
+7900
+7901
+7902
+7903
+7904
+7905
+7906
+7907
+7908
+7909
+7910
+7911
+7912
+7913
+7914
+7915
+7916
+7917
+7918
+7919
+7920
+7921
+7922
+7923
+7924
+7925
+7926
+7927
+7928
+7929
+7930
+7931
+7932
+7933
+7934
+7935
+7936
+7937
+7938
+7939
+7940
+7941
+7942
+7943
+7944
+7945
+7946
+7947
+7948
+7949
+7950
+7951
+7952
+7953
+7954
+7955
+7956
+7957
+7958
+7959
+7960
+7961
+7962
+7963
+7964
+7965
+7966
+7967
+7968
+7969
+7970
+7971
+7972
+7973
+7974
+7975
+7976
+7977
+7978
+7979
+7980
+7981
+7982
+7983
+7984
+7985
+7986
+7987
+7988
+7989
+7990
+7991
+7992
+7993
+7994
+7995
+7996
+7997
+7998
+7999
+8000
+8001
+8002
+8003
+8004
+8005
+8006
+8007
+8008
+8009
+8010
+8011
+8012
+8013
+8014
+8015
+8016
+8017
+8018
+8019
+8020
+8021
+8022
+8023
+8024
+8025
+8026
+8027
+8028
+8029
+8030
+8031
+8032
+8033
+8034
+8035
+8036
+8037
+8038
+8039
+8040
+8041
+8042
+8043
+8044
+8045
+8046
+8047
+8048
+8049
+8050
+8051
+8052
+8053
+8054
+8055
+8056
+8057
+8058
+8059
+8060
+8061
+8062
+8063
+8064
+8065
+8066
+8067
+8068
+8069
+8070
+8071
+8072
+8073
+8074
+8075
+8076
+8077
+8078
+8079
+8080
+8081
+8082
+8083
+8084
+8085
+8086
+8087
+8088
+8089
+8090
+8091
+8092
+8093
+8094
+8095
+8096
+8097
+8098
+8099
+8100
+8101
+8102
+8103
+8104
+8105
+8106
+8107
+8108
+8109
+8110
+8111
+8112
+8113
+8114
+8115
+8116
+8117
+8118
+8119
+8120
+8121
+8122
+8123
+8124
+8125
+8126
+8127
+8128
+8129
+8130
+8131
+8132
+8133
+8134
+8135
+8136
+8137
+8138
+8139
+8140
+8141
+8142
+8143
+8144
+8145
+8146
+8147
+8148
+8149
+8150
+8151
+8152
+8153
+8154
+8155
+8156
+8157
+8158
+8159
+8160
+8161
+8162
+8163
+8164
+8165
+8166
+8167
+8168
+8169
+8170
+8171
+8172
+8173
+8174
+8175
+8176
+8177
+8178
+8179
+8180
+8181
+8182
+8183
+8184
+8185
+8186
+8187
+8188
+8189
+8190
+8191
+8192
+8193
+8194
+8195
+8196
+8197
+8198
+8199
+8200
+8201
+8202
+8203
+8204
+8205
+8206
+8207
+8208
+8209
+8210
+8211
+8212
+8213
+8214
+8215
+8216
+8217
+8218
+8219
+8220
+8221
+8222
+8223
+8224
+8225
+8226
+8227
+8228
+8229
+8230
+8231
+8232
+8233
+8234
+8235
+8236
+8237
+8238
+8239
+8240
+8241
+8242
+8243
+8244
+8245
+8246
+8247
+8248
+8249
+8250
+8251
+8252
+8253
+8254
+8255
+8256
+8257
+8258
+8259
+8260
+8261
+8262
+8263
+8264
+8265
+8266
+8267
+8268
+8269
+8270
+8271
+8272
+8273
+8274
+8275
+8276
+8277
+8278
+8279
+8280
+8281
+8282
+8283
+8284
+8285
+8286
+8287
+8288
+8289
+8290
+8291
+8292
+8293
+8294
+8295
+8296
+8297
+8298
+8299
+8300
+8301
+8302
+8303
+8304
+8305
+8306
+8307
+8308
+8309
+8310
+8311
+8312
+8313
+8314
+8315
+8316
+8317
+8318
+8319
+8320
+8321
+8322
+8323
+8324
+8325
+8326
+8327
+8328
+8329
+8330
+8331
+8332
+8333
+8334
+8335
+8336
+8337
+8338
+8339
+8340
+8341
+8342
+8343
+8344
+8345
+8346
+8347
+8348
+8349
+8350
+8351
+8352
+8353
+8354
+8355
+8356
+8357
+8358
+8359
+8360
+8361
+8362
+8363
+8364
+8365
+8366
+8367
+8368
+8369
+8370
+8371
+8372
+8373
+8374
+8375
+8376
+8377
+8378
+8379
+8380
+8381
+8382
+8383
+8384
+8385
+8386
+8387
+8388
+8389
+8390
+8391
+8392
+8393
+8394
+8395
+8396
+8397
+8398
+8399
+8400
+8401
+8402
+8403
+8404
+8405
+8406
+8407
+8408
+8409
+8410
+8411
+8412
+8413
+8414
+8415
+8416
+8417
+8418
+8419
+8420
+8421
+8422
+8423
+8424
+8425
+8426
+8427
+8428
+8429
+8430
+8431
+8432
+8433
+8434
+8435
+8436
+8437
+8438
+8439
+8440
+8441
+8442
+8443
+8444
+8445
+8446
+8447
+8448
+8449
+8450
+8451
+8452
+8453
+8454
+8455
+8456
+8457
+8458
+8459
+8460
+8461
+8462
+8463
+8464
+8465
+8466
+8467
+8468
+8469
+8470
+8471
+8472
+8473
+8474
+8475
+8476
+8477
+8478
+8479
+8480
+8481
+8482
+8483
+8484
+8485
+8486
+8487
+8488
+8489
+8490
+8491
+8492
+8493
+8494
+8495
+8496
+8497
+8498
+8499
+8500
+8501
+8502
+8503
+8504
+8505
+8506
+8507
+8508
+8509
+8510
+8511
+8512
+8513
+8514
+8515
+8516
+8517
+8518
+8519
+8520
+8521
+8522
+8523
+8524
+8525
+8526
+8527
+8528
+8529
+8530
+8531
+8532
+8533
+8534
+8535
+8536
+8537
+8538
+8539
+8540
+8541
+8542
+8543
+8544
+8545
+8546
+8547
+8548
+8549
+8550
+8551
+8552
+8553
+8554
+8555
+8556
+8557
+8558
+8559
+8560
+8561
+8562
+8563
+8564
+8565
+8566
+8567
+8568
+8569
+8570
+8571
+8572
+8573
+8574
+8575
+8576
+8577
+8578
+8579
+8580
+8581
+8582
+8583
+8584
+8585
+8586
+8587
+8588
+8589
+8590
+8591
+8592
+8593
+8594
+8595
+8596
+8597
+8598
+8599
+8600
+8601
+8602
+8603
+8604
+8605
+8606
+8607
+8608
+8609
+8610
+8611
+8612
+8613
+8614
+8615
+8616
+8617
+8618
+8619
+8620
+8621
+8622
+8623
+8624
+8625
+8626
+8627
+8628
+8629
+8630
+8631
+8632
+8633
+8634
+8635
+8636
+8637
+8638
+8639
+8640
+8641
+8642
+8643
+8644
+8645
+8646
+8647
+8648
+8649
+8650
+8651
+8652
+8653
+8654
+8655
+8656
+8657
+8658
+8659
+8660
+8661
+8662
+8663
+8664
+8665
+8666
+8667
+8668
+8669
+8670
+8671
+8672
+8673
+8674
+8675
+8676
+8677
+8678
+8679
+8680
+8681
+8682
+8683
+8684
+8685
+8686
+8687
+8688
+8689
+8690
+8691
+8692
+8693
+8694
+8695
+8696
+8697
+8698
+8699
+8700
+8701
+8702
+8703
+8704
+8705
+8706
+8707
+8708
+8709
+8710
+8711
+8712
+8713
+8714
+8715
+8716
+8717
+8718
+8719
+8720
+8721
+8722
+8723
+8724
+8725
+8726
+8727
+8728
+8729
+8730
+8731
+8732
+8733
+8734
+8735
+8736
+8737
+8738
+8739
+8740
+8741
+8742
+8743
+8744
+8745
+8746
+8747
+8748
+8749
+8750
+8751
+8752
+8753
+8754
+8755
+8756
+8757
+8758
+8759
+8760
+8761
+8762
+8763
+8764
+8765
+8766
+8767
+8768
+8769
+8770
+8771
+8772
+8773
+8774
+8775
+8776
+8777
+8778
+8779
+8780
+8781
+8782
+8783
+8784
+8785
+8786
+8787
+8788
+8789
+8790
+8791
+8792
+8793
+8794
+8795
+8796
+8797
+8798
+8799
+8800
+8801
+8802
+8803
+8804
+8805
+8806
+8807
+8808
+8809
+8810
+8811
+8812
+8813
+8814
+8815
+8816
+8817
+8818
+8819
+8820
+8821
+8822
+8823
+8824
+8825
+8826
+8827
+8828
+8829
+8830
+8831
+8832
+8833
+8834
+8835
+8836
+8837
+8838
+8839
+8840
+8841
+8842
+8843
+8844
+8845
+8846
+8847
+8848
+8849
+8850
+8851
+8852
+8853
+8854
+8855
+8856
+8857
+8858
+8859
+8860
+8861
+8862
+8863
+8864
+8865
+8866
+8867
+8868
+8869
+8870
+8871
+8872
+8873
+8874
+8875
+8876
+8877
+8878
+8879
+8880
+8881
+8882
+8883
+8884
+8885
+8886
+8887
+8888
+8889
+8890
+8891
+8892
+8893
+8894
+8895
+8896
+8897
+8898
+8899
+8900
+8901
+8902
+8903
+8904
+8905
+8906
+8907
+8908
+8909
+8910
+8911
+8912
+8913
+8914
+8915
+8916
+8917
+8918
+8919
+8920
+8921
+8922
+8923
+8924
+8925
+8926
+8927
+8928
+8929
+8930
+8931
+8932
+8933
+8934
+8935
+8936
+8937
+8938
+8939
+8940
+8941
+8942
+8943
+8944
+8945
+8946
+8947
+8948
+8949
+8950
+8951
+8952
+8953
+8954
+8955
+8956
+8957
+8958
+8959
+8960
+8961
+8962
+8963
+8964
+8965
+8966
+8967
+8968
+8969
+8970
+8971
+8972
+8973
+8974
+8975
+8976
+8977
+8978
+8979
+8980
+8981
+8982
+8983
+8984
+8985
+8986
+8987
+8988
+8989
+8990
+8991
+8992
+8993
+8994
+8995
+8996
+8997
+8998
+8999
+9000
+9001
+9002
+9003
+9004
+9005
+9006
+9007
+9008
+9009
+9010
+9011
+9012
+9013
+9014
+9015
+9016
+9017
+9018
+9019
+9020
+9021
+9022
+9023
+9024
+9025
+9026
+9027
+9028
+9029
+9030
+9031
+9032
+9033
+9034
+9035
+9036
+9037
+9038
+9039
+9040
+9041
+9042
+9043
+9044
+9045
+9046
+9047
+9048
+9049
+9050
+9051
+9052
+9053
+9054
+9055
+9056
+9057
+9058
+9059
+9060
+9061
+9062
+9063
+9064
+9065
+9066
+9067
+9068
+9069
+9070
+9071
+9072
+9073
+9074
+9075
+9076
+9077
+9078
+9079
+9080
+9081
+9082
+9083
+9084
+9085
+9086
+9087
+9088
+9089
+9090
+9091
+9092
+9093
+9094
+9095
+9096
+9097
+9098
+9099
+9100
+9101
+9102
+9103
+9104
+9105
+9106
+9107
+9108
+9109
+9110
+9111
+9112
+9113
+9114
+9115
+9116
+9117
+9118
+9119
+9120
+9121
+9122
+9123
+9124
+9125
+9126
+9127
+9128
+9129
+9130
+9131
+9132
+9133
+9134
+9135
+9136
+9137
+9138
+9139
+9140
+9141
+9142
+9143
+9144
+9145
+9146
+9147
+9148
+9149
+9150
+9151
+9152
+9153
+9154
+9155
+9156
+9157
+9158
+9159
+9160
+9161
+9162
+9163
+9164
+9165
+9166
+9167
+9168
+9169
+9170
+9171
+9172
+9173
+9174
+9175
+9176
+9177
+9178
+9179
+9180
+9181
+9182
+9183
+9184
+9185
+9186
+9187
+9188
+9189
+9190
+9191
+9192
+9193
+9194
+9195
+9196
+9197
+9198
+9199
+9200
+9201
+9202
+9203
+9204
+9205
+9206
+9207
+9208
+9209
+9210
+9211
+9212
+9213
+9214
+9215
+9216
+9217
+9218
+9219
+9220
+9221
+9222
+9223
+9224
+9225
+9226
+9227
+9228
+9229
+9230
+9231
+9232
+9233
+9234
+9235
+9236
+9237
+9238
+9239
+9240
+9241
+9242
+9243
+9244
+9245
+9246
+9247
+9248
+9249
+9250
+9251
+9252
+9253
+9254
+9255
+9256
+9257
+9258
+9259
+9260
+9261
+9262
+9263
+9264
+9265
+9266
+9267
+9268
+9269
+9270
+9271
+9272
+9273
+9274
+9275
+9276
+9277
+9278
+9279
+9280
+9281
+9282
+9283
+9284
+9285
+9286
+9287
+9288
+9289
+9290
+9291
+9292
+9293
+9294
+9295
+9296
+9297
+9298
+9299
+9300
+9301
+9302
+9303
+9304
+9305
+9306
+9307
+9308
+9309
+9310
+9311
+9312
+9313
+9314
+9315
+9316
+9317
+9318
+9319
+9320
+9321
+9322
+9323
+9324
+9325
+9326
+9327
+9328
+9329
+9330
+9331
+9332
+9333
+9334
+9335
+9336
+9337
+9338
+9339
+9340
+9341
+9342
+9343
+9344
+9345
+9346
+9347
+9348
+9349
+9350
+9351
+9352
+9353
+9354
+9355
+9356
+9357
+9358
+9359
+9360
+9361
+9362
+9363
+9364
+9365
+9366
+9367
+9368
+9369
+9370
+9371
+9372
+9373
+9374
+9375
+9376
+9377
+9378
+9379
+9380
+9381
+9382
+9383
+9384
+9385
+9386
+9387
+9388
+9389
+9390
+9391
+9392
+9393
+9394
+9395
+9396
+9397
+9398
+9399
+9400
+9401
+9402
+9403
+9404
+9405
+9406
+9407
+9408
+9409
+9410
+9411
+9412
+9413
+9414
+9415
+9416
+9417
+9418
+9419
+9420
+9421
+9422
+9423
+9424
+9425
+9426
+9427
+9428
+9429
+9430
+9431
+9432
+9433
+9434
+9435
+9436
+9437
+9438
+9439
+9440
+9441
+9442
+9443
+9444
+9445
+9446
+9447
+9448
+9449
+9450
+9451
+9452
+9453
+9454
+9455
+9456
+9457
+9458
+9459
+9460
+9461
+9462
+9463
+9464
+9465
+9466
+9467
+9468
+9469
+9470
+9471
+9472
+9473
+9474
+9475
+9476
+9477
+9478
+9479
+9480
+9481
+9482
+9483
+9484
+9485
+9486
+9487
+9488
+9489
+9490
+9491
+9492
+9493
+9494
+9495
+9496
+9497
+9498
+9499
+9500
+9501
+9502
+9503
+9504
+9505
+9506
+9507
+9508
+9509
+9510
+9511
+9512
+9513
+9514
+9515
+9516
+9517
+9518
+9519
+9520
+9521
+9522
+9523
+9524
+9525
+9526
+9527
+9528
+9529
+9530
+9531
+9532
+9533
+9534
+9535
+9536
+9537
+9538
+9539
+9540
+9541
+9542
+9543
+9544
+9545
+9546
+9547
+9548
+9549
+9550
+9551
+9552
+9553
+9554
+9555
+9556
+9557
+9558
+9559
+9560
+9561
+9562
+9563
+9564
+9565
+9566
+9567
+9568
+9569
+9570
+9571
+9572
+9573
+9574
+9575
+9576
+9577
+9578
+9579
+9580
+9581
+9582
+9583
+9584
+9585
+9586
+9587
+9588
+9589
+9590
+9591
+9592
+9593
+9594
+9595
+9596
+9597
+9598
+9599
+9600
+9601
+9602
+9603
+9604
+9605
+9606
+9607
+9608
+9609
+9610
+9611
+9612
+9613
+9614
+9615
+9616
+9617
+9618
+9619
+9620
+9621
+9622
+9623
+9624
+9625
+9626
+9627
+9628
+9629
+9630
+9631
+9632
+9633
+9634
+9635
+9636
+9637
+9638
+9639
+9640
+9641
+9642
+9643
+9644
+9645
+9646
+9647
+9648
+9649
+9650
+9651
+9652
+9653
+9654
+9655
+9656
+9657
+9658
+9659
+9660
+9661
+9662
+9663
+9664
+9665
+9666
+9667
+9668
+9669
+9670
+9671
+9672
+9673
+9674
+9675
+9676
+9677
+9678
+9679
+9680
+9681
+9682
+9683
+9684
+9685
+9686
+9687
+9688
+9689
+9690
+9691
+9692
+9693
+9694
+9695
+9696
+9697
+9698
+9699
+9700
+9701
+9702
+9703
+9704
+9705
+9706
+9707
+9708
+9709
+9710
+9711
+9712
+9713
+9714
+9715
+9716
+9717
+9718
+9719
+9720
+9721
+9722
+9723
+9724
+9725
+9726
+9727
+9728
+9729
+9730
+9731
+9732
+9733
+9734
+9735
+9736
+9737
+9738
+9739
+9740
+9741
+9742
+9743
+9744
+9745
+9746
+9747
+9748
+9749
+9750
+9751
+9752
+9753
+9754
+9755
+9756
+9757
+9758
+9759
+9760
+9761
+9762
+9763
+9764
+9765
+9766
+9767
+9768
+9769
+9770
+9771
+9772
+9773
+9774
+9775
+9776
+9777
+9778
+9779
+9780
+9781
+9782
+9783
+9784
+9785
+9786
+9787
+9788
+9789
+9790
+9791
+9792
+9793
+9794
+9795
+9796
+9797
+9798
+9799
+9800
+9801
+9802
+9803
+9804
+9805
+9806
+9807
+9808
+9809
+9810
+9811
+9812
+9813
+9814
+9815
+9816
+9817
+9818
+9819
+9820
+9821
+9822
+9823
+9824
+9825
+9826
+9827
+9828
+9829
+9830
+9831
+9832
+9833
+9834
+9835
+9836
+9837
+9838
+9839
+9840
+9841
+9842
+9843
+9844
+9845
+9846
+9847
+9848
+9849
+9850
+9851
+9852
+9853
+9854
+9855
+9856
+9857
+9858
+9859
+9860
+9861
+9862
+9863
+9864
+9865
+9866
+9867
+9868
+9869
+9870
+9871
+9872
+9873
+9874
+9875
+9876
+9877
+9878
+9879
+9880
+9881
+9882
+9883
+9884
+9885
+9886
+9887
+9888
+9889
+9890
+9891
+9892
+9893
+9894
+9895
+9896
+9897
+9898
+9899
+9900
+9901
+9902
+9903
+9904
+9905
+9906
+9907
+9908
+9909
+9910
+9911
+9912
+9913
+9914
+9915
+9916
+9917
+9918
+9919
+9920
+9921
+9922
+9923
+9924
+9925
+9926
+9927
+9928
+9929
+9930
+9931
+9932
+9933
+9934
+9935
+9936
+9937
+9938
+9939
+9940
+9941
+9942
+9943
+9944
+9945
+9946
+9947
+9948
+9949
+9950
+9951
+9952
+9953
+9954
+9955
+9956
+9957
+9958
+9959
+9960
+9961
+9962
+9963
+9964
+9965
+9966
+9967
+9968
+9969
+9970
+9971
+9972
+9973
+9974
+9975
+9976
+9977
+9978
+9979
+9980
+9981
+9982
+9983
+9984
+9985
+9986
+9987
+9988
+9989
+9990
+9991
+9992
+9993
+9994
+9995
+9996
+9997
+9998
+9999
+10000
+10001
+10002
+10003
+10004
+10005
+10006
+10007
+10008
+10009
+10010
+10011
+10012
+10013
+10014
+10015
+10016
+10017
+10018
+10019
+10020
+10021
+10022
+10023
+10024
+10025
+10026
+10027
+10028
+10029
+10030
+10031
+10032
+10033
+10034
+10035
+10036
+10037
+10038
+10039
+10040
+10041
+10042
+10043
+10044
+10045
+10046
+10047
+10048
+10049
+10050
+10051
+10052
+10053
+10054
+10055
+10056
+10057
+10058
+10059
+10060
+10061
+10062
+10063
+10064
+10065
+10066
+10067
+10068
+10069
+10070
+10071
+10072
+10073
+10074
+10075
+10076
+10077
+10078
+10079
+10080
+10081
+10082
+10083
+10084
+10085
+10086
+10087
+10088
+10089
+10090
+10091
+10092
+10093
+10094
+10095
+10096
+10097
+10098
+10099
+10100
+10101
+10102
+10103
+10104
+10105
+10106
+10107
+10108
+10109
+10110
+10111
+10112
+10113
+10114
+10115
+10116
+10117
+10118
+10119
+10120
+10121
+10122
+10123
+10124
+10125
+10126
+10127
+10128
+10129
+10130
+10131
+10132
+10133
+10134
+10135
+10136
+10137
+10138
+10139
+10140
+10141
+10142
+10143
+10144
+10145
+10146
+10147
+10148
+10149
+10150
+10151
+10152
+10153
+10154
+10155
+10156
+10157
+10159
+10160
+10161
+10162
+10163
+10164
+10165
+10166
+10167
+10168
+10169
+10170
+10171
+10172
+10173
+10174
+10175
+10176
+10177
+10178
+10179
+10180
+10181
+10182
+10183
+10184
+10185
+10186
+10187
+10188
+10189
+10190
+10191
+10192
+10193
+10194
+10195
+10196
+10197
+10198
+10199
+10200
+10201
+10202
+10203
+10204
+10205
+10206
+10207
+10208
+10209
+10210
+10211
+10212
+10213
+10214
+10215
+10216
+10217
+10218
+10219
+10220
+10221
+10222
+10223
+10224
+10225
+10226
+10227
+10228
+10229
+10230
+10231
+10232
+10233
+10234
+10235
+10236
+10237
+10238
+10239
+10240
+10241
+10242
+10243
+10244
+10245
+10246
+10247
+10248
+10249
+10250
+10251
+10252
+10253
+10254
+10255
+10256
+10257
+10258
+10259
+10260
+10261
+10262
+10263
+10264
+10265
+10266
+10267
+10268
+10269
+10270
+10271
+10272
+10273
+10274
+10275
+10276
+10277
+10278
+10279
+10282
+10283
+10284
+10285
+10286
+10287
+10288
+10289
+10290
+10291
+10292
+10293
+10294
+10295
+10296
+10297
+10298
+10299
+10300
+10301
+10302
+10303
+10304
+10305
+10306
+10307
+10308
+10309
+10310
+10311
+10312
+10313
+10314
+10315
+10316
+10317
+10318
+10319
+10320
+10321
+10322
+10323
+10324
+10325
+10326
+10327
+10328
+10329
+10330
+10331
+10332
+10333
+10334
+10335
+10336
+10337
+10338
+10339
+10340
+10341
+10342
+10343
+10344
+10345
+10346
+10347
+10348
+10349
+10350
+10351
+10352
+10353
+10354
+10355
+10356
+10357
+10358
+10359
+10360
+10361
+10362
+10363
+10364
+10365
+10366
+10367
+10368
+10369
+10370
+10371
+10372
+10373
+10374
+10375
+10376
+10377
+10378
+10379
+10380
+10381
+10382
+10383
+10384
+10385
+10386
+10387
+10388
+10389
+10390
+10391
+10392
+10393
+10394
+10395
+10396
+10397
+10398
+10399
+10400
+10401
+10402
+10403
+10404
+10405
+10406
+10407
+10408
+10409
+10410
+10411
+10412
+10413
+10414
+10415
+10416
+10417
+10418
+10419
+10420
+10421
+10422
+10423
+10424
+10425
+10426
+10427
+10428
+10429
+10430
+10431
+10432
+10433
+10434
+10435
+10436
+10437
+10438
+10439
+10440
+10441
+10442
+10443
+10444
+10445
+10446
+10447
+10448
+10449
+10450
+10451
+10452
+10453
+10454
+10455
+10456
+10457
+10458
+10459
+10460
+10461
+10462
+10463
+10464
+10465
+10466
+10467
+10468
+10469
+10470
+10471
+10472
+10473
+10474
+10475
+10476
+10477
+10478
+10479
+10480
+10481
+10482
+10483
+10484
+10485
+10486
+10487
+10488
+10489
+10490
+10491
+10492
+10493
+10494
+10495
+10496
+10497
+10498
+10499
+10500
+10501
+10502
+10503
+10504
+10505
+10506
+10507
+10508
+10509
+10510
+10511
+10512
+10513
+10514
+10515
+10516
+10517
+10518
+10519
+10520
+10521
+10522
+10523
+10524
+10525
+10526
+10527
+10528
+10529
+10530
+10531
+10532
+10533
+10534
+10535
+10536
+10537
+10538
+10539
+10540
+10541
+10542
+10543
+10544
+10545
+10546
+10547
+10548
+10549
+10550
+10551
+10552
+10553
+10554
+10555
+10556
+10557
+10558
+10559
+10560
+10561
+10562
+10563
+10564
+10565
+10566
+10567
+10568
+10569
+10570
+10571
+10572
+10573
+10574
+10575
+10576
+10577
+10578
+10579
+10580
+10581
+10582
+10583
+10584
+10585
+10586
+10587
+10588
+10589
+10590
+10591
+10592
+10593
+10594
+10595
+10596
+10597
+10598
+10599
+10600
+10601
+10602
+10603
+10604
+10605
+10606
+10607
+10608
+10609
+10610
+10611
+10612
+10613
+10614
+10615
+10616
+10617
+10618
+10619
+10620
+10621
+10622
+10623
+10624
+10625
+10626
+10627
+10628
+10629
+10630
+10631
+10632
+10633
+10634
+10635
+10636
+10637
+10638
+10639
+10640
+10641
+10642
+10643
+10644
+10645
+10646
+10647
+10648
+10649
+10650
+10651
+10652
+10653
+10654
+10655
+10656
+10657
+10658
+10659
+10660
+10661
+10662
+10663
+10664
+10665
+10666
+10667
+10668
+10669
+10670
+10671
+10672
+10673
+10674
+10675
+10676
+10677
+10678
+10679
+10680
+10681
+10682
+10683
+10684
+10685
+10686
+10687
+10688
+10689
+10690
+10691
+10692
+10693
+10694
+10695
+10696
+10697
+10698
+10699
+10700
+10701
+10702
+10703
+10704
+10705
+10706
+10707
+10708
+10709
+10710
+10711
+10712
+10713
+10714
+10715
+10716
+10717
+10718
+10719
+10720
+10721
+10722
+10723
+10724
+10725
+10726
+10727
+10728
+10729
+10730
+10731
+10732
+10733
+10734
+10735
+10736
+10737
+10738
+10739
+10740
+10741
+10742
+10743
+10744
+10745
+10746
+10747
+10748
+10749
+10750
+10751
+10752
+10753
+10754
+10755
+10756
+10757
+10758
+10759
+10760
+10761
+10762
+10763
+10764
+10765
+10766
+10767
+10768
+10769
+10770
+10771
+10772
+10773
+10774
+10775
+10776
+10777
+10778
+10779
+10780
+10781
+10782
+10783
+10784
+10785
+10786
+10787
+10788
+10789
+10790
+10791
+10792
+10793
+10794
+10795
+10796
+10797
+10798
+10799
+10800
+10801
+10802
+10803
+10804
+10805
+10806
+10807
+10808
+10809
+10810
+10811
+10812
+10813
+10814
+10815
+10816
+10817
+10818
+10819
+10820
+10821
+10822
+10823
+10824
+10825
+10826
+10827
+10828
+10829
+10830
+10831
+10832
+10833
+10834
+10835
+10836
+10837
+10838
+10839
+10840
+10841
+10842
+10843
+10844
+10845
+10846
+10847
+10848
+10849
+10850
+10851
+10852
+10853
+10854
+10855
+10856
+10857
+10858
+10859
+10860
+10861
+10862
+10863
+10864
+10865
+10866
+10867
+10868
+10869
+10870
+10871
+10872
+10873
+10874
+10875
+10876
+10877
+10878
+10879
+10880
+10881
+10882
+10883
+10884
+10885
+10886
+10887
+10888
+10889
+10890
+10891
+10892
+10893
+10894
+10895
+10896
+10897
+10898
+10899
+10900
+10901
+10902
+10903
+10904
+10905
+10906
+10907
+10908
+10909
+10910
+10911
+10912
+10913
+10914
+10915
+10916
+10917
+10918
+10919
+10920
+10921
+10922
+10923
+10924
+10925
+10926
+10927
+10928
+10929
+10930
+10931
+10932
+10933
+10934
+10935
+10936
+10937
+10938
+10939
+10940
+10941
+10942
+10943
+10944
+10945
+10946
+10947
+10948
+10949
+10950
+10951
+10952
+10953
+10954
+10955
+10956
+10957
+10958
+10959
+10960
+10961
+10962
+10963
+10964
+10965
+10966
+10967
+10968
+10969
+10970
+10971
+10972
+10973
+10974
+10975
+10976
+10977
+10978
+10979
+10980
+10981
+10982
+10983
+10984
+10985
+10986
+10987
+10988
+10989
+10990
+10991
+10992
+10993
+10994
+10995
+10996
+10997
+10998
+10999
+11000
+11001
+11002
+11003
+11004
+11005
+11006
+11007
+11008
+11009
+11010
+11011
+11012
+11013
+11014
+11015
+11016
+11017
+11018
+11019
+11020
+11021
+11022
+11023
+11024
+11025
+11026
+11027
+11028
+11029
+11030
+11031
+11032
+11033
+11034
+11035
+11036
+11037
+11038
+11039
+11040
+11041
+11042
+11043
+11044
+11045
+11046
+11047
+11048
+11049
+11050
+11051
+11052
+11053
+11054
+11055
+11056
+11057
+11058
+11059
+11060
+11061
+11062
+11063
+11064
+11065
+11066
+11067
+11068
+11069
+11070
+11071
+11072
+11073
+11074
+11075
+11076
+11077
+11078
+11079
+11080
+11081
+11082
+11083
+11084
+11085
+11086
+11087
+11088
+11089
+11090
+11091
+11092
+11093
+11094
+11095
+11096
+11097
+11098
+11099
+11100
+11101
+11102
+11103
+11104
+11105
+11106
+11107
+11108
+11109
+11110
+11111
+11112
+11113
+11114
+11115
+11116
+11117
+11118
+11119
+11120
+11121
+11122
+11123
+11124
+11125
+11126
+11127
+11128
+11129
+11130
+11131
+11132
+11133
+11134
+11135
+11136
+11137
+11138
+11139
+11140
+11141
+11142
+11143
+11144
+11145
+11146
+11147
+11148
+11149
+11150
+11151
+11152
+11153
+11154
+11155
+11156
+11157
+11158
+11159
+11160
+11161
+11162
+11163
+11164
+11165
+11166
+11167
+11168
+11169
+11170
+11171
+11172
+11173
+11174
+11175
+11176
+11177
+11178
+11179
+11180
+11181
+11182
+11183
+11184
+11185
+11186
+11187
+11188
+11189
+11190
+11191
+11192
+11193
+11194
+11195
+11196
+11197
+11198
+11199
+11200
+11201
+11202
+11203
+11204
+11205
+11206
+11207
+11208
+11209
+11210
+11211
+11212
+11213
+11214
+11215
+11216
+11217
+11218
+11219
+11220
+11221
+11222
+11223
+11224
+11225
+11226
+11227
+11228
+11229
+11230
+11231
+11232
+11233
+11234
+11235
+11236
+11237
+11238
+11239
+11240
+11241
+11242
+11243
+11244
+11245
+11246
+11247
+11248
+11249
+11250
+11251
+11252
+11253
+11254
+11255
+11256
+11257
+11258
+11259
+11260
+11261
+11262
+11263
+11264
+11265
+11266
+11267
+11268
+11269
+11270
+11271
+11272
+11273
+11274
+11275
+11276
+11277
+11278
+11279
+11280
+11281
+11282
+11283
+11284
+11285
+11286
+11287
+11288
+11289
+11290
+11291
+11292
+11293
+11294
+11295
+11296
+11297
+11298
+11299
+11300
+11301
+11302
+11303
+11304
+11305
+11306
+11307
+11308
+11309
+11310
+11311
+11312
+11313
+11314
+11315
+11316
+11317
+11318
+11319
+11320
+11321
+11322
+11323
+11324
+11325
+11326
+11327
+11328
+11329
+11330
+11331
+11332
+11333
+11334
+11335
+11336
+11337
+11338
+11339
+11340
+11341
+11342
+11343
+11344
+11345
+11346
+11347
+11348
+11349
+11350
+11351
+11352
+11353
+11354
+11355
+11356
+11357
+11358
+11359
+11360
+11361
+11362
+11363
+11364
+11365
+11366
+11367
+11368
+11369
+11370
+11371
+11372
+11373
+11374
+11375
+11376
+11377
+11378
+11379
+11380
+11381
+11382
+11383
+11384
+11385
+11386
+11387
+11388
+11389
+11390
+11391
+11392
+11393
+11394
+11395
+11396
+11397
+11398
+11399
+11400
+11401
+11402
+11403
+11404
+11405
+11406
+11407
+11408
+11409
+11410
+11411
+11412
+11413
+11414
+11415
+11416
+11417
+11418
+11419
+11420
+11421
+11422
+11423
+11424
+11425
+11426
+11427
+11428
+11429
+11430
+11431
+11432
+11433
+11434
+11435
+11436
+11437
+11438
+11439
+11440
+11441
+11442
+11443
+11444
+11445
+11446
+11447
+11448
+11449
+11450
+11451
+11452
+11453
+11454
+11455
+11456
+11457
+11458
+11459
+11460
+11461
+11462
+11463
+11464
+11465
+11466
+11467
+11468
+11469
+11470
+11471
+11472
+11473
+11474
+11475
+11476
+11477
+11478
+11479
+11480
+11481
+11482
+11483
+11484
+11485
+11486
+11487
+11488
+11489
+11490
+11491
+11492
+11493
+11494
+11495
+11496
+11497
+11498
+11499
+11500
+11501
+11502
+11503
+11504
+11505
+11506
+11507
+11508
+11509
+11510
+11511
+11512
+11513
+11514
+11515
+11516
+11517
+11518
+11519
+11520
+11521
+11522
+11523
+11524
+11525
+11526
+11527
+11528
+11529
+11530
+11531
+11532
+11533
+11534
+11535
+11536
+11537
+11538
+11539
+11540
+11541
+11542
+11543
+11544
+11545
+11546
+11547
+11548
+11549
+11550
+11551
+11552
+11553
+11554
+11555
+11556
+11557
+11558
+11559
+11560
+11561
+11562
+11563
+11564
+11565
+11566
+11567
+11568
+11569
+11570
+11571
+11572
+11573
+11574
+11575
+11576
+11577
+11578
+11579
+11580
+11581
+11582
+11583
+11584
+11585
+11586
+11587
+11588
+11589
+11590
+11591
+11592
+11593
+11594
+11595
+11596
+11597
+11598
+11599
+11600
+11601
+11602
+11603
+11604
+11605
+11606
+11607
+11608
+11609
+11610
+11611
+11612
+11613
+11614
+11615
+11616
+11617
+11618
+11619
+11620
+11621
+11622
+11623
+11624
+11625
+11626
+11627
+11628
+11629
+11630
+11631
+11632
+11633
+11634
+11635
+11636
+11637
+11638
+11639
+11640
+11641
+11642
+11643
+11644
+11645
+11646
+11647
+11648
+11649
+11650
+11651
+11652
+11653
+11654
+11655
+11656
+11657
+11658
+11659
+11660
+11661
+11662
+11663
+11664
+11665
+11666
+11667
+11668
+11669
+11670
+11671
+11672
+11673
+11674
+11675
+11676
+11677
+11678
+11679
+11680
+11681
+11682
+11683
+11684
+11685
+11686
+11687
+11688
+11689
+11690
+11691
+11692
+11693
+11694
+11695
+11696
+11697
+11698
+11699
+11700
+11701
+11702
+11703
+11704
+11705
+11706
+11707
+11708
+11709
+11710
+11711
+11712
+11713
+11714
+11715
+11716
+11717
+11718
+11719
+11720
+11721
+11722
+11723
+11724
+11725
+11726
+11727
+11728
+11729
+11730
+11731
+11732
+11733
+11734
+11735
+11736
+11737
+11738
+11739
+11740
+11741
+11742
+11743
+11744
+11745
+11746
+11747
+11748
+11749
+11750
+11751
+11752
+11753
+11754
+11755
+11756
+11757
+11758
+11759
+11760
+11761
+11762
+11763
+11764
+11765
+11766
+11767
+11768
+11769
+11770
+11771
+11772
+11773
+11774
+11775
+11776
+11777
+11778
+11779
+11780
+11781
+11782
+11783
+11784
+11785
+11786
+11787
+11788
+11789
+11790
+11791
+11792
+11793
+11794
+11795
+11796
+11797
+11798
+11799
+11800
+11801
+11802
+11803
+11804
+11805
+11806
+11807
+11808
+11809
+11810
+11811
+11812
+11813
+11814
+11815
+11816
+11817
+11818
+11819
+11820
+11821
+11822
+11823
+11824
+11825
+11826
+11827
+11828
+11829
+11830
+11831
+11832
+11833
+11834
+11835
+11836
+11837
+11838
+11839
+11840
+11841
+11842
+11843
+11844
+11845
+11846
+11847
+11848
+11849
+11850
+11851
+11852
+11853
+11854
+11855
+11856
+11857
+11858
+11859
+11860
+11861
+11862
+11863
+11864
+11865
+11866
+11867
+11868
+11869
+11870
+11871
+11872
+11873
+11874
+11875
+11876
+11877
+11878
+11879
+11880
+11881
+11882
+11883
+11884
+11885
+11886
+11887
+11888
+11889
+11890
+11891
+11892
+11893
+11894
+11895
+11896
+11897
+11898
+11899
+11900
+11901
+11902
+11903
+11904
+11905
+11906
+11907
+11908
+11909
+11910
+11911
+11912
+11913
+11914
+11915
+11916
+11917
+11918
+11919
+11920
+11921
+11922
+11923
+11924
+11925
+11926
+11927
+11928
+11929
+11930
+11931
+11932
+11933
+11934
+11935
+11936
+11937
+11938
+11939
+11940
+11941
+11942
+11943
+11944
+11945
+11946
+11947
+11948
+11949
+11950
+11951
+11952
+11953
+11954
+11955
+11956
+11957
+11958
+11959
+11960
+11961
+11962
+11963
+11964
+11965
+11966
+11967
+11968
+11969
+11970
+11971
+11972
+11973
+11974
+11975
+11976
+11977
+11978
+11979
+11980
+11981
+11982
+11983
+11984
+11985
+11986
+11987
+11988
+11989
+11990
+11991
+11992
+11993
+11994
+11995
+11996
+11997
+11998
+11999
+12000
+12001
+12002
+12003
+12004
+12005
+12006
+12007
+12008
+12009
+12010
+12011
+12012
+12013
+12014
+12015
+12016
+12017
+12018
+12019
+12020
+12021
+12022
+12023
+12024
+12025
+12026
+12027
+12028
+12029
+12030
+12031
+12032
+12033
+12034
+12035
+12036
+12037
+12038
+12039
+12040
+12041
+12042
+12043
+12044
+12045
+12046
+12047
+12048
+12049
+12050
+12051
+12052
+12053
+12054
+12055
+12056
+12057
+12058
+12059
+12060
+12061
+12062
+12063
+12064
+12065
+12066
+12067
+12068
+12069
+12070
+12071
+12072
+12073
+12074
+12075
+12076
+12077
+12078
+12079
+12080
+12081
+12082
+12083
+12084
+12085
+12086
+12087
+12088
+12089
+12090
+12091
+12092
+12093
+12094
+12095
+12096
+12097
+12098
+12099
+12100
+12101
+12102
+12103
+12104
+12105
+12106
+12107
+12108
+12109
+12110
+12111
+12112
+12113
+12114
+12115
+12116
+12117
+12118
+12119
+12120
+12121
+12122
+12123
+12124
+12125
+12126
+12127
+12128
+12129
+12130
+12131
+12132
+12133
+12134
+12135
+12136
+12137
+12138
+12139
+12140
+12141
+12142
+12143
+12144
+12145
+12146
+12147
+12148
+12149
+12150
+12151
+12152
+12153
+12154
+12155
+12156
+12157
+12158
+12159
+12160
+12161
+12162
+12163
+12164
+12165
+12166
+12167
+12168
+12169
+12170
+12171
+12172
+12173
+12174
+12175
+12176
+12177
+12178
+12179
+12180
+12181
+12182
+12183
+12184
+12185
+12186
+12187
+12188
+12189
+12190
+12191
+12192
+12193
+12194
+12195
+12196
+12197
+12198
+12199
+12200
+12201
+12202
+12203
+12204
+12205
+12206
+12207
+12208
+12209
+12210
+12211
+12212
+12213
+12214
+12215
+12216
+12217
+12218
+12219
+12220
+12221
+12222
+12223
+12224
+12225
+12226
+12227
+12228
+12229
+12230
+12231
+12232
+12233
+12234
+12235
+12236
+12237
+12238
+12239
+12240
+12241
+12242
+12243
+12244
+12245
+12246
+12247
+12248
+12249
+12250
+12251
+12252
+12253
+12254
+12255
+12256
+12257
+12258
+12259
+12260
+12261
+12262
+12263
+12264
+12265
+12266
+12267
+12268
+12269
+12270
+12271
+12272
+12273
+12274
+12275
+12276
+12277
+12278
+12279
+12280
+12281
+12282
+12283
+12284
+12285
+12286
+12287
+12288
+12289
+12290
+12291
+12292
+12293
+12294
+12295
+12296
+12297
+12298
+12299
+12300
+12301
+12302
+12303
+12304
+12305
+12306
+12307
+12308
+12309
+12310
+12311
+12312
+12313
+12314
+12315
+12316
+12317
+12318
+12319
+12320
+12321
+12322
+12323
+12324
+12325
+12326
+12327
+12328
+12329
+12330
+12331
+12332
+12333
+12334
+12335
+12336
+12337
+12338
+12339
+12340
+12341
+12342
+12343
+12344
+12345
+12346
+12347
+12348
+12349
+12350
+12351
+12352
+12353
+12354
+12355
+12356
+12357
+12358
+12359
+12360
+12361
+12362
+12363
+12364
+12365
+12366
+12367
+12368
+12369
+12370
+12371
+12372
+12373
+12374
+12375
+12376
+12377
+12378
+12379
+12380
+12381
+12382
+12383
+12384
+12385
+12386
+12387
+12388
+12389
+12390
+12391
+12392
+12393
+12394
+12395
+12396
+12397
+12398
+12399
+12400
+12401
+12402
+12403
+12404
+12405
+12406
+12407
+12408
+12409
+12410
+12411
+12412
+12413
+12414
+12415
+12416
+12417
+12418
+12419
+12420
+12421
+12422
+12423
+12424
+12425
+12426
+12427
+12428
+12429
+12430
+12431
+12432
+12433
+12434
+12435
+12436
+12437
+12438
+12439
+12440
+12441
+12442
+12443
+12444
+12445
+12446
+12447
+12448
+12449
+12450
+12451
+12452
+12453
+12454
+12455
+12456
+12457
+12458
+12459
+12460
+12461
+12462
+12463
+12464
+12465
+12466
+12467
+12468
+12469
+12470
+12471
+12472
+12473
+12474
+12475
+12476
+12477
+12478
+12479
+12480
+12481
+12482
+12483
+12484
+12485
+12486
+12487
+12488
+12489
+12490
+12491
+12492
+12493
+12494
+12495
+12496
+12497
+12498
+12499
+12500
+12501
+12502
+12503
+12504
+12505
+12506
+12507
+12508
+12509
+12510
+12511
+12512
+12513
+12514
+12515
+12516
+12517
+12518
+12519
+12520
+12521
+12522
+12523
+12524
+12525
+12526
+12527
+12528
+12529
+12530
+12531
+12532
+12533
+12534
+12535
+12536
+12537
+12538
+12539
+12540
+12541
+12542
+12543
+12544
+12545
+12546
+12547
+12548
+12549
+12550
+12551
+12552
+12553
+12554
+12555
+12556
+12557
+12558
+12559
+12560
+12561
+12562
+12563
+12564
+12565
+12566
+12567
+12568
+12569
+12570
+12571
+12572
+12573
+12574
+12575
+12576
+12577
+12578
+12579
+12580
+12581
+12582
+12583
+12584
+12585
+12586
+12587
+12588
+12589
+12590
+12591
+12592
+12593
+12594
+12595
+12596
+12597
+12598
+12599
+12600
+12601
+12602
+12603
+12604
+12605
+12606
+12607
+12608
+12609
+12610
+12611
+12612
+12613
+12614
+12615
+12616
+12617
+12618
+12619
+12620
+12621
+12622
+12623
+12624
+12625
+12626
+12627
+12628
+12629
+12630
+12631
+12632
+12633
+12634
+12635
+12636
+12637
+12638
+12639
+12640
+12641
+12642
+12643
+12644
+12645
+12646
+12647
+12648
+12649
+12650
+12651
+12652
+12653
+12654
+12655
+12656
+12657
+12658
+12659
+12660
+12661
+12662
+12663
+12664
+12665
+12666
+12667
+12668
+12669
+12670
+12671
+12672
+12673
+12674
+12675
+12676
+12677
+12678
+12679
+12680
+12681
+12682
+12683
+12684
+12685
+12686
+12687
+12688
+12689
+12690
+12691
+12692
+12693
+12694
+12695
+12696
+12697
+12698
+12699
+12700
+12701
+12702
+12703
+12704
+12705
+12706
+12707
+12708
+12709
+12710
+12711
+12712
+12713
+12714
+12715
+12716
+12717
+12718
+12719
+12720
+12721
+12722
+12723
+12724
+12725
+12726
+12727
+12728
+12729
+12730
+12731
+12732
+12733
+12734
+12735
+12736
+12737
+12738
+12739
+12740
+12741
+12742
+12743
+12744
+12745
+12746
+12747
+12748
+12749
+12750
+12751
+12752
+12753
+12754
+12755
+12756
+12757
+12758
+12759
+12760
+12761
+12762
+12763
+12764
+12765
+12766
+12767
+12768
+12769
+12770
+12771
+12772
+12773
+12774
+12775
+12776
+12777
+12778
+12779
+12780
+12781
+12782
+12783
+12784
+12785
+12786
+12787
+12788
+12789
+12790
+12791
+12792
+12793
+12794
+12795
+12796
+12797
+12798
+12799
+12800
+12801
+12802
+12803
+12804
+12805
+12806
+12807
+12808
+12809
+12810
+12811
+12812
+12813
+12814
+12815
+12816
+12817
+12818
+12819
+12820
+12821
+12822
+12823
+12824
+12825
+12826
+12827
+12828
+12829
+12830
+12831
+12832
+12833
+12834
+12835
+12836
+12837
+12838
+12839
+12840
+12841
+12842
+12843
+12844
+12845
+12846
+12847
+12848
+12849
+12850
+12851
+12852
+12853
+12854
+12855
+12856
+12857
+12858
+12859
+12860
+12861
+12862
+12863
+12864
+12865
+12866
+12867
+12868
+12869
+12870
+12871
+12872
+12873
+12874
+12875
+12876
+12877
+12878
+12879
+12880
+12881
+12882
+12883
+12884
+12885
+12886
+12887
+12888
+12889
+12890
+12891
+12892
+12893
+12894
+12895
+12896
+12897
+12898
+12899
+12900
+12901
+12902
+12903
+12904
+12905
+12906
+12907
+12908
+12909
+12910
+12911
+12912
+12913
+12914
+12915
+12916
+12917
+12918
+12919
+12920
+12921
+12922
+12923
+12924
+12925
+12926
+12927
+12928
+12929
+12930
+12931
+12932
+12933
+12934
+12935
+12936
+12937
+12938
+12939
+12940
+12941
+12942
+12943
+12944
+12945
+12946
+12947
+12948
+12949
+12950
+12951
+12952
+12953
+12954
+12955
+12956
+12957
+12958
+12959
+12960
+12961
+12962
+12963
+12964
+12965
+12966
+12967
+12968
+12969
+12970
+12971
+12972
+12973
+12974
+12975
+12976
+12977
+12978
+12979
+12980
+12981
+12982
+12983
+12984
+12985
+12986
+12987
+12988
+12989
+12990
+12991
+12992
+12993
+12994
+12995
+12996
+12997
+12998
+12999
+13000
+13001
+13002
+13003
+13004
+13005
+13006
+13007
+13008
+13009
+13010
+13011
+13012
+13013
+13014
+13015
+13016
+13017
+13018
+13019
+13020
+13021
+13022
+13023
+13024
+13025
+13026
+13027
+13028
+13029
+13030
+13031
+13032
+13033
+13034
+13035
+13036
+13037
+13038
+13039
+13040
+13041
+13042
+13043
+13044
+13045
+13046
+13047
+13048
+13049
+13050
+13051
+13052
+13053
+13054
+13055
+13056
+13057
+13058
+13059
+13060
+13061
+13062
+13063
+13064
+13065
+13066
+13067
+13068
+13069
+13070
+13071
+13072
+13073
+13074
+13075
+13076
+13077
+13078
+13079
+13080
+13081
+13082
+13083
+13084
+13085
+13086
+13087
+13088
+13089
+13090
+13091
+13092
+13093
+13094
+13095
+13096
+13097
+13098
+13099
+13100
+13101
+13102
+13103
+13104
+13105
+13106
+13107
+13108
+13109
+13110
+13111
+13112
+13113
+13114
+13115
+13116
+13117
+13118
+13119
+13120
+13121
+13122
+13123
+13124
+13125
+13126
+13127
+13128
+13129
+13130
+13131
+13132
+13133
+13134
+13135
+13136
+13137
+13138
+13139
+13140
+13141
+13142
+13143
+13144
+13145
+13146
+13147
+13148
+13149
+13150
+13151
+13152
+13153
+13154
+13155
+13156
+13157
+13158
+13159
+13160
+13161
+13162
+13163
+13164
+13165
+13166
+13167
+13168
+13169
+13170
+13171
+13172
+13173
+13174
+13175
+13176
+13177
+13178
+13179
+13180
+13181
+13182
+13183
+13184
+13185
+13186
+13187
+13188
+13189
+13190
+13191
+13192
+13193
+13194
+13195
+13196
+13197
+13198
+13199
+13200
+13201
+13202
+13203
+13204
+13205
+13206
+13207
+13208
+13209
+13210
+13211
+13212
+13213
+13214
+13215
+13216
+13217
+13218
+13219
+13220
+13221
+13222
+13223
+13224
+13225
+13226
+13227
+13228
+13229
+13230
+13231
+13232
+13233
+13234
+13235
+13236
+13237
+13238
+13239
+13240
+13241
+13242
+13243
+13244
+13245
+13246
+13247
+13248
+13249
+13250
+13251
+13252
+13253
+13254
+13255
+13256
+13257
+13258
+13259
+13260
+13261
+13262
+13263
+13264
+13265
+13266
+13267
+13268
+13269
+13270
+13271
+13272
+13273
+13274
+13275
+13276
+13277
+13278
+13279
+13280
+13281
+13282
+13283
+13284
+13285
+13286
+13287
+13288
+13289
+13290
+13291
+13292
+13293
+13294
+13295
+13296
+13297
+13298
+13299
+13300
+13301
+13302
+13303
+13304
+13305
+13306
+13307
+13308
+13309
+13310
+13311
+13312
+13313
+13314
+13315
+13316
+13317
+13318
+13319
+13320
+13321
+13322
+13323
+13324
+13325
+13326
+13327
+13328
+13329
+13330
+13331
+13332
+13333
+13334
+13335
+13336
+13337
+13338
+13339
+13340
+13341
+13342
+13343
+13344
+13345
+13346
+13347
+13348
+13349
+13350
+13351
+13352
+13353
+13354
+13355
+13356
+13357
+13358
+13359
+13360
+13361
+13362
+13363
+13364
+13365
+13366
+13367
+13368
+13369
+13370
+13371
+13372
+13373
+13374
+13375
+13376
+13377
+13378
+13379
+13380
+13381
+13382
+13383
+13384
+13385
+13386
+13387
+13388
+13389
+13390
+13391
+13392
+13393
+13394
+13395
+13396
+13397
+13398
+13399
+13400
+13401
+13402
+13403
+13404
+13405
+13406
+13407
+13408
+13409
+13410
+13411
+13412
+13413
+13414
+13415
+13416
+13417
+13418
+13419
+13420
+13421
+13422
+13423
+13424
+13425
+13426
+13427
+13428
+13429
+13430
+13431
+13432
+13433
+13434
+13435
+13436
+13437
+13438
+13439
+13440
+13441
+13442
+13443
+13444
+13445
+13446
+13447
+13448
+13449
+13450
+13451
+13452
+13453
+13454
+13455
+13456
+13457
+13458
+13459
+13460
+13461
+13462
+13463
+13464
+13465
+13466
+13467
+13468
+13469
+13470
+13471
+13472
+13473
+13474
+13475
+13476
+13477
+13478
+13479
+13480
+13481
+13482
+13483
+13484
+13485
+13486
+13487
+13488
+13489
+13490
+13491
+13492
+13493
+13494
+13495
+13496
+13497
+13498
+13499
+13500
+13501
+13502
+13503
+13504
+13505
+13506
+13507
+13508
+13509
+13510
+13511
+13512
+13513
+13514
+13515
+13516
+13517
+13518
+13519
+13520
+13521
+13522
+13523
+13524
+13525
+13526
+13527
+13528
+13529
+13530
+13531
+13532
+13533
+13534
+13535
+13536
+13537
+13538
+13539
+13540
+13541
+13542
+13543
+13544
+13545
+13546
+13547
+13548
+13549
+13550
+13551
+13552
+13553
+13554
+13555
+13556
+13557
+13558
+13559
+13560
+13561
+13562
+13563
+13564
+13565
+13566
+13567
+13568
+13569
+13570
+13571
+13572
+13573
+13574
+13575
+13576
+13577
+13578
+13579
+13580
+13581
+13582
+13583
+13584
+13585
+13586
+13587
+13588
+13589
+13590
+13591
+13592
+13593
+13594
+13595
+13596
+13597
+13598
+13599
+13600
+13601
+13602
+13603
+13604
+13605
+13606
+13607
+13608
+13609
+13610
+13611
+13612
+13613
+13614
+13615
+13616
+13617
+13618
+13619
+13620
+13621
+13622
+13623
+13624
+13625
+13626
+13627
+13628
+13629
+13630
+13631
+13632
+13633
+13634
+13635
+13636
+13637
+13638
+13639
+13640
+13641
+13642
+13643
+13644
+13645
+13646
+13647
+13648
+13649
+13650
+13651
+13652
+13653
+13654
+13655
+13656
+13657
+13658
+13659
+13660
+13661
+13662
+13663
+13664
+13665
+13666
+13667
+13668
+13669
+13670
+13671
+13672
+13673
+13674
+13675
+13676
+13677
+13678
+13679
+13680
+13681
+13682
+13683
+13684
+13685
+13686
+13687
+13688
+13689
+13690
+13691
+13692
+13693
+13694
+13695
+13696
+13697
+13698
+13699
+13700
+13701
+13702
+13703
+13704
+13705
+13706
+13707
+13708
+13709
+13710
+13711
+13712
+13713
+13714
+13715
+13716
+13717
+13718
+13719
+13720
+13721
+13722
+13723
+13724
+13725
+13726
+13727
+13728
+13729
+13730
+13731
+13732
+13733
+13734
+13735
+13736
+13737
+13738
+13739
+13740
+13741
+13742
+13743
+13744
+13745
+13746
+13747
+13748
+13749
+13750
+13751
+13752
+13753
+13754
+13755
+13756
+13757
+13758
+13759
+13760
+13761
+13762
+13763
+13764
+13765
+13766
+13767
+13768
+13769
+13770
+13771
+13772
+13773
+13774
+13775
+13776
+13777
+13778
+13779
+13780
+13781
+13782
+13783
+13784
+13785
+13786
+13787
+13788
+13789
+13790
+13791
+13792
+13793
+13794
+13795
+13796
+13797
+13798
+13799
+13800
+13801
+13802
+13803
+13804
+13805
+13806
+13807
+13808
+13809
+13810
+13811
+13812
+13813
+13814
+13815
+13816
+13817
+13818
+13819
+13820
+13821
+13822
+13823
+13824
+13825
+13826
+13827
+13828
+13829
+13830
+13831
+13832
+13833
+13834
+13835
+13836
+13837
+13838
+13839
+13840
+13841
+13842
+13843
+13844
+13845
+13846
+13847
+13848
+13849
+13850
+13851
+13852
+13853
+13854
+13855
+13856
+13857
+13858
+13859
+13860
+13861
+13862
+13863
+13864
+13865
+13866
+13867
+13868
+13869
+13870
+13871
+13872
+13873
+13874
+13875
+13876
+13877
+13878
+13879
+13880
+13881
+13882
+13883
+13884
+13885
+13886
+13887
+13888
+13889
+13890
+13891
+13892
+13893
+13894
+13895
+13896
+13897
+13898
+13899
+13900
+13901
+13902
+13903
+13904
+13905
+13906
+13907
+13908
+13909
+13910
+13911
+13912
+13913
+13914
+13915
+13916
+13917
+13918
+13919
+13920
+13921
+13922
+13923
+13924
+13925
+13926
+13927
+13928
+13929
+13930
+13931
+13932
+13933
+13934
+13935
+13936
+13937
+13938
+13939
+13940
+13941
+13942
+13943
+13944
+13945
+13946
+13947
+13948
+13949
+13950
+13951
+13952
+13953
+13954
+13955
+13956
+13957
+13958
+13959
+13960
+13961
+13962
+13963
+13964
+13965
+13966
+13967
+13968
+13969
+13970
+13971
+13972
+13973
+13974
+13975
+13976
+13977
+13978
+13979
+13980
+13981
+13982
+13983
+13984
+13985
+13986
+13987
+13988
+13989
+13990
+13991
+13992
+13993
+13994
+13995
+13996
+13997
+13998
+13999
+14000
+14001
+14002
+14003
+14004
+14005
+14006
+14007
+14008
+14009
+14010
+14011
+14012
+14013
+14014
+14015
+14016
+14017
+14018
+14019
+14020
+14021
+14022
+14023
+14024
+14025
+14026
+14027
+14028
+14029
+14030
+14031
+14032
+14033
+14034
+14035
+14036
+14037
+14038
+14039
+14040
+14041
+14042
+14043
+14044
+14045
+14046
+14047
+14048
+14049
+14050
+14051
+14052
+14053
+14054
+14055
+14056
+14057
+14058
+14059
+14060
+14061
+14062
+14063
+14064
+14065
+14066
+14067
+14068
+14069
+14070
+14071
+14072
+14073
+14074
+14075
+14076
+14077
+14078
+14079
+14080
+14081
+14082
+14083
+14084
+14085
+14086
+14087
+14088
+14089
+14090
+14091
+14092
+14093
+14094
+14095
+14096
+14097
+14098
+14099
+14100
+14101
+14102
+14103
+14104
+14105
+14106
+14107
+14108
+14109
+14110
+14111
+14112
+14113
+14114
+14115
+14116
+14117
+14118
+14119
+14120
+14121
+14122
+14123
+14124
+14125
+14126
+14127
+14128
+14129
+14130
+14131
+14132
+14133
+14134
+14135
+14136
+14137
+14138
+14139
+14140
+14141
+14142
+14143
+14144
+14145
+14146
+14147
+14148
+14149
+14150
+14151
+14152
+14153
+14154
+14155
+14156
+14157
+14158
+14159
+14160
+14161
+14162
+14163
+14164
+14165
+14166
+14167
+14168
+14169
+14170
+14171
+14172
+14173
+14174
+14175
+14176
+14177
+14178
+14179
+14180
+14181
+14182
+14183
+14184
+14185
+14186
+14187
+14188
+14189
+14190
+14191
+14192
+14193
+14194
+14195
+14196
+14197
+14198
+14199
+14200
+14201
+14202
+14203
+14204
+14205
+14206
+14207
+14208
+14209
+14210
+14211
+14212
+14213
+14214
+14215
+14216
+14217
+14218
+14219
+14220
+14221
+14222
+14223
+14224
+14225
+14226
+14227
+14228
+14229
+14230
+14231
+14232
+14233
+14234
+14235
+14236
+14237
+14238
+14239
+14240
+14241
+14242
+14243
+14244
+14245
+14246
+14247
+14248
+14249
+14250
+14251
+14252
+14253
+14254
+14255
+14256
+14257
+14258
+14259
+14260
+14261
+14262
+14263
+14264
+14265
+14266
+14267
+14268
+14269
+14270
+14271
+14272
+14273
+14274
+14275
+14276
+14277
+14278
+14279
+14280
+14281
+14282
+14283
+14284
+14285
+14286
+14287
+14288
+14289
+14290
+14291
+14292
+14293
+14294
+14295
+14296
+14297
+14298
+14299
+14300
+14301
+14302
+14303
+14304
+14305
+14306
+14307
+14308
+14309
+14310
+14311
+14312
+14313
+14314
+14315
+14316
+14317
+14318
+14319
+14320
+14321
+14322
+14323
+14324
+14325
+14326
+14327
+14328
+14329
+14330
+14331
+14332
+14333
+14334
+14335
+14336
+14337
+14338
+14339
+14340
+14341
+14342
+14343
+14344
+14345
+14346
+14347
+14348
+14349
+14350
+14351
+14352
+14353
+14354
+14355
+14356
+14357
+14358
+14359
+14360
+14361
+14362
+14363
+14364
+14365
+14366
+14367
+14368
+14369
+14370
+14371
+14372
+14373
+14374
+14375
+14376
+14377
+14378
+14379
+14380
+14381
+14382
+14383
+14384
+14385
+14386
+14387
+14388
+14389
+14390
+14391
+14392
+14393
+14394
+14395
+14396
+14397
+14398
+14399
+14400
+14401
+14402
+14403
+14404
+14405
+14406
+14407
+14408
+14409
+14410
+14411
+14412
+14413
+14414
+14415
+14416
+14417
+14418
+14419
+14420
+14421
+14422
+14423
+14424
+14425
+14426
+14427
+14428
+14429
+14430
+14431
+14432
+14433
+14434
+14435
+14436
+14437
+14438
+14439
+14440
+14441
+14442
+14443
+14444
+14445
+14446
+14447
+14448
+14449
+14450
+14451
+14452
+14453
+14454
+14455
+14456
+14457
+14458
+14459
+14460
+14461
+14462
+14463
+14464
+14465
+14466
+14467
+14468
+14469
+14470
+14471
+14472
+14473
+14474
+14475
+14476
+14477
+14478
+14479
+14480
+14481
+14482
+14483
+14484
+14485
+14486
+14487
+14488
+14489
+14490
+14491
+14492
+14493
+14494
+14495
+14496
+14497
+14498
+14499
+14500
+14501
+14502
+14503
+14504
+14505
+14506
+14507
+14508
+14509
+14510
+14511
+14512
+14513
+14514
+14515
+14516
+14517
+14518
+14519
+14520
+14521
+14522
+14523
+14524
+14525
+14526
+14527
+14528
+14529
+14530
+14531
+14532
+14533
+14534
+14535
+14536
+14537
+14538
+14539
+14540
+14541
+14542
+14543
+14544
+14545
+14546
+14547
+14548
+14549
+14550
+14551
+14552
+14553
+14554
+14555
+14556
+14557
+14558
+14559
+14560
+14561
+14562
+14563
+14564
+14565
+14566
+14567
+14568
+14569
+14570
+14571
+14572
+14573
+14574
+14575
+14576
+14577
+14578
+14579
+14580
+14581
+14582
+14583
+14584
+14585
+14586
+14587
+14588
+14589
+14590
+14591
+14592
+14593
+14594
+14595
+14596
+14597
+14598
+14599
+14600
+14601
+14602
+14603
+14604
+14605
+14606
+14607
+14608
+14609
+14610
+14611
+14612
+14613
+14614
+14615
+14616
+14617
+14618
+14619
+14620
+14621
+14622
+14623
+14624
+14625
+14626
+14627
+14628
+14629
+14630
+14631
+14632
+14633
+14634
+14635
+14636
+14637
+14638
+14639
+14640
+14641
+14642
+14643
+14644
+14645
+14646
+14647
+14648
+14649
+14650
+14651
+14652
+14653
+14654
+14655
+14656
+14657
+14658
+14659
+14660
+14661
+14662
+14663
+14664
+14665
+14666
+14667
+14668
+14669
+14670
+14671
+14672
+14673
+14674
+14675
+14676
+14677
+14678
+14679
+14680
+14681
+14682
+14683
+14684
+14685
+14686
+14687
+14688
+14689
+14690
+14691
+14692
+14693
+14694
+14695
+14696
+14697
+14698
+14699
+14700
+14701
+14702
+14703
+14704
+14705
+14706
+14707
+14708
+14709
+14710
+14711
+14712
+14713
+14714
+14715
+14716
+14717
+14718
+14719
+14720
+14721
+14722
+14723
+14724
+14725
+14726
+14727
+14728
+14729
+14730
+14731
+14732
+14733
+14734
+14735
+14736
+14737
+14738
+14739
+14740
+14741
+14742
+14743
+14744
+14745
+14746
+14747
+14748
+14749
+14750
+14751
+14752
+14753
+14754
+14755
+14756
+14757
+14758
+14759
+14760
+14761
+14762
+14763
+14764
+14765
+14766
+14767
+14768
+14769
+14770
+14771
+14772
+14773
+14774
+14775
+14776
+14777
+14778
+14779
+14780
+14781
+14782
+14783
+14784
+14785
+14786
+14787
+14788
+14789
+14790
+14791
+14792
+14793
+14794
+14795
+14796
+14797
+14798
+14799
+14800
+14801
+14802
+14803
+14804
+14805
+14806
+14807
+14808
+14809
+14810
+14811
+14812
+14813
+14814
+14815
+14816
+14817
+14818
+14819
+14820
+14821
+14822
+14823
+14824
+14825
+14826
+14827
+14828
+14829
+14830
+14831
+14832
+14833
+14834
+14835
+14836
+14837
+14838
+14839
+14840
+14841
+14842
+14843
+14844
+14845
+14846
+14847
+14848
+14849
+14850
+14851
+14852
+14853
+14854
+14855
+14856
+14857
+14858
+14859
+14860
+14861
+14862
+14863
+14864
+14865
+14866
+14867
+14868
+14869
+14870
+14871
+14872
+14873
+14874
+14875
+14876
+14877
+14878
+14879
+14880
+14881
+14882
+14883
+14884
+14885
+14886
+14887
+14888
+14889
+14890
+14891
+14892
+14893
+14894
+14895
+14896
+14897
+14898
+14899
+14900
+14901
+14902
+14903
+14904
+14905
+14906
+14907
+14908
+14909
+14910
+14911
+14912
+14913
+14914
+14915
+14916
+14917
+14918
+14919
+14920
+14921
+14922
+14923
+14924
+14925
+14926
+14927
+14928
+14929
+14930
+14931
+14932
+14933
+14934
+14935
+14936
+14937
+14938
+14939
+14940
+14941
+14942
+14943
+14944
+14945
+14946
+14947
+14948
+14949
+14950
+14951
+14952
+14953
+14954
+14955
+14956
+14957
+14958
+14959
+14960
+14961
+14962
+14963
+14964
+14965
+14966
+14967
+14968
+14969
+14970
+14971
+14972
+14973
+14974
+14975
+14976
+14977
+14978
+14979
+14980
+14981
+14982
+14983
+14984
+14985
+14986
+14987
+14988
+14989
+14990
+14991
+14992
+14993
+14994
+14995
+14996
+14997
+14998
+14999
+15000
+15001
+15002
+15003
+15004
+15005
+15006
+15007
+15008
+15009
+15010
+15011
+15012
+15013
+15014
+15015
+15016
+15017
+15018
+15019
+15020
+15021
+15022
+15023
+15024
+15025
+15026
+15027
+15028
+15029
+15030
+15031
+15032
+15033
+15034
+15035
+15036
+15037
+15038
+15039
+15040
+15041
+15042
+15043
+15044
+15045
+15046
+15047
+15048
+15049
+15050
+15051
+15052
+15053
+15054
+15055
+15056
+15057
+15058
+15059
+15060
+15061
+15062
+15063
+15064
+15065
+15066
+15067
+15068
+15069
+15070
+15071
+15072
+15073
+15074
+15075
+15076
+15077
+15078
+15079
+15080
+15081
+15082
+15083
+15084
+15085
+15086
+15087
+15088
+15089
+15090
+15091
+15092
+15093
+15094
+15095
+15096
+15097
+15098
+15099
+15100
+15101
+15102
+15103
+15104
+15105
+15106
+15107
+15108
+15109
+15110
+15111
+15112
+15113
+15114
+15115
+15116
+15117
+15118
+15119
+15120
+15121
+15122
+15123
+15124
+15125
+15126
+15127
+15128
+15129
+15130
+15131
+15132
+15133
+15134
+15135
+15136
+15137
+15138
+15139
+15140
+15141
+15142
+15143
+15144
+15145
+15146
+15147
+15148
+15149
+15150
+15151
+15152
+15153
+15154
+15155
+15156
+15157
+15158
+15159
+15160
+15161
+15162
+15163
+15164
+15165
+15166
+15167
+15168
+15169
+15170
+15171
+15172
+15173
+15174
+15175
+15176
+15177
+15178
+15179
+15180
+15181
+15182
+15183
+15184
+15185
+15186
+15187
+15188
+15189
+15190
+15191
+15192
+15193
+15194
+15195
+15196
+15197
+15198
+15199
+15200
+15201
+15202
+15203
+15204
+15205
+15206
+15207
+15208
+15209
+15210
+15211
+15212
+15213
+15214
+15215
+15216
+15217
+15218
+15219
+15220
+15221
+15222
+15223
+15224
+15225
+15226
+15227
+15228
+15229
+15230
+15231
+15232
+15233
+15234
+15235
+15236
+15237
+15238
+15239
+15240
+15241
+15242
+15243
+15244
+15245
+15246
+15247
+15248
+15249
+15250
+15251
+15252
+15253
+15254
+15255
+15256
+15257
+15258
+15259
+15260
+15261
+15262
+15263
+15264
+15265
+15266
+15267
+15268
+15269
+15270
+15271
+15272
+15273
+15274
+15275
+15276
+15277
+15278
+15279
+15280
+15281
+15282
+15283
+15284
+15285
+15286
+15287
+15288
+15289
+15290
+15291
+15292
+15293
+15294
+15295
+15296
+15297
+15298
+15299
+15300
+15301
+15302
+15303
+15304
+15305
+15306
+15307
+15308
+15309
+15310
+15311
+15312
+15313
+15314
+15315
+15316
+15317
+15318
+15319
+15320
+15321
+15322
+15323
+15324
+15325
+15326
+15327
+15328
+15329
+15330
+15331
+15332
+15333
+15334
+15335
+15336
+15337
+15338
+15339
+15340
+15341
+15342
+15343
+15344
+15345
+15346
+15347
+15348
+15349
+15350
+15351
+15352
+15353
+15354
+15355
+15356
+15357
+15358
+15359
+15360
+15361
+15362
+15363
+15364
+15365
+15366
+15367
+15368
+15369
+15370
+15371
+15372
+15373
+15374
+15375
+15376
+15377
+15378
+15379
+15380
+15381
+15382
+15383
+15384
+15385
+15386
+15387
+15388
+15389
+15390
+15391
+15392
+15393
+15394
+15395
+15396
+15397
+15398
+15399
+15400
+15401
+15402
+15403
+15404
+15405
+15406
+15407
+15408
+15409
+15410
+15411
+15412
+15413
+15414
+15415
+15416
+15417
+15418
+15419
+15420
+15421
+15422
+15423
+15424
+15425
+15426
+15427
+15428
+15429
+15430
+15431
+15432
+15433
+15434
+15435
+15436
+15437
+15438
+15439
+15440
+15441
+15442
+15443
+15444
+15445
+15446
+15447
+15448
+15449
+15450
+15451
+15452
+15453
+15454
+15455
+15456
+15457
+15458
+15459
+15460
+15461
+15462
+15463
+15464
+15465
+15466
+15467
+15468
+15469
+15470
+15471
+15472
+15473
+15474
+15475
+15476
+15477
+15478
+15479
+15480
+15481
+15482
+15483
+15484
+15485
+15486
+15487
+15488
+15489
+15490
+15491
+15492
+15493
+15494
+15495
+15496
+15497
+15498
+15499
+15500
+15501
+15502
+15503
+15504
+15505
+15506
+15507
+15508
+15509
+15510
+15511
+15512
+15513
+15514
+15515
+15516
+15517
+15518
+15519
+15520
+15521
+15522
+15523
+15524
+15525
+15526
+15527
+15528
+15529
+15530
+15531
+15532
+15533
+15534
+15535
+15536
+15537
+15538
+15539
+15540
+15541
+15542
+15543
+15544
+15545
+15546
+15547
+15548
+15549
+15550
+15551
+15552
+15553
+15554
+15555
+15556
+15557
+15558
+15559
+15560
+15561
+15562
+15563
+15564
+15565
+15566
+15567
+15568
+15569
+15570
+15571
+15572
+15573
+15574
+15575
+15576
+15577
+15578
+15579
+15580
+15581
+15582
+15583
+15584
+15585
+15586
+15587
+15588
+15589
+15590
+15591
+15592
+15593
+15594
+15595
+15596
+15597
+15598
+15599
+15600
+15601
+15602
+15603
+15604
+15605
+15606
+15607
+15608
+15609
+15610
+15611
+15612
+15613
+15614
+15615
+15616
+15617
+15618
+15619
+15620
+15621
+15622
+15623
+15624
+15625
+15626
+15627
+15628
+15629
+15630
+15631
+15632
+15633
+15634
+15635
+15636
+15637
+15638
+15639
+15640
+15641
+15642
+15643
+15644
+15645
+15646
+15647
+15648
+15649
+15650
+15651
+15652
+15653
+15654
+15655
+15656
+15657
+15658
+15659
+15660
+15661
+15662
+15663
+15664
+15665
+15666
+15667
+15668
+15669
+15670
+15671
+15672
+15673
+15674
+15675
+15676
+15677
+15678
+15679
+15680
+15681
+15682
+15683
+15684
+15685
+15686
+15687
+15688
+15689
+15690
+15691
+15692
+15693
+15694
+15695
+15696
+15697
+15698
+15699
+15700
+15701
+15702
+15703
+15704
+15705
+15706
+15707
+15708
+15709
+15710
+15711
+15712
+15713
+15714
+15715
+15716
+15717
+15718
+15719
+15720
+15721
+15722
+15723
+15724
+15725
+15726
+15727
+15728
+15729
+15730
+15731
+15732
+15733
+15734
+15735
+15736
+15737
+15738
+15739
+15740
+15741
+15742
+15743
+15744
+15745
+15746
+15747
+15748
+15749
+15750
+15751
+15752
+15753
+15754
+15755
+15756
+15757
+15758
+15759
+15760
+15761
+15762
+15763
+15764
+15765
+15766
+15767
+15768
+15769
+15770
+15771
+15772
+15773
+15774
+15775
+15776
+15777
+15778
+15779
+15780
+15781
+15782
+15783
+15784
+15785
+15786
+15787
+15788
+15789
+15790
+15791
+15792
+15793
+15794
+15795
+15796
+15797
+15798
+15799
+15800
+15801
+15802
+15803
+15804
+15805
+15806
+15807
+15808
+15809
+15810
+15811
+15812
+15813
+15814
+15815
+15816
+15817
+15818
+15819
+15820
+15821
+15822
+15823
+15824
+15825
+15826
+15827
+15828
+15829
+15830
+15831
+15832
+15833
+15834
+15835
+15836
+15837
+15838
+15839
+15840
+15841
+15842
+15843
+15844
+15845
+15846
+15847
+15848
+15849
+15850
+15851
+15852
+15853
+15854
+15855
+15856
+15857
+15858
+15859
+15860
+15861
+15862
+15863
+15864
+15865
+15866
+15867
+15868
+15869
+15870
+15871
+15872
+15873
+15874
+15875
+15876
+15877
+15878
+15879
+15880
+15881
+15882
+15883
+15884
+15885
+15886
+15887
+15888
+15889
+15890
+15891
+15892
+15893
+15894
+15895
+15896
+15897
+15898
+15899
+15900
+15901
+15902
+15903
+15904
+15905
+15906
+15907
+15908
+15909
+15910
+15911
+15912
+15913
+15914
+15915
+15916
+15917
+15918
+15919
+15920
+15921
+15922
+15923
+15924
+15925
+15926
+15927
+15928
+15929
+15930
+15931
+15932
+15933
+15934
+15935
+15936
+15937
+15938
+15939
+15940
+15941
+15942
+15943
+15944
+15945
+15946
+15947
+15948
+15949
+15950
+15951
+15952
+15953
+15954
+15955
+15956
+15957
+15958
+15959
+15960
+15961
+15962
+15963
+15964
+15965
+15966
+15967
+15968
+15969
+15970
+15971
+15972
+15973
+15974
+15975
+15976
+15977
+15978
+15979
+15980
+15981
+15982
+15983
+15984
+15985
+15986
+15987
+15988
+15989
+15990
+15991
+15992
+15993
+15994
+15995
+15996
+15997
+15998
+15999
+16000
+16001
+16002
+16003
+16004
+16005
+16006
+16007
+16008
+16009
+16010
+16011
+16012
+16013
+16014
+16015
+16016
+16017
+16018
+16019
+16020
+16021
+16022
+16023
+16024
+16025
+16026
+16027
+16028
+16029
+16030
+16031
+16032
+16033
+16034
+16035
+16036
+16037
+16038
+16039
+16040
+16041
+16042
+16043
+16044
+16045
+16046
+16047
+16048
+16049
+16050
+16051
+16052
+16053
+16054
+16055
+16056
+16057
+16058
+16059
+16060
+16061
+16062
+16063
+16064
+16065
+16066
+16067
+16068
+16069
+16070
+16071
+16072
+16073
+16074
+16075
+16076
+16077
+16078
+16079
+16080
+16081
+16082
+16083
+16084
+16085
+16086
+16087
+16088
+16089
+16090
+16091
+16092
+16093
+16094
+16095
+16096
+16097
+16098
+16099
+16100
+16101
+16102
+16103
+16104
+16105
+16106
+16107
+16108
+16109
+16110
+16111
+16112
+16113
+16114
+16115
+16116
+16117
+16118
+16119
+16120
+16121
+16122
+16123
+16124
+16125
+16126
+16127
+16128
+16129
+16130
+16131
+16132
+16133
+16134
+16135
+16136
+16137
+16138
+16139
+16140
+16141
+16142
+16143
+16144
+16145
+16146
+16147
+16148
+16149
+16150
+16151
+16152
+16153
+16154
+16155
+16156
+16157
+16158
+16159
+16160
+16161
+16162
+16163
+16164
+16165
+16166
+16167
+16168
+16169
+16170
+16171
+16172
+16173
+16174
+16175
+16176
+16177
+16178
+16179
+16180
+16181
+16182
+16183
+16184
+16185
+16186
+16187
+16188
+16189
+16190
+16191
+16192
+16193
+16194
+16195
+16196
+16197
+16198
+16199
+16200
+16201
+16202
+16203
+16204
+16205
+16206
+16207
+16208
+16209
+16210
+16211
+16212
+16213
+16214
+16215
+16216
+16217
+16218
+16219
+16220
+16221
+16222
+16223
+16224
+16225
+16226
+16227
+16228
+16229
+16230
+16231
+16232
+16233
+16234
+16235
+16236
+16237
+16238
+16239
+16240
+16241
+16242
+16243
+16244
+16245
+16246
+16247
+16248
+16249
+16250
+16251
+16252
+16253
+16254
+16255
+16256
+16257
+16258
+16259
+16260
+16261
+16262
+16263
+16264
+16265
+16266
+16267
+16268
+16269
+16270
+16271
+16272
+16273
+16274
+16275
+16276
+16277
+16278
+16279
+16280
+16281
+16282
+16283
+16284
+16285
+16286
+16287
+16288
+16289
+16290
+16291
+16292
+16293
+16294
+16295
+16296
+16297
+16298
+16299
+16300
+16301
+16302
+16303
+16304
+16305
+16306
+16307
+16308
+16309
+16310
+16311
+16312
+16313
+16314
+16315
+16316
+16317
+16318
+16319
+16320
+16321
+16322
+16323
+16324
+16325
+16326
+16327
+16328
+16329
+16330
+16331
+16332
+16333
+16334
+16335
+16336
+16337
+16338
+16339
+16340
+16341
+16342
+16343
+16344
+16345
+16346
+16347
+16348
+16349
+16350
+16351
+16352
+16353
+16354
+16355
+16356
+16357
+16358
+16359
+16360
+16361
+16362
+16363
+16364
+16365
+16366
+16367
+16368
+16369
+16370
+16371
+16372
+16373
+16374
+16375
+16376
+16377
+16378
+16379
+16380
+16381
+16382
+16383
+16384
+16385
+16386
+16387
+16388
+16389
+16390
+16391
+16392
+16393
+16394
+16395
+16396
+16397
+16398
+16399
+16400
+16401
+16402
+16403
+16404
+16405
+16406
+16407
+16408
+16409
+16410
+16411
+16412
+16413
+16414
+16415
+16416
+16417
+16418
+16419
+16420
+16421
+16422
+16423
+16424
+16425
+16426
+16427
+16428
+16429
+16430
+16431
+16432
+16433
+16434
+16435
+16436
+16437
+16438
+16439
+16440
+16441
+16442
+16443
+16444
+16445
+16446
+16447
+16448
+16449
+16450
+16451
+16452
+16453
+16454
+16455
+16456
+16457
+16458
+16459
+16460
+16461
+16462
+16463
+16464
+16465
+16466
+16467
+16468
+16469
+16470
+16471
+16472
+16473
+16474
+16475
+16476
+16477
+16478
+16479
+16480
+16481
+16482
+16483
+16484
+16485
+16486
+16487
+16488
+16489
+16490
+16491
+16492
+16493
+16494
+16495
+16496
+16497
+16498
+16499
+16500
+16501
+16502
+16503
+16504
+16505
+16506
+16507
+16508
+16509
+16510
+16511
+16512
+16513
+16514
+16515
+16516
+16517
+16518
+16519
+16520
+16521
+16522
+16523
+16524
+16525
+16526
+16527
+16528
+16529
+16530
+16531
+16532
+16533
+16534
+16535
+16536
+16537
+16538
+16539
+16540
+16541
+16542
+16543
+16544
+16545
+16546
+16547
+16548
+16549
+16550
+16551
+16552
+16553
+16554
+16555
+16556
+16557
+16558
+16559
+16560
+16561
+16562
+16563
+16564
+16565
+16566
+16567
+16568
+16569
+16570
+16571
+16572
+16573
+16574
+16575
+16576
+16577
+16578
+16579
+16580
+16581
+16582
+16583
+16584
+16585
+16586
+16587
+16588
+16589
+16590
+16591
+16592
+16593
+16594
+16595
+16596
+16597
+16598
+16599
+16600
+16601
+16602
+16603
+16604
+16605
+16606
+16607
+16608
+16609
+16610
+16611
+16612
+16613
+16614
+16615
+16616
+16617
+16618
+16619
+16620
+16621
+16622
+16623
+16624
+16625
+16626
+16627
+16628
+16629
+16630
+16631
+16632
+16633
+16634
+16635
+16636
+16637
+16638
+16639
+16640
+16641
+16642
+16643
+16644
+16645
+16646
+16647
+16648
+16649
+16650
+16651
+16652
+16653
+16654
+16655
+16656
+16657
+16658
+16659
+16660
+16661
+16662
+16663
+16664
+16665
+16666
+16667
+16668
+16669
+16670
+16671
+16672
+16673
+16674
+16675
+16676
+16677
+16678
+16679
+16680
+16681
+16682
+16683
+16684
+16685
+16686
+16687
+16688
+16689
+16690
+16691
+16692
+16693
+16694
+16695
+16696
+16697
+16698
+16699
+16700
+16701
+16702
+16703
+16704
+16705
+16706
+16707
+16708
+16709
+16710
+16711
+16712
+16713
+16714
+16715
+16716
+16717
+16718
+16719
+16720
+16721
+16722
+16723
+16724
+16725
+16726
+16727
+16728
+16729
+16730
+16731
+16732
+16733
+16734
+16735
+16736
+16737
+16738
+16739
+16740
+16741
+16742
+16743
+16744
+16745
+16746
+16747
+16748
+16749
+16750
+16751
+16752
+16753
+16754
+16755
+16756
+16757
+16758
+16759
+16760
+16761
+16762
+16763
+16764
+16765
+16766
+16767
+16768
+16769
+16770
+16771
+16772
+16773
+16774
+16775
+16776
+16777
+16778
+16779
+16780
+16781
+16782
+16783
+16784
+16785
+16786
+16787
+16788
+16789
+16790
+16791
+16792
+16793
+16794
+16795
+16796
+16797
+16798
+16799
+16800
+16801
+16802
+16803
+16804
+16805
+16806
+16807
+16808
+16809
+16810
+16811
+16812
+16813
+16814
+16815
+16816
+16817
+16818
+16819
+16820
+16821
+16822
+16823
+16824
+16825
+16826
+16827
+16828
+16829
+16830
+16831
+16832
+16833
+16834
+16835
+16836
+16837
+16838
+16839
+16840
+16841
+16842
+16843
+16844
+16845
+16846
+16847
+16848
+16849
+16850
+16851
+16852
+16853
+16854
+16855
+16856
+16857
+16858
+16859
+16860
+16861
+16862
+16863
+16864
+16865
+16866
+16867
+16868
+16869
+16870
+16871
+16872
+16873
+16874
+16875
+16876
+16877
+16878
+16879
+16880
+16881
+16882
+16883
+16884
+16885
+16886
+16887
+16888
+16889
+16890
+16891
+16892
+16893
+16894
+16895
+16896
+16897
+16898
+16899
+16900
+16901
+16902
+16903
+16904
+16905
+16906
+16907
+16908
+16909
+16910
+16911
+16912
+16913
+16914
+16915
+16916
+16917
+16918
+16919
+16920
+16921
+16922
+16923
+16924
+16925
+16926
+16927
+16928
+16929
+16930
+16931
+16932
+16933
+16934
+16935
+16936
+16937
+16938
+16939
+16940
+16941
+16942
+16943
+16944
+16945
+16946
+16947
+16948
+16949
+16950
+16951
+16952
+16953
+16954
+16955
+16956
+16957
+16958
+16959
+16960
+16961
+16962
+16963
+16964
+16965
+16966
+16967
+16968
+16969
+16970
+16971
+16972
+16973
+16974
+16975
+16976
+16977
+16978
+16979
+16980
+16981
+16982
+16983
+16984
+16985
+16986
+16987
+16988
+16989
+16990
+16991
+16992
+16993
+16994
+16995
+16996
+16997
+16998
+16999
+17000
+17001
+17002
+17003
+17004
+17005
+17006
+17007
+17008
+17009
+17010
+17011
+17012
+17013
+17014
+17015
+17016
+17017
+17018
+17019
+17020
+17021
+17022
+17023
+17024
+17025
+17026
+17027
+17028
+17029
+17030
+17031
+17032
+17033
+17034
+17035
+17036
+17037
+17038
+17039
+17040
+17041
+17042
+17043
+17044
+17045
+17046
+17047
+17048
+17049
+17050
+17051
+17052
+17053
+17054
+17055
+17056
+17057
+17058
+17059
+17060
+17061
+17062
+17063
+17064
+17065
+17066
+17067
+17068
+17069
+17070
+17071
+17072
+17073
+17074
+17075
+17076
+17077
+17078
+17079
+17080
+17081
+17082
+17083
+17084
+17085
+17086
+17087
+17088
+17089
+17090
+17091
+17092
+17093
+17094
+17095
+17096
+17097
+17098
+17099
+17100
+17101
+17102
+17103
+17104
+17105
+17106
+17107
+17108
+17109
+17110
+17111
+17112
+17113
+17114
+17115
+17116
+17117
+17118
+17119
+17120
+17121
+17122
+17123
+17124
+17125
+17126
+17127
+17128
+17129
+17130
+17131
+17132
+17133
+17134
+17135
+17136
+17137
+17138
+17139
+17140
+17141
+17142
+17143
+17144
+17145
+17146
+17147
+17148
+17149
+17150
+17151
+17152
+17153
+17154
+17155
+17156
+17157
+17158
+17159
+17160
+17161
+17162
+17163
+17164
+17165
+17166
+17167
+17168
+17169
+17170
+17171
+17172
+17173
+17174
+17175
+17176
+17177
+17178
+17179
+17180
+17181
+17182
+17183
+17184
+17185
+17186
+17187
+17188
+17189
+17190
+17191
+17192
+17193
+17194
+17195
+17196
+17197
+17198
+17199
+17200
+17201
+17202
+17203
+17204
+17205
+17206
+17207
+17208
+17209
+17210
+17211
+17212
+17213
+17214
+17215
+17216
+17217
+17218
+17219
+17220
+17221
+17222
+17223
+17224
+17225
+17226
+17227
+17228
+17229
+17230
+17231
+17232
+17233
+17234
+17235
+17236
+17237
+17238
+17239
+17240
+17241
+17242
+17243
+17244
+17245
+17246
+17247
+17248
+17249
+17250
+17251
+17252
+17253
+17254
+17255
+17256
+17257
+17258
+17259
+17260
+17261
+17262
+17263
+17264
+17265
+17266
+17267
+17268
+17269
+17270
+17271
+17272
+17273
+17274
+17275
+17276
+17277
+17278
+17279
+17280
+17281
+17282
+17283
+17284
+17285
+17286
+17287
+17288
+17289
+17290
+17291
+17292
+17293
+17294
+17295
+17296
+17297
+17298
+17299
+17300
+17301
+17302
+17303
+17304
+17305
+17306
+17307
+17308
+17309
+17310
+17311
+17312
+17313
+17314
+17315
+17316
+17317
+17318
+17319
+17320
+17321
+17322
+17323
+17324
+17325
+17326
+17327
+17328
+17329
+17330
+17331
+17332
+17333
+17334
+17335
+17336
+17337
+17338
+17339
+17340
+17341
+17342
+17343
+17344
+17345
+17346
+17347
+17348
+17349
+17350
+17351
+17352
+17353
+17354
+17355
+17356
+17357
+17358
+17359
+17360
+17361
+17362
+17363
+17364
+17365
+17366
+17367
+17368
+17369
+17370
+17371
+17372
+17373
+17374
+17375
+17376
+17377
+17378
+17379
+17380
+17381
+17382
+17383
+17384
+17385
+17386
+17387
+17388
+17389
+17390
+17391
+17392
+17393
+17394
+17395
+17396
+17397
+17398
+17399
+17400
+17401
+17402
+17403
+17404
+17405
+17406
+17407
+17408
+17409
+17410
+17411
+17412
+17413
+17414
+17415
+17416
+17417
+17418
+17419
+17420
+17421
+17422
+17423
+17424
+17425
+17426
+17427
+17428
+17429
+17430
+17431
+17432
+17433
+17434
+17435
+17436
+17437
+17438
+17439
+17440
+17441
+17442
+17443
+17444
+17445
+17446
+17447
+17448
+17449
+17450
+17451
+17452
+17453
+17454
+17455
+17456
+17457
+17458
+17459
+17460
+17461
+17462
+17463
+17464
+17465
+17466
+17467
+17468
+17469
+17470
+17471
+17472
+17473
+17474
+17475
+17476
+17477
+17478
+17479
+17480
+17481
+17482
+17483
+17484
+17485
+17486
+17487
+17488
+17489
+17490
+17491
+17492
+17493
+17494
+17495
+17496
+17497
+17498
+17499
+17500
+17501
+17502
+17503
+17504
+17505
+17506
+17507
+17508
+17509
+17510
+17511
+17512
+17513
+17514
+17515
+17516
+17517
+17518
+17519
+17520
+17521
+17522
+17523
+17524
+17525
+17526
+17527
+17528
+17529
+17530
+17531
+17532
+17533
+17534
+17535
+17536
+17537
+17538
+17539
+17540
+17541
+17542
+17543
+17544
+17545
+17546
+17547
+17548
+17549
+17550
+17551
+17552
+17553
+17554
+17555
+17556
+17557
+17558
+17559
+17560
+17561
+17562
+17563
+17564
+17565
+17566
+17567
+17568
+17569
+17570
+17571
+17572
+17573
+17574
+17575
+17576
+17577
+17578
+17579
+17580
+17581
+17582
+17583
+17584
+17585
+17586
+17587
+17588
+17589
+17590
+17591
+17592
+17593
+17594
+17595
+17596
+17597
+17598
+17599
+17600
+17601
+17602
+17603
+17604
+17605
+17606
+17607
+17608
+17609
+17610
+17611
+17612
+17613
+17614
+17615
+17616
+17617
+17618
+17619
+17620
+17621
+17622
+17623
+17624
+17625
+17626
+17627
+17628
+17629
+17630
+17631
+17632
+17633
+17634
+17635
+17636
+17637
+17638
+17639
+17640
+17641
+17642
+17643
+17644
+17645
+17646
+17647
+17648
+17649
+17650
+17651
+17652
+17653
+17654
+17655
+17656
+17657
+17658
+17659
+17660
+17661
+17662
+17663
+17664
+17665
+17666
+17667
+17668
+17669
+17670
+17671
+17672
+17673
+17674
+17675
+17676
+17677
+17678
+17679
+17680
+17681
+17682
+17683
+17684
+17685
+17686
+17687
+17688
+17689
+17690
+17691
+17692
+17693
+17694
+17695
+17696
+17697
+17698
+17699
+17700
+17701
+17702
+17703
+17704
+17705
+17706
+17707
+17708
+17709
+17710
+17711
+17712
+17713
+17714
+17715
+17716
+17717
+17718
+17719
+17720
+17721
+17722
+17723
+17724
+17725
+17726
+17727
+17728
+17729
+17730
+17731
+17732
+17733
+17734
+17735
+17736
+17737
+17738
+17739
+17740
+17741
+17742
+17743
+17744
+17745
+17746
+17747
+17748
+17749
+17750
+17751
+17752
+17753
+17754
+17755
+17756
+17757
+17758
+17759
+17760
+17761
+17762
+17763
+17764
+17765
+17766
+17767
+17768
+17769
+17770
+17771
+17772
+17773
+17774
+17775
+17776
+17777
+17778
+17779
+17780
+17781
+17782
+17783
+17784
+17785
+17786
+17787
+17788
+17789
+17790
+17791
+17792
+17793
+17794
+17795
+17796
+17797
+17798
+17799
+17800
+17801
+17802
+17803
+17804
+17805
+17806
+17807
+17808
+17809
+17810
+17811
+17812
+17813
+17814
+17815
+17816
+17817
+17818
+17819
+17820
+17821
+17822
+17823
+17824
+17825
+17826
+17827
+17828
+17829
+17830
+17831
+17832
+17833
+17834
+17835
+17836
+17837
+17838
+17839
+17840
+17841
+17842
+17843
+17844
+17845
+17846
+17847
+17848
+17849
+17850
+17851
+17852
+17853
+17854
+17855
+17856
+17857
+17858
+17859
+17860
+17861
+17862
+17863
+17864
+17865
+17866
+17867
+17868
+17869
+17870
+17871
+17872
+17873
+17874
+17875
+17876
+17877
+17878
+17879
+17880
+17881
+17882
+17883
+17884
+17885
+17886
+17887
+17888
+17889
+17890
+17891
+17892
+17893
+17894
+17895
+17896
+17897
+17898
+17899
+17900
+17901
+17902
+17903
+17904
+17905
+17906
+17907
+17908
+17909
+17910
+17911
+17912
+17913
+17914
+17915
+17916
+17917
+17918
+17919
+17920
+17921
+17922
+17923
+17924
+17925
+17926
+17927
+17928
+17929
+17930
+17931
+17932
+17933
+17934
+17935
+17936
+17937
+17938
+17939
+17940
+17941
+17942
+17943
+17944
+17945
+17946
+17947
+17948
+17949
+17950
+17951
+17952
+17953
+17954
+17955
+17956
+17957
+17958
+17959
+17960
+17961
+17962
+17963
+17964
+17965
+17966
+17967
+17968
+17969
+17970
+17971
+17972
+17973
+17974
+17975
+17976
+17977
+17978
+17979
+17980
+17981
+17982
+17983
+17984
+17985
+17986
+17987
+17988
+17989
+17990
+17991
+17992
+17993
+17994
+17995
+17996
+17997
+17998
+17999
+18000
+18001
+18002
+18003
+18004
+18005
+18006
+18007
+18008
+18009
+18010
+18011
+18012
+18013
+18014
+18015
+18016
+18017
+18018
+18019
+18020
+18021
+18022
+18023
+18024
+18025
+18026
+18027
+18028
+18029
+18030
+18031
+18032
+18033
+18034
+18035
+18036
+18037
+18038
+18039
+18040
+18041
+18042
+18043
+18044
+18045
+18046
+18047
+18048
+18049
+18050
+18051
+18052
+18053
+18054
+18055
+18056
+18057
+18058
+18059
+18060
+18061
+18062
+18063
+18064
+18065
+18066
+18067
+18068
+18069
+18070
+18071
+18072
+18073
+18074
+18075
+18076
+18077
+18078
+18079
+18080
+18081
+18082
+18083
+18084
+18085
+18086
+18087
+18088
+18089
+18090
+18091
+18092
+18093
+18094
+18095
+18096
+18097
+18098
+18099
+18100
+18101
+18102
+18103
+18104
+18105
+18106
+18107
+18108
+18109
+18110
+18111
+18112
+18113
+18114
+18115
+18116
+18117
+18118
+18119
+18120
+18121
+18122
+18123
+18124
+18125
+18126
+18127
+18128
+18129
+18130
+18131
+18132
+18133
+18134
+18135
+18136
+18137
+18138
+18139
+18140
+18141
+18142
+18143
+18144
+18145
+18146
+18147
+18148
+18149
+18150
+18151
+18152
+18153
+18154
+18155
+18156
+18157
+18158
+18159
+18160
+18161
+18162
+18163
+18164
+18165
+18166
+18167
+18168
+18169
+18170
+18171
+18172
+18173
+18174
+18175
+18176
+18177
+18178
+18179
+18180
+18181
+18182
+18183
+18184
+18185
+18186
+18187
+18188
+18189
+18190
+18191
+18192
+18193
+18194
+18195
+18196
+18197
+18198
+18199
+18200
+18201
+18202
+18203
+18204
+18205
+18206
+18207
+18208
+18209
+18210
+18211
+18212
+18213
+18214
+18215
+18216
+18217
+18218
+18219
+18220
+18221
+18222
+18223
+18224
+18225
+18226
+18227
+18228
+18229
+18230
+18231
+18232
+18233
+18234
+18235
+18236
+18237
+18238
+18239
+18240
+18241
+18242
+18243
+18244
+18245
+18246
+18247
+18248
+18249
+18250
+18251
+18252
+18253
+18254
+18255
+18256
+18257
+18258
+18259
+18260
+18261
+18262
+18263
+18264
+18265
+18266
+18267
+18268
+18269
+18270
+18271
+18272
+18273
+18274
+18275
+18276
+18277
+18278
+18279
+18280
+18281
+18282
+18283
+18284
+18285
+18286
+18287
+18288
+18289
+18290
+18291
+18292
+18293
+18294
+18295
+18296
+18297
+18298
+18299
+18300
+18301
+18302
+18303
+18304
+18305
+18306
+18307
+18308
+18309
+18310
+18311
+18312
+18313
+18314
+18315
+18316
+18317
+18318
+18319
+18320
+18321
+18322
+18323
+18324
+18325
+18326
+18327
+18328
+18329
+18330
+18331
+18332
+18333
+18334
+18335
+18336
+18337
+18338
+18339
+18340
+18341
+18342
+18343
+18344
+18345
+18346
+18347
+18348
+18349
+18350
+18351
+18352
+18353
+18354
+18355
+18356
+18357
+18358
+18359
+18360
+18361
+18362
+18363
+18364
+18365
+18366
+18367
+18368
+18369
+18370
+18371
+18372
+18373
+18374
+18375
+18376
+18377
+18378
+18379
+18380
+18381
+18382
+18383
+18384
+18385
+18386
+18387
+18388
+18389
+18390
+18391
+18392
+18393
+18394
+18395
+18396
+18397
+18398
+18399
+18400
+18401
+18402
+18403
+18404
+18405
+18406
+18407
+18408
+18409
+18410
+18411
+18412
+18413
+18414
+18415
+18416
+18417
+18418
+18419
+18420
+18421
+18422
+18423
+18424
+18425
+18426
+18427
+18428
+18429
+18430
+18431
+18432
+18433
+18434
+18435
+18436
+18437
+18438
+18439
+18440
+18441
+18442
+18443
+18444
+18445
+18446
+18447
+18448
+18449
+18450
+18451
+18452
+18453
+18454
+18455
+18456
+18457
+18458
+18459
+18460
+18461
+18462
+18463
+18464
+18465
+18466
+18467
+18468
+18469
+18470
+18471
+18472
+18473
+18474
+18475
+18476
+18477
+18478
+18479
+18480
+18481
+18482
+18483
+18484
+18485
+18486
+18487
+18488
+18489
+18490
+18491
+18492
+18493
+18494
+18495
+18496
+18497
+18498
+18499
+18500
+18501
+18502
+18503
+18504
+18505
+18506
+18507
+18508
+18509
+18510
+18511
+18512
+18513
+18514
+18515
+18516
+18517
+18518
+18519
+18520
+18521
+18522
+18523
+18524
+18525
+18526
+18527
+18528
+18529
+18530
+18531
+18532
+18533
+18534
+18535
+18536
+18537
+18538
+18539
+18540
+18541
+18542
+18543
+18544
+18545
+18546
+18547
+18548
+18549
+18550
+18551
+18552
+18553
+18554
+18555
+18556
+18557
+18558
+18559
+18560
+18561
+18562
+18563
+18564
+18565
+18566
+18567
+18568
+18569
+18570
+18571
+18572
+18573
+18574
+18575
+18576
+18577
+18578
+18579
+18580
+18581
+18582
+18583
+18584
+18585
+18586
+18587
+18588
+18589
+18590
+18591
+18592
+18593
+18594
+18595
+18596
+18597
+18598
+18599
+18600
+18601
+18602
+18603
+18604
+18605
+18606
+18607
+18608
+18609
+18610
+18611
+18612
+18613
+18614
+18615
+18616
+18617
+18618
+18619
+18620
+18621
+18622
+18623
+18624
+18625
+18626
+18627
+18628
+18629
+18630
+18631
+18632
+18633
+18634
+18635
+18636
+18637
+18638
+18639
+18640
+18641
+18642
+18643
+18644
+18645
+18646
+18647
+18648
+18649
+18650
+18651
+18652
+18653
+18654
+18655
+18656
+18657
+18658
+18659
+18660
+18661
+18662
+18663
+18664
+18665
+18666
+18667
+18668
+18669
+18670
+18671
+18672
+18673
+18674
+18675
+18676
+18677
+18678
+18679
+18680
+18681
+18682
+18683
+18684
+18685
+18686
+18687
+18688
+18689
+18690
+18691
+18692
+18693
+18694
+18695
+18696
+18697
+18698
+18699
+18700
+18701
+18702
+18703
+18704
+18705
+18706
+18707
+18708
+18709
+18710
+18711
+18712
+18713
+18714
+18715
+18716
+18717
+18718
+18719
+18720
+18721
+18722
+18723
+18724
+18725
+18726
+18727
+18728
+18729
+18730
+18731
+18732
+18733
+18734
+18735
+18736
+18737
+18738
+18739
+18740
+18741
+18742
+18743
+18744
+18745
+18746
+18747
+18748
+18749
+18750
+18751
+18752
+18753
+18754
+18755
+18756
+18757
+18758
+18759
+18760
+18761
+18762
+18763
+18764
+18765
+18766
+18767
+18768
+18769
+18770
+18771
+18772
+18773
+18774
+18775
+18776
+18777
+18778
+18779
+18780
+18781
+18782
+18783
+18784
+18785
+18786
+18787
+18788
+18789
+18790
+18791
+18792
+18793
+18794
+18795
+18796
+18797
+18798
+18799
+18800
+18801
+18802
+18803
+18804
+18805
+18806
+18807
+18808
+18809
+18810
+18811
+18812
+18813
+18814
+18815
+18816
+18817
+18818
+18819
+18820
+18821
+18822
+18823
+18824
+18825
+18826
+18827
+18828
+18829
+18830
+18831
+18832
+18833
+18834
+18835
+18836
+18837
+18838
+18839
+18840
+18841
+18842
+18843
+18844
+18845
+18846
+18847
+18848
+18849
+18850
+18851
+18852
+18853
+18854
+18855
+18856
+18857
+18858
+18859
+18860
+18861
+18862
+18863
+18864
+18865
+18866
+18867
+18868
+18869
+18870
+18871
+18872
+18873
+18874
+18875
+18876
+18877
+18878
+18879
+18880
+18881
+18882
+18883
+18884
+18885
+18886
+18887
+18888
+18889
+18890
+18891
+18892
+18893
+18894
+18895
+18896
+18897
+18898
+18899
+18900
+18901
+18902
+18903
+18904
+18905
+18906
+18907
+18908
+18909
+18910
+18911
+18912
+18913
+18914
+18915
+18916
+18917
+18918
+18919
+18920
+18921
+18922
+18923
+18924
+18925
+18926
+18927
+18928
+18929
+18930
+18931
+18932
+18933
+18934
+18935
+18936
+18937
+18938
+18939
+18940
+18941
+18942
+18943
+18944
+18945
+18946
+18947
+18948
+18949
+18950
+18951
+18952
+18953
+18954
+18955
+18956
+18957
+18958
+18959
+18960
+18961
+18962
+18963
+18964
+18965
+18966
+18967
+18968
+18969
+18970
+18971
+18972
+18973
+18974
+18975
+18976
+18977
+18978
+18979
+18980
+18981
+18982
+18983
+18984
+18985
+18986
+18987
+18988
+18989
+18990
+18991
+18992
+18993
+18994
+18995
+18996
+18997
+18998
+18999
+19000
+19001
+19002
+19003
+19004
+19005
+19006
+19007
+19008
+19009
+19010
+19011
+19012
+19013
+19014
+19015
+19016
+19017
+19018
+19019
+19020
+19021
+19022
+19023
+19024
+19025
+19026
+19027
+19028
+19029
+19030
+19031
+19032
+19033
+19034
+19035
+19036
+19037
+19038
+19039
+19040
+19041
+19042
+19043
+19044
+19045
+19046
+19047
+19048
+19049
+19050
+19051
+19052
+19053
+19054
+19055
+19056
+19057
+19058
+19059
+19060
+19061
+19062
+19063
+19064
+19065
+19066
+19067
+19068
+19069
+19070
+19071
+19072
+19073
+19074
+19075
+19076
+19077
+19078
+19079
+19080
+19081
+19082
+19083
+19084
+19085
+19086
+19087
+19088
+19089
+19090
+19091
+19092
+19093
+19094
+19095
+19096
+19097
+19098
+19099
+19100
+19101
+19102
+19103
+19104
+19105
+19106
+19107
+19108
+19109
+19110
+19111
+19112
+19113
+19114
+19115
+19116
+19117
+19118
+19119
+19120
+19121
+19122
+19123
+19124
+19125
+19126
+19127
+19128
+19129
+19130
+19131
+19132
+19133
+19134
+19135
+19136
+19137
+19138
+19139
+19140
+19141
+19142
+19143
+19144
+19145
+19146
+19147
+19148
+19149
+19150
+19151
+19152
+19153
+19154
+19155
+19156
+19157
+19158
+19159
+19160
+19161
+19162
+19163
+19164
+19165
+19166
+19167
+19168
+19169
+19170
+19171
+19172
+19173
+19174
+19175
+19176
+19177
+19178
+19179
+19180
+19181
+19182
+19183
+19184
+19185
+19186
+19187
+19188
+19189
+19190
+19191
+19192
+19193
+19194
+19195
+19196
+19197
+19198
+19199
+19200
+19201
+19202
+19203
+19204
+19205
+19206
+19207
+19208
+19209
+19210
+19211
+19212
+19213
+19214
+19215
+19216
+19217
+19218
+19219
+19220
+19221
+19222
+19223
+19224
+19225
+19226
+19227
+19228
+19229
+19230
+19231
+19232
+19233
+19234
+19235
+19236
+19237
+19238
+19239
+19240
+19241
+19242
+19243
+19244
+19245
+19246
+19247
+19248
+19249
+19250
+19251
+19252
+19253
+19254
+19255
+19256
+19257
+19258
+19259
+19260
+19261
+19262
+19263
+19264
+19265
+19266
+19267
+19268
+19269
+19270
+19271
+19272
+19273
+19274
+19275
+19276
+19277
+19278
+19279
+19280
+19281
+19282
+19283
+19284
+19285
+19286
+19287
+19288
+19289
+19290
+19291
+19292
+19293
+19294
+19295
+19296
+19297
+19298
+19299
+19300
+19301
+19302
+19303
+19304
+19305
+19306
+19307
+19308
+19309
+19310
+19311
+19312
+19313
+19314
+19315
+19316
+19317
+19318
+19319
+19320
+19321
+19322
+19323
+19324
+19325
+19326
+19327
+19328
+19329
+19330
+19331
+19332
+19333
+19334
+19335
+19336
+19337
+19338
+19339
+19340
+19341
+19342
+19343
+19344
+19345
+19346
+19347
+19348
+19349
+19350
+19351
+19352
+19353
+19354
+19355
+19356
+19357
+19358
+19359
+19360
+19361
+19362
+19363
+19364
+19365
+19366
+19367
+19368
+19369
+19370
+19371
+19372
+19373
+19374
+19375
+19376
+19377
+19378
+19379
+19380
+19381
+19382
+19383
+19384
+19385
+19386
+19387
+19388
+19389
+19390
+19391
+19392
+19393
+19394
+19395
+19396
+19397
+19398
+19399
+19400
+19401
+19402
+19403
+19404
+19405
+19406
+19407
+19408
+19409
+19410
+19411
+19412
+19413
+19414
+19415
+19416
+19417
+19418
+19419
+19420
+19421
+19422
+19423
+19424
+19425
+19426
+19427
+19428
+19429
+19430
+19431
+19432
+19433
+19434
+19435
+19436
+19437
+19438
+19439
+19440
+19441
+19442
+19443
+19444
+19445
+19446
+19447
+19448
+19449
+19450
+19451
+19452
+19453
+19454
+19455
+19456
+19457
+19458
+19459
+19460
+19461
+19462
+19463
+19464
+19465
+19466
+19467
+19468
+19469
+19470
+19471
+19472
+19473
+19474
+19475
+19476
+19477
+19478
+19479
+19480
+19481
+19482
+19483
+19484
+19485
+19486
+19487
+19488
+19489
+19490
+19491
+19492
+19493
+19494
+19495
+19496
+19497
+19498
+19499
+19500
+19501
+19502
+19503
+19504
+19505
+19506
+19507
+19508
+19509
+19510
+19511
+19512
+19513
+19514
+19515
+19516
+19517
+19518
+19519
+19520
+19521
+19522
+19523
+19524
+19525
+19526
+19527
+19528
+19529
+19530
+19531
+19532
+19533
+19534
+19535
+19536
+19537
+19538
+19539
+19540
+19541
+19542
+19543
+19544
+19545
+19546
+19547
+19548
+19549
+19550
+19551
+19552
+19553
+19554
+19555
+19556
+19557
+19558
+19559
+19560
+19561
+19562
+19563
+19564
+19565
+19566
+19567
+19568
+19569
+19570
+19571
+19572
+19573
+19574
+19575
+19576
+19577
+19578
+19579
+19580
+19581
+19582
+19583
+19584
+19585
+19586
+19587
+19588
+19589
+19590
+19591
+19592
+19593
+19594
+19595
+19596
+19597
+19598
+19599
+19600
+19601
+19602
+19603
+19604
+19605
+19606
+19607
+19608
+19609
+19610
+19611
+19612
+19613
+19614
+19615
+19616
+19617
+19618
+19619
+19620
+19621
+19622
+19623
+19624
+19625
+19626
+19627
+19628
+19629
+19630
+19631
+19632
+19633
+19634
+19635
+19636
+19637
+19638
+19639
+19640
+19641
+19642
+19643
+19644
+19645
+19646
+19647
+19648
+19649
+19650
+19651
+19652
+19653
+19654
+19655
+19656
+19657
+19658
+19659
+19660
+19661
+19662
+19663
+19664
+19665
+19666
+19667
+19668
+19669
+19670
+19671
+19672
+19673
+19674
+19675
+19676
+19677
+19678
+19679
+19680
+19681
+19682
+19683
+19684
+19685
+19686
+19687
+19688
+19689
+19690
+19691
+19692
+19693
+19694
+19695
+19696
+19697
+19698
+19699
+19700
+19701
+19702
+19703
+19704
+19705
+19706
+19707
+19708
+19709
+19710
+19711
+19712
+19713
+19714
+19715
+19716
+19717
+19718
+19719
+19720
+19721
+19722
+19723
+19724
+19725
+19726
+19727
+19728
+19729
+19730
+19731
+19732
+19733
+19734
+19735
+19736
+19737
+19738
+19739
+19740
+19741
+19742
+19743
+19744
+19745
+19746
+19747
+19748
+19749
+19750
+19751
+19752
+19753
+19754
+19755
+19756
+19757
+19758
+19759
+19760
+19761
+19762
+19763
+19764
+19765
+19766
+19767
+19768
+19769
+19770
+19771
+19772
+19773
+19774
+19775
+19776
+19777
+19778
+19779
+19780
+19781
+19782
+19783
+19784
+19785
+19786
+19787
+19788
+19789
+19790
+19791
+19792
+19793
+19794
+19795
+19796
+19797
+19798
+19799
+19800
+19801
+19802
+19803
+19804
+19805
+19806
+19807
+19808
+19809
+19810
+19811
+19812
+19813
+19814
+19815
+19816
+19817
+19818
+19819
+19820
+19821
+19822
+19823
+19824
+19825
+19826
+19827
+19828
+19829
+19830
+19831
+19832
+19833
+19834
+19835
+19836
+19837
+19838
+19839
+19840
+19841
+19842
+19843
+19844
+19845
+19846
+19847
+19848
+19849
+19850
+19851
+19852
+19853
+19854
+19855
+19856
+19857
+19858
+19859
+19860
+19861
+19862
+19863
+19864
+19865
+19866
+19867
+19868
+19869
+19870
+19871
+19872
+19873
+19874
+19875
+19876
+19877
+19878
+19879
+19880
+19881
+19882
+19883
+19884
+19885
+19886
+19887
+19888
+19889
+19890
+19891
+19892
+19893
+19894
+19895
+19896
+19897
+19898
+19899
+19900
+19901
+19902
+19903
+19904
+19905
+19906
+19907
+19908
+19909
+19910
+19911
+19912
+19913
+19914
+19915
+19916
+19917
+19918
+19919
+19920
+19921
+19922
+19923
+19924
+19925
+19926
+19927
+19928
+19929
+19930
+19931
+19932
+19933
+19934
+19935
+19936
+19937
+19938
+19939
+19940
+19941
+19942
+19943
+19944
+19945
+19946
+19947
+19948
+19949
+19950
+19951
+19952
+19953
+19954
+19955
+19956
+19957
+19958
+19959
+19960
+19961
+19962
+19963
+19964
+19965
+19966
+19967
+19968
+19969
+19970
+19971
+19972
+19973
+19974
+19975
+19976
+19977
+19978
+19979
+19980
+19981
+19982
+19983
+19984
+19985
+19986
+19987
+19988
+19989
+19990
+19991
+19992
+19993
+19994
+19995
+19996
+19997
+19998
+19999
+20000
+20001
+20002
+20003
+20004
+20005
+20006
+20007
+20008
+20009
+20010
+20011
+20012
+20013
+20014
+20015
+20016
+20017
+20018
+20019
+20020
+20021
+20022
+20023
+20024
+20025
+20026
+20027
+20028
+20029
+20030
+20031
+20032
+20033
+20034
+20035
+20036
+20037
+20038
+20039
+20040
+20041
+20042
+20043
+20044
+20045
+20046
+20047
+20048
+20049
+20050
+20051
+20052
+20053
+20054
+20055
+20056
+20057
+20058
+20059
+20060
+20061
+20062
+20063
+20064
+20065
+20066
+20067
+20068
+20069
+20070
+20071
+20072
+20073
+20074
+20075
+20076
+20077
+20078
+20079
+20080
+20081
+20082
+20083
+20084
+20085
+20086
+20087
+20088
+20089
+20090
+20091
+20092
+20093
+20094
+20095
+20096
+20097
+20098
+20099
+20100
+20101
+20102
+20103
+20104
+20105
+20106
+20107
+20108
+20109
+20110
+20111
+20112
+20113
+20114
+20115
+20116
+20117
+20118
+20119
+20120
+20121
+20122
+20123
+20124
+20125
+20126
+20127
+20128
+20129
+20130
+20131
+20132
+20133
+20134
+20135
+20136
+20137
+20138
+20139
+20140
+20141
+20142
+20143
+20144
+20145
+20146
+20147
+20148
+20149
+20150
+20151
+20152
+20153
+20154
+20155
+20156
+20157
+20158
+20159
+20160
+20161
+20162
+20163
+20164
+20165
+20166
+20167
+20168
+20169
+20170
+20171
+20172
+20173
+20174
+20175
+20176
+20177
+20178
+20179
+20180
+20181
+20182
+20183
+20184
+20185
+20186
+20187
+20188
+20189
+20190
+20191
+20192
+20193
+20194
+20195
+20196
+20197
+20198
+20199
+20200
+20201
+20202
+20203
+20204
+20205
+20206
+20207
+20208
+20209
+20210
+20211
+20212
+20213
+20214
+20215
+20216
+20217
+20218
+20219
+20220
+20221
+20222
+20223
+20224
+20225
+20226
+20227
+20228
+20229
+20230
+20231
+20232
+20233
+20234
+20235
+20236
+20237
+20238
+20239
+20240
+20241
+20242
+20243
+20244
+20245
+20246
+20247
+20248
+20249
+20250
+20251
+20252
+20253
+20254
+20255
+20256
+20257
+20258
+20259
+20260
+20261
+20262
+20263
+20264
+20265
+20266
+20267
+20268
+20269
+20270
+20271
+20272
+20273
+20274
+20275
+20276
+20277
+20278
+20279
+20280
+20281
+20282
+20283
+20284
+20285
+20286
+20287
+20288
+20289
+20290
+20291
+20292
+20293
+20294
+20295
+20296
+20297
+20298
+20299
+20300
+20301
+20302
+20303
+20304
+20305
+20306
+20307
+20308
+20309
+20310
+20311
+20312
+20313
+20314
+20315
+20316
+20317
+20318
+20319
+20320
+20321
+20322
+20323
+20324
+20325
+20326
+20327
+20328
+20329
+20330
+20331
+20332
+20333
+20334
+20335
+20336
+20337
+20338
+20339
+20340
+20341
+20342
+20343
+20344
+20345
+20346
+20347
+20348
+20349
+20350
+20351
+20352
+20353
+20354
+20355
+20356
+20357
+20358
+20359
+20360
+20361
+20362
+20363
+20364
+20365
+20366
+20367
+20368
+20369
+20370
+20371
+20372
+20373
+20374
+20375
+20376
+20377
+20378
+20379
+20380
+20381
+20382
+20383
+20384
+20385
+20386
+20387
+20388
+20389
+20390
+20391
+20392
+20393
+20394
+20395
+20396
+20397
+20398
+20399
+20400
+20401
+20402
+20403
+20404
+20405
+20406
+20407
+20408
+20409
+20410
+20411
+20412
+20413
+20414
+20415
+20416
+20417
+20418
+20419
+20420
+20421
+20422
+20423
+20424
+20425
+20426
+20427
+20428
+20429
+20430
+20431
+20432
+20433
+20434
+20435
+20436
+20437
+20438
+20439
+20440
+20441
+20442
+20443
+20444
+20445
+20446
+20447
+20448
+20449
+20450
+20451
+20452
+20453
+20454
+20455
+20456
+20457
+20458
+20459
+20460
+20461
+20462
+20463
+20464
+20465
+20466
+20467
+20468
+20469
+20470
+20471
+20472
+20473
+20474
+20475
+20476
+20477
+20478
+20479
+20480
+20481
+20482
+20483
+20484
+20485
+20486
+20487
+20488
+20489
+20490
+20491
+20492
+20493
+20494
+20495
+20496
+20497
+20498
+20499
+20500
+20501
+20502
+20503
+20504
+20505
+20506
+20507
+20508
+20509
+20510
+20511
+20512
+20513
+20514
+20515
+20516
+20517
+20518
+20519
+20520
+20521
+20522
+20523
+20524
+20525
+20526
+20527
+20528
+20529
+20530
+20531
+20532
+20533
+20534
+20535
+20536
+20537
+20538
+20539
+20540
+20541
+20542
+20543
+20544
+20545
+20546
+20547
+20548
+20549
+20550
+20551
+20552
+20553
+20554
+20555
+20556
+20557
+20558
+20559
+20560
+20561
+20562
+20563
+20564
+20565
+20566
+20567
+20568
+20569
+20570
+20571
+20572
+20573
+20574
+20575
+20576
+20577
+20578
+20579
+20580
+20581
+20582
+20583
+20584
+20585
+20586
+20587
+20588
+20589
+20590
+20591
+20592
+20593
+20594
+20595
+20596
+20597
+20598
+20599
+20600
+20601
+20602
+20603
+20604
+20605
+20606
+20607
+20608
+20609
+20610
+20611
+20612
+20613
+20614
+20615
+20616
+20617
+20618
+20619
+20620
+20621
+20622
+20623
+20624
+20625
+20626
+20627
+20628
+20629
+20630
+20631
+20632
+20633
+20634
+20635
+20636
+20637
+20638
+20639
+20640
+20641
+20642
+20643
+20644
+20645
+20646
+20647
+20648
+20649
+20650
+20651
+20652
+20653
+20654
+20655
+20656
+20657
+20658
+20659
+20660
+20661
+20662
+20663
+20664
+20665
+20666
+20667
+20668
+20669
+20670
+20671
+20672
+20673
+20674
+20675
+20676
+20677
+20678
+20679
+20680
+20681
+20682
+20683
+20684
+20685
+20686
+20687
+20688
+20689
+20690
+20691
+20692
+20693
+20694
+20695
+20696
+20697
+20698
+20699
+20700
+20701
+20702
+20703
+20704
+20705
+20706
+20707
+20708
+20709
+20710
+20711
+20712
+20713
+20714
+20715
+20716
+20717
+20718
+20719
+20720
+20721
+20722
+20723
+20724
+20725
+20726
+20727
+20728
+20729
+20730
+20731
+20732
+20733
+20734
+20735
+20736
+20737
+20738
+20739
+20740
+20741
+20742
+20743
+20744
+20745
+20746
+20747
+20748
+20749
+20750
+20751
+20752
+20753
+20754
+20755
+20756
+20757
+20758
+20759
+20760
+20761
+20762
+20763
+20764
+20765
+20766
+20767
+20768
+20769
+20770
+20771
+20772
+20773
+20774
+20775
+20776
+20777
+20778
+20779
+20780
+20781
+20782
+20783
+20784
+20785
+20786
+20787
+20788
+20789
+20790
+20791
+20792
+20793
+20794
+20795
+20796
+20797
+20798
+20799
+20800
+20801
+20802
+20803
+20804
+20805
+20806
+20807
+20808
+20809
+20810
+20811
+20812
+20813
+20814
+20815
+20816
+20817
+20818
+20819
+20820
+20821
+20822
+20823
+20824
+20825
+20826
+20827
+20828
+20829
+20830
+20831
+20832
+20833
+20834
+20835
+20836
+20837
+20838
+20839
+20840
+20841
+20842
+20843
+20844
+20845
+20846
+20847
+20848
+20849
+20850
+20851
+20852
+20853
+20854
+20855
+20856
+20857
+20858
+20859
+20860
+20861
+20862
+20863
+20864
+20865
+20866
+20867
+20868
+20869
+20870
+20871
+20872
+20873
+20874
+20875
+20876
+20877
+20878
+20879
+20880
+20881
+20882
+20883
+20884
+20885
+20886
+20887
+20888
+20889
+20890
+20891
+20892
+20893
+20894
+20895
+20896
+20897
+20898
+20899
+20900
+20901
+20902
+20903
+20904
+20905
+20906
+20907
+20908
+20909
+20910
+20911
+20912
+20913
+20914
+20915
+20916
+20917
+20918
+20919
+20920
+20921
+20922
+20923
+20924
+20925
+20926
+20927
+20928
+20929
+20930
+20931
+20932
+20933
+20934
+20935
+20936
+20937
+20938
+20939
+20940
+20941
+20942
+20943
+20944
+20945
+20946
+20947
+20948
+20949
+20950
+20951
+20952
+20953
+20954
+20955
+20956
+20957
+20958
+20959
+20960
+20961
+20962
+20963
+20964
+20965
+20966
+20967
+20968
+20969
+20970
+20971
+20972
+20973
+20974
+20975
+20976
+20977
+20978
+20979
+20980
+20981
+20982
+20983
+20984
+20985
+20986
+20987
+20988
+20989
+20990
+20991
+20992
+20993
+20994
+20995
+20996
+20997
+20998
+20999
+21000
+21001
+21002
+21003
+21004
+21005
+21006
+21007
+21008
+21009
+21010
+21011
+21012
+21013
+21014
+21015
+21016
+21017
+21018
+21019
+21020
+21021
+21022
+21023
+21024
+21025
+21026
+21027
+21028
+21029
+21030
+21031
+21032
+21033
+21034
+21035
+21036
+21037
+21038
+21039
+21040
+21041
+21042
+21043
+21044
+21045
+21046
+21047
+21048
+21049
+21050
+21051
+21052
+21053
+21054
+21055
+21056
+21057
+21058
+21059
+21060
+21061
+21062
+21063
+21064
+21065
+21066
+21067
+21068
+21069
+21070
+21071
+21072
+21073
+21074
+21075
+21076
+21077
+21078
+21079
+21080
+21081
+21082
+21083
+21084
+21085
+21086
+21087
+21088
+21089
+21090
+21091
+21092
+21093
+21094
+21095
+21096
+21097
+21098
+21099
+21100
+21101
+21102
+21103
+21104
+21105
+21106
+21107
+21108
+21109
+21110
+21111
+21112
+21113
+21114
+21115
+21116
+21117
+21118
+21119
+21120
+21121
+21122
+21123
+21124
+21125
+21126
+21127
+21128
+21129
+21130
+21131
+21132
+21133
+21134
+21135
+21136
+21137
+21138
+21139
+21140
+21141
+21142
+21143
+21144
+21145
+21146
+21147
+21148
+21149
+21150
+21151
+21152
+21153
+21154
+21155
+21156
+21157
+21158
+21159
+21160
+21161
+21162
+21163
+21164
+21165
+21166
+21167
+21168
+21169
+21170
+21171
+21172
+21173
+21174
+21175
+21176
+21177
+21178
+21179
+21180
+21181
+21182
+21183
+21184
+21185
+21186
+21187
+21188
+21189
+21190
+21191
+21192
+21193
+21194
+21195
+21196
+21197
+21198
+21199
+21200
+21201
+21202
+21203
+21204
+21205
+21206
+21207
+21208
+21209
+21210
+21211
+21212
+21213
+21214
+21215
+21216
+21217
+21218
+21219
+21220
+21221
+21222
+21223
+21224
+21225
+21226
+21227
+21228
+21229
+21230
+21231
+21232
+21233
+21234
+21235
+21236
+21237
+21238
+21239
+21240
+21241
+21242
+21243
+21244
+21245
+21246
+21247
+21248
+21249
+21250
+21251
+21252
+21253
+21254
+21255
+21256
+21257
+21258
+21259
+21260
+21261
+21262
+21263
+21264
+21265
+21266
+21267
+21268
+21269
+21270
+21271
+21272
+21273
+21274
+21275
+21276
+21277
+21278
+21279
+21280
+21281
+21282
+21283
+21284
+21285
+21286
+21287
+21288
+21289
+21290
+21291
+21292
+21293
+21294
+21295
+21296
+21297
+21298
+21299
+21300
+21301
+21302
+21303
+21304
+21305
+21306
+21307
+21308
+21309
+21310
+21311
+21312
+21313
+21314
+21315
+21316
+21317
+21318
+21319
+21320
+21321
+21322
+21323
+21324
+21325
+21326
+21327
+21328
+21329
+21330
+21331
+21332
+21333
+21334
+21335
+21336
+21337
+21338
+21339
+21340
+21341
+21342
+21343
+21344
+21345
+21346
+21347
+21348
+21349
+21350
+21351
+21352
+21353
+21354
+21355
+21356
+21357
+21358
+21359
+21360
+21361
+21362
+21363
+21364
+21365
+21366
+21367
+21368
+21369
+21370
+21371
+21372
+21373
+21374
+21375
+21376
+21377
+21378
+21379
+21380
+21381
+21382
+21383
+21384
+21385
+21386
+21387
+21388
+21389
+21390
+21391
+21392
+21393
+21394
+21395
+21396
+21397
+21398
+21399
+21400
+21401
+21402
+21403
+21404
+21405
+21406
+21407
+21408
+21409
+21410
+21411
+21412
+21413
+21414
+21415
+21416
+21417
+21418
+21419
+21420
+21421
+21422
+21423
+21424
+21425
+21426
+21427
+21428
+21429
+21430
+21431
+21432
+21433
+21434
+21435
+21436
+21437
+21438
+21439
+21440
+21441
+21442
+21443
+21444
+21445
+21446
+21447
+21448
+21449
+21450
+21451
+21452
+21453
+21454
+21455
+21456
+21457
+21458
+21459
+21460
+21461
+21462
+21463
+21464
+21465
+21466
+21467
+21468
+21469
+21470
+21471
+21472
+21473
+21474
+21475
+21476
+21477
+21478
+21479
+21480
+21481
+21482
+21483
+21484
+21485
+21486
+21487
+21488
+21489
+21490
+21491
+21492
+21493
+21494
+21495
+21496
+21497
+21498
+21499
+21500
+21501
+21502
+21503
+21504
+21505
+21506
+21507
+21508
+21509
+21510
+21511
+21512
+21513
+21514
+21515
+21516
+21517
+21518
+21519
+21520
+21521
+21522
+21523
+21524
+21525
+21526
+21527
+21528
+21529
+21530
+21531
+21532
+21533
+21534
+21535
+21536
+21537
+21538
+21539
+21540
+21541
+21542
+21543
+21544
+21545
+21546
+21547
+21548
+21549
+21550
+21551
+21552
+21553
+21554
+21555
+21556
+21557
+21558
+21559
+21560
+21561
+21562
+21563
+21564
+21565
+21566
+21567
+21568
+21569
+21570
+21571
+21572
+21573
+21574
+21575
+21576
+21577
+21578
+21579
+21580
+21581
+21582
+21583
+21584
+21585
+21586
+21587
+21588
+21589
+21590
+21591
+21592
+21593
+21594
+21595
+21596
+21597
+21598
+21599
+21600
+21601
+21602
+21603
+21604
+21605
+21606
+21607
+21608
+21609
+21610
+21611
+21612
+21613
+21614
+21615
+21616
+21617
+21618
+21619
+21620
+21621
+21622
+21623
+21624
+21625
+21626
+21627
+21628
+21629
+21630
+21631
+21632
+21633
+21634
+21635
+21636
+21637
+21638
+21639
+21640
+21641
+21642
+21643
+21644
+21645
+21646
+21647
+21648
+21649
+21650
+21651
+21652
+21653
+21654
+21655
+21656
+21657
+21658
+21659
+21660
+21661
+21662
+21663
+21664
+21665
+21666
+21667
+21668
+21669
+21670
+21671
+21672
+21673
+21674
+21675
+21676
+21677
+21678
+21679
+21680
+21681
+21682
+21683
+21684
+21685
+21686
+21687
+21688
+21689
+21690
+21691
+21692
+21693
+21694
+21695
+21696
+21697
+21698
+21699
+21700
+21701
+21702
+21703
+21704
+21705
+21706
+21707
+21708
+21709
+21710
+21711
+21712
+21713
+21714
+21715
+21716
+21717
+21718
+21719
+21720
+21721
+21722
+21723
+21724
+21725
+21726
+21727
+21728
+21729
+21730
+21731
+21732
+21733
+21734
+21735
+21736
+21737
+21738
+21739
+21740
+21741
+21742
+21743
+21744
+21745
+21746
+21747
+21748
+21749
+21750
+21751
+21752
+21753
+21754
+21755
+21756
+21757
+21758
+21759
+21760
+21761
+21762
+21763
+21764
+21765
+21766
+21767
+21768
+21769
+21770
+21771
+21772
+21773
+21774
+21775
+21776
+21777
+21778
+21779
+21780
+21781
+21782
+21783
+21784
+21785
+21786
+21787
+21788
+21789
+21790
+21791
+21792
+21793
+21794
+21795
+21796
+21797
+21798
+21799
+21800
+21801
+21802
+21803
+21804
+21805
+21806
+21807
+21808
+21809
+21810
+21811
+21812
+21813
+21814
+21815
+21816
+21817
+21818
+21819
+21820
+21821
+21822
+21823
+21824
+21825
+21826
+21827
+21828
+21829
+21830
+21831
+21832
+21833
+21834
+21835
+21836
+21837
+21838
+21839
+21840
+21841
+21842
+21843
+21844
+21845
+21846
+21847
+21848
+21849
+21850
+21851
+21852
+21853
+21854
+21855
+21856
+21857
+21858
+21859
+21860
+21861
+21862
+21863
+21864
+21865
+21866
+21867
+21868
+21869
+21870
+21871
+21872
+21873
+21874
+21875
+21876
+21877
+21878
+21879
+21880
+21881
+21882
+21883
+21884
+21885
+21886
+21887
+21888
+21889
+21890
+21891
+21892
+21893
+21894
+21895
+21896
+21897
+21898
+21899
+21900
+21901
+21902
+21903
+21904
+21905
+21906
+21907
+21908
+21909
+21910
+21911
+21912
+21913
+21914
+21915
+21916
+21917
+21918
+21919
+21920
+21921
+21922
+21923
+21924
+21925
+21926
+21927
+21928
+21929
+21930
+21931
+21932
+21933
+21934
+21935
+21936
+21937
+21938
+21939
+21940
+21941
+21942
+21943
+21944
+21945
+21946
+21947
+21948
+21949
+21950
+21951
+21952
+21953
+21954
+21955
+21956
+21957
+21958
+21959
+21960
+21961
+21962
+21963
+21964
+21965
+21966
+21967
+21968
+21969
+21970
+21971
+21972
+21973
+21974
+21975
+21976
+21977
+21978
+21979
+21980
+21981
+21982
+21983
+21984
+21985
+21986
+21987
+21988
+21989
+21990
+21991
+21992
+21993
+21994
+21995
+21996
+21997
+21998
+21999
+22000
+22001
+22002
+22003
+22004
+22005
+22006
+22007
+22008
+22009
+22010
+22011
+22012
+22013
+22014
+22015
+22016
+22017
+22018
+22019
+22020
+22021
+22022
+22023
+22024
+22025
+22026
+22027
+22028
+22029
+22030
+22031
+22032
+22033
+22034
+22035
+22036
+22037
+22038
+22039
+22040
+22041
+22042
+22043
+22044
+22045
+22046
+22047
+22048
+22049
+22050
+22051
+22052
+22053
+22054
+22055
+22056
+22057
+22058
+22059
+22060
+22061
+22062
+22063
+22064
+22065
+22066
+22067
+22068
+22069
+22070
+22071
+22072
+22073
+22074
+22075
+22076
+22077
+22078
+22079
+22080
+22081
+22082
+22083
+22084
+22085
+22086
+22087
+22088
+22089
+22090
+22091
+22092
+22093
+22094
+22095
+22096
+22097
+22098
+22099
+22100
+22101
+22102
+22103
+22104
+22105
+22106
+22107
+22108
+22109
+22110
+22111
+22112
+22113
+22114
+22115
+22116
+22117
+22118
+22119
+22120
+22121
+22122
+22123
+22124
+22125
+22126
+22127
+22128
+22129
+22130
+22131
+22132
+22133
+22134
+22135
+22136
+22137
+22138
+22139
+22140
+22141
+22142
+22143
+22144
+22145
+22146
+22147
+22148
+22149
+22150
+22151
+22152
+22153
+22154
+22155
+22156
+22157
+22158
+22159
+22160
+22161
+22162
+22163
+22164
+22165
+22166
+22167
+22168
+22169
+22170
+22171
+22172
+22173
+22174
+22175
+22176
+22177
+22178
+22179
+22180
+22181
+22182
+22183
+22184
+22185
+22186
+22187
+22188
+22189
+22190
+22191
+22192
+22193
+22194
+22195
+22196
+22197
+22198
+22199
+22200
+22201
+22202
+22203
+22204
+22205
+22206
+22207
+22208
+22209
+22210
+22211
+22212
+22213
+22214
+22215
+22216
+22217
+22218
+22219
+22220
+22221
+22222
+22223
+22224
+22225
+22226
+22227
+22228
+22229
+22230
+22231
+22232
+22233
+22234
+22235
+22236
+22237
+22238
+22239
+22240
+22241
+22242
+22243
+22244
+22245
+22246
+22247
+22248
+22249
+22250
+22251
+22252
+22253
+22254
+22255
+22256
+22257
+22258
+22259
+22260
+22261
+22262
+22263
+22264
+22265
+22266
+22267
+22268
+22269
+22270
+22271
+22272
+22273
+22274
+22275
+22276
+22277
+22278
+22279
+22280
+22281
+22282
+22283
+22284
+22285
+22286
+22287
+22288
+22289
+22290
+22291
+22292
+22293
+22294
+22295
+22296
+22297
+22298
+22299
+22300
+22301
+22302
+22303
+22304
+22305
+22306
+22307
+22308
+22309
+22310
+22311
+22312
+22313
+22314
+22315
+22316
+22317
+22318
+22319
+22320
+22321
+22322
+22323
+22324
+22325
+22326
+22327
+22328
+22329
+22330
+22331
+22332
+22333
+22334
+22335
+22336
+22337
+22338
+22339
+22340
+22341
+22342
+22343
+22344
+22345
+22346
+22347
+22348
+22349
+22350
+22351
+22352
+22353
+22354
+22355
+22356
+22357
+22358
+22359
+22360
+22361
+22362
+22363
+22364
+22365
+22366
+22367
+22368
+22369
+22370
+22371
+22372
+22373
+22374
+22375
+22376
+22377
+22378
+22379
+22380
+22381
+22382
+22383
+22384
+22385
+22386
+22387
+22388
+22389
+22390
+22391
+22392
+22393
+22394
+22395
+22396
+22397
+22398
+22399
+22400
+22401
+22402
+22403
+22404
+22405
+22406
+22407
+22408
+22409
+22410
+22411
+22412
+22413
+22414
+22415
+22416
+22417
+22418
+22419
+22420
+22421
+22422
+22423
+22424
+22425
+22426
+22427
+22428
+22429
+22430
+22431
+22432
+22433
+22434
+22435
+22436
+22437
+22438
+22439
+22440
+22441
+22442
+22443
+22444
+22445
+22446
+22447
+22448
+22449
+22450
+22451
+22452
+22453
+22454
+22455
+22456
+22457
+22458
+22459
+22460
+22461
+22462
+22463
+22464
+22465
+22466
+22467
+22468
+22469
+22470
+22471
+22472
+22473
+22474
+22475
+22476
+22477
+22478
+22479
+22480
+22481
+22482
+22483
+22484
+22485
+22486
+22487
+22488
+22489
+22490
+22491
+22492
+22493
+22494
+22495
+22496
+22497
+22498
+22499
+22500
+22501
+22502
+22503
+22504
+22505
+22506
+22507
+22508
+22509
+22510
+22511
+22512
+22513
+22514
+22515
+22516
+22517
+22518
+22519
+22520
+22521
+22522
+22523
+22524
+22525
+22526
+22527
+22528
+22529
+22530
+22531
+22532
+22533
+22534
+22535
+22536
+22537
+22538
+22539
+22540
+22541
+22542
+22543
+22544
+22545
+22546
+22547
+22548
+22549
+22550
+22551
+22552
+22553
+22554
+22555
+22556
+22557
+22558
+22559
+22560
+22561
+22562
+22563
+22564
+22565
+22566
+22567
+22568
+22569
+22570
+22571
+22572
+22573
+22574
+22575
+22576
+22577
+22578
+22579
+22580
+22581
+22582
+22583
+22584
+22585
+22586
+22587
+22588
+22589
+22590
+22591
+22592
+22593
+22594
+22595
+22596
+22597
+22598
+22599
+22600
+22601
+22602
+22603
+22604
+22605
+22606
+22607
+22608
+22609
+22610
+22611
+22612
+22613
+22614
+22615
+22616
+22617
+22618
+22619
+22620
+22621
+22622
+22623
+22624
+22625
+22626
+22627
+22628
+22629
+22630
+22631
+22632
+22633
+22634
+22635
+22636
+22637
+22638
+22639
+22640
+22641
+22642
+22643
+22644
+22645
+22646
+22647
+22648
+22649
+22650
+22651
+22652
+22653
+22654
+22655
+22656
+22657
+22658
+22659
+22660
+22661
+22662
+22663
+22664
+22665
+22666
+22667
+22668
+22669
+22670
+22671
+22672
+22673
+22674
+22675
+22676
+22677
+22678
+22679
+22680
+22681
+22682
+22683
+22684
+22685
+22686
+22687
+22688
+22689
+22690
+22691
+22692
+22693
+22694
+22695
+22696
+22697
+22698
+22699
+22700
+22701
+22702
+22703
+22704
+22705
+22706
+22707
+22708
+22709
+22710
+22711
+22712
+22713
+22714
+22715
+22716
+22717
+22718
+22719
+22720
+22721
+22722
+22723
+22724
+22725
+22726
+22727
+22728
+22729
+22730
+22731
+22732
+22733
+22734
+22735
+22736
+22737
+22738
+22739
+22740
+22741
+22742
+22743
+22744
+22745
+22746
+22747
+22748
+22749
+22750
+22751
+22752
+22753
+22754
+22755
+22756
+22757
+22758
+22759
+22760
+22761
+22762
+22763
+22764
+22765
+22766
+22767
+22768
+22769
+22770
+22771
+22772
+22773
+22774
+22775
+22776
+22777
+22778
+22779
+22780
+22781
+22782
+22783
+22784
+22785
+22786
+22787
+22788
+22789
+22790
+22791
+22792
+22793
+22794
+22795
+22796
+22797
+22798
+22799
+22800
+22801
+22802
+22803
+22804
+22805
+22806
+22807
+22808
+22809
+22810
+22811
+22812
+22813
+22814
+22815
+22816
+22817
+22818
+22819
+22820
+22821
+22822
+22823
+22824
+22825
+22826
+22827
+22828
+22829
+22830
+22831
+22832
+22833
+22834
+22835
+22836
+22837
+22838
+22839
+22840
+22841
+22842
+22843
+22844
+22845
+22846
+22847
+22848
+22849
+22850
+22851
+22852
+22853
+22854
+22855
+22856
+22857
+22858
+22859
+22860
+22861
+22862
+22863
+22864
+22865
+22866
+22867
+22868
+22869
+22870
+22871
+22872
+22873
+22874
+22875
+22876
+22877
+22878
+22879
+22880
+22881
+22882
+22883
+22884
+22885
+22886
+22887
+22888
+22889
+22890
+22891
+22892
+22893
+22894
+22895
+22896
+22897
+22898
+22899
+22900
+22901
+22902
+22903
+22904
+22905
+22906
+22907
+22908
+22909
+22910
+22911
+22912
+22913
+22914
+22915
+22916
+22917
+22918
+22919
+22920
+22921
+22922
+22923
+22924
+22925
+22926
+22927
+22928
+22929
+22930
+22931
+22932
+22933
+22934
+22935
+22936
+22937
+22938
+22939
+22940
+22941
+22942
+22943
+22944
+22945
+22946
+22947
+22948
+22949
+22950
+22951
+22952
+22953
+22954
+22955
+22956
+22957
+22958
+22959
+22960
+22961
+22962
+22963
+22964
+22965
+22966
+22967
+22968
+22969
+22970
+22971
+22972
+22973
+22974
+22975
+22976
+22977
+22978
+22979
+22980
+22981
+22982
+22983
+22984
+22985
+22986
+22987
+22988
+22989
+22990
+22991
+22992
+22993
+22994
+22995
+22996
+22997
+22998
+22999
+23000
+23001
+23002
+23003
+23004
+23005
+23006
+23007
+23008
+23009
+23010
+23011
+23012
+23013
+23014
+23015
+23016
+23017
+23018
+23019
+23020
+23021
+23022
+23023
+23024
+23025
+23026
+23027
+23028
+23029
+23030
+23031
+23032
+23033
+23034
+23035
+23036
+23037
+23038
+23039
+23040
+23041
+23042
+23043
+23044
+23045
+23046
+23047
+23048
+23049
+23050
+23051
+23052
+23053
+23054
+23055
+23056
+23057
+23058
+23059
+23060
+23061
+23062
+23063
+23064
+23065
+23066
+23067
+23068
+23069
+23070
+23071
+23072
+23073
+23074
+23075
+23076
+23077
+23078
+23079
+23080
+23081
+23082
+23083
+23084
+23085
+23086
+23087
+23088
+23089
+23090
+23091
+23092
+23093
+23094
+23095
+23096
+23097
+23098
+23099
+23100
+23101
+23102
+23103
+23104
+23105
+23106
+23107
+23108
+23109
+23110
+23111
+23112
+23113
+23114
+23115
+23116
+23117
+23118
+23119
+23120
+23121
+23122
+23123
+23124
+23125
+23126
+23127
+23128
+23129
+23130
+23131
+23132
+23133
+23134
+23135
+23136
+23137
+23138
+23139
+23140
+23141
+23142
+23143
+23144
+23145
+23146
+23147
+23148
+23149
+23150
+23151
+23152
+23153
+23154
+23155
+23156
+23157
+23158
+23159
+23160
+23161
+23162
+23163
+23164
+23165
+23166
+23167
+23168
+23169
+23170
+23171
+23172
+23173
+23174
+23175
+23176
+23177
+23178
+23179
+23180
+23181
+23182
+23183
+23184
+23185
+23186
+23187
+23188
+23189
+23190
+23191
+23192
+23193
+23194
+23195
+23196
+23197
+23198
+23199
+23200
+23201
+23202
+23203
+23204
+23205
+23206
+23207
+23208
+23209
+23210
+23211
+23212
+23213
+23214
+23215
+23216
+23217
+23218
+23219
+23220
+23221
+23222
+23223
+23224
+23225
+23226
+23227
+23228
+23229
+23230
+23231
+23232
+23233
+23234
+23235
+23236
+23237
+23238
+23239
+23240
+23241
+23242
+23243
+23244
+23245
+23246
+23247
+23248
+23249
+23250
+23251
+23252
+23253
+23254
+23255
+23256
+23257
+23258
+23259
+23260
+23261
+23262
+23263
+23264
+23265
+23266
+23267
+23268
+23269
+23270
+23271
+23272
+23273
+23274
+23275
+23276
+23277
+23278
+23279
+23280
+23281
+23282
+23283
+23284
+23285
+23286
+23287
+23288
+23289
+23290
+23291
+23292
+23293
+23294
+23295
+23296
+23297
+23298
+23299
+23300
+23301
+23302
+23303
+23304
+23305
+23306
+23307
+23308
+23309
+23310
+23311
+23312
+23313
+23314
+23315
+23316
+23317
+23318
+23319
+23320
+23321
+23322
+23323
+23324
+23325
+23326
+23327
+23328
+23329
+23330
+23331
+23332
+23333
+23334
+23335
+23336
+23337
+23338
+23339
+23340
+23341
+23342
+23343
+23344
+23345
+23346
+23347
+23348
+23349
+23350
+23351
+23352
+23353
+23354
+23355
+23356
+23357
+23358
+23359
+23360
+23361
+23362
+23363
+23364
+23365
+23366
+23367
+23368
+23369
+23370
+23371
+23372
+23373
+23374
+23375
+23376
+23377
+23378
+23379
+23380
+23381
+23382
+23383
+23384
+23385
+23386
+23387
+23388
+23389
+23390
+23391
+23392
+23393
+23394
+23395
+23396
+23397
+23398
+23399
+23400
+23401
+23402
+23403
+23404
+23405
+23406
+23407
+23408
+23409
+23410
+23411
+23412
+23413
+23414
+23415
+23416
+23417
+23418
+23419
+23420
+23421
+23422
+23423
+23424
+23425
+23426
+23427
+23428
+23429
+23430
+23431
+23432
+23433
+23434
+23435
+23436
+23437
+23438
+23439
+23440
+23441
+23442
+23443
+23444
+23445
+23446
+23447
+23448
+23449
+23450
+23451
+23452
+23453
+23454
+23455
+23456
+23457
+23458
+23459
+23460
+23461
+23462
+23463
+23464
+23465
+23466
+23467
+23468
+23469
+23470
+23471
+23472
+23473
+23474
+23475
+23476
+23477
+23478
+23479
+23480
+23481
+23482
+23483
+23484
+23485
+23486
+23487
+23488
+23489
+23490
+23491
+23492
+23493
+23494
+23495
+23496
+23497
+23498
+23499
+23500
+23501
+23502
+23503
+23504
+23505
+23506
+23507
+23508
+23509
+23510
+23511
+23512
+23513
+23514
+23515
+23516
+23517
+23518
+23519
+23520
+23521
+23522
+23523
+23524
+23525
+23526
+23527
+23528
+23529
+23530
+23531
+23532
+23533
+23534
+23535
+23536
+23537
+23538
+23539
+23540
+23541
+23542
+23543
+23544
+23545
+23546
+23547
+23548
+23549
+23550
+23551
+23552
+23553
+23554
+23555
+23556
+23557
+23558
+23559
+23560
+23561
+23562
+23563
+23564
+23565
+23566
+23567
+23568
+23569
+23570
+23571
+23572
+23573
+23574
+23575
+23576
+23577
+23578
+23579
+23580
+23581
+23582
+23583
+23584
+23585
+23586
+23587
+23588
+23589
+23590
+23591
+23592
+23593
+23594
+23595
+23596
+23597
+23598
+23599
+23600
+23601
+23602
+23603
+23604
+23605
+23606
+23607
+23608
+23609
+23610
+23611
+23612
+23613
+23614
+23615
+23616
+23617
+23618
+23619
+23620
+23621
+23622
+23623
+23624
+23625
+23626
+23627
+23628
+23629
+23630
+23631
+23632
+23633
+23634
+23635
+23636
+23637
+23638
+23639
+23640
+23641
+23642
+23643
+23644
+23645
+23646
+23647
+23648
+23649
+23650
+23651
+23652
+23653
+23654
+23655
+23656
+23657
+23658
+23659
+23660
+23661
+23662
+23663
+23664
+23665
+23666
+23667
+23668
+23669
+23670
+23671
+23672
+23673
+23674
+23675
+23676
+23677
+23678
+23679
+23680
+23681
+23682
+23683
+23684
+23685
+23686
+23687
+23688
+23689
+23690
+23691
+23692
+23693
+23694
+23695
+23696
+23697
+23698
+23699
+23700
+23701
+23702
+23703
+23704
+23705
+23706
+23707
+23708
+23709
+23710
+23711
+23712
+23713
+23714
+23715
+23716
+23717
+23718
+23719
+23720
+23721
+23722
+23723
+23724
+23725
+23726
+23727
+23728
+23729
+23730
+23731
+23732
+23733
+23734
+23735
+23736
+23737
+23738
+23739
+23740
+23741
+23742
+23743
+23744
+23745
+23746
+23747
+23748
+23749
+23750
+23751
+23752
+23753
+23754
+23755
+23756
+23757
+23758
+23759
+23760
+23761
+23762
+23763
+23764
+23765
+23766
+23767
+23768
+23769
+23770
+23771
+23772
+23773
+23774
+23775
+23776
+23777
+23778
+23779
+23780
+23781
+23782
+23783
+23784
+23785
+23786
+23787
+23788
+23789
+23790
+23791
+23792
+23793
+23794
+23795
+23796
+23797
+23798
+23799
+23800
+23801
+23802
+23803
+23804
+23805
+23806
+23807
+23808
+23809
+23810
+23811
+23812
+23813
+23814
+23815
+23816
+23817
+23818
+23819
+23820
+23821
+23822
+23823
+23824
+23825
+23826
+23827
+23828
+23829
+23830
+23831
+23832
+23833
+23834
+23835
+23836
+23837
+23838
+23839
+23840
+23841
+23842
+23843
+23844
+23845
+23846
+23847
+23848
+23849
+23850
+23851
+23852
+23853
+23854
+23855
+23856
+23857
+23858
+23859
+23860
+23861
+23862
+23863
+23864
+23865
+23866
+23867
+23868
+23869
+23870
+23871
+23872
+23873
+23874
+23875
+23876
+23877
+23878
+23879
+23880
+23881
+23882
+23883
+23884
+23885
+23886
+23887
+23888
+23889
+23890
+23891
+23892
+23893
+23894
+23895
+23896
+23897
+23898
+23899
+23900
+23901
+23902
+23903
+23904
+23905
+23906
+23907
+23908
+23909
+23910
+23911
+23912
+23913
+23914
+23915
+23916
+23917
+23918
+23919
+23920
+23921
+23922
+23923
+23924
+23925
+23926
+23927
+23928
+23929
+23930
+23931
+23932
+23933
+23934
+23935
+23936
+23937
+23938
+23939
+23940
+23941
+23942
+23943
+23944
+23945
+23946
+23947
+23948
+23949
+23950
+23951
+23952
+23953
+23954
+23955
+23956
+23957
+23958
+23959
+23960
+23961
+23962
+23963
+23964
+23965
+23966
+23967
+23968
+23969
+23970
+23971
+23972
+23973
+23974
+23975
+23976
+23977
+23978
+23979
+23980
+23981
+23982
+23983
+23984
+23985
+23986
+23987
+23988
+23989
+23990
+23991
+23992
+23993
+23994
+23995
+23996
+23997
+23998
+23999
+24000
+24001
+24002
+24003
+24004
+24005
+24006
+24007
+24008
+24009
+24010
+24011
+24012
+24013
+24014
+24015
+24016
+24017
+24018
+24019
+24020
+24021
+24022
+24023
+24024
+24025
+24026
+24027
+24028
+24029
+24030
+24031
+24032
+24033
+24034
+24035
+24036
+24037
+24038
+24039
+24040
+24041
+24042
+24043
+24044
+24045
+24046
+24047
+24048
+24049
+24050
+24051
+24052
+24053
+24054
+24055
+24056
+24057
+24058
+24059
+24060
+24061
+24062
+24063
+24064
+24065
+24066
+24067
+24068
+24069
+24070
+24071
+24072
+24073
+24074
+24075
+24076
+24077
+24078
+24079
+24080
+24081
+24082
+24083
+24084
+24085
+24086
+24087
+24088
+24089
+24090
+24091
+24092
+24093
+24094
+24095
+24096
+24097
+24098
+24099
+24100
+24101
+24102
+24103
+24104
+24105
+24106
+24107
+24108
+24109
+24110
+24111
+24112
+24113
+24114
+24115
+24116
+24117
+24118
+24119
+24120
+24121
+24122
+24123
+24124
+24125
+24126
+24127
+24128
+24129
+24130
+24131
+24132
+24133
+24134
+24135
+24136
+24137
+24138
+24139
+24140
+24141
+24142
+24143
+24144
+24145
+24146
+24147
+24148
+24149
+24150
+24151
+24152
+24153
+24154
+24155
+24156
+24157
+24158
+24159
+24160
+24161
+24162
+24163
+24164
+24165
+24166
+24167
+24168
+24169
+24170
+24171
+24172
+24173
+24174
+24175
+24176
+24177
+24178
+24179
+24180
+24181
+24182
+24183
+24184
+24185
+24186
+24187
+24188
+24189
+24190
+24191
+24192
+24193
+24194
+24195
+24196
+24197
+24198
+24199
+24200
+24201
+24202
+24203
+24204
+24205
+24206
+24207
+24208
+24209
+24210
+24211
+24212
+24213
+24214
+24215
+24216
+24217
+24218
+24219
+24220
+24221
+24222
+24223
+24224
+24225
+24226
+24227
+24228
+24229
+24230
+24231
+24232
+24233
+24234
+24235
+24236
+24237
+24238
+24239
+24240
+24241
+24242
+24243
+24244
+24245
+24246
+24247
+24248
+24249
+24250
+24251
+24252
+24253
+24254
+24255
+24256
+24257
+24258
+24259
+24260
+24261
+24262
+24263
+24264
+24265
+24266
+24267
+24268
+24269
+24270
+24271
+24272
+24273
+24274
+24275
+24276
+24277
+24278
+24279
+24280
+24281
+24282
+24283
+24284
+24285
+24286
+24287
+24288
+24289
+24290
+24291
+24292
+24293
+24294
+24295
+24296
+24297
+24298
+24299
+24300
+24301
+24302
+24303
+24304
+24305
+24306
+24307
+24308
+24309
+24310
+24311
+24312
+24313
+24314
+24315
+24316
+24317
+24318
+24319
+24320
+24321
+24322
+24323
+24324
+24325
+24326
+24327
+24328
+24329
+24330
+24331
+24332
+24333
+24334
+24335
+24336
+24337
+24338
+24339
+24340
+24341
+24342
+24343
+24344
+24345
+24346
+24347
+24348
+24349
+24350
+24351
+24352
+24353
+24354
+24355
+24356
+24357
+24358
+24359
+24360
+24361
+24362
+24363
+24364
+24365
+24366
+24367
+24368
+24369
+24370
+24371
+24372
+24373
+24374
+24375
+24376
+24377
+24378
+24379
+24380
+24381
+24382
+24383
+24384
+24385
+24386
+24387
+24388
+24389
+24390
+24391
+24392
+24393
+24394
+24395
+24396
+24397
+24398
+24399
+24400
+24401
+24402
+24403
+24404
+24405
+24406
+24407
+24408
+24409
+24410
+24411
+24412
+24413
+24414
+24415
+24416
+24417
+24418
+24419
+24420
+24421
+24422
+24423
+24424
+24425
+24426
+24427
+24428
+24429
+24430
+24431
+24432
+24433
+24434
+24435
+24436
+24437
+24438
+24439
+24440
+24441
+24442
+24443
+24444
+24445
+24446
+24447
+24448
+24449
+24450
+24451
+24452
+24453
+24454
+24455
+24456
+24457
+24458
+24459
+24460
+24461
+24462
+24463
+24464
+24465
+24466
+24467
+24468
+24469
+24470
+24471
+24472
+24473
+24474
+24475
+24476
+24477
+24478
+24479
+24480
+24481
+24482
+24483
+24484
+24485
+24486
+24487
+24488
+24489
+24490
+24491
+24492
+24493
+24494
+24495
+24496
+24497
+24498
+24499
+24500
+24501
+24502
+24503
+24504
+24505
+24506
+24507
+24508
+24509
+24510
+24511
+24512
+24513
+24514
+24515
+24516
+24517
+24518
+24519
+24520
+24521
+24522
+24523
+24524
+24525
+24526
+24527
+24528
+24529
+24530
+24531
+24532
+24533
+24534
+24535
+24536
+24537
+24538
+24539
+24540
+24541
+24542
+24543
+24544
+24545
+24546
+24547
+24548
+24549
+24550
+24551
+24552
+24553
+24554
+24555
+24556
+24557
+24558
+24559
+24560
+24561
+24562
+24563
+24564
+24565
+24566
+24567
+24568
+24569
+24570
+24571
+24572
+24573
+24574
+24575
+24576
+24577
+24578
+24579
+24580
+24581
+24582
+24583
+24584
+24585
+24586
+24587
+24588
+24589
+24590
+24591
+24592
+24593
+24594
+24595
+24596
+24597
+24598
+24599
+24600
+24601
+24602
+24603
+24604
+24605
+24606
+24607
+24608
+24609
+24610
+24611
+24612
+24613
+24614
+24615
+24616
+24617
+24618
+24619
+24620
+24621
+24622
+24623
+24624
+24625
+24626
+24627
+24628
+24629
+24630
+24631
+24632
+24633
+24634
+24635
+24636
+24637
+24638
+24639
+24640
+24641
+24642
+24643
+24644
+24645
+24646
+24647
+24648
+24649
+24650
+24651
+24652
+24653
+24654
+24655
+24656
+24657
+24658
+24659
+24660
+24661
+24662
+24663
+24664
+24665
+24666
+24667
+24668
+24669
+24670
+24671
+24672
+24673
+24674
+24675
+24676
+24677
+24678
+24679
+24680
+24681
+24682
+24683
+24684
+24685
+24686
+24687
+24688
+24689
+24690
+24691
+24692
+24693
+24694
+24695
+24696
+24697
+24698
+24699
+24700
+24701
+24702
+24703
+24704
+24705
+24706
+24707
+24708
+24709
+24710
+24711
+24712
+24713
+24714
+24715
+24716
+24717
+24718
+24719
+24720
+24721
+24722
+24723
+24724
+24725
+24726
+24727
+24728
+24729
+24730
+24731
+24732
+24733
+24734
+24735
+24736
+24737
+24738
+24739
+24740
+24741
+24742
+24743
+24744
+24745
+24746
+24747
+24748
+24749
+24750
+24751
+24752
+24753
+24754
+24755
+24756
+24757
+24758
+24759
+24760
+24761
+24762
+24763
+24764
+24765
+24766
+24767
+24768
+24769
+24770
+24771
+24772
+24773
+24774
+24775
+24776
+24777
+24778
+24779
+24780
+24781
+24782
+24783
+24784
+24785
+24786
+24787
+24788
+24789
+24790
+24791
+24792
+24793
+24794
+24795
+24796
+24797
+24798
+24799
+24800
+24801
+24802
+24803
+24804
+24805
+24806
+24807
+24808
+24809
+24810
+24811
+24812
+24813
+24814
+24815
+24816
+24817
+24818
+24819
+24820
+24821
+24822
+24823
+24824
+24825
+24826
+24827
+24828
+24829
+24830
+24831
+24832
+24833
+24834
+24835
+24836
+24837
+24838
+24839
+24840
+24841
+24842
+24843
+24844
+24845
+24846
+24847
+24848
+24849
+24850
+24851
+24852
+24853
+24854
+24855
+24856
+24857
+24858
+24859
+24860
+24861
+24862
+24863
+24864
+24865
+24866
+24867
+24868
+24869
+24870
+24871
+24872
+24873
+24874
+24875
+24876
+24877
+24878
+24879
+24880
+24881
+24882
+24883
+24884
+24885
+24886
+24887
+24888
+24889
+24890
+24891
+24892
+24893
+24894
+24895
+24896
+24897
+24898
+24899
+24900
+24901
+24902
+24903
+24904
+24905
+24906
+24907
+24908
+24909
+24910
+24911
+24912
+24913
+24914
+24915
+24916
+24917
+24918
+24919
+24920
+24921
+24922
+24923
+24924
+24925
+24926
+24927
+24928
+24929
+24930
+24931
+24932
+24933
+24934
+24935
+24936
+24937
+24938
+24939
+24940
+24941
+24942
+24943
+24944
+24945
+24946
+24947
+24948
+24949
+24950
+24951
+24952
+24953
+24954
+24955
+24956
+24957
+24958
+24959
+24960
+24961
+24962
+24963
+24964
+24965
+24966
+24967
+24968
+24969
+24970
+24971
+24972
+24973
+24974
+24975
+24976
+24977
+24978
+24979
+24980
+24981
+24982
+24983
+24984
+24985
+24986
+24987
+24988
+24989
+24990
+24991
+24992
+24993
+24994
+24995
+24996
+24997
+24998
+24999
+25000
+25001
+25002
+25003
+25004
+25005
+25006
+25007
+25008
+25009
+25010
+25011
+25012
+25013
+25014
+25015
+25016
+25017
+25018
+25019
+25020
+25021
+25022
+25023
+25024
+25025
+25026
+25027
+25028
+25029
+25030
+25031
+25032
+25033
+25034
+25035
+25036
+25037
+25038
+25039
+25040
+25041
+25042
+25043
+25044
+25045
+25046
+25047
+25048
+25049
+25050
+25051
+25052
+25053
+25054
+25055
+25056
+25057
+25058
+25059
+25060
+25061
+25062
+25063
+25064
+25065
+25066
+25067
+25068
+25069
+25070
+25071
+25072
+25073
+25074
+25075
+25076
+25077
+25078
+25079
+25080
+25081
+25082
+25083
+25084
+25085
+25086
+25087
+25088
+25089
+25090
+25091
+25092
+25093
+25094
+25095
+25096
+25097
+25098
+25099
+25100
+25101
+25102
+25103
+25104
+25105
+25106
+25107
+25108
+25109
+25110
+25111
+25112
+25113
+25114
+25115
+25116
+25117
+25118
+25119
+25120
+25121
+25122
+25123
+25124
+25125
+25126
+25127
+25128
+25129
+25130
+25131
+25132
+25133
+25134
+25135
+25136
+25137
+25138
+25139
+25140
+25141
+25142
+25143
+25144
+25145
+25146
+25147
+25148
+25149
+25150
+25151
+25152
+25153
+25154
+25155
+25156
+25157
+25158
+25159
+25160
+25161
+25162
+25163
+25164
+25165
+25166
+25167
+25168
+25169
+25170
+25171
+25172
+25173
+25174
+25175
+25176
+25177
+25178
+25179
+25180
+25181
+25182
+25183
+25184
+25185
+25186
+25187
+25188
+25189
+25190
+25191
+25192
+25193
+25194
+25195
+25196
+25197
+25198
+25199
+25200
+25201
+25202
+25203
+25204
+25205
+25206
+25207
+25208
+25209
+25210
+25211
+25212
+25213
+25214
+25215
+25216
+25217
+25218
+25219
+25220
+25221
+25222
+25223
+25224
+25225
+25226
+25227
+25228
+25229
+25230
+25231
+25232
+25233
+25234
+25235
+25236
+25237
+25238
+25239
+25240
+25241
+25242
+25243
+25244
+25245
+25246
+25247
+25248
+25249
+25250
+25251
+25252
+25253
+25254
+25255
+25256
+25257
+25258
+25259
+25260
+25261
+25262
+25263
+25264
+25265
+25266
+25267
+25268
+25269
+25270
+25271
+25272
+25273
+25274
+25275
+25276
+25277
+25278
+25279
+25280
+25281
+25282
+25283
+25284
+25285
+25286
+25287
+25288
+25289
+25290
+25291
+25292
+25293
+25294
+25295
+25296
+25297
+25298
+25299
+25300
+25301
+25302
+25303
+25304
+25305
+25306
+25307
+25308
+25309
+25310
+25311
+25312
+25313
+25314
+25315
+25316
+25317
+25318
+25319
+25320
+25321
+25322
+25323
+25324
+25325
+25326
+25327
+25328
+25329
+25330
+25331
+25332
+25333
+25334
+25335
+25336
+25337
+25338
+25339
+25340
+25341
+25342
+25343
+25344
+25345
+25346
+25347
+25348
+25349
+25350
+25351
+25352
+25353
+25354
+25355
+25356
+25357
+25358
+25359
+25360
+25361
+25362
+25363
+25364
+25365
+25366
+25367
+25368
+25369
+25370
+25371
+25372
+25373
+25374
+25375
+25376
+25377
+25378
+25379
+25380
+25381
+25382
+25383
+25384
+25385
+25386
+25387
+25388
+25389
+25390
+25391
+25392
+25393
+25394
+25395
+25396
+25397
+25398
+25399
+25400
+25401
+25402
+25403
+25404
+25405
+25406
+25407
+25408
+25409
+25410
+25411
+25412
+25413
+25414
+25415
+25416
+25417
+25418
+25419
+25420
+25421
+25422
+25423
+25424
+25425
+25426
+25427
+25428
+25429
+25430
+25431
+25432
+25433
+25434
+25435
+25436
+25437
+25438
+25439
+25440
+25441
+25442
+25443
+25444
+25445
+25446
+25447
+25448
+25449
+25450
+25451
+25452
+25453
+25454
+25455
+25456
+25457
+25458
+25459
+25460
+25461
+25462
+25463
+25464
+25465
+25466
+25467
+25468
+25469
+25470
+25471
+25472
+25473
+25474
+25475
+25476
+25477
+25478
+25479
+25480
+25481
+25482
+25483
+25484
+25485
+25486
+25487
+25488
+25489
+25490
+25491
+25492
+25493
+25494
+25495
+25496
+25497
+25498
+25499
+25500
+25501
+25502
+25503
+25504
+25505
+25506
+25507
+25508
+25509
+25510
+25511
+25512
+25513
+25514
+25515
+25516
+25517
+25518
+25519
+25520
+25521
+25522
+25523
+25524
+25525
+25526
+25527
+25528
+25529
+25530
+25531
+25532
+25533
+25534
+25535
+25536
+25537
+25538
+25539
+25540
+25541
+25542
+25543
+25544
+25545
+25546
+25547
+25548
+25549
+25550
+25551
+25552
+25553
+25554
+25555
+25556
+25557
+25558
+25559
+25560
+25561
+25562
+25563
+25564
+25565
+25566
+25567
+25568
+25569
+25570
+25571
+25572
+25573
+25574
+25575
+25576
+25577
+25578
+25579
+25580
+25581
+25582
+25583
+25584
+25585
+25586
+25587
+25588
+25589
+25590
+25591
+25592
+25593
+25594
+25595
+25596
+25597
+25598
+25599
+25600
+25601
+25602
+25603
+25604
+25605
+25606
+25607
+25608
+25609
+25610
+25611
+25612
+25613
+25614
+25615
+25616
+25617
+25618
+25619
+25620
+25621
+25622
+25623
+25624
+25625
+25626
+25627
+25628
+25629
+25630
+25631
+25632
+25633
+25634
+25635
+25636
+25637
+25638
+25639
+25640
+25641
+25642
+25643
+25644
+25645
+25646
+25647
+25648
+25649
+25650
+25651
+25652
+25653
+25654
+25655
+25656
+25657
+25658
+25659
+25660
+25661
+25662
+25663
+25664
+25665
+25666
+25667
+25668
+25669
+25670
+25671
+25672
+25673
+25674
+25675
+25676
+25677
+25678
+25679
+25680
+25681
+25682
+25683
+25684
+25685
+25686
+25687
+25688
+25689
+25690
+25691
+25692
+25693
+25694
+25695
+25696
+25697
+25698
+25699
+25700
+25701
+25702
+25703
+25704
+25705
+25706
+25707
+25708
+25709
+25710
+25711
+25712
+25713
+25714
+25715
+25716
+25717
+25718
+25719
+25720
+25721
+25722
+25723
+25724
+25725
+25726
+25727
+25728
+25729
+25730
+25731
+25732
+25733
+25734
+25735
+25736
+25737
+25738
+25739
+25740
+25741
+25742
+25743
+25744
+25745
+25746
+25747
+25748
+25749
+25750
+25751
+25752
+25753
+25754
+25755
+25756
+25757
+25758
+25759
+25760
+25761
+25762
+25763
+25764
+25765
+25766
+25767
+25768
+25769
+25770
+25771
+25772
+25773
+25774
+25775
+25776
+25777
+25778
+25779
+25780
+25781
+25782
+25783
+25784
+25785
+25786
+25787
+25788
+25789
+25790
+25791
+25792
+25793
+25794
+25795
+25796
+25797
+25798
+25799
+25800
+25801
+25802
+25803
+25804
+25805
+25806
+25807
+25808
+25809
+25810
+25811
+25812
+25813
+25814
+25815
+25816
+25817
+25818
+25819
+25820
+25821
+25822
+25823
+25824
+25825
+25826
+25827
+25828
+25829
+25830
+25831
+25832
+25833
+25834
+25835
+25836
+25837
+25838
+25839
+25840
+25841
+25842
+25843
+25844
+25845
+25846
+25847
+25848
+25849
+25850
+25851
+25852
+25853
+25854
+25855
+25856
+25857
+25858
+25859
+25860
+25861
+25862
+25863
+25864
+25865
+25866
+25867
+25868
+25869
+25870
+25871
+25872
+25873
+25874
+25875
+25876
+25877
+25878
+25879
+25880
+25881
+25882
+25883
+25884
+25885
+25886
+25887
+25888
+25889
+25890
+25891
+25892
+25893
+25894
+25895
+25896
+25897
+25898
+25899
+25900
+25901
+25902
+25903
+25904
+25905
+25906
+25907
+25908
+25909
+25910
+25911
+25912
+25913
+25914
+25915
+25916
+25917
+25918
+25919
+25920
+25921
+25922
+25923
+25924
+25925
+25926
+25927
+25928
+25929
+25930
+25931
+25932
+25933
+25934
+25935
+25936
+25937
+25938
+25939
+25940
+25941
+25942
+25943
+25944
+25945
+25946
+25947
+25948
+25949
+25950
+25951
+25952
+25953
+25954
+25955
+25956
+25957
+25958
+25959
+25960
+25961
+25962
+25963
+25964
+25965
+25966
+25967
+25968
+25969
+25970
+25971
+25972
+25973
+25974
+25975
+25976
+25977
+25978
+25979
+25980
+25981
+25982
+25983
+25984
+25985
+25986
+25987
+25988
+25989
+25990
+25991
+25992
+25993
+25994
+25995
+25996
+25997
+25998
+25999
+26000
+26001
+26002
+26003
+26004
+26005
+26006
+26007
+26008
+26009
+26010
+26011
+26012
+26013
+26014
+26015
+26016
+26017
+26018
+26019
+26020
+26021
+26022
+26023
+26024
+26025
+26026
+26027
+26028
+26029
+26030
+26031
+26032
+26033
+26034
+26035
+26036
+26037
+26038
+26039
+26040
+26041
+26042
+26043
+26044
+26045
+26046
+26047
+26048
+26049
+26050
+26051
+26052
+26053
+26054
+26055
+26056
+26057
+26058
+26059
+26060
+26061
+26062
+26063
+26064
+26065
+26066
+26067
+26068
+26069
+26070
+26071
+26072
+26073
+26074
+26075
+26076
+26077
+26078
+26079
+26080
+26081
+26082
+26083
+26084
+26085
+26086
+26087
+26088
+26089
+26090
+26091
+26092
+26093
+26094
+26095
+26096
+26097
+26098
+26099
+26100
+26101
+26102
+26103
+26104
+26105
+26106
+26107
+26108
+26109
+26110
+26111
+26112
+26113
+26114
+26115
+26116
+26117
+26118
+26119
+26120
+26121
+26122
+26123
+26124
+26125
+26126
+26127
+26128
+26129
+26130
+26131
+26132
+26133
+26134
+26135
+26136
+26137
+26138
+26139
+26140
+26141
+26142
+26143
+26144
+26145
+26146
+26147
+26148
+26149
+26150
+26151
+26152
+26153
+26154
+26155
+26156
+26157
+26158
+26159
+26160
+26161
+26162
+26163
+26164
+26165
+26166
+26167
+26168
+26169
+26170
+26171
+26172
+26173
+26174
+26175
+26176
+26177
+26178
+26179
+26180
+26181
+26182
+26183
+26184
+26185
+26186
+26187
+26188
+26189
+26190
+26191
+26192
+26193
+26194
+26195
+26196
+26197
+26198
+26199
+26200
+26201
+26202
+26203
+26204
+26205
+26206
+26207
+26208
+26209
+26210
+26211
+26212
+26213
+26214
+26215
+26216
+26217
+26218
+26219
+26220
+26221
+26222
+26223
+26224
+26225
+26226
+26227
+26228
+26229
+26230
+26231
+26232
+26233
+26234
+26235
+26236
+26237
+26238
+26239
+26240
+26241
+26242
+26243
+26244
+26245
+26246
+26247
+26248
+26249
+26250
+26251
+26252
+26253
+26254
+26255
+26256
+26257
+26258
+26259
+26260
+26261
+26262
+26263
+26264
+26265
+26266
+26267
+26268
+26269
+26270
+26271
+26272
+26273
+26274
+26275
+26276
+26277
+26278
+26279
+26280
+26281
+26282
+26283
+26284
+26285
+26286
+26287
+26288
+26289
+26290
+26291
+26292
+26293
+26294
+26295
+26296
+26297
+26298
+26299
+26300
+26301
+26302
+26303
+26304
+26305
+26306
+26307
+26308
+26309
+26310
+26311
+26312
+26313
+26314
+26315
+26316
+26317
+26318
+26319
+26320
+26321
+26322
+26323
+26324
+26325
+26326
+26327
+26328
+26329
+26330
+26331
+26332
+26333
+26334
+26335
+26336
+26337
+26338
+26339
+26340
+26341
+26342
+26343
+26344
+26345
+26346
+26347
+26348
+26349
+26350
+26351
+26352
+26353
+26354
+26355
+26356
+26357
+26358
+26359
+26360
+26361
+26362
+26363
+26364
+26365
+26366
+26367
+26368
+26369
+26370
+26371
+26372
+26373
+26374
+26375
+26376
+26377
+26378
+26379
+26380
+26381
+26382
+26383
+26384
+26385
+26386
+26387
+26388
+26389
+26390
+26391
+26392
+26393
+26394
+26395
+26396
+26397
+26398
+26399
+26400
+26401
+26402
+26403
+26404
+26405
+26406
+26407
+26408
+26409
+26410
+26411
+26412
+26413
+26414
+26415
+26416
+26417
+26418
+26419
+26420
+26421
+26422
+26423
+26424
+26425
+26426
+26427
+26428
+26429
+26430
+26431
+26432
+26433
+26434
+26435
+26436
+26437
+26438
+26439
+26440
+26441
+26442
+26443
+26444
+26445
+26446
+26447
+26448
+26449
+26450
+26451
+26452
+26453
+26454
+26455
+26456
+26457
+26458
+26459
+26460
+26461
+26462
+26463
+26464
+26465
+26466
+26467
+26468
+26469
+26470
+26471
+26472
+26473
+26474
+26475
+26476
+26477
+26478
+26479
+26480
+26481
+26482
+26483
+26484
+26485
+26486
+26487
+26488
+26489
+26490
+26491
+26492
+26493
+26494
+26495
+26496
+26497
+26498
+26499
+26500
+26501
+26502
+26503
+26504
+26505
+26506
+26507
+26508
+26509
+26510
+26511
+26512
+26513
+26514
+26515
+26516
+26517
+26518
+26519
+26520
+26521
+26522
+26523
+26524
+26525
+26526
+26527
+26528
+26529
+26530
+26531
+26532
+26533
+26534
+26535
+26536
+26537
+26538
+26539
+26540
+26541
+26542
+26543
+26544
+26545
+26546
+26547
+26548
+26549
+26550
+26551
+26552
+26553
+26554
+26555
+26556
+26557
+26558
+26559
+26560
+26561
+26562
+26563
+26564
+26565
+26566
+26567
+26568
+26569
+26570
+26571
+26572
+26573
+26574
+26575
+26576
+26577
+26578
+26579
+26580
+26581
+26582
+26583
+26584
+26585
+26586
+26587
+26588
+26589
+26590
+26591
+26592
+26593
+26594
+26595
+26596
+26597
+26598
+26599
+26600
+26601
+26602
+26603
+26604
+26605
+26606
+26607
+26608
+26609
+26610
+26611
+26612
+26613
+26614
+26615
+26616
+26617
+26618
+26619
+26620
+26621
+26622
+26623
+26624
+26625
+26626
+26627
+26628
+26629
+26630
+26631
+26632
+26633
+26634
+26635
+26636
+26637
+26638
+26639
+26640
+26641
+26642
+26643
+26644
+26645
+26646
+26647
+26648
+26649
+26650
+26651
+26652
+26653
+26654
+26655
+26656
+26657
+26658
+26659
+26660
+26661
+26662
+26663
+26664
+26665
+26666
+26667
+26668
+26669
+26670
+26671
+26672
+26673
+26674
+26675
+26676
+26677
+26678
+26679
+26680
+26681
+26682
+26683
+26684
+26685
+26686
+26687
+26688
+26689
+26690
+26691
+26692
+26693
+26694
+26695
+26696
+26697
+26698
+26699
+26700
+26701
+26702
+26703
+26704
+26705
+26706
+26707
+26708
+26709
+26710
+26711
+26712
+26713
+26714
+26715
+26716
+26717
+26718
+26719
+26720
+26721
+26722
+26723
+26724
+26725
+26726
+26727
+26728
+26729
+26730
+26731
+26732
+26733
+26734
+26735
+26736
+26737
+26738
+26739
+26740
+26741
+26742
+26743
+26744
+26745
+26746
+26747
+26748
+26749
+26750
+26751
+26752
+26753
+26754
+26755
+26756
+26757
+26758
+26759
+26760
+26761
+26762
+26763
+26764
+26765
+26766
+26767
+26768
+26769
+26770
+26771
+26772
+26773
+26774
+26775
+26776
+26777
+26778
+26779
+26780
+26781
+26782
+26783
+26784
+26785
+26786
+26787
+26788
+26789
+26790
+26791
+26792
+26793
+26794
+26795
+26796
+26797
+26798
+26799
+26800
+26801
+26802
+26803
+26804
+26805
+26806
+26807
+26808
+26809
+26810
+26811
+26812
+26813
+26814
+26815
+26816
+26817
+26818
+26819
+26820
+26821
+26822
+26823
+26824
+26825
+26826
+26827
+26828
+26829
+26830
+26831
+26832
+26833
+26834
+26835
+26836
+26837
+26838
+26839
+26840
+26841
+26842
+26843
+26844
+26845
+26846
+26847
+26848
+26849
+26850
+26851
+26852
+26853
+26854
+26855
+26856
+26857
+26858
+26859
+26860
+26861
+26862
+26863
+26864
+26865
+26866
+26867
+26868
+26869
+26870
+26871
+26872
+26873
+26874
+26875
+26876
+26877
+26878
+26879
+26880
+26881
+26882
+26883
+26884
+26885
+26886
+26887
+26888
+26889
+26890
+26891
+26892
+26893
+26894
+26895
+26896
+26897
+26898
+26899
+26900
+26901
+26902
+26903
+26904
+26905
+26906
+26907
+26908
+26909
+26910
+26911
+26912
+26913
+26914
+26915
+26916
+26917
+26918
+26919
+26920
+26921
+26922
+26923
+26924
+26925
+26926
+26927
+26928
+26929
+26930
+26931
+26932
+26933
+26934
+26935
+26936
+26937
+26938
+26939
+26940
+26941
+26942
+26943
+26944
+26945
+26946
+26947
+26948
+26949
+26950
+26951
+26952
+26953
+26954
+26955
+26956
+26957
+26958
+26959
+26960
+26961
+26962
+26963
+26964
+26965
+26966
+26967
+26968
+26969
+26970
+26971
+26972
+26973
+26974
+26975
+26976
+26977
+26978
+26979
+26980
+26981
+26982
+26983
+26984
+26985
+26986
+26987
+26988
+26989
+26990
+26991
+26992
+26993
+26994
+26995
+26996
+26997
+26998
+26999
+27000
+27001
+27002
+27003
+27004
+27005
+27006
+27007
+27008
+27009
+27010
+27011
+27012
+27013
+27014
+27015
+27016
+27017
+27018
+27019
+27020
+27021
+27022
+27023
+27024
+27025
+27026
+27027
+27028
+27029
+27030
+27031
+27032
+27033
+27034
+27035
+27036
+27037
+27038
+27039
+27040
+27041
+27042
+27043
+27044
+27045
+27046
+27047
+27048
+27049
+27050
+27051
+27052
+27053
+27054
+27055
+27056
+27057
+27058
+27059
+27060
+27061
+27062
+27063
+27064
+27065
+27066
+27067
+27068
+27069
+27070
+27071
+27072
+27073
+27074
+27075
+27076
+27077
+27078
+27079
+27080
+27081
+27082
+27083
+27084
+27085
+27086
+27087
+27088
+27089
+27090
+27091
+27092
+27093
+27094
+27095
+27096
+27097
+27098
+27099
+27100
+27101
+27102
+27103
+27104
+27105
+27106
+27107
+27108
+27109
+27110
+27111
+27112
+27113
+27114
+27115
+27116
+27117
+27118
+27119
+27120
+27121
+27122
+27123
+27124
+27125
+27126
+27127
+27128
+27129
+27130
+27131
+27132
+27133
+27134
+27135
+27136
+27137
+27138
+27139
+27140
+27141
+27142
+27143
+27144
+27145
+27146
+27147
+27148
+27149
+27150
+27151
+27152
+27153
+27154
+27155
+27156
+27157
+27158
+27159
+27160
+27161
+27162
+27163
+27164
+27165
+27166
+27167
+27168
+27169
+27170
+27171
+27172
+27173
+27174
+27175
+27176
+27177
+27178
+27179
+27180
+27181
+27182
+27183
+27184
+27185
+27186
+27187
+27188
+27189
+27190
+27191
+27192
+27193
+27194
+27195
+27196
+27197
+27198
+27199
+27200
+27201
+27202
+27203
+27204
+27205
+27206
+27207
+27208
+27209
+27210
+27211
+27212
+27213
+27214
+27215
+27216
+27217
+27218
+27219
+27220
+27221
+27222
+27223
+27224
+27225
+27226
+27227
+27228
+27229
+27230
+27231
+27232
+27233
+27234
+27235
+27236
+27237
+27238
+27239
+27240
+27241
+27242
+27243
+27244
+27245
+27246
+27247
+27248
+27249
+27250
+27251
+27252
+27253
+27254
+27255
+27256
+27257
+27258
+27259
+27260
+27261
+27262
+27263
+27264
+27265
+27266
+27267
+27268
+27269
+27270
+27271
+27272
+27273
+27274
+27275
+27276
+27277
+27278
+27279
+27280
+27281
+27282
+27283
+27284
+27285
+27286
+27287
+27288
+27289
+27290
+27291
+27292
+27293
+27294
+27295
+27296
+27297
+27298
+27299
+27300
+27301
+27302
+27303
+27304
+27305
+27306
+27307
+27308
+27309
+27310
+27311
+27312
+27313
+27314
+27315
+27316
+27317
+27318
+27319
+27320
+27321
+27322
+27323
+27324
+27325
+27326
+27327
+27328
+27329
+27330
+27331
+27332
+27333
+27334
+27335
+27336
+27337
+27338
+27339
+27340
+27341
+27342
+27343
+27344
+27345
+27346
+27347
+27348
+27349
+27350
+27351
+27352
+27353
+27354
+27355
+27356
+27357
+27358
+27359
+27360
+27361
+27362
+27363
+27364
+27365
+27366
+27367
+27368
+27369
+27370
+27371
+27372
+27373
+27374
+27375
+27376
+27377
+27378
+27379
+27380
+27381
+27382
+27383
+27384
+27385
+27386
+27387
+27388
+27389
+27390
+27391
+27392
+27393
+27394
+27395
+27396
+27397
+27398
+27399
+27400
+27401
+27402
+27403
+27404
+27405
+27406
+27407
+27408
+27409
+27410
+27411
+27412
+27413
+27414
+27415
+27416
+27417
+27418
+27419
+27420
+27421
+27422
+27423
+27424
+27425
+27426
+27427
+27428
+27429
+27430
+27431
+27432
+27433
+27434
+27435
+27436
+27437
+27438
+27439
+27440
+27441
+27442
+27443
+27444
+27445
+27446
+27447
+27448
+27449
+27450
+27451
+27452
+27453
+27454
+27455
+27456
+27457
+27458
+27459
+27460
+27461
+27462
+27463
+27464
+27465
+27466
+27467
+27468
+27469
+27470
+27471
+27472
+27473
+27474
+27475
+27476
+27477
+27478
+27479
+27480
+27481
+27482
+27483
+27484
+27485
+27486
+27487
+27488
+27489
+27490
+27491
+27492
+27493
+27494
+27495
+27496
+27497
+27498
+27499
+27500
+27501
+27502
+27503
+27504
+27505
+27506
+27507
+27508
+27509
+27510
+27511
+27512
+27513
+27514
+27515
+27516
+27517
+27518
+27519
+27520
+27521
+27522
+27523
+27524
+27525
+27526
+27527
+27528
+27529
+27530
+27531
+27532
+27533
+27534
+27535
+27536
+27537
+27538
+27539
+27540
+27541
+27542
+27543
+27544
+27545
+27546
+27547
+27548
+27549
+27550
+27551
+27552
+27553
+27554
+27555
+27556
+27557
+27558
+27559
+27560
+27561
+27562
+27563
+27564
+27565
+27566
+27567
+27568
+27569
+27570
+27571
+27572
+27573
+27574
+27575
+27576
+27577
+27578
+27579
+27580
+27581
+27582
+27583
+27584
+27585
+27586
+27587
+27588
+27589
+27590
+27591
+27592
+27593
+27594
+27595
+27596
+27597
+27598
+27599
+27600
+27601
+27602
+27603
+27604
+27605
+27606
+27607
+27608
+27609
+27610
+27611
+27612
+27613
+27614
+27615
+27616
+27617
+27618
+27619
+27620
+27621
+27622
+27623
+27624
+27625
+27626
+27627
+27628
+27629
+27630
+27631
+27632
+27633
+27634
+27635
+27636
+27637
+27638
+27639
+27640
+27641
+27642
+27643
+27644
+27645
+27646
+27647
+27648
+27649
+27650
+27651
+27652
+27653
+27654
+27655
+27656
+27657
+27658
+27659
+27660
+27661
+27662
+27663
+27664
+27665
+27666
+27667
+27668
+27669
+27670
+27671
+27672
+27673
+27674
+27675
+27676
+27677
+27678
+27679
+27680
+27681
+27682
+27683
+27684
+27685
+27686
+27687
+27688
+27689
+27690
+27691
+27692
+27693
+27694
+27695
+27696
+27697
+27698
+27699
+27700
+27701
+27702
+27703
+27704
+27705
+27706
+27707
+27708
+27709
+27710
+27711
+27712
+27713
+27714
+27715
+27716
+27717
+27718
+27719
+27720
+27721
+27722
+27723
+27724
+27725
+27726
+27727
+27728
+27729
+27730
+27731
+27732
+27733
+27734
+27735
+27736
+27737
+27738
+27739
+27740
+27741
+27742
+27743
+27744
+27745
+27746
+27747
+27748
+27749
+27750
+27751
+27752
+27753
+27754
+27755
+27756
+27757
+27758
+27759
+27760
+27761
+27762
+27763
+27764
+27765
+27766
+27767
+27768
+27769
+27770
+27771
+27772
+27773
+27774
+27775
+27776
+27777
+27778
+27779
+27780
+27781
+27782
+27783
+27784
+27785
+27786
+27787
+27788
+27789
+27790
+27791
+27792
+27793
+27794
+27795
+27796
+27797
+27798
+27799
+27800
+27801
+27802
+27803
+27804
+27805
+27806
+27807
+27808
+27809
+27810
+27811
+27812
+27813
+27814
+27815
+27816
+27817
+27818
+27819
+27820
+27821
+27822
+27823
+27824
+27825
+27826
+27827
+27828
+27829
+27830
+27831
+27832
+27833
+27834
+27835
+27836
+27837
+27838
+27839
+27840
+27841
+27842
+27843
+27844
+27845
+27846
+27847
+27848
+27849
+27850
+27851
+27852
+27853
+27854
+27855
+27856
+27857
+27858
+27859
+27860
+27861
+27862
+27863
+27864
+27865
+27866
+27867
+27868
+27869
+27870
+27871
+27872
+27873
+27874
+27875
+27876
+27877
+27878
+27879
+27880
+27881
+27882
+27883
+27884
+27885
+27886
+27887
+27888
+27889
+27890
+27891
+27892
+27893
+27894
+27895
+27896
+27897
+27898
+27899
+27900
+27901
+27902
+27903
+27904
+27905
+27906
+27907
+27908
+27909
+27910
+27911
+27912
+27913
+27914
+27915
+27916
+27917
+27918
+27919
+27920
+27921
+27922
+27923
+27924
+27925
+27926
+27927
+27928
+27929
+27930
+27931
+27932
+27933
+27934
+27935
+27936
+27937
+27938
+27939
+27940
+27941
+27942
+27943
+27944
+27945
+27946
+27947
+27948
+27949
+27950
+27951
+27952
+27953
+27954
+27955
+27956
+27957
+27958
+27959
+27960
+27961
+27962
+27963
+27964
+27965
+27966
+27967
+27968
+27969
+27970
+27971
+27972
+27973
+27974
+27975
+27976
+27977
+27978
+27979
+27980
+27981
+27982
+27983
+27984
+27985
+27986
+27987
+27988
+27989
+27990
+27991
+27992
+27993
+27994
+27995
+27996
+27997
+27998
+27999
+28000
+28001
+28002
+28003
+28004
+28005
+28006
+28007
+28008
+28009
+28010
+28011
+28012
+28013
+28014
+28015
+28016
+28017
+28018
+28019
+28020
+28021
+28022
+28023
+28024
+28025
+28026
+28027
+28028
+28029
+28030
+28031
+28032
+28033
+28034
+28035
+28036
+28037
+28038
+28039
+28040
+28041
+28042
+28043
+28044
+28045
+28046
+28047
+28048
+28049
+28050
+28051
+28052
+28053
+28054
+28055
+28056
+28057
+28058
+28059
+28060
+28061
+28062
+28063
+28064
+28065
+28066
+28067
+28068
+28069
+28070
+28071
+28072
+28073
+28074
+28075
+28076
+28077
+28078
+28079
+28080
+28081
+28082
+28083
+28084
+28085
+28086
+28087
+28088
+28089
+28090
+28091
+28092
+28093
+28094
+28095
+28096
+28097
+28098
+28099
+28100
+28101
+28102
+28103
+28104
+28105
+28106
+28107
+28108
+28109
+28110
+28111
+28112
+28113
+28114
+28115
+28116
+28117
+28118
+28119
+28120
+28121
+28122
+28123
+28124
+28125
+28126
+28127
+28128
+28129
+28130
+28131
+28132
+28133
+28134
+28135
+28136
+28137
+28138
+28139
+28140
+28141
+28142
+28143
+28144
+28145
+28146
+28147
+28148
+28149
+28150
+28151
+28152
+28153
+28154
+28155
+28156
+28157
+28158
+28159
+28160
+28161
+28162
+28163
+28164
+28165
+28166
+28167
+28168
+28169
+28170
+28171
+28172
+28173
+28174
+28175
+28176
+28177
+28178
+28179
+28180
+28181
+28182
+28183
+28184
+28185
+28186
+28187
+28188
+28189
+28190
+28191
+28192
+28193
+28194
+28195
+28196
+28197
+28198
+28199
+28200
+28201
+28202
+28203
+28204
+28205
+28206
+28207
+28208
+28209
+28210
+28211
+28212
+28213
+28214
+28215
+28216
+28217
+28218
+28219
+28220
+28221
+28222
+28223
+28224
+28225
+28226
+28227
+28228
+28229
+28230
+28231
+28232
+28233
+28234
+28235
+28236
+28237
+28238
+28239
+28240
+28241
+28242
+28243
+28244
+28245
+28246
+28247
+28248
+28249
+28250
+28251
+28252
+28253
+28254
+28255
+28256
+28257
+28258
+28259
+28260
+28261
+28262
+28263
+28264
+28265
+28266
+28267
+28268
+28269
+28270
+28271
+28272
+28273
+28274
+28275
+28276
+28277
+28278
+28279
+28280
+28281
+28282
+28283
+28284
+28285
+28286
+28287
+28288
+28289
+28290
+28291
+28292
+28293
+28294
+28295
+28296
+28297
+28298
+28299
+28300
+28301
+28302
+28303
+28304
+28305
+28306
+28307
+28308
+28309
+28310
+28311
+28312
+28313
+28314
+28315
+28316
+28317
+28318
+28319
+28320
+28321
+28322
+28323
+28324
+28325
+28326
+28327
+28328
+28329
+28330
+28331
+28332
+28333
+28334
+28335
+28336
+28337
+28338
+28339
+28340
+28341
+28342
+28343
+28344
+28345
+28346
+28347
+28348
+28349
+28350
+28351
+28352
+28353
+28354
+28355
+28356
+28357
+28358
+28359
+28360
+28361
+28362
+28363
+28364
+28365
+28366
+28367
+28368
+28369
+28370
+28371
+28372
+28373
+28374
+28375
+28376
+28377
+28378
+28379
+28380
+28381
+28382
+28383
+28384
+28385
+28386
+28387
+28388
+28389
+28390
+28391
+28392
+28393
+28394
+28395
+28396
+28397
+28398
+28399
+28400
+28401
+28402
+28403
+28404
+28405
+28406
+28407
+28408
+28409
+28410
+28411
+28412
+28413
+28414
+28415
+28416
+28417
+28418
+28419
+28420
+28421
+28422
+28423
+28424
+28425
+28426
+28427
+28428
+28429
+28430
+28431
+28432
+28433
+28434
+28435
+28436
+28437
+28438
+28439
+28440
+28441
+28442
+28443
+28444
+28445
+28446
+28447
+28448
+28449
+28450
+28451
+28452
+28453
+28454
+28455
+28456
+28457
+28458
+28459
+28460
+28461
+28462
+28463
+28464
+28465
+28466
+28467
+28468
+28469
+28470
+28471
+28472
+28473
+28474
+28475
+28476
+28477
+28478
+28479
+28480
+28481
+28482
+28483
+28484
+28485
+28486
+28487
+28488
+28489
+28490
+28491
+28492
+28493
+28494
+28495
+28496
+28497
+28498
+28499
+28500
+28501
+28502
+28503
+28504
+28505
+28506
+28507
+28508
+28509
+28510
+28511
+28512
+28513
+28514
+28515
+28516
+28517
+28518
+28519
+28520
+28521
+28522
+28523
+28524
+28525
+28526
+28527
+28528
+28529
+28530
+28531
+28532
+28533
+28534
+28535
+28536
+28537
+28538
+28539
+28540
+28541
+28542
+28543
+28544
+28545
+28546
+28547
+28548
+28549
+28550
+28551
+28552
+28553
+28554
+28555
+28556
+28557
+28558
+28559
+28560
+28561
+28562
+28563
+28564
+28565
+28566
+28567
+28568
+28569
+28570
+28571
+28572
+28573
+28574
+28575
+28576
+28577
+28578
+28579
+28580
+28581
+28582
+28583
+28584
+28585
+28586
+28587
+28588
+28589
+28590
+28591
+28592
+28593
+28594
+28595
+28596
+28597
+28598
+28599
+28600
+28601
+28602
+28603
+28604
+28605
+28606
+28607
+28608
+28609
+28610
+28611
+28612
+28613
+28614
+28615
+28616
+28617
+28618
+28619
+28620
+28621
+28622
+28623
+28624
+28625
+28626
+28627
+28628
+28629
+28630
+28631
+28632
+28633
+28634
+28635
+28636
+28637
+28638
+28639
+28640
+28641
+28642
+28643
+28644
+28645
+28646
+28647
+28648
+28649
+28650
+28651
+28652
+28653
+28654
+28655
+28656
+28657
+28658
+28659
+28660
+28661
+28662
+28663
+28664
+28665
+28666
+28667
+28668
+28669
+28670
+28671
+28672
+28673
+28674
+28675
+28676
+28677
+28678
+28679
+28680
+28681
+28682
+28683
+28684
+28685
+28686
+28687
+28688
+28689
+28690
+28691
+28692
+28693
+28694
+28695
+28696
+28697
+28698
+28699
+28700
+28701
+28702
+28703
+28704
+28705
+28706
+28707
+28708
+28709
+28710
+28711
+28712
+28713
+28714
+28715
+28716
+28717
+28718
+28719
+28720
+28721
+28722
+28723
+28724
+28725
+28726
+28727
+28728
+28729
+28730
+28731
+28732
+28733
+28734
+28735
+28736
+28737
+28738
+28739
+28740
+28741
+28742
+28743
+28744
+28745
+28746
+28747
+28748
+28749
+28750
+28751
+28752
+28753
+28754
+28755
+28756
+28757
+28758
+28759
+28760
+28761
+28762
+28763
+28764
+28765
+28766
+28767
+28768
+28769
+28770
+28771
+28772
+28773
+28774
+28775
+28776
+28777
+28778
+28779
+28780
+28781
+28782
+28783
+28784
+28785
+28786
+28787
+28788
+28789
+28790
+28791
+28792
+28793
+28794
+28795
+28796
+28797
+28798
+28799
+28800
+28801
+28802
+28803
+28804
+28805
+28806
+28807
+28808
+28809
+28810
+28811
+28812
+28813
+28814
+28815
+28816
+28817
+28818
+28819
+28820
+28821
+28822
+28823
+28824
+28825
+28826
+28827
+28828
+28829
+28830
+28831
+28832
+28833
+28834
+28835
+28836
+28837
+28838
+28839
+28840
+28841
+28842
+28843
+28844
+28845
+28846
+28847
+28848
+28849
+28850
+28851
+28852
+28853
+28854
+28855
+28856
+28857
+28858
+28859
+28860
+28861
+28862
+28863
+28864
+28865
+28866
+28867
+28868
+28869
+28870
+28871
+28872
+28873
+28874
+28875
+28876
+28877
+28878
+28879
+28880
+28881
+28882
+28883
+28884
+28885
+28886
+28887
+28888
+28889
+28890
+28891
+28892
+28893
+28894
+28895
+28896
+28897
+28898
+28899
+28900
+28901
+28902
+28903
+28904
+28905
+28906
+28907
+28908
+28909
+28910
+28911
+28912
+28913
+28914
+28915
+28916
+28917
+28918
+28919
+28920
+28921
+28922
+28923
+28924
+28925
+28926
+28927
+28928
+28929
+28930
+28931
+28932
+28933
+28934
+28935
+28936
+28937
+28938
+28939
+28940
+28941
+28942
+28943
+28944
+28945
+28946
+28947
+28948
+28949
+28950
+28951
+28952
+28953
+28954
+28955
+28956
+28957
+28958
+28959
+28960
+28961
+28962
+28963
+28964
+28965
+28966
+28967
+28968
+28969
+28970
+28971
+28972
+28973
+28974
+28975
+28976
+28977
+28978
+28979
+28980
+28981
+28982
+28983
+28984
+28985
+28986
+28987
+28988
+28989
+28990
+28991
+28992
+28993
+28994
+28995
+28996
+28997
+28998
+28999
+29000
+29001
+29002
+29003
+29004
+29005
+29006
+29007
+29008
+29009
+29010
+29011
+29012
+29013
+29014
+29015
+29016
+29017
+29018
+29019
+29020
+29021
+29022
+29023
+29024
+29025
+29026
+29027
+29028
+29029
+29030
+29031
+29032
+29033
+29034
+29035
+29036
+29037
+29038
+29039
+29040
+29041
+29042
+29043
+29044
+29045
+29046
+29047
+29048
+29049
+29050
+29051
+29052
+29053
+29054
+29055
+29056
+29057
+29058
+29059
+29060
+29061
+29062
+29063
+29064
+29065
+29066
+29067
+29068
+29069
+29070
+29071
+29072
+29073
+29074
+29075
+29076
+29077
+29078
+29079
+29080
+29081
+29082
+29083
+29084
+29085
+29086
+29087
+29088
+29089
+29090
+29091
+29092
+29093
+29094
+29095
+29096
+29097
+29098
+29099
+29100
+29101
+29102
+29103
+29104
+29105
+29106
+29107
+29108
+29109
+29110
+29111
+29112
+29113
+29114
+29115
+29116
+29117
+29118
+29119
+29120
+29121
+29122
+29123
+29124
+29125
+29126
+29127
+29128
+29129
+29130
+29131
+29132
+29133
+29134
+29135
+29136
+29137
+29138
+29139
+29140
+29141
+29142
+29143
+29144
+29145
+29146
+29147
+29148
+29149
+29150
+29151
+29152
+29153
+29154
+29155
+29156
+29157
+29158
+29159
+29160
+29161
+29162
+29163
+29164
+29165
+29166
+29167
+29168
+29169
+29170
+29171
+29172
+29173
+29174
+29175
+29176
+29177
+29178
+29179
+29180
+29181
+29182
+29183
+29184
+29185
+29186
+29187
+29188
+29189
+29190
+29191
+29192
+29193
+29194
+29195
+29196
+29197
+29198
+29199
+29200
+29201
+29202
+29203
+29204
+29205
+29206
+29207
+29208
+29209
+29210
+29211
+29212
+29213
+29214
+29215
+29216
+29217
+29218
+29219
+29220
+29221
+29222
+29223
+29224
+29225
+29226
+29227
+29228
+29229
+29230
+29231
+29232
+29233
+29234
+29235
+29236
+29237
+29238
+29239
+29240
+29241
+29242
+29243
+29244
+29245
+29246
+29247
+29248
+29249
+29250
+29251
+29252
+29253
+29254
+29255
+29256
+29257
+29258
+29259
+29260
+29261
+29262
+29263
+29264
+29265
+29266
+29267
+29268
+29269
+29270
+29271
+29272
+29273
+29274
+29275
+29276
+29277
+29278
+29279
+29280
+29281
+29282
+29283
+29284
+29285
+29286
+29287
+29288
+29289
+29290
+29291
+29292
+29293
+29294
+29295
+29296
+29297
+29298
+29299
+29300
+29301
+29302
+29303
+29304
+29305
+29306
+29307
+29308
+29309
+29310
+29311
+29312
+29313
+29314
+29315
+29316
+29317
+29318
+29319
+29320
+29321
+29322
+29323
+29324
+29325
+29326
+29327
+29328
+29329
+29330
+29331
+29332
+29333
+29334
+29335
+29336
+29337
+29338
+29339
+29340
+29341
+29342
+29343
+29344
+29345
+29346
+29347
+29348
+29349
+29350
+29351
+29352
+29353
+29354
+29355
+29356
+29357
+29358
+29359
+29360
+29361
+29362
+29363
+29364
+29365
+29366
+29367
+29368
+29369
+29370
+29371
+29372
+29373
+29374
+29375
+29376
+29377
+29378
+29379
+29380
+29381
+29382
+29383
+29384
+29385
+29386
+29387
+29388
+29389
+29390
+29391
+29392
+29393
+29394
+29395
+29396
+29397
+29398
+29399
+29400
+29401
+29402
+29403
+29404
+29405
+29406
+29407
+29408
+29409
+29410
+29411
+29412
+29413
+29414
+29415
+29416
+29417
+29418
+29419
+29420
+29421
+29422
+29423
+29424
+29425
+29426
+29427
+29428
+29429
+29430
+29431
+29432
+29433
+29434
+29435
+29436
+29437
+29438
+29439
+29440
+29441
+29442
+29443
+29444
+29445
+29446
+29447
+29448
+29449
+29450
+29451
+29452
+29453
+29454
+29455
+29456
+29457
+29458
+29459
+29460
+29461
+29462
+29463
+29464
+29465
+29466
+29467
+29468
+29469
+29470
+29471
+29472
+29473
+29474
+29475
+29476
+29477
+29478
+29479
+29480
+29481
+29482
+29483
+29484
+29485
+29486
+29487
+29488
+29489
+29490
+29491
+29492
+29493
+29494
+29495
+29496
+29497
+29498
+29499
+29500
+29501
+29502
+29503
+29504
+29505
+29506
+29507
+29508
+29509
+29510
+29511
+29512
+29513
+29514
+29515
+29516
+29517
+29518
+29519
+29520
+29521
+29522
+29523
+29524
+29525
+29526
+29527
+29528
+29529
+29530
+29531
+29532
+29533
+29534
+29535
+29536
+29537
+29538
+29539
+29540
+29541
+29542
+29543
+29544
+29545
+29546
+29547
+29548
+29549
+29550
+29551
+29552
+29553
+29554
+29555
+29556
+29557
+29558
+29559
+29560
+29561
+29562
+29563
+29564
+29565
+29566
+29567
+29568
+29569
+29570
+29571
+29572
+29573
+29574
+29575
+29576
+29577
+29578
+29579
+29580
+29581
+29582
+29583
+29584
+29585
+29586
+29587
+29588
+29589
+29590
+29591
+29592
+29593
+29594
+29595
+29596
+29597
+29598
+29599
+29600
+29601
+29602
+29603
+29604
+29605
+29606
+29607
+29608
+29609
+29610
+29611
+29612
+29613
+29614
+29615
+29616
+29617
+29618
+29619
+29620
+29621
+29622
+29623
+29624
+29625
+29626
+29627
+29628
+29629
+29630
+29631
+29632
+29633
+29634
+29635
+29636
+29637
+29638
+29639
+29640
+29641
+29642
+29643
+29644
+29645
+29646
+29647
+29648
+29649
+29650
+29651
+29652
+29653
+29654
+29655
+29656
+29657
+29658
+29659
+29660
+29661
+29662
+29663
+29664
+29665
+29666
+29667
+29668
+29669
+29670
+29671
+29672
+29673
+29674
+29675
+29676
+29677
+29678
+29679
+29680
+29681
+29682
+29683
+29684
+29685
+29686
+29687
+29688
+29689
+29690
+29691
+29692
+29693
+29694
+29695
+29696
+29697
+29698
+29699
+29700
+29701
+29702
+29703
+29704
+29705
+29706
+29707
+29708
+29709
+29710
+29711
+29712
+29713
+29714
+29715
+29716
+29717
+29718
+29719
+29720
+29721
+29722
+29723
+29724
+29725
+29726
+29727
+29728
+29729
+29730
+29731
+29732
+29733
+29734
+29735
+29736
+29737
+29738
+29739
+29740
+29741
+29742
+29743
+29744
+29745
+29746
+29747
+29748
+29749
+29750
+29751
+29752
+29753
+29754
+29755
+29756
+29757
+29758
+29759
+29760
+29761
+29762
+29763
+29764
+29765
+29766
+29767
+29768
+29769
+29770
+29771
+29772
+29773
+29774
+29775
+29776
+29777
+29778
+29779
+29780
+29781
+29782
+29783
+29784
+29785
+29786
+29787
+29788
+29789
+29790
+29791
+29792
+29793
+29794
+29795
+29796
+29797
+29798
+29799
+29800
+29801
+29802
+29803
+29804
+29805
+29806
+29807
+29808
+29809
+29810
+29811
+29812
+29813
+29814
+29815
+29816
+29817
+29818
+29819
+29820
+29821
+29822
+29823
+29824
+29825
+29826
+29827
+29828
+29829
+29830
+29831
+29832
+29833
+29834
+29835
+29836
+29837
+29838
+29839
+29840
+29841
+29842
+29843
+29844
+29845
+29846
+29847
+29848
+29849
+29850
+29851
+29852
+29853
+29854
+29855
+29856
+29857
+29858
+29859
+29860
+29861
+29862
+29863
+29864
+29865
+29866
+29867
+29868
+29869
+29870
+29871
+29872
+29873
+29874
+29875
+29876
+29877
+29878
+29879
+29880
+29881
+29882
+29883
+29884
+29885
+29886
+29887
+29888
+29889
+29890
+29891
+29892
+29893
+29894
+29895
+29896
+29897
+29898
+29899
+29900
+29901
+29902
+29903
+29904
+29905
+29906
+29907
+29908
+29909
+29910
+29911
+29912
+29913
+29914
+29915
+29916
+29917
+29918
+29919
+29920
+29921
+29922
+29923
+29924
+29925
+29926
+29927
+29928
+29929
+29930
+29931
+29932
+29933
+29934
+29935
+29936
+29937
+29938
+29939
+29940
+29941
+29942
+29943
+29944
+29945
+29946
+29947
+29948
+29949
+29950
+29951
+29952
+29953
+29954
+29955
+29956
+29957
+29958
+29959
+29960
+29961
+29962
+29963
+29964
+29965
+29966
+29967
+29968
+29969
+29970
+29971
+29972
+29973
+29974
+29975
+29976
+29977
+29978
+29979
+29980
+29981
+29982
+29983
+29984
+29985
+29986
+29987
+29988
+29989
+29990
+29991
+29992
+29993
+29994
+29995
+29996
+29997
+29998
+29999
+30000
+30001
+30002
+30003
+30004
+30005
+30006
+30007
+30008
+30009
+30010
+30011
+30012
+30013
+30014
+30015
+30016
+30017
+30018
+30019
+30020
+30021
+30022
+30023
+30024
+30025
+30026
+30027
+30028
+30029
+30030
+30031
+30032
+30033
+30034
+30035
+30036
+30037
+30038
+30039
+30040
+30041
+30042
+30043
+30044
+30045
+30046
+30047
+30048
+30049
+30050
+30051
+30052
+30053
+30054
+30055
+30056
+30057
+30058
+30059
+30060
+30061
+30062
+30063
+30064
+30065
+30066
+30067
+30068
+30069
+30070
+30071
+30072
+30073
+30074
+30075
+30076
+30077
+30078
+30079
+30080
+30081
+30082
+30083
+30084
+30085
+30086
+30087
+30088
+30089
+30090
+30091
+30092
+30093
+30094
+30095
+30096
+30097
+30098
+30099
+30100
+30101
+30102
+30103
+30104
+30105
+30106
+30107
+30108
+30109
+30110
+30111
+30112
+30113
+30114
+30115
+30116
+30117
+30118
+30119
+30120
+30121
+30122
+30123
+30124
+30125
+30126
+30127
+30128
+30129
+30130
+30131
+30132
+30133
+30134
+30135
+30136
+30137
+30138
+30139
+30140
+30141
+30142
+30143
+30144
+30145
+30146
+30147
+30148
+30149
+30150
+30151
+30152
+30153
+30154
+30155
+30156
+30157
+30158
+30159
+30160
+30161
+30162
+30163
+30164
+30165
+30166
+30167
+30168
+30169
+30170
+30171
+30172
+30173
+30174
+30175
+30176
+30177
+30178
+30179
+30180
+30181
+30182
+30183
+30184
+30185
+30186
+30187
+30188
+30189
+30190
+30191
+30192
+30193
+30194
+30195
+30196
+30197
+30198
+30199
+30200
+30201
+30202
+30203
+30204
+30205
+30206
+30207
+30208
+30209
+30210
+30211
+30212
+30213
+30214
+30215
+30216
+30217
+30218
+30219
+30220
+30221
+30222
+30223
+30224
+30225
+30226
+30227
+30228
+30229
+30230
+30231
+30232
+30233
+30234
+30235
+30236
+30237
+30238
+30239
+30240
+30241
+30242
+30243
+30244
+30245
+30246
+30247
+30248
+30249
+30250
+30251
+30252
+30253
+30254
+30255
+30256
+30257
+30258
+30259
+30260
+30261
+30262
+30263
+30264
+30265
+30266
+30267
+30268
+30269
+30270
+30271
+30272
+30273
+30274
+30275
+30276
+30277
+30278
+30279
+30280
+30281
+30282
+30283
+30284
+30285
+30286
+30287
+30288
+30289
+30290
+30291
+30292
+30293
+30294
+30295
+30296
+30297
+30298
+30299
+30300
+30301
+30302
+30303
+30304
+30305
+30306
+30307
+30308
+30309
+30310
+30311
+30312
+30313
+30314
+30315
+30316
+30317
+30318
+30319
+30320
+30321
+30322
+30323
+30324
+30325
+30326
+30327
+30328
+30329
+30330
+30331
+30332
+30333
+30334
+30335
+30336
+30337
+30338
+30339
+30340
+30341
+30342
+30343
+30344
+30345
+30346
+30347
+30348
+30349
+30350
+30351
+30352
+30353
+30354
+30355
+30356
+30357
+30358
+30359
+30360
+30361
+30362
+30363
+30364
+30365
+30366
+30367
+30368
+30369
+30370
+30371
+30372
+30373
+30374
+30375
+30376
+30377
+30378
+30379
+30380
+30381
+30382
+30383
+30384
+30385
+30386
+30387
+30388
+30389
+30390
+30391
+30392
+30393
+30394
+30395
+30396
+30397
+30398
+30399
+30400
+30401
+30402
+30403
+30404
+30405
+30406
+30407
+30408
+30409
+30410
+30411
+30412
+30413
+30414
+30415
+30416
+30417
+30418
+30419
+30420
+30421
+30422
+30423
+30424
+30425
+30426
+30427
+30428
+30429
+30430
+30431
+30432
+30433
+30434
+30435
+30436
+30437
+30438
+30439
+30440
+30441
+30442
+30443
+30444
+30445
+30446
+30447
+30448
+30449
+30450
+30451
+30452
+30453
+30454
+30455
+30456
+30457
+30458
+30459
+30460
+30461
+30462
+30463
+30464
+30465
+30466
+30467
+30468
+30469
+30470
+30471
+30472
+30473
+30474
+30475
+30476
+30477
+30478
+30479
+30480
+30481
+30482
+30483
+30484
+30485
+30486
+30487
+30488
+30489
+30490
+30491
+30492
+30493
+30494
+30495
+30496
+30497
+30498
+30499
+30500
+30501
+30502
+30503
+30504
+30505
+30506
+30507
+30508
+30509
+30510
+30511
+30512
+30513
+30514
+30515
+30516
+30517
+30518
+30519
+30520
+30521
+30522
+30523
+30524
+30525
+30526
+30527
+30528
+30529
+30530
+30531
+30532
+30533
+30534
+30535
+30536
+30537
+30538
+30539
+30540
+30541
+30542
+30543
+30544
+30545
+30546
+30547
+30548
+30549
+30550
+30551
+30552
+30553
+30554
+30555
+30556
+30557
+30558
+30559
+30560
+30561
+30562
+30563
+30564
+30565
+30566
+30567
+30568
+30569
+30570
+30571
+30572
+30573
+30574
+30575
+30576
+30577
+30578
+30579
+30580
+30581
+30582
+30583
+30584
+30585
+30586
+30587
+30588
+30589
+30590
+30591
+30592
+30593
+30594
+30595
+30596
+30597
+30598
+30599
+30600
+30601
+30602
+30603
+30604
+30605
+30606
+30607
+30608
+30609
+30610
+30611
+30612
+30613
+30614
+30615
+30616
+30617
+30618
+30619
+30620
+30621
+30622
+30623
+30624
+30625
+30626
+30627
+30628
+30629
+30630
+30631
+30632
+30633
+30634
+30635
+30636
+30637
+30638
+30639
+30640
+30641
+30642
+30643
+30644
+30645
+30646
+30647
+30648
+30649
+30650
+30651
+30652
+30653
+30654
+30655
+30656
+30657
+30658
+30659
+30660
+30661
+30662
+30663
+30664
+30665
+30666
+30667
+30668
+30669
+30670
+30671
+30672
+30673
+30674
+30675
+30676
+30677
+30678
+30679
+30680
+30681
+30682
+30683
+30684
+30685
+30686
+30687
+30688
+30689
+30690
+30691
+30692
+30693
+30694
+30695
+30696
+30697
+30698
+30699
+30700
+30701
+30702
+30703
+30704
+30705
+30706
+30707
+30708
+30709
+30710
+30711
+30712
+30713
+30714
+30715
+30716
+30717
+30718
+30719
+30720
+30721
+30722
+30723
+30724
+30725
+30726
+30727
+30728
+30729
+30730
+30731
+30732
+30733
+30734
+30735
+30736
+30737
+30738
+30739
+30740
+30741
+30742
+30743
+30744
+30745
+30746
+30747
+30748
+30749
+30750
+30751
+30752
+30753
+30754
+30755
+30756
+30757
+30758
+30759
+30760
+30761
+30762
+30763
+30764
+30765
+30766
+30767
+30768
+30769
+30770
+30771
+30772
+30773
+30774
+30775
+30776
+30777
+30778
+30779
+30780
+30781
+30782
+30783
+30784
+30785
+30786
+30787
+30788
+30789
+30790
+30791
+30792
+30793
+30794
+30795
+30796
+30797
+30798
+30799
+30800
+30801
+30802
+30803
+30804
+30805
+30806
+30807
+30808
+30809
+30810
+30811
+30812
+30813
+30814
+30815
+30816
+30817
+30818
+30819
+30820
+30821
+30822
+30823
+30824
+30825
+30826
+30827
+30828
+30829
+30830
+30831
+30832
+30833
+30834
+30835
+30836
+30837
+30838
+30839
+30840
+30841
+30842
+30843
+30844
+30845
+30846
+30847
+30848
+30849
+30850
+30851
+30852
+30853
+30854
+30855
+30856
+30857
+30858
+30859
+30860
+30861
+30862
+30863
+30864
+30865
+30866
+30867
+30868
+30869
+30870
+30871
+30872
+30873
+30874
+30875
+30876
+30877
+30878
+30879
+30880
+30881
+30882
+30883
+30884
+30885
+30886
+30887
+30888
+30889
+30890
+30891
+30892
+30893
+30894
+30895
+30896
+30897
+30898
+30899
+30900
+30901
+30902
+30903
+30904
+30905
+30906
+30907
+30908
+30909
+30910
+30911
+30912
+30913
+30914
+30915
+30916
+30917
+30918
+30919
+30920
+30921
+30922
+30923
+30924
+30925
+30926
+30927
+30928
+30929
+30930
+30931
+30932
+30933
+30934
+30935
+30936
+30937
+30938
+30939
+30940
+30941
+30942
+30943
+30944
+30945
+30946
+30947
+30948
+30949
+30950
+30951
+30952
+30953
+30954
+30955
+30956
+30957
+30958
+30959
+30960
+30961
+30962
+30963
+30964
+30965
+30966
+30967
+30968
+30969
+30970
+30971
+30972
+30973
+30974
+30975
+30976
+30977
+30978
+30979
+30980
+30981
+30982
+30983
+30984
+30985
+30986
+30987
+30988
+30989
+30990
+30991
+30992
+30993
+30994
+30995
+30996
+30997
+30998
+30999
+31000
+31001
+31002
+31003
+31004
+31005
+31006
+31007
+31008
+31009
+31010
+31011
+31012
+31013
+31014
+31015
+31016
+31017
+31018
+31019
+31020
+31021
+31022
+31023
+31024
+31025
+31026
+31027
+31028
+31029
+31030
+31031
+31032
+31033
+31034
+31035
+31036
+31037
+31038
+31039
+31040
+31041
+31042
+31043
+31044
+31045
+31046
+31047
+31048
+31049
+31050
+31051
+31052
+31053
+31054
+31055
+31056
+31057
+31058
+31059
+31060
+31061
+31062
+31063
+31064
+31065
+31066
+31067
+31068
+31069
+31070
+31071
+31072
+31073
+31074
+31075
+31076
+31077
+31078
+31079
+31080
+31081
+31082
+31083
+31084
+31085
+31086
+31087
+31088
+31089
+31090
+31091
+31092
+31093
+31094
+31095
+31096
+31097
+31098
+31099
+31100
+31101
+31102
+31103
+31104
+31105
+31106
+31107
+31108
+31109
+31110
+31111
+31112
+31113
+31114
+31115
+31116
+31117
+31118
+31119
+31120
+31121
+31122
+31123
+31124
+31125
+31126
+31127
+31128
+31129
+31130
+31131
+31132
+31133
+31134
+31135
+31136
+31137
+31138
+31139
+31140
+31141
+31142
+31143
+31144
+31145
+31146
+31147
+31148
+31149
+31150
+31151
+31152
+31153
+31154
+31155
+31156
+31157
+31158
+31159
+31160
+31161
+31162
+31163
+31164
+31165
+31166
+31167
+31168
+31169
+31170
+31171
+31172
+31173
+31174
+31175
+31176
+31177
+31178
+31179
+31180
+31181
+31182
+31183
+31184
+31185
+31186
+31187
+31188
+31189
+31190
+31191
+31192
+31193
+31194
+31195
+31196
+31197
+31198
+31199
+31200
+31201
+31202
+31203
+31204
+31205
+31206
+31207
+31208
+31209
+31210
+31211
+31212
+31213
+31214
+31215
+31216
+31217
+31218
+31219
+31220
+31221
+31222
+31223
+31224
+31225
+31226
+31227
+31228
+31229
+31230
+31231
+31232
+31233
+31234
+31235
+31236
+31237
+31238
+31239
+31240
+31241
+31242
+31243
+31244
+31245
+31246
+31247
+31248
+31249
+31250
+31251
+31252
+31253
+31254
+31255
+31256
+31257
+31258
+31259
+31260
+31261
+31262
+31263
+31264
+31265
+31266
+31267
+31268
+31269
+31270
+31271
+31272
+31273
+31274
+31275
+31276
+31277
+31278
+31279
+31280
+31281
+31282
+31283
+31284
+31285
+31286
+31287
+31288
+31289
+31290
+31291
+31292
+31293
+31294
+31295
+31296
+31297
+31298
+31299
+31300
+31301
+31302
+31303
+31304
+31305
+31306
+31307
+31308
+31309
+31310
+31311
+31312
+31313
+31314
+31315
+31316
+31317
+31318
+31319
+31320
+31321
+31322
+31323
+31324
+31325
+31326
+31327
+31328
+31329
+31330
+31331
+31332
+31333
+31334
+31335
+31336
+31337
+31338
+31339
+31340
+31341
+31342
+31343
+31344
+31345
+31346
+31347
+31348
+31349
+31350
+31351
+31352
+31353
+31354
+31355
+31356
+31357
+31358
+31359
+31360
+31361
+31362
+31363
+31364
+31365
+31366
+31367
+31368
+31369
+31370
+31371
+31372
+31373
+31374
+31375
+31376
+31377
+31378
+31379
+31380
+31381
+31382
+31383
+31384
+31385
+31386
+31387
+31388
+31389
+31390
+31391
+31392
+31393
+31394
+31395
+31396
+31397
+31398
+31399
+31400
+31401
+31402
+31403
+31404
+31405
+31406
+31407
+31408
+31409
+31410
+31411
+31412
+31413
+31414
+31415
+31416
+31417
+31418
+31419
+31420
+31421
+31422
+31423
+31424
+31425
+31426
+31427
+31428
+31429
+31430
+31431
+31432
+31433
+31434
+31435
+31436
+31437
+31438
+31439
+31440
+31441
+31442
+31443
+31444
+31445
+31446
+31447
+31448
+31449
+31450
+31451
+31452
+31453
+31454
+31455
+31456
+31457
+31458
+31459
+31460
+31461
+31462
+31463
+31464
+31465
+31466
+31467
+31468
+31469
+31470
+31471
+31472
+31473
+31474
+31475
+31476
+31477
+31478
+31479
+31480
+31481
+31482
+31483
+31484
+31485
+31486
+31487
+31488
+31489
+31490
+31491
+31492
+31493
+31494
+31495
+31496
+31497
+31498
+31499
+31500
+31501
+31502
+31503
+31504
+31505
+31506
+31507
+31508
+31509
+31510
+31511
+31512
+31513
+31514
+31515
+31516
+31517
+31518
+31519
+31520
+31521
+31522
+31523
+31524
+31525
+31526
+31527
+31528
+31529
+31530
+31531
+31532
+31533
+31534
+31535
+31536
+31537
+31538
+31539
+31540
+31541
+31542
+31543
+31544
+31545
+31546
+31547
+31548
+31549
+31550
+31551
+31552
+31553
+31554
+31555
+31556
+31557
+31558
+31559
+31560
+31561
+31562
+31563
+31564
+31565
+31566
+31567
+31568
+31569
+31570
+31571
+31572
+31573
+31574
+31575
+31576
+31577
+31578
+31579
+31580
+31581
+31582
+31583
+31584
+31585
+31586
+31587
+31588
+31589
+31590
+31591
+31592
+31593
+31594
+31595
+31596
+31597
+31598
+31599
+31600
+31601
+31602
+31603
+31604
+31605
+31606
+31607
+31608
+31609
+31610
+31611
+31612
+31613
+31614
+31615
+31616
+31617
+31618
+31619
+31620
+31621
+31622
+31623
+31624
+31625
+31626
+31627
+31628
+31629
+31630
+31631
+31632
+31633
+31634
+31635
+31636
+31637
+31638
+31639
+31640
+31641
+31642
+31643
+31644
+31645
+31646
+31647
+31648
+31649
+31650
+31651
+31652
+31653
+31654
+31655
+31656
+31657
+31658
+31659
+31660
+31661
+31662
+31663
+31664
+31665
+31666
+31667
+31668
+31669
+31670
+31671
+31672
+31673
+31674
+31675
+31676
+31677
+31678
+31679
+31680
+31681
+31682
+31683
+31684
+31685
+31686
+31687
+31688
+31689
+31690
+31691
+31692
+31693
+31694
+31695
+31696
+31697
+31698
+31699
+31700
+31701
+31702
+31703
+31704
+31705
+31706
+31707
+31708
+31709
+31710
+31711
+31712
+31713
+31714
+31715
+31716
+31717
+31718
+31719
+31720
+31721
+31722
+31723
+31724
+31725
+31726
+31727
+31728
+31729
+31730
+31731
+31732
+31733
+31734
+31735
+31736
+31737
+31738
+31739
+31740
+31741
+31742
+31743
+31744
+31745
+31746
+31747
+31748
+31749
+31750
+31751
+31752
+31753
+31754
+31755
+31756
+31757
+31758
+31759
+31760
+31761
+31762
+31763
+31764
+31765
+31766
+31767
+31768
+31769
+31770
+31771
+31772
+31773
+31774
+31775
+31776
+31777
+31778
+31779
+31780
+31781
+31782
+31783
+31784
+31785
+31786
+31787
+31788
+31789
+31790
+31791
+31792
+31793
+31794
+31795
+31796
+31797
+31798
+31799
+31800
+31801
+31802
+31803
+31804
+31805
+31806
+31807
+31808
+31809
+31810
+31811
+31812
+31813
+31814
+31815
+31816
+31817
+31818
+31819
+31820
+31821
+31822
+31823
+31824
+31825
+31826
+31827
+31828
+31829
+31830
+31831
+31832
+31833
+31834
+31835
+31836
+31837
+31838
+31839
+31840
+31841
+31842
+31843
+31844
+31845
+31846
+31847
+31848
+31849
+31850
+31851
+31852
+31853
+31854
+31855
+31856
+31857
+31858
+31859
+31860
+31861
+31862
+31863
+31864
+31865
+31866
+31867
+31868
+31869
+31870
+31871
+31872
+31873
+31874
+31875
+31876
+31877
+31878
+31879
+31880
+31881
+31882
+31883
+31884
+31885
+31886
+31887
+31888
+31889
+31890
+31891
+31892
+31893
+31894
+31895
+31896
+31897
+31898
+31899
+31900
+31901
+31902
+31903
+31904
+31905
+31906
+31907
+31908
+31909
+31910
+31911
+31912
+31913
+31914
+31915
+31916
+31917
+31918
+31919
+31920
+31921
+31922
+31923
+31924
+31925
+31926
+31927
+31928
+31929
+31930
+31931
+31932
+31933
+31934
+31935
+31936
+31937
+31938
+31939
+31940
+31941
+31942
+31943
+31944
+31945
+31946
+31947
+31948
+31949
+31950
+31951
+31952
+31953
+31954
+31955
+31956
+31957
+31958
+31959
+31960
+31961
+31962
+31963
+31964
+31965
+31966
+31967
+31968
+31969
+31970
+31971
+31972
+31973
+31974
+31975
+31976
+31977
+31978
+31979
+31980
+31981
+31982
+31983
+31984
+31985
+31986
+31987
+31988
+31989
+31990
+31991
+31992
+31993
+31994
+31995
+31996
+31997
+31998
+31999
+32000
+32001
+32002
+32003
+32004
+32005
+32006
+32007
+32008
+32009
+32010
+32011
+32012
+32013
+32014
+32015
+32016
+32017
+32018
+32019
+32020
+32021
+32022
+32023
+32024
+32025
+32026
+32027
+32028
+32029
+32030
+32031
+32032
+32033
+32034
+32035
+32036
+32037
+32038
+32039
+32040
+32041
+32042
+32043
+32044
+32045
+32046
+32047
+32048
+32049
+32050
+32051
+32052
+32053
+32054
+32055
+32056
+32057
+32058
+32059
+32060
+32061
+32062
+32063
+32064
+32065
+32066
+32067
+32068
+32069
+32070
+32071
+32072
+32073
+32074
+32075
+32076
+32077
+32078
+32079
+32080
+32081
+32082
+32083
+32084
+32085
+32086
+32087
+32088
+32089
+32090
+32091
+32092
+32093
+32094
+32095
+32096
+32097
+32098
+32099
+32100
+32101
+32102
+32103
+32104
+32105
+32106
+32107
+32108
+32109
+32110
+32111
+32112
+32113
+32114
+32115
+32116
+32117
+32118
+32119
+32120
+32121
+32122
+32123
+32124
+32125
+32126
+32127
+32128
+32129
+32130
+32131
+32132
+32133
+32134
+32135
+32136
+32137
+32138
+32139
+32140
+32141
+32142
+32143
+32144
+32145
+32146
+32147
+32148
+32149
+32150
+32151
+32152
+32153
+32154
+32155
+32156
+32157
+32158
+32159
+32160
+32161
+32162
+32163
+32164
+32165
+32166
+32167
+32168
+32169
+32170
+32171
+32172
+32173
+32174
+32175
+32176
+32177
+32178
+32179
+32180
+32181
+32182
+32183
+32184
+32185
+32186
+32187
+32188
+32189
+32190
+32191
+32192
+32193
+32194
+32195
+32196
+32197
+32198
+32199
+32200
+32201
+32202
+32203
+32204
+32205
+32206
+32207
+32208
+32209
+32210
+32211
+32212
+32213
+32214
+32215
+32216
+32217
+32218
+32219
+32220
+32221
+32222
+32223
+32224
+32225
+32226
+32227
+32228
+32229
+32230
+32231
+32232
+32233
+32234
+32235
+32236
+32237
+32238
+32239
+32240
+32241
+32242
+32243
+32244
+32245
+32246
+32247
+32248
+32249
+32250
+32251
+32252
+32253
+32254
+32255
+32256
+32257
+32258
+32259
+32260
+32261
+32262
+32263
+32264
+32265
+32266
+32267
+32268
+32269
+32270
+32271
+32272
+32273
+32274
+32275
+32276
+32277
+32278
+32279
+32280
+32281
+32282
+32283
+32284
+32285
+32286
+32287
+32288
+32289
+32290
+32291
+32292
+32293
+32294
+32295
+32296
+32297
+32298
+32299
+32300
+32301
+32302
+32303
+32304
+32305
+32306
+32307
+32308
+32309
+32310
+32311
+32312
+32313
+32314
+32315
+32316
+32317
+32318
+32319
+32320
+32321
+32322
+32323
+32324
+32325
+32326
+32327
+32328
+32329
+32330
+32331
+32332
+32333
+32334
+32335
+32336
+32337
+32338
+32339
+32340
+32341
+32342
+32343
+32344
+32345
+32346
+32347
+32348
+32349
+32350
+32351
+32352
+32353
+32354
+32355
+32356
+32357
+32358
+32359
+32360
+32361
+32362
+32363
+32364
+32365
+32366
+32367
+32368
+32369
+32370
+32371
+32372
+32373
+32374
+32375
+32376
+32377
+32378
+32379
+32380
+32381
+32382
+32383
+32384
+32385
+32386
+32387
+32388
+32389
+32390
+32391
+32392
+32393
+32394
+32395
+32396
+32397
+32398
+32399
+32400
+32401
+32402
+32403
+32404
+32405
+32406
+32407
+32408
+32409
+32410
+32411
+32412
+32413
+32414
+32415
+32416
+32417
+32418
+32419
+32420
+32421
+32422
+32423
+32424
+32425
+32426
+32427
+32428
+32429
+32430
+32431
+32432
+32433
+32434
+32435
+32436
+32437
+32438
+32439
+32440
+32441
+32442
+32443
+32444
+32445
+32446
+32447
+32448
+32449
+32450
+32451
+32452
+32453
+32454
+32455
+32456
+32457
+32458
+32459
+32460
+32461
+32462
+32463
+32464
+32465
+32466
+32467
+32468
+32469
+32470
+32471
+32472
+32473
+32474
+32475
+32476
+32477
+32478
+32479
+32480
+32481
+32482
+32483
+32484
+32485
+32486
+32487
+32488
+32489
+32490
+32491
+32492
+32493
+32494
+32495
+32496
+32497
+32498
+32499
+32500
+32501
+32502
+32503
+32504
+32505
+32506
+32507
+32508
+32509
+32510
+32511
+32512
+32513
+32514
+32515
+32516
+32517
+32518
+32519
+32520
+32521
+32522
+32523
+32524
+32525
+32526
+32527
+32528
+32529
+32530
+32531
+32532
+32533
+32534
+32535
+32536
+32537
+32538
+32539
+32540
+32541
+32542
+32543
+32544
+32545
+32546
+32547
+32548
+32549
+32550
+32551
+32552
+32553
+32554
+32555
+32556
+32557
+32558
+32559
+32560
+32561
+32562
+32563
+32564
+32565
+32566
+32567
+32568
+32569
+32570
+32571
+32572
+32573
+32574
+32575
+32576
+32577
+32578
+32579
+32580
+32581
+32582
+32583
+32584
+32585
+32586
+32587
+32588
+32589
+32590
+32591
+32592
+32593
+32594
+32595
+32596
+32597
+32598
+32599
+32600
+32601
+32602
+32603
+32604
+32605
+32606
+32607
+32608
+32609
+32610
+32611
+32612
+32613
+32614
+32615
+32616
+32617
+32618
+32619
+32620
+32621
+32622
+32623
+32624
+32625
+32626
+32627
+32628
+32629
+32630
+32631
+32632
+32633
+32634
+32635
+32636
+32637
+32638
+32639
+32640
+32641
+32642
+32643
+32644
+32645
+32646
+32647
+32648
+32649
+32650
+32651
+32652
+32653
+32654
+32655
+32656
+32657
+32658
+32659
+32660
+32661
+32662
+32663
+32664
+32665
+32666
+32667
+32668
+32669
+32670
+32671
+32672
+32673
+32674
+32675
+32676
+32677
+32678
+32679
+32680
+32681
+32682
+32683
+32684
+32685
+32686
+32687
+32688
+32689
+32690
+32691
+32692
+32693
+32694
+32695
+32696
+32697
+32698
+32699
+32700
+32701
+32702
+32703
+32704
+32705
+32706
+32707
+32708
+32709
+32710
+32711
+32712
+32713
+32714
+32715
+32716
+32717
+32718
+32719
+32720
+32721
+32722
+32723
+32724
+32725
+32726
+32727
+32728
+32729
+32730
+32731
+32732
+32733
+32734
+32735
+32736
+32737
+32738
+32739
+32740
+32741
+32742
+32743
+32744
+32745
+32746
+32747
+32748
+32749
+32750
+32751
+32752
+32753
+32754
+32755
+32756
+32757
+32758
+32759
+32760
+32761
+32762
+32763
+32764
+32765
+32766
+32767
+32768
+32769
+32770
+32771
+32772
+32773
+32774
+32775
+32776
+32777
+32778
+32779
+32780
+32781
+32782
+32783
+32784
+32785
+32786
+32787
+32788
+32789
+32790
+32791
+32792
+32793
+32794
+32795
+32796
+32797
+32798
+32799
+32800
+32801
+32802
+32803
+32804
+32805
+32806
+32807
+32808
+32809
+32810
+32811
+32812
+32813
+32814
+32815
+32816
+32817
+32818
+32819
+32820
+32821
+32822
+32823
+32824
+32825
+32826
+32827
+32828
+32829
+32830
+32831
+32832
+32833
+32834
+32835
+32836
+32837
+32838
+32839
+32840
+32841
+32842
+32843
+32844
+32845
+32846
+32847
+32848
+32849
+32850
+32851
+32852
+32853
+32854
+32855
+32856
+32857
+32858
+32859
+32860
+32861
+32862
+32863
+32864
+32865
+32866
+32867
+32868
+32869
+32870
+32871
+32872
+32873
+32874
+32875
+32876
+32877
+32878
+32879
+32880
+32881
+32882
+32883
+32884
+32885
+32886
+32887
+32888
+32889
+32890
+32891
+32892
+32893
+32894
+32895
+32896
+32897
+32898
+32899
+32900
+32901
+32902
+32903
+32904
+32905
+32906
+32907
+32908
+32909
+32910
+32911
+32912
+32913
+32914
+32915
+32916
+32917
+32918
+32919
+32920
+32921
+32922
+32923
+32924
+32925
+32926
+32927
+32928
+32929
+32930
+32931
+32932
+32933
+32934
+32935
+32936
+32937
+32938
+32939
+32940
+32941
+32942
+32943
+32944
+32945
+32946
+32947
+32948
+32949
+32950
+32951
+32952
+32953
+32954
+32955
+32956
+32957
+32958
+32959
+32960
+32961
+32962
+32963
+32964
+32965
+32966
+32967
+32968
+32969
+32970
+32971
+32972
+32973
+32974
+32975
+32976
+32977
+32978
+32979
+32980
+32981
+32982
+32983
+32984
+32985
+32986
+32987
+32988
+32989
+32990
+32991
+32992
+32993
+32994
+32995
+32996
+32997
+32998
+32999
+33000
+33001
+33002
+33003
+33004
+33005
+33006
+33007
+33008
+33009
+33010
+33011
+33012
+33013
+33014
+33015
+33016
+33017
+33018
+33019
+33020
+33021
+33022
+33023
+33024
+33025
+33026
+33027
+33028
+33029
+33030
+33031
+33032
+33033
+33034
+33035
+33036
+33037
+33038
+33039
+33040
+33041
+33042
+33043
+33044
+33045
+33046
+33047
+33048
+33049
+33050
+33051
+33052
+33053
+33054
+33055
+33056
+33057
+33058
+33059
+33060
+33061
+33062
+33063
+33064
+33065
+33066
+33067
+33068
+33069
+33070
+33071
+33072
+33073
+33074
+33075
+33076
+33077
+33078
+33079
+33080
+33081
+33082
+33083
+33084
+33085
+33086
+33087
+33088
+33089
+33090
+33091
+33092
+33093
+33094
+33095
+33096
+33097
+33098
+33099
+33100
+33101
+33102
+33103
+33104
+33105
+33106
+33107
+33108
+33109
+33110
+33111
+33112
+33113
+33114
+33115
+33116
+33117
+33118
+33119
+33120
+33121
+33122
+33123
+33124
+33125
+33126
+33127
+33128
+33129
+33130
+33131
+33132
+33133
+33134
+33135
+33136
+33137
+33138
+33139
+33140
+33141
+33142
+33143
+33144
+33145
+33146
+33147
+33148
+33149
+33150
+33151
+33152
+33153
+33154
+33155
+33156
+33157
+33158
+33159
+33160
+33161
+33162
+33163
+33164
+33165
+33166
+33167
+33168
+33169
+33170
+33171
+33172
+33173
+33174
+33175
+33176
+33177
+33178
+33179
+33180
+33181
+33182
+33183
+33184
+33185
+33186
+33187
+33188
+33189
+33190
+33191
+33192
+33193
+33194
+33195
+33196
+33197
+33198
+33199
+33200
+33201
+33202
+33203
+33204
+33205
+33206
+33207
+33208
+33209
+33210
+33211
+33212
+33213
+33214
+33215
+33216
+33217
+33218
+33219
+33220
+33221
+33222
+33223
+33224
+33225
+33226
+33227
+33228
+33229
+33230
+33231
+33232
+33233
+33234
+33235
+33236
+33237
+33238
+33239
+33240
+33241
+33242
+33243
+33244
+33245
+33246
+33247
+33248
+33249
+33250
+33251
+33252
+33253
+33254
+33255
+33256
+33257
+33258
+33259
+33260
+33261
+33262
+33263
+33264
+33265
+33266
+33267
+33268
+33269
+33270
+33271
+33272
+33273
+33274
+33275
+33276
+33277
+33278
+33279
+33280
+33281
+33282
+33283
+33284
+33285
+33286
+33287
+33288
+33289
+33290
+33291
+33292
+33293
+33294
+33295
+33296
+33297
+33298
+33299
+33300
+33301
+33302
+33303
+33304
+33305
+33306
+33307
+33308
+33309
+33310
+33311
+33312
+33313
+33314
+33315
+33316
+33317
+33318
+33319
+33320
+33321
+33322
+33323
+33324
+33325
+33326
+33327
+33328
+33329
+33330
+33331
+33332
+33333
+33334
+33335
+33336
+33337
+33338
+33339
+33340
+33341
+33342
+33343
+33344
+33345
+33346
+33347
+33348
+33349
+33350
+33351
+33352
+33353
+33354
+33355
+33356
+33357
+33358
+33359
+33360
+33361
+33362
+33363
+33364
+33365
+33366
+33367
+33368
+33369
+33370
+33371
+33372
+33373
+33374
+33375
+33376
+33377
+33378
+33379
+33380
+33381
+33382
+33383
+33384
+33385
+33386
+33387
+33388
+33389
+33390
+33391
+33392
+33393
+33394
+33395
+33396
+33397
+33398
+33399
+33400
+33401
+33402
+33403
+33404
+33405
+33406
+33407
+33408
+33409
+33410
+33411
+33412
+33413
+33414
+33415
+33416
+33417
+33418
+33419
+33420
+33421
+33422
+33423
+33424
+33425
+33426
+33427
+33428
+33429
+33430
+33431
+33432
+33433
+33434
+33435
+33436
+33437
+33438
+33439
+33440
+33441
+33442
+33443
+33444
+33445
+33446
+33447
+33448
+33449
+33450
+33451
+33452
+33453
+33454
+33455
+33456
+33457
+33458
+33459
+33460
+33461
+33462
+33463
+33464
+33465
+33466
+33467
+33468
+33469
+33470
+33471
+33472
+33473
+33474
+33475
+33476
+33477
+33478
+33479
+33480
+33481
+33482
+33483
+33484
+33485
+33486
+33487
+33488
+33489
+33490
+33491
+33492
+33493
+33494
+33495
+33496
+33497
+33498
+33499
+33500
+33501
+33502
+33503
+33504
+33505
+33506
+33507
+33508
+33509
+33510
+33511
+33512
+33513
+33514
+33515
+33516
+33517
+33518
+33519
+33520
+33521
+33522
+33523
+33524
+33525
+33526
+33527
+33528
+33529
+33530
+33531
+33532
+33533
+33534
+33535
+33536
+33537
+33538
+33539
+33540
+33541
+33542
+33543
+33544
+33545
+33546
+33547
+33548
+33549
+33550
+33551
+33552
+33553
+33554
+33555
+33556
+33557
+33558
+33559
+33560
+33561
+33562
+33563
+33564
+33565
+33566
+33567
+33568
+33569
+33570
+33571
+33572
+33573
+33574
+33575
+33576
+33577
+33578
+33579
+33580
+33581
+33582
+33583
+33584
+33585
+33586
+33587
+33588
+33589
+33590
+33591
+33592
+33593
+33594
+33595
+33596
+33597
+33598
+33599
+33600
+33601
+33602
+33603
+33604
+33605
+33606
+33607
+33608
+33609
+33610
+33611
+33612
+33613
+33614
+33615
+33616
+33617
+33618
+33619
+33620
+33621
+33622
+33623
+33624
+33625
+33626
+33627
+33628
+33629
+33630
+33631
+33632
+33633
+33634
+33635
+33636
+33637
+33638
+33639
+33640
+33641
+33642
+33643
+33644
+33645
+33646
+33647
+33648
+33649
+33650
+33651
+33652
+33653
+33654
+33655
+33656
+33657
+33658
+33659
+33660
+33661
+33662
+33663
+33664
+33665
+33666
+33667
+33668
+33669
+33670
+33671
+33672
+33673
+33674
+33675
+33676
+33677
+33678
+33679
+33680
+33681
+33682
+33683
+33684
+33685
+33686
+33687
+33688
+33689
+33690
+33691
+33692
+33693
+33694
+33695
+33696
+33697
+33698
+33699
+33700
+33701
+33702
+33703
+33704
+33705
+33706
+33707
+33708
+33709
+33710
+33711
+33712
+33713
+33714
+33715
+33716
+33717
+33718
+33719
+33720
+33721
+33722
+33723
+33724
+33725
+33726
+33727
+33728
+33729
+33730
+33731
+33732
+33733
+33734
+33735
+33736
+33737
+33738
+33739
+33740
+33741
+33742
+33743
+33744
+33745
+33746
+33747
+33748
+33749
+33750
+33751
+33752
+33753
+33754
+33755
+33756
+33757
+33758
+33759
+33760
+33761
+33762
+33763
+33764
+33765
+33766
+33767
+33768
+33769
+33770
+33771
+33772
+33773
+33774
+33775
+33776
+33777
+33778
+33779
+33780
+33781
+33782
+33783
+33784
+33785
+33786
+33787
+33788
+33789
+33790
+33791
+33792
+33793
+33794
+33795
+33796
+33797
+33798
+33799
+33800
+33801
+33802
+33803
+33804
+33805
+33806
+33807
+33808
+33809
+33810
+33811
+33812
+33813
+33814
+33815
+33816
+33817
+33818
+33819
+33820
+33821
+33822
+33823
+33824
+33825
+33826
+33827
+33828
+33829
+33830
+33831
+33832
+33833
+33834
+33835
+33836
+33837
+33838
+33839
+33840
+33841
+33842
+33843
+33844
+33845
+33846
+33847
+33848
+33849
+33850
+33851
+33852
+33853
+33854
+33855
+33856
+33857
+33858
+33859
+33860
+33861
+33862
+33863
+33864
+33865
+33866
+33867
+33868
+33869
+33870
+33871
+33872
+33873
+33874
+33875
+33876
+33877
+33878
+33879
+33880
+33881
+33882
+33883
+33884
+33885
+33886
+33887
+33888
+33889
+33890
+33891
+33892
+33893
+33894
+33895
+33896
+33897
+33898
+33899
+33900
+33901
+33902
+33903
+33904
+33905
+33906
+33907
+33908
+33909
+33910
+33911
+33912
+33913
+33914
+33915
+33916
+33917
+33918
+33919
+33920
+33921
+33922
+33923
+33924
+33925
+33926
+33927
+33928
+33929
+33930
+33931
+33932
+33933
+33934
+33935
+33936
+33937
+33938
+33939
+33940
+33941
+33942
+33943
+33944
+33945
+33946
+33947
+33948
+33949
+33950
+33951
+33952
+33953
+33954
+33955
+33956
+33957
+33958
+33959
+33960
+33961
+33962
+33963
+33964
+33965
+33966
+33967
+33968
+33969
+33970
+33971
+33972
+33973
+33974
+33975
+33976
+33977
+33978
+33979
+33980
+33981
+33982
+33983
+33984
+33985
+33986
+33987
+33988
+33989
+33990
+33991
+33992
+33993
+33994
+33995
+33996
+33997
+33998
+33999
+34000
+34001
+34002
+34003
+34004
+34005
+34006
+34007
+34008
+34009
+34010
+34011
+34012
+34013
+34014
+34015
+34016
+34017
+34018
+34019
+34020
+34021
+34022
+34023
+34024
+34025
+34026
+34027
+34028
+34029
+34030
+34031
+34032
+34033
+34034
+34035
+34036
+34037
+34038
+34039
+34040
+34041
+34042
+34043
+34044
+34045
+34046
+34047
+34048
+34049
+34050
+34051
+34052
+34053
+34054
+34055
+34056
+34057
+34058
+34059
+34060
+34061
+34062
+34063
+34064
+34065
+34066
+34067
+34068
+34069
+34070
+34071
+34072
+34073
+34074
+34075
+34076
+34077
+34078
+34079
+34080
+34081
+34082
+34083
+34084
+34085
+34086
+34087
+34088
+34089
+34090
+34091
+34092
+34093
+34094
+34095
+34096
+34097
+34098
+34099
+34100
+34101
+34102
+34103
+34104
+34105
+34106
+34107
+34108
+34109
+34110
+34111
+34112
+34113
+34114
+34115
+34116
+34117
+34118
+34119
+34120
+34121
+34122
+34123
+34124
+34125
+34126
+34127
+34128
+34129
+34130
+34131
+34132
+34133
+34134
+34135
+34136
+34137
+34138
+34139
+34140
+34141
+34142
+34143
+34144
+34145
+34146
+34147
+34148
+34149
+34150
+34151
+34152
+34153
+34154
+34155
+34156
+34157
+34158
+34159
+34160
+34161
+34162
+34163
+34164
+34165
+34166
+34167
+34168
+34169
+34170
+34171
+34172
+34173
+34174
+34175
+34176
+34177
+34178
+34179
+34180
+34181
+34182
+34183
+34184
+34185
+34186
+34187
+34188
+34189
+34190
+34191
+34192
+34193
+34194
+34195
+34196
+34197
+34198
+34199
+34200
+34201
+34202
+34203
+34204
+34205
+34206
+34207
+34208
+34209
+34210
+34211
+34212
+34213
+34214
+34215
+34216
+34217
+34218
+34219
+34220
+34221
+34222
+34223
+34224
+34225
+34226
+34227
+34228
+34229
+34230
+34231
+34232
+34233
+34234
+34235
+34236
+34237
+34238
+34239
+34240
+34241
+34242
+34243
+34244
+34245
+34246
+34247
+34248
+34249
+34250
+34251
+34252
+34253
+34254
+34255
+34256
+34257
+34258
+34259
+34260
+34261
+34262
+34263
+34264
+34265
+34266
+34267
+34268
+34269
+34270
+34271
+34272
+34273
+34274
+34275
+34276
+34277
+34278
+34279
+34280
+34281
+34282
+34283
+34284
+34285
+34286
+34287
+34288
+34289
+34290
+34291
+34292
+34293
+34294
+34295
+34296
+34297
+34298
+34299
+34300
+34301
+34302
+34303
+34304
+34305
+34306
+34307
+34308
+34309
+34310
+34311
+34312
+34313
+34314
+34315
+34316
+34317
+34318
+34319
+34320
+34321
+34322
+34323
+34324
+34325
+34326
+34327
+34328
+34329
+34330
+34331
+34332
+34333
+34334
+34335
+34336
+34337
+34338
+34339
+34340
+34341
+34342
+34343
+34344
+34345
+34346
+34347
+34348
+34349
+34350
+34351
+34352
+34353
+34354
+34355
+34356
+34357
+34358
+34359
+34360
+34361
+34362
+34363
+34364
+34365
+34366
+34367
+34368
+34369
+34370
+34371
+34372
+34373
+34374
+34375
+34376
+34377
+34378
+34379
+34380
+34381
+34382
+34383
+34384
+34385
+34386
+34387
+34388
+34389
+34390
+34391
+34392
+34393
+34394
+34395
+34396
+34397
+34398
+34399
+34400
+34401
+34402
+34403
+34404
+34405
+34406
+34407
+34408
+34409
+34410
+34411
+34412
+34413
+34414
+34415
+34416
+34417
+34418
+34419
+34420
+34421
+34422
+34423
+34424
+34425
+34426
+34427
+34428
+34429
+34430
+34431
+34432
+34433
+34434
+34435
+34436
+34437
+34438
+34439
+34440
+34441
+34442
+34443
+34444
+34445
+34446
+34447
+34448
+34449
+34450
+34451
+34452
+34453
+34454
+34455
+34456
+34457
+34458
+34459
+34460
+34461
+34462
+34463
+34464
+34465
+34466
+34467
+34468
+34469
+34470
+34471
+34472
+34473
+34474
+34475
+34476
+34477
+34478
+34479
+34480
+34481
+34482
+34483
+34484
+34485
+34486
+34487
+34488
+34489
+34490
+34491
+34492
+34493
+34494
+34495
+34496
+34497
+34498
+34499
+34500
+34501
+34502
+34503
+34504
+34505
+34506
+34507
+34508
+34509
+34510
+34511
+34512
+34513
+34514
+34515
+34516
+34517
+34518
+34519
+34520
+34521
+34522
+34523
+34524
+34525
+34526
+34527
+34528
+34529
+34530
+34531
+34532
+34533
+34534
+34535
+34536
+34537
+34538
+34539
+34540
+34541
+34542
+34543
+34544
+34545
+34546
+34547
+34548
+34549
+34550
+34551
+34552
+34553
+34554
+34555
+34556
+34557
+34558
+34559
+34560
+34561
+34562
+34563
+34564
+34565
+34566
+34567
+34568
+34569
+34570
+34571
+34572
+34573
+34574
+34575
+34576
+34577
+34578
+34579
+34580
+34581
+34582
+34583
+34584
+34585
+34586
+34587
+34588
+34589
+34590
+34591
+34592
+34593
+34594
+34595
+34596
+34597
+34598
+34599
+34600
+34601
+34602
+34603
+34604
+34605
+34606
+34607
+34608
+34609
+34610
+34611
+34612
+34613
+34614
+34615
+34616
+34617
+34618
+34619
+34620
+34621
+34622
+34623
+34624
+34625
+34626
+34627
+34628
+34629
+34630
+34631
+34632
+34633
+34634
+34635
+34636
+34637
+34638
+34639
+34640
+34641
+34642
+34643
+34644
+34645
+34646
+34647
+34648
+34649
+34650
+34651
+34652
+34653
+34654
+34655
+34656
+34657
+34658
+34659
+34660
+34661
+34662
+34663
+34664
+34665
+34666
+34667
+34668
+34669
+34670
+34671
+34672
+34673
+34674
+34675
+34676
+34677
+34678
+34679
+34680
+34681
+34682
+34683
+34684
+34685
+34686
+34687
+34688
+34689
+34690
+34691
+34692
+34693
+34694
+34695
+34696
+34697
+34698
+34699
+34700
+34701
+34702
+34703
+34704
+34705
+34706
+34707
+34708
+34709
+34710
+34711
+34712
+34713
+34714
+34715
+34716
+34717
+34718
+34719
+34720
+34721
+34722
+34723
+34724
+34725
+34726
+34727
+34728
+34729
+34730
+34731
+34732
+34733
+34734
+34735
+34736
+34737
+34738
+34739
+34740
+34741
+34742
+34743
+34744
+34745
+34746
+34747
+34748
+34749
+34750
+34751
+34752
+34753
+34754
+34755
+34756
+34757
+34758
+34759
+34760
+34761
+34762
+34763
+34764
+34765
+34766
+34767
+34768
+34769
+34770
+34771
+34772
+34773
+34774
+34775
+34776
+34777
+34778
+34779
+34780
+34781
+34782
+34783
+34784
+34785
+34786
+34787
+34788
+34789
+34790
+34791
+34792
+34793
+34794
+34795
+34796
+34797
+34798
+34799
+34800
+34801
+34802
+34803
+34804
+34805
+34806
+34807
+34808
+34809
+34810
+34811
+34812
+34813
+34814
+34815
+34816
+34817
+34818
+34819
+34820
+34821
+34822
+34823
+34824
+34825
+34826
+34827
+34828
+34829
+34830
+34831
+34832
+34833
+34834
+34835
+34836
+34837
+34838
+34839
+34840
+34841
+34842
+34843
+34844
+34845
+34846
+34847
+34848
+34849
+34850
+34851
+34852
+34853
+34854
+34855
+34856
+34857
+34858
+34859
+34860
+34861
+34862
+34863
+34864
+34865
+34866
+34867
+34868
+34869
+34870
+34871
+34872
+34873
+34874
+34875
+34876
+34877
+34878
+34879
+34880
+34881
+34882
+34883
+34884
+34885
+34886
+34887
+34888
+34889
+34890
+34891
+34892
+34893
+34894
+34895
+34896
+34897
+34898
+34899
+34900
+34901
+34902
+34903
+34904
+34905
+34906
+34907
+34908
+34909
+34910
+34911
+34912
+34913
+34914
+34915
+34916
+34917
+34918
+34919
+34920
+34921
+34922
+34923
+34924
+34925
+34926
+34927
+34928
+34929
+34930
+34931
+34932
+34933
+34934
+34935
+34936
+34937
+34938
+34939
+34940
+34941
+34942
+34943
+34944
+34945
+34946
+34947
+34948
+34949
+34950
+34951
+34952
+34953
+34954
+34955
+34956
+34957
+34958
+34959
+34960
+34961
+34962
+34963
+34964
+34965
+34966
+34967
+34968
+34969
+34970
+34971
+34972
+34973
+34974
+34975
+34976
+34977
+34978
+34979
+34980
+34981
+34982
+34983
+34984
+34985
+34986
+34987
+34988
+34989
+34990
+34991
+34992
+34993
+34994
+34995
+34996
+34997
+34998
+34999
+35000
+35001
+35002
+35003
+35004
+35005
+35006
+35007
+35008
+35009
+35010
+35011
+35012
+35013
+35014
+35015
+35016
+35017
+35018
+35019
+35020
+35021
+35022
+35023
+35024
+35025
+35026
+35027
+35028
+35029
+35030
+35031
+35032
+35033
+35034
+35035
+35036
+35037
+35038
+35039
+35040
+35041
+35042
+35043
+35044
+35045
+35046
+35047
+35048
+35049
+35050
+35051
+35052
+35053
+35054
+35055
+35056
+35057
+35058
+35059
+35060
+35061
+35062
+35063
+35064
+35065
+35066
+35067
+35068
+35069
+35070
+35071
+35072
+35073
+35074
+35075
+35076
+35077
+35078
+35079
+35080
+35081
+35082
+35083
+35084
+35085
+35086
+35087
+35088
+35089
+35090
+35091
+35092
+35093
+35094
+35095
+35096
+35097
+35098
+35099
+35100
+35101
+35102
+35103
+35104
+35105
+35106
+35107
+35108
+35109
+35110
+35111
+35112
+35113
+35114
+35115
+35116
+35117
+35118
+35119
+35120
+35121
+35122
+35123
+35124
+35125
+35126
+35127
+35128
+35129
+35130
+35131
+35132
+35133
+35134
+35135
+35136
+35137
+35138
+35139
+35140
+35141
+35142
+35143
+35144
+35145
+35146
+35147
+35148
+35149
+35150
+35151
+35152
+35153
+35154
+35155
+35156
+35157
+35158
+35159
+35160
+35161
+35162
+35163
+35164
+35165
+35166
+35167
+35168
+35169
+35170
+35171
+35172
+35173
+35174
+35175
+35176
+35177
+35178
+35179
+35180
+35181
+35182
+35183
+35184
+35185
+35186
+35187
+35188
+35189
+35190
+35191
+35192
+35193
+35194
+35195
+35196
+35197
+35198
+35199
+35200
+35201
+35202
+35203
+35204
+35205
+35206
+35207
+35208
+35209
+35210
+35211
+35212
+35213
+35214
+35215
+35216
+35217
+35218
+35219
+35220
+35221
+35222
+35223
+35224
+35225
+35226
+35227
+35228
+35229
+35230
+35231
+35232
+35233
+35234
+35235
+35236
+35237
+35238
+35239
+35240
+35241
+35242
+35243
+35244
+35245
+35246
+35247
+35248
+35249
+35250
+35251
+35252
+35253
+35254
+35255
+35256
+35257
+35258
+35259
+35260
+35261
+35262
+35263
+35264
+35265
+35266
+35267
+35268
+35269
+35270
+35271
+35272
+35273
+35274
+35275
+35276
+35277
+35278
+35279
+35280
+35281
+35282
+35283
+35284
+35285
+35286
+35287
+35288
+35289
+35290
+35291
+35292
+35293
+35294
+35295
+35296
+35297
+35298
+35299
+35300
+35301
+35302
+35303
+35304
+35305
+35306
+35307
+35308
+35309
+35310
+35311
+35312
+35313
+35314
+35315
+35316
+35317
+35318
+35319
+35320
+35321
+35322
+35323
+35324
+35325
+35326
+35327
+35328
+35329
+35330
+35331
+35332
+35333
+35334
+35335
+35336
+35337
+35338
+35339
+35340
+35341
+35342
+35343
+35344
+35345
+35346
+35347
+35348
+35349
+35350
+35351
+35352
+35353
+35354
+35355
+35356
+35357
+35358
+35359
+35360
+35361
+35362
+35363
+35364
+35365
+35366
+35367
+35368
+35369
+35370
+35371
+35372
+35373
+35374
+35375
+35376
+35377
+35378
+35379
+35380
+35381
+35382
+35383
+35384
+35385
+35386
+35387
+35388
+35389
+35390
+35391
+35392
+35393
+35394
+35395
+35396
+35397
+35398
+35399
+35400
+35401
+35402
+35403
+35404
+35405
+35406
+35407
+35408
+35409
+35410
+35411
+35412
+35413
+35414
+35415
+35416
+35417
+35418
+35419
+35420
+35421
+35422
+35423
+35424
+35425
+35426
+35427
+35428
+35429
+35430
+35431
+35432
+35433
+35434
+35435
+35436
+35437
+35438
+35439
+35440
+35441
+35442
+35443
+35444
+35445
+35446
+35447
+35448
+35449
+35450
+35451
+35452
+35453
+35454
+35455
+35456
+35457
+35458
+35459
+35460
+35461
+35462
+35463
+35464
+35465
+35466
+35467
+35468
+35469
+35470
+35471
+35472
+35473
+35474
+35475
+35476
+35477
+35478
+35479
+35480
+35481
+35482
+35483
+35484
+35485
+35486
+35487
+35488
+35489
+35490
+35491
+35492
+35493
+35494
+35495
+35496
+35497
+35498
+35499
+35500
+35501
+35502
+35503
+35504
+35505
+35506
+35507
+35508
+35509
+35510
+35511
+35512
+35513
+35514
+35515
+35516
+35517
+35518
+35519
+35520
+35521
+35522
+35523
+35524
+35525
+35526
+35527
+35528
+35529
+35530
+35531
+35532
+35533
+35534
+35535
+35536
+35537
+35538
+35539
+35540
+35541
+35542
+35543
+35544
+35545
+35546
+35547
+35548
+35549
+35550
+35551
+35552
+35553
+35554
+35555
+35556
+35557
+35558
+35559
+35560
+35561
+35562
+35563
+35564
+35565
+35566
+35567
+35568
+35569
+35570
+35571
+35572
+35573
+35574
+35575
+35576
+35577
+35578
+35579
+35580
+35581
+35582
+35583
+35584
+35585
+35586
+35587
+35588
+35589
+35590
+35591
+35592
+35593
+35594
+35595
+35596
+35597
+35598
+35599
+35600
+35601
+35602
+35603
+35604
+35605
+35606
+35607
+35608
+35609
+35610
+35611
+35612
+35613
+35614
+35615
+35616
+35617
+35618
+35619
+35620
+35621
+35622
+35623
+35624
+35625
+35626
+35627
+35628
+35629
+35630
+35631
+35632
+35633
+35634
+35635
+35636
+35637
+35638
+35639
+35640
+35641
+35642
+35643
+35644
+35645
+35646
+35647
+35648
+35649
+35650
+35651
+35652
+35653
+35654
+35655
+35656
+35657
+35658
+35659
+35660
+35661
+35662
+35663
+35664
+35665
+35666
+35667
+35668
+35669
+35670
+35671
+35672
+35673
+35674
+35675
+35676
+35677
+35678
+35679
+35680
+35681
+35682
+35683
+35684
+35685
+35686
+35687
+35688
+35689
+35690
+35691
+35692
+35693
+35694
+35695
+35696
+35697
+35698
+35699
+35700
+35701
+35702
+35703
+35704
+35705
+35706
+35707
+35708
+35709
+35710
+35711
+35712
+35713
+35714
+35715
+35716
+35717
+35718
+35719
+35720
+35721
+35722
+35723
+35724
+35725
+35726
+35727
+35728
+35729
+35730
+35731
+35732
+35733
+35734
+35735
+35736
+35737
+35738
+35739
+35740
+35741
+35742
+35743
+35744
+35745
+35746
+35747
+35748
+35749
+35750
+35751
+35752
+35753
+35754
+35755
+35756
+35757
+35758
+35759
+35760
+35761
+35762
+35763
+35764
+35765
+35766
+35767
+35768
+35769
+35770
+35771
+35772
+35773
+35774
+35775
+35776
+35777
+35778
+35779
+35780
+35781
+35782
+35783
+35784
+35785
+35786
+35787
+35788
+35789
+35790
+35791
+35792
+35793
+35794
+35795
+35796
+35797
+35798
+35799
+35800
+35801
+35802
+35803
+35804
+35805
+35806
+35807
+35808
+35809
+35810
+35811
+35812
+35813
+35814
+35815
+35816
+35817
+35818
+35819
+35820
+35821
+35822
+35823
+35824
+35825
+35826
+35827
+35828
+35829
+35830
+35831
+35832
+35833
+35834
+35835
+35836
+35837
+35838
+35839
+35840
+35841
+35842
+35843
+35844
+35845
+35846
+35847
+35848
+35849
+35850
+35851
+35852
+35853
+35854
+35855
+35856
+35857
+35858
+35859
+35860
+35861
+35862
+35863
+35864
+35865
+35866
+35867
+35868
+35869
+35870
+35871
+35872
+35873
+35874
+35875
+35876
+35877
+35878
+35879
+35880
+35881
+35882
+35883
+35884
+35885
+35886
+35887
+35888
+35889
+35890
+35891
+35892
+35893
+35894
+35895
+35896
+35897
+35898
+35899
+35900
+35901
+35902
+35903
+35904
+35905
+35906
+35907
+35908
+35909
+35910
+35911
+35912
+35913
+35914
+35915
+35916
+35917
+35918
+35919
+35920
+35921
+35922
+35923
+35924
+35925
+35926
+35927
+35928
+35929
+35930
+35931
+35932
+35933
+35934
+35935
+35936
+35937
+35938
+35939
+35940
+35941
+35942
+35943
+35944
+35945
+35946
+35947
+35948
+35949
+35950
+35951
+35952
+35953
+35954
+35955
+35956
+35957
+35958
+35959
+35960
+35961
+35962
+35963
+35964
+35965
+35966
+35967
+35968
+35969
+35970
+35971
+35972
+35973
+35974
+35975
+35976
+35977
+35978
+35979
+35980
+35981
+35982
+35983
+35984
+35985
+35986
+35987
+35988
+35989
+35990
+35991
+35992
+35993
+35994
+35995
+35996
+35997
+35998
+35999
+36000
+36001
+36002
+36003
+36004
+36005
+36006
+36007
+36008
+36009
+36010
+36011
+36012
+36013
+36014
+36015
+36016
+36017
+36018
+36019
+36020
+36021
+36022
+36023
+36024
+36025
+36026
+36027
+36028
+36029
+36030
+36031
+36032
+36033
+36034
+36035
+36036
+36037
+36038
+36039
+36040
+36041
+36042
+36043
+36044
+36045
+36046
+36047
+36048
+36049
+36050
+36051
+36052
+36053
+36054
+36055
+36056
+36057
+36058
+36059
+36060
+36061
+36062
+36063
+36064
+36065
+36066
+36067
+36068
+36069
+36070
+36071
+36072
+36073
+36074
+36075
+36076
+36077
+36078
+36079
+36080
+36081
+36082
+36083
+36084
+36085
+36086
+36087
+36088
+36089
+36090
+36091
+36092
+36093
+36094
+36095
+36096
+36097
+36098
+36099
+36100
+36101
+36102
+36103
+36104
+36105
+36106
+36107
+36108
+36109
+36110
+36111
+36112
+36113
+36114
+36115
+36116
+36117
+36118
+36119
+36120
+36121
+36122
+36123
+36124
+36125
+36126
+36127
+36128
+36129
+36130
+36131
+36132
+36133
+36134
+36135
+36136
+36137
+36138
+36139
+36140
+36141
+36142
+36143
+36144
+36145
+36146
+36147
+36148
+36149
+36150
+36151
+36152
+36153
+36154
+36155
+36156
+36157
+36158
+36159
+36160
+36161
+36162
+36163
+36164
+36165
+36166
+36167
+36168
+36169
+36170
+36171
+36172
+36173
+36174
+36175
+36176
+36177
+36178
+36179
+36180
+36181
+36182
+36183
+36184
+36185
+36186
+36187
+36188
+36189
+36190
+36191
+36192
+36193
+36194
+36195
+36196
+36197
+36198
+36199
+36200
+36201
+36202
+36203
+36204
+36205
+36206
+36207
+36208
+36209
+36210
+36211
+36212
+36213
+36214
+36215
+36216
+36217
+36218
+36219
+36220
+36221
+36222
+36223
+36224
+36225
+36226
+36227
+36228
+36229
+36230
+36231
+36232
+36233
+36234
+36235
+36236
+36237
+36238
+36239
+36240
+36241
+36242
+36243
+36244
+36245
+36246
+36247
+36248
+36249
+36250
+36251
+36252
+36253
+36254
+36255
+36256
+36257
+36258
+36259
+36260
+36261
+36262
+36263
+36264
+36265
+36266
+36267
+36268
+36269
+36270
+36271
+36272
+36273
+36274
+36275
+36276
+36277
+36278
+36279
+36280
+36281
+36282
+36283
+36284
+36285
+36286
+36287
+36288
+36289
+36290
+36291
+36292
+36293
+36294
+36295
+36296
+36297
+36298
+36299
+36300
+36301
+36302
+36303
+36304
+36305
+36306
+36307
+36308
+36309
+36310
+36311
+36312
+36313
+36314
+36315
+36316
+36317
+36318
+36319
+36320
+36321
+36322
+36323
+36324
+36325
+36326
+36327
+36328
+36329
+36330
+36331
+36332
+36333
+36334
+36335
+36336
+36337
+36338
+36339
+36340
+36341
+36342
+36343
+36344
+36345
+36346
+36347
+36348
+36349
+36350
+36351
+36352
+36353
+36354
+36355
+36356
+36357
+36358
+36359
+36360
+36361
+36362
+36363
+36364
+36365
+36366
+36367
+36368
+36369
+36370
+36371
+36372
+36373
+36374
+36375
+36376
+36377
+36378
+36379
+36380
+36381
+36382
+36383
+36384
+36385
+36386
+36387
+36388
+36389
+36390
+36391
+36392
+36393
+36394
+36395
+36396
+36397
+36398
+36399
+36400
+36401
+36402
+36403
+36404
+36405
+36406
+36407
+36408
+36409
+36410
+36411
+36412
+36413
+36414
+36415
+36416
+36417
+36418
+36419
+36420
+36421
+36422
+36423
+36424
+36425
+36426
+36427
+36428
+36429
+36430
+36431
+36432
+36433
+36434
+36435
+36436
+36437
+36438
+36439
+36440
+36441
+36442
+36443
+36444
+36445
+36446
+36447
+36448
+36449
+36450
+36451
+36452
+36453
+36454
+36455
+36456
+36457
+36458
+36459
+36460
+36461
+36462
+36463
+36464
+36465
+36466
+36467
+36468
+36469
+36470
+36471
+36472
+36473
+36474
+36475
+36476
+36477
+36478
+36479
+36480
+36481
+36482
+36483
+36484
+36485
+36486
+36487
+36488
+36489
+36490
+36491
+36492
+36493
+36494
+36495
+36496
+36497
+36498
+36499
+36500
+36501
+36502
+36503
+36504
+36505
+36506
+36507
+36508
+36509
+36510
+36511
+36512
+36513
+36514
+36515
+36516
+36517
+36518
+36519
+36520
+36521
+36522
+36523
+36524
+36525
+36526
+36527
+36528
+36529
+36530
+36531
+36532
+36533
+36534
+36535
+36536
+36537
+36538
+36539
+36540
+36541
+36542
+36543
+36544
+36545
+36546
+36547
+36548
+36549
+36550
+36551
+36552
+36553
+36554
+36555
+36556
+36557
+36558
+36559
+36560
+36561
+36562
+36563
+36564
+36565
+36566
+36567
+36568
+36569
+36570
+36571
+36572
+36573
+36574
+36575
+36576
+36577
+36578
+36579
+36580
+36581
+36582
+36583
+36584
+36585
+36586
+36587
+36588
+36589
+36590
+36591
+36592
+36593
+36594
+36595
+36596
+36597
+36598
+36599
+36600
+36601
+36602
+36603
+36604
+36605
+36606
+36607
+36608
+36609
+36610
+36611
+36612
+36613
+36614
+36615
+36616
+36617
+36618
+36619
+36620
+36621
+36622
+36623
+36624
+36625
+36626
+36627
+36628
+36629
+36630
+36631
+36632
+36633
+36634
+36635
+36636
+36637
+36638
+36639
+36640
+36641
+36642
+36643
+36644
+36645
+36646
+36647
+36648
+36649
+36650
+36651
+36652
+36653
+36654
+36655
+36656
+36657
+36658
+36659
+36660
+36661
+36662
+36663
+36664
+36665
+36666
+36667
+36668
+36669
+36670
+36671
+36672
+36673
+36674
+36675
+36676
+36677
+36678
+36679
+36680
+36681
+36682
+36683
+36684
+36685
+36686
+36687
+36688
+36689
+36690
+36691
+36692
+36693
+36694
+36695
+36696
+36697
+36698
+36699
+36700
+36701
+36702
+36703
+36704
+36705
+36706
+36707
+36708
+36709
+36710
+36711
+36712
+36713
+36714
+36715
+36716
+36717
+36718
+36719
+36720
+36721
+36722
+36723
+36724
+36725
+36726
+36727
+36728
+36729
+36730
+36731
+36732
+36733
+36734
+36735
+36736
+36737
+36738
+36739
+36740
+36741
+36742
+36743
+36744
+36745
+36746
+36747
+36748
+36749
+36750
+36751
+36752
+36753
+36754
+36755
+36756
+36757
+36758
+36759
+36760
+36761
+36762
+36763
+36764
+36765
+36766
+36767
+36768
+36769
+36770
+36771
+36772
+36773
+36774
+36775
+36776
+36777
+36778
+36779
+36780
+36781
+36782
+36783
+36784
+36785
+36786
+36787
+36788
+36789
+36790
+36791
+36792
+36793
+36794
+36795
+36796
+36797
+36798
+36799
+36800
+36801
+36802
+36803
+36804
+36805
+36806
+36807
+36808
+36809
+36810
+36811
+36812
+36813
+36814
+36815
+36816
+36817
+36818
+36819
+36820
+36821
+36822
+36823
+36824
+36825
+36826
+36827
+36828
+36829
+36830
+36831
+36832
+36833
+36834
+36835
+36836
+36837
+36838
+36839
+36840
+36841
+36842
+36843
+36844
+36845
+36846
+36847
+36848
+36849
+36850
+36851
+36852
+36853
+36854
+36855
+36856
+36857
+36858
+36859
+36860
+36861
+36862
+36863
+36864
+36865
+36866
+36867
+36868
+36869
+36870
+36871
+36872
+36873
+36874
+36875
+36876
+36877
+36878
+36879
+36880
+36881
+36882
+36883
+36884
+36885
+36886
+36887
+36888
+36889
+36890
+36891
+36892
+36893
+36894
+36895
+36896
+36897
+36898
+36899
+36900
+36901
+36902
+36903
+36904
+36905
+36906
+36907
+36908
+36909
+36910
+36911
+36912
+36913
+36914
+36915
+36916
+36917
+36918
+36919
+36920
+36921
+36922
+36923
+36924
+36925
+36926
+36927
+36928
+36929
+36930
+36931
+36932
+36933
+36934
+36935
+36936
+36937
+36938
+36939
+36940
+36941
+36942
+36943
+36944
+36945
+36946
+36947
+36948
+36949
+36950
+36951
+36952
+36953
+36954
+36955
+36956
+36957
+36958
+36959
+36960
+36961
+36962
+36963
+36964
+36965
+36966
+36967
+36968
+36969
+36970
+36971
+36972
+36973
+36974
+36975
+36976
+36977
+36978
+36979
+36980
+36981
+36982
+36983
+36984
+36985
+36986
+36987
+36988
+36989
+36990
+36991
+36992
+36993
+36994
+36995
+36996
+36997
+36998
+36999
+37000
+37001
+37002
+37003
+37004
+37005
+37006
+37007
+37008
+37009
+37010
+37011
+37012
+37013
+37014
+37015
+37016
+37017
+37018
+37019
+37020
+37021
+37022
+37023
+37024
+37025
+37026
+37027
+37028
+37029
+37030
+37031
+37032
+37033
+37034
+37035
+37036
+37037
+37038
+37039
+37040
+37041
+37042
+37043
+37044
+37045
+37046
+37047
+37048
+37049
+37050
+37051
+37052
+37053
+37054
+37055
+37056
+37057
+37058
+37059
+37060
+37061
+37062
+37063
+37064
+37065
+37066
+37067
+37068
+37069
+37070
+37071
+37072
+37073
+37074
+37075
+37076
+37077
+37078
+37079
+37080
+37081
+37082
+37083
+37084
+37085
+37086
+37087
+37088
+37089
+37090
+37091
+37092
+37093
+37094
+37095
+37096
+37097
+37098
+37099
+37100
+37101
+37102
+37103
+37104
+37105
+37106
+37107
+37108
+37109
+37110
+37111
+37112
+37113
+37114
+37115
+37116
+37117
+37118
+37119
+37120
+37121
+37122
+37123
+37124
+37125
+37126
+37127
+37128
+37129
+37130
+37131
+37132
+37133
+37134
+37135
+37136
+37137
+37138
+37139
+37140
+37141
+37142
+37143
+37144
+37145
+37146
+37147
+37148
+37149
+37150
+37151
+37152
+37153
+37154
+37155
+37156
+37157
+37158
+37159
+37160
+37161
+37162
+37163
+37164
+37165
+37166
+37167
+37168
+37169
+37170
+37171
+37172
+37173
+37174
+37175
+37176
+37177
+37178
+37179
+37180
+37181
+37182
+37183
+37184
+37185
+37186
+37187
+37188
+37189
+37190
+37191
+37192
+37193
+37194
+37195
+37196
+37197
+37198
+37199
+37200
+37201
+37202
+37203
+37204
+37205
+37206
+37207
+37208
+37209
+37210
+37211
+37212
+37213
+37214
+37215
+37216
+37217
+37218
+37219
+37220
+37221
+37222
+37223
+37224
+37225
+37226
+37227
+37228
+37229
+37230
+37231
+37232
+37233
+37234
+37235
+37236
+37237
+37238
+37239
+37240
+37241
+37242
+37243
+37244
+37245
+37246
+37247
+37248
+37249
+37250
+37251
+37252
+37253
+37254
+37255
+37256
+37257
+37258
+37259
+37260
+37261
+37262
+37263
+37264
+37265
+37266
+37267
+37268
+37269
+37270
+37271
+37272
+37273
+37274
+37275
+37276
+37277
+37278
+37279
+37280
+37281
+37282
+37283
+37284
+37285
+37286
+37287
+37288
+37289
+37290
+37291
+37292
+37293
+37294
+37295
+37296
+37297
+37298
+37299
+37300
+37301
+37302
+37303
+37304
+37305
+37306
+37307
+37308
+37309
+37310
+37311
+37312
+37313
+37314
+37315
+37316
+37317
+37318
+37319
+37320
+37321
+37322
+37323
+37324
+37325
+37326
+37327
+37328
+37329
+37330
+37331
+37332
+37333
+37334
+37335
+37336
+37337
+37338
+37339
+37340
+37341
+37342
+37343
+37344
+37345
+37346
+37347
+37348
+37349
+37350
+37351
+37352
+37353
+37354
+37355
+37356
+37357
+37358
+37359
+37360
+37361
+37362
+37363
+37364
+37365
+37366
+37367
+37368
+37369
+37370
+37371
+37372
+37373
+37374
+37375
+37376
+37377
+37378
+37379
+37380
+37381
+37382
+37383
+37384
+37385
+37386
+37387
+37388
+37389
+37390
+37391
+37392
+37393
+37394
+37395
+37396
+37397
+37398
+37399
+37400
+37401
+37402
+37403
+37404
+37405
+37406
+37407
+37408
+37409
+37410
+37411
+37412
+37413
+37414
+37415
+37416
+37417
+37418
+37419
+37420
+37421
+37422
+37423
+37424
+37425
+37426
+37427
+37428
+37429
+37430
+37431
+37432
+37433
+37434
+37435
+37436
+37437
+37438
+37439
+37440
+37441
+37442
+37443
+37444
+37445
+37446
+37447
+37448
+37449
+37450
+37451
+37452
+37453
+37454
+37455
+37456
+37457
+37458
+37459
+37460
+37461
+37462
+37463
+37464
+37465
+37466
+37467
+37468
+37469
+37470
+37471
+37472
+37473
+37474
+37475
+37476
+37477
+37478
+37479
+37480
+37481
+37482
+37483
+37484
+37485
+37486
+37487
+37488
+37489
+37490
+37491
+37492
+37493
+37494
+37495
+37496
+37497
+37498
+37499
+37500
+37501
+37502
+37503
+37504
+37505
+37506
+37507
+37508
+37509
+37510
+37511
+37512
+37513
+37514
+37515
+37516
+37517
+37518
+37519
+37520
+37521
+37522
+37523
+37524
+37525
+37526
+37527
+37528
+37529
+37530
+37531
+37532
+37533
+37534
+37535
+37536
+37537
+37538
+37539
+37540
+37541
+37542
+37543
+37544
+37545
+37546
+37547
+37548
+37549
+37550
+37551
+37552
+37553
+37554
+37555
+37556
+37557
+37558
+37559
+37560
+37561
+37562
+37563
+37564
+37565
+37566
+37567
+37568
+37569
+37570
+37571
+37572
+37573
+37574
+37575
+37576
+37577
+37578
+37579
+37580
+37581
+37582
+37583
+37584
+37585
+37586
+37587
+37588
+37589
+37590
+37591
+37592
+37593
+37594
+37595
+37596
+37597
+37598
+37599
+37600
+37601
+37602
+37603
+37604
+37605
+37606
+37607
+37608
+37609
+37610
+37611
+37612
+37613
+37614
+37615
+37616
+37617
+37618
+37619
+37620
+37621
+37622
+37623
+37624
+37625
+37626
+37627
+37628
+37629
+37630
+37631
+37632
+37633
+37634
+37635
+37636
+37637
+37638
+37639
+37640
+37641
+37642
+37643
+37644
+37645
+37646
+37647
+37648
+37649
+37650
+37651
+37652
+37653
+37654
+37655
+37656
+37657
+37658
+37659
+37660
+37661
+37662
+37663
+37664
+37665
+37666
+37667
+37668
+37669
+37670
+37671
+37672
+37673
+37674
+37675
+37676
+37677
+37678
+37679
+37680
+37681
+37682
+37683
+37684
+37685
+37686
+37687
+37688
+37689
+37690
+37691
+37692
+37693
+37694
+37695
+37696
+37697
+37698
+37699
+37700
+37701
+37702
+37703
+37704
+37705
+37706
+37707
+37708
+37709
+37710
+37711
+37712
+37713
+37714
+37715
+37716
+37717
+37718
+37719
+37720
+37721
+37722
+37723
+37724
+37725
+37726
+37727
+37728
+37729
+37730
+37731
+37732
+37733
+37734
+37735
+37736
+37737
+37738
+37739
+37740
+37741
+37742
+37743
+37744
+37745
+37746
+37747
+37748
+37749
+37750
+37751
+37752
+37753
+37754
+37755
+37756
+37757
+37758
+37759
+37760
+37761
+37762
+37763
+37764
+37765
+37766
+37767
+37768
+37769
+37770
+37771
+37772
+37773
+37774
+37775
+37776
+37777
+37778
+37779
+37780
+37781
+37782
+37783
+37784
+37785
+37786
+37787
+37788
+37789
+37790
+37791
+37792
+37793
+37794
+37795
+37796
+37797
+37798
+37799
+37800
+37801
+37802
+37803
+37804
+37805
+37806
+37807
+37808
+37809
+37810
+37811
+37812
+37813
+37814
+37815
+37816
+37817
+37818
+37819
+37820
+37821
+37822
+37823
+37824
+37825
+37826
+37827
+37828
+37829
+37830
+37831
+37832
+37833
+37834
+37835
+37836
+37837
+37838
+37839
+37840
+37841
+37842
+37843
+37844
+37845
+37846
+37847
+37848
+37849
+37850
+37851
+37852
+37853
+37854
+37855
+37856
+37857
+37858
+37859
+37860
+37861
+37862
+37863
+37864
+37865
+37866
+37867
+37868
+37869
+37870
+37871
+37872
+37873
+37874
+37875
+37876
+37877
+37878
+37879
+37880
+37881
+37882
+37883
+37884
+37885
+37886
+37887
+37888
+37889
+37890
+37891
+37892
+37893
+37894
+37895
+37896
+37897
+37898
+37899
+37900
+37901
+37902
+37903
+37904
+37905
+37906
+37907
+37908
+37909
+37910
+37911
+37912
+37913
+37914
+37915
+37916
+37917
+37918
+37919
+37920
+37921
+37922
+37923
+37924
+37925
+37926
+37927
+37928
+37929
+37930
+37931
+37932
+37933
+37934
+37935
+37936
+37937
+37938
+37939
+37940
+37941
+37942
+37943
+37944
+37945
+37946
+37947
+37948
+37949
+37950
+37951
+37952
+37953
+37954
+37955
+37956
+37957
+37958
+37959
+37960
+37961
+37962
+37963
+37964
+37965
+37966
+37967
+37968
+37969
+37970
+37971
+37972
+37973
+37974
+37975
+37976
+37977
+37978
+37979
+37980
+37981
+37982
+37983
+37984
+37985
+37986
+37987
+37988
+37989
+37990
+37991
+37992
+37993
+37994
+37995
+37996
+37997
+37998
+37999
+38000
+38001
+38002
+38003
+38004
+38005
+38006
+38007
+38008
+38009
+38010
+38011
+38012
+38013
+38014
+38015
+38016
+38017
+38018
+38019
+38020
+38021
+38022
+38023
+38024
+38025
+38026
+38027
+38028
+38029
+38030
+38031
+38032
+38033
+38034
+38035
+38036
+38037
+38038
+38039
+38040
+38041
+38042
+38043
+38044
+38045
+38046
+38047
+38048
+38049
+38050
+38051
+38052
+38053
+38054
+38055
+38056
+38057
+38058
+38059
+38060
+38061
+38062
+38063
+38064
+38065
+38066
+38067
+38068
+38069
+38070
+38071
+38072
+38073
+38074
+38075
+38076
+38077
+38078
+38079
+38080
+38081
+38082
+38083
+38084
+38085
+38086
+38087
+38088
+38089
+38090
+38091
+38092
+38093
+38094
+38095
+38096
+38097
+38098
+38099
+38100
+38101
+38102
+38103
+38104
+38105
+38106
+38107
+38108
+38109
+38110
+38111
+38112
+38113
+38114
+38115
+38116
+38117
+38118
+38119
+38120
+38121
+38122
+38123
+38124
+38125
+38126
+38127
+38128
+38129
+38130
+38131
+38132
+38133
+38134
+38135
+38136
+38137
+38138
+38139
+38140
+38141
+38142
+38143
+38144
+38145
+38146
+38147
+38148
+38149
+38150
+38151
+38152
+38153
+38154
+38155
+38156
+38157
+38158
+38159
+38160
+38161
+38162
+38163
+38164
+38165
+38166
+38167
+38168
+38169
+38170
+38171
+38172
+38173
+38174
+38175
+38176
+38177
+38178
+38179
+38180
+38181
+38182
+38183
+38184
+38185
+38186
+38187
+38188
+38189
+38190
+38191
+38192
+38193
+38194
+38195
+38196
+38197
+38198
+38199
+38200
+38201
+38202
+38203
+38204
+38205
+38206
+38207
+38208
+38209
+38210
+38211
+38212
+38213
+38214
+38215
+38216
+38217
+38218
+38219
+38220
+38221
+38222
+38223
+38224
+38225
+38226
+38227
+38228
+38229
+38230
+38231
+38232
+38233
+38234
+38235
+38236
+38237
+38238
+38239
+38240
+38241
+38242
+38243
+38244
+38245
+38246
+38247
+38248
+38249
+38250
+38251
+38252
+38253
+38254
+38255
+38256
+38257
+38258
+38259
+38260
+38261
+38262
+38263
+38264
+38265
+38266
+38267
+38268
+38269
+38270
+38271
+38272
+38273
+38274
+38275
+38276
+38277
+38278
+38279
+38280
+38281
+38282
+38283
+38284
+38285
+38286
+38287
+38288
+38289
+38290
+38291
+38292
+38293
+38294
+38295
+38296
+38297
+38298
+38299
+38300
+38301
+38302
+38303
+38304
+38305
+38306
+38307
+38308
+38309
+38310
+38311
+38312
+38313
+38314
+38315
+38316
+38317
+38318
+38319
+38320
+38321
+38322
+38323
+38324
+38325
+38326
+38327
+38328
+38329
+38330
+38331
+38332
+38333
+38334
+38335
+38336
+38337
+38338
+38339
+38340
+38341
+38342
+38343
+38344
+38345
+38346
+38347
+38348
+38349
+38350
+38351
+38352
+38353
+38354
+38355
+38356
+38357
+38358
+38359
+38360
+38361
+38362
+38363
+38364
+38365
+38366
+38367
+38368
+38369
+38370
+38371
+38372
+38373
+38374
+38375
+38376
+38377
+38378
+38379
+38380
+38381
+38382
+38383
+38384
+38385
+38386
+38387
+38388
+38389
+38390
+38391
+38392
+38393
+38394
+38395
+38396
+38397
+38398
+38399
+38400
+38401
+38402
+38403
+38404
+38405
+38406
+38407
+38408
+38409
+38410
+38411
+38412
+38413
+38414
+38415
+38416
+38417
+38418
+38419
+38420
+38421
+38422
+38423
+38424
+38425
+38426
+38427
+38428
+38429
+38430
+38431
+38432
+38433
+38434
+38435
+38436
+38437
+38438
+38439
+38440
+38441
+38442
+38443
+38444
+38445
+38446
+38447
+38448
+38449
+38450
+38451
+38452
+38453
+38454
+38455
+38456
+38457
+38458
+38459
+38460
+38461
+38462
+38463
+38464
+38465
+38466
+38467
+38468
+38469
+38470
+38471
+38472
+38473
+38474
+38475
+38476
+38477
+38478
+38479
+38480
+38481
+38482
+38483
+38484
+38485
+38486
+38487
+38488
+38489
+38490
+38491
+38492
+38493
+38494
+38495
+38496
+38497
+38498
+38499
+38500
+38501
+38502
+38503
+38504
+38505
+38506
+38507
+38508
+38509
+38510
+38511
+38512
+38513
+38514
+38515
+38516
+38517
+38518
+38519
+38520
+38521
+38522
+38523
+38524
+38525
+38526
+38527
+38528
+38529
+38530
+38531
+38532
+38533
+38534
+38535
+38536
+38537
+38538
+38539
+38540
+38541
+38542
+38543
+38544
+38545
+38546
+38547
+38548
+38549
+38550
+38551
+38552
+38553
+38554
+38555
+38556
+38557
+38558
+38559
+38560
+38561
+38562
+38563
+38564
+38565
+38566
+38567
+38568
+38569
+38570
+38571
+38572
+38573
+38574
+38575
+38576
+38577
+38578
+38579
+38580
+38581
+38582
+38583
+38584
+38585
+38586
+38587
+38588
+38589
+38590
+38591
+38592
+38593
+38594
+38595
+38596
+38597
+38598
+38599
+38600
+38601
+38602
+38603
+38604
+38605
+38606
+38607
+38608
+38609
+38610
+38611
+38612
+38613
+38614
+38615
+38616
+38617
+38618
+38619
+38620
+38621
+38622
+38623
+38624
+38625
+38626
+38627
+38628
+38629
+38630
+38631
+38632
+38633
+38634
+38635
+38636
+38637
+38638
+38639
+38640
+38641
+38642
+38643
+38644
+38645
+38646
+38647
+38648
+38649
+38650
+38651
+38652
+38653
+38654
+38655
+38656
+38657
+38658
+38659
+38660
+38661
+38662
+38663
+38664
+38665
+38666
+38667
+38668
+38669
+38670
+38671
+38672
+38673
+38674
+38675
+38676
+38677
+38678
+38679
+38680
+38681
+38682
+38683
+38684
+38685
+38686
+38687
+38688
+38689
+38690
+38691
+38692
+38693
+38694
+38695
+38696
+38697
+38698
+38699
+38700
+38701
+38702
+38703
+38704
+38705
+38706
+38707
+38708
+38709
+38710
+38711
+38712
+38713
+38714
+38715
+38716
+38717
+38718
+38719
+38720
+38721
+38722
+38723
+38724
+38725
+38726
+38727
+38728
+38729
+38730
+38731
+38732
+38733
+38734
+38735
+38736
+38737
+38738
+38739
+38740
+38741
+38742
+38743
+38744
+38745
+38746
+38747
+38748
+38749
+38750
+38751
+38752
+38753
+38754
+38755
+38756
+38757
+38758
+38759
+38760
+38761
+38762
+38763
+38764
+38765
+38766
+38767
+38768
+38769
+38770
+38771
+38772
+38773
+38774
+38775
+38776
+38777
+38778
+38779
+38780
+38781
+38782
+38783
+38784
+38785
+38786
+38787
+38788
+38789
+38790
+38791
+38792
+38793
+38794
+38795
+38796
+38797
+38798
+38799
+38800
+38801
+38802
+38803
+38804
+38805
+38806
+38807
+38808
+38809
+38810
+38811
+38812
+38813
+38814
+38815
+38816
+38817
+38818
+38819
+38820
+38821
+38822
+38823
+38824
+38825
+38826
+38827
+38828
+38829
+38830
+38831
+38832
+38833
+38834
+38835
+38836
+38837
+38838
+38839
+38840
+38841
+38842
+38843
+38844
+38845
+38846
+38847
+38848
+38849
+38850
+38851
+38852
+38853
+38854
+38855
+38856
+38857
+38858
+38859
+38860
+38861
+38862
+38863
+38864
+38865
+38866
+38867
+38868
+38869
+38870
+38871
+38872
+38873
+38874
+38875
+38876
+38877
+38878
+38879
+38880
+38881
+38882
+38883
+38884
+38885
+38886
+38887
+38888
+38889
+38890
+38891
+38892
+38893
+38894
+38895
+38896
+38897
+38898
+38899
+38900
+38901
+38902
+38903
+38904
+38905
+38906
+38907
+38908
+38909
+38910
+38911
+38912
+38913
+38914
+38915
+38916
+38917
+38918
+38919
+38920
+38921
+38922
+38923
+38924
+38925
+38926
+38927
+38928
+38929
+38930
+38931
+38932
+38933
+38934
+38935
+38936
+38937
+38938
+38939
+38940
+38941
+38942
+38943
+38944
+38945
+38946
+38947
+38948
+38949
+38950
+38951
+38952
+38953
+38954
+38955
+38956
+38957
+38958
+38959
+38960
+38961
+38962
+38963
+38964
+38965
+38966
+38967
+38968
+38969
+38970
+38971
+38972
+38973
+38974
+38975
+38976
+38977
+38978
+38979
+38980
+38981
+38982
+38983
+38984
+38985
+38986
+38987
+38988
+38989
+38990
+38991
+38992
+38993
+38994
+38995
+38996
+38997
+38998
+38999
+39000
+39001
+39002
+39003
+39004
+39005
+39006
+39007
+39008
+39009
+39010
+39011
+39012
+39013
+39014
+39015
+39016
+39017
+39018
+39019
+39020
+39021
+39022
+39023
+39024
+39025
+39026
+39027
+39028
+39029
+39030
+39031
+39032
+39033
+39034
+39035
+39036
+39037
+39038
+39039
+39040
+39041
+39042
+39043
+39044
+39045
+39046
+39047
+39048
+39049
+39050
+39051
+39052
+39053
+39054
+39055
+39056
+39057
+39058
+39059
+39060
+39061
+39062
+39063
+39064
+39065
+39066
+39067
+39068
+39069
+39070
+39071
+39072
+39073
+39074
+39075
+39076
+39077
+39078
+39079
+39080
+39081
+39082
+39083
+39084
+39085
+39086
+39087
+39088
+39089
+39090
+39091
+39092
+39093
+39094
+39095
+39096
+39097
+39098
+39099
+39100
+39101
+39102
+39103
+39104
+39105
+39106
+39107
+39108
+39109
+39110
+39111
+39112
+39113
+39114
+39115
+39116
+39117
+39118
+39119
+39120
+39121
+39122
+39123
+39124
+39125
+39126
+39127
+39128
+39129
+39130
+39131
+39132
+39133
+39134
+39135
+39136
+39137
+39138
+39139
+39140
+39141
+39142
+39143
+39144
+39145
+39146
+39147
+39148
+39149
+39150
+39151
+39152
+39153
+39154
+39155
+39156
+39157
+39158
+39159
+39160
+39161
+39162
+39163
+39164
+39165
+39166
+39167
+39168
+39169
+39170
+39171
+39172
+39173
+39174
+39175
+39176
+39177
+39178
+39179
+39180
+39181
+39182
+39183
+39184
+39185
+39186
+39187
+39188
+39189
+39190
+39191
+39192
+39193
+39194
+39195
+39196
+39197
+39198
+39199
+39200
+39201
+39202
+39203
+39204
+39205
+39206
+39207
+39208
+39209
+39210
+39211
+39212
+39213
+39214
+39215
+39216
+39217
+39218
+39219
+39220
+39221
+39222
+39223
+39224
+39225
+39226
+39227
+39228
+39229
+39230
+39231
+39232
+39233
+39234
+39235
+39236
+39237
+39238
+39239
+39240
+39241
+39242
+39243
+39244
+39245
+39246
+39247
+39248
+39249
+39250
+39251
+39252
+39253
+39254
+39255
+39256
+39257
+39258
+39259
+39260
+39261
+39262
+39263
+39264
+39265
+39266
+39267
+39268
+39269
+39270
+39271
+39272
+39273
+39274
+39275
+39276
+39277
+39278
+39279
+39280
+39281
+39282
+39283
+39284
+39285
+39286
+39287
+39288
+39289
+39290
+39291
+39292
+39293
+39294
+39295
+39296
+39297
+39298
+39299
+39300
+39301
+39302
+39303
+39304
+39305
+39306
+39307
+39308
+39309
+39310
+39311
+39312
+39313
+39314
+39315
+39316
+39317
+39318
+39319
+39320
+39321
+39322
+39323
+39324
+39325
+39326
+39327
+39328
+39329
+39330
+39331
+39332
+39333
+39334
+39335
+39336
+39337
+39338
+39339
+39340
+39341
+39342
+39343
+39344
+39345
+39346
+39347
+39348
+39349
+39350
+39351
+39352
+39353
+39354
+39355
+39356
+39357
+39358
+39359
+39360
+39361
+39362
+39363
+39364
+39365
+39366
+39367
+39368
+39369
+39370
+39371
+39372
+39373
+39374
+39375
+39376
+39377
+39378
+39379
+39380
+39381
+39382
+39383
+39384
+39385
+39386
+39387
+39388
+39389
+39390
+39391
+39392
+39393
+39394
+39395
+39396
+39397
+39398
+39399
+39400
+39401
+39402
+39403
+39404
+39405
+39406
+39407
+39408
+39409
+39410
+39411
+39412
+39413
+39414
+39415
+39416
+39417
+39418
+39419
+39420
+39421
+39422
+39423
+39424
+39425
+39426
+39427
+39428
+39429
+39430
+39431
+39432
+39433
+39434
+39435
+39436
+39437
+39438
+39439
+39440
+39441
+39442
+39443
+39444
+39445
+39446
+39447
+39448
+39449
+39450
+39451
+39452
+39453
+39454
+39455
+39456
+39457
+39458
+39459
+39460
+39461
+39462
+39463
+39464
+39465
+39466
+39467
+39468
+39469
+39470
+39471
+39472
+39473
+39474
+39475
+39476
+39477
+39478
+39479
+39480
+39481
+39482
+39483
+39484
+39485
+39486
+39487
+39488
+39489
+39490
+39491
+39492
+39493
+39494
+39495
+39496
+39497
+39498
+39499
+39500
+39501
+39502
+39503
+39504
+39505
+39506
+39507
+39508
+39509
+39510
+39511
+39512
+39513
+39514
+39515
+39516
+39517
+39518
+39519
+39520
+39521
+39522
+39523
+39524
+39525
+39526
+39527
+39528
+39529
+39530
+39531
+39532
+39533
+39534
+39535
+39536
+39537
+39538
+39539
+39540
+39541
+39542
+39543
+39544
+39545
+39546
+39547
+39548
+39549
+39550
+39551
+39552
+39553
+39554
+39555
+39556
+39557
+39558
+39559
+39560
+39561
+39562
+39563
+39564
+39565
+39566
+39567
+39568
+39569
+39570
+39571
+39572
+39573
+39574
+39575
+39576
+39577
+39578
+39579
+39580
+39581
+39582
+39583
+39584
+39585
+39586
+39587
+39588
+39589
+39590
+39591
+39592
+39593
+39594
+39595
+39596
+39597
+39598
+39599
+39600
+39601
+39602
+39603
+39604
+39605
+39606
+39607
+39608
+39609
+39610
+39611
+39612
+39613
+39614
+39615
+39616
+39617
+39618
+39619
+39620
+39621
+39622
+39623
+39624
+39625
+39626
+39627
+39628
+39629
+39630
+39631
+39632
+39633
+39634
+39635
+39636
+39637
+39638
+39639
+39640
+39641
+39642
+39643
+39644
+39645
+39646
+39647
+39648
+39649
+39650
+39651
+39652
+39653
+39654
+39655
+39656
+39657
+39658
+39659
+39660
+39661
+39662
+39663
+39664
+39665
+39666
+39667
+39668
+39669
+39670
+39671
+39672
+39673
+39674
+39675
+39676
+39677
+39678
+39679
+39680
+39681
+39682
+39683
+39684
+39685
+39686
+39687
+39688
+39689
+39690
+39691
+39692
+39693
+39694
+39695
+39696
+39697
+39698
+39699
+39700
+39701
+39702
+39703
+39704
+39705
+39706
+39707
+39708
+39709
+39710
+39711
+39712
+39713
+39714
+39715
+39716
+39717
+39718
+39719
+39720
+39721
+39722
+39723
+39724
+39725
+39726
+39727
+39728
+39729
+39730
+39731
+39732
+39733
+39734
+39735
+39736
+39737
+39738
+39739
+39740
+39741
+39742
+39743
+39744
+39745
+39746
+39747
+39748
+39749
+39750
+39751
+39752
+39753
+39754
+39755
+39756
+39757
+39758
+39759
+39760
+39761
+39762
+39763
+39764
+39765
+39766
+39767
+39768
+39769
+39770
+39771
+39772
+39773
+39774
+39775
+39776
+39777
+39778
+39779
+39780
+39781
+39782
+39783
+39784
+39785
+39786
+39787
+39788
+39789
+39790
+39791
+39792
+39793
+39794
+39795
+39796
+39797
+39798
+39799
+39800
+39801
+39802
+39803
+39804
+39805
+39806
+39807
+39808
+39809
+39810
+39811
+39812
+39813
+39814
+39815
+39816
+39817
+39818
+39819
+39820
+39821
+39822
+39823
+39824
+39825
+39826
+39827
+39828
+39829
+39830
+39831
+39832
+39833
+39834
+39835
+39836
+39837
+39838
+39839
+39840
+39841
+39842
+39843
+39844
+39845
+39846
+39847
+39848
+39849
+39850
+39851
+39852
+39853
+39854
+39855
+39856
+39857
+39858
+39859
+39860
+39861
+39862
+39863
+39864
+39865
+39866
+39867
+39868
+39869
+39870
+39871
+39872
+39873
+39874
+39875
+39876
+39877
+39878
+39879
+39880
+39881
+39882
+39883
+39884
+39885
+39886
+39887
+39888
+39889
+39890
+39891
+39892
+39893
+39894
+39895
+39896
+39897
+39898
+39899
+39900
+39901
+39902
+39903
+39904
+39905
+39906
+39907
+39908
+39909
+39910
+39911
+39912
+39913
+39914
+39915
+39916
+39917
+39918
+39919
+39920
+39921
+39922
+39923
+39924
+39925
+39926
+39927
+39928
+39929
+39930
+39931
+39932
+39933
+39934
+39935
+39936
+39937
+39938
+39939
+39940
+39941
+39942
+39943
+39944
+39945
+39946
+39947
+39948
+39949
+39950
+39951
+39952
+39953
+39954
+39955
+39956
+39957
+39958
+39959
+39960
+39961
+39962
+39963
+39964
+39965
+39966
+39967
+39968
+39969
+39970
+39971
+39972
+39973
+39974
+39975
+39976
+39977
+39978
+39979
+39980
+39981
+39982
+39983
+39984
+39985
+39986
+39987
+39988
+39989
+39990
+39991
+39992
+39993
+39994
+39995
+39996
+39997
+39998
+39999
+40000
+40001
+40002
+40003
+40004
+40005
+40006
+40007
+40008
+40009
+40010
+40011
+40012
+40013
+40014
+40015
+40016
+40017
+40018
+40019
+40020
+40021
+40022
+40023
+40024
+40025
+40026
+40027
+40028
+40029
+40030
+40031
+40032
+40033
+40034
+40035
+40036
+40037
+40038
+40039
+40040
+40041
+40042
+40043
+40044
+40045
+40046
+40047
+40048
+40049
+40050
+40051
+40052
+40053
+40054
+40055
+40056
+40057
+40058
+40059
+40060
+40061
+40062
+40063
+40064
+40065
+40066
+40067
+40068
+40069
+40070
+40071
+40072
+40073
+40074
+40075
+40076
+40077
+40078
+40079
+40080
+40081
+40082
+40083
+40084
+40085
+40086
+40087
+40088
+40089
+40090
+40091
+40092
+40093
+40094
+40095
+40096
+40097
+40098
+40099
+40100
+40101
+40102
+40103
+40104
+40105
+40106
+40107
+40108
+40109
+40110
+40111
+40112
+40113
+40114
+40115
+40116
+40117
+40118
+40119
+40120
+40121
+40122
+40123
+40124
+40125
+40126
+40127
+40128
+40129
+40130
+40131
+40132
+40133
+40134
+40135
+40136
+40137
+40138
+40139
+40140
+40141
+40142
+40143
+40144
+40145
+40146
+40147
+40148
+40149
+40150
+40151
+40152
+40153
+40154
+40155
+40156
+40157
+40158
+40159
+40160
+40161
+40162
+40163
+40164
+40165
+40166
+40167
+40168
+40169
+40170
+40171
+40172
+40173
+40174
+40175
+40176
+40177
+40178
+40179
+40180
+40181
+40182
+40183
+40184
+40185
+40186
+40187
+40188
+40189
+40190
+40191
+40192
+40193
+40194
+40195
+40196
+40197
+40198
+40199
+40200
+40201
+40202
+40203
+40204
+40205
+40206
+40207
+40208
+40209
+40210
+40211
+40212
+40213
+40214
+40215
+40216
+40217
+40218
+40219
+40220
+40221
+40222
+40223
+40224
+40225
+40226
+40227
+40228
+40229
+40230
+40231
+40232
+40233
+40234
+40235
+40236
+40237
+40238
+40239
+40240
+40241
+40242
+40243
+40244
+40245
+40246
+40247
+40248
+40249
+40250
+40251
+40252
+40253
+40254
+40255
+40256
+40257
+40258
+40259
+40260
+40261
+40262
+40263
+40264
+40265
+40266
+40267
+40268
+40269
+40270
+40271
+40272
+40273
+40274
+40275
+40276
+40277
+40278
+40279
+40280
+40281
+40282
+40283
+40284
+40285
+40286
+40287
+40288
+40289
+40290
+40291
+40292
+40293
+40294
+40295
+40296
+40297
+40298
+40299
+40300
+40301
+40302
+40303
+40304
+40305
+40306
+40307
+40308
+40309
+40310
+40311
+40312
+40313
+40314
+40315
+40316
+40317
+40318
+40319
+40320
+40321
+40322
+40323
+40324
+40325
+40326
+40327
+40328
+40329
+40330
+40331
+40332
+40333
+40334
+40335
+40336
+40337
+40338
+40339
+40340
+40341
+40342
+40343
+40344
+40345
+40346
+40347
+40348
+40349
+40350
+40351
+40352
+40353
+40354
+40355
+40356
+40357
+40358
+40359
+40360
+40361
+40362
+40363
+40364
+40365
+40366
+40367
+40368
+40369
+40370
+40371
+40372
+40373
+40374
+40375
+40376
+40377
+40378
+40379
+40380
+40381
+40382
+40383
+40384
+40385
+40386
+40387
+40388
+40389
+40390
+40391
+40392
+40393
+40394
+40395
+40396
+40397
+40398
+40399
+40400
+40401
+40402
+40403
+40404
+40405
+40406
+40407
+40408
+40409
+40410
+40411
+40412
+40413
+40414
+40415
+40416
+40417
+40418
+40419
+40420
+40421
+40422
+40423
+40424
+40425
+40426
+40427
+40428
+40429
+40430
+40431
+40432
+40433
+40434
+40435
+40436
+40437
+40438
+40439
+40440
+40441
+40442
+40443
+40444
+40445
+40446
+40447
+40448
+40449
+40450
+40451
+40452
+40453
+40454
+40455
+40456
+40457
+40458
+40459
+40460
+40461
+40462
+40463
+40464
+40465
+40466
+40467
+40468
+40469
+40470
+40471
+40472
+40473
+40474
+40475
+40476
+40477
+40478
+40479
+40480
+40481
+40482
+40483
+40484
+40485
+40486
+40487
+40488
+40489
+40490
+40491
+40492
+40493
+40494
+40495
+40496
+40497
+40498
+40499
+40500
+40501
+40502
+40503
+40504
+40505
+40506
+40507
+40508
+40509
+40510
+40511
+40512
+40513
+40514
+40515
+40516
+40517
+40518
+40519
+40520
+40521
+40522
+40523
+40524
+40525
+40526
+40527
+40528
+40529
+40530
+40531
+40532
+40533
+40534
+40535
+40536
+40537
+40538
+40539
+40540
+40541
+40542
+40543
+40544
+40545
+40546
+40547
+40548
+40549
+40550
+40551
+40552
+40553
+40554
+40555
+40556
+40557
+40558
+40559
+40560
+40561
+40562
+40563
+40564
+40565
+40566
+40567
+40568
+40569
+40570
+40571
+40572
+40573
+40574
+40575
+40576
+40577
+40578
+40579
+40580
+40581
+40582
+40583
+40584
+40585
+40586
+40587
+40588
+40589
+40590
+40591
+40592
+40593
+40594
+40595
+40596
+40597
+40598
+40599
+40600
+40601
+40602
+40603
+40604
+40605
+40606
+40607
+40608
+40609
+40610
+40611
+40612
+40613
+40614
+40615
+40616
+40617
+40618
+40619
+40620
+40621
+40622
+40623
+40624
+40625
+40626
+40627
+40628
+40629
+40630
+40631
+40632
+40633
+40634
+40635
+40636
+40637
+40638
+40639
+40640
+40641
+40642
+40643
+40644
+40645
+40646
+40647
+40648
+40649
+40650
+40651
+40652
+40653
+40654
+40655
+40656
+40657
+40658
+40659
+40660
+40661
+40662
+40663
+40664
+40665
+40666
+40667
+40668
+40669
+40670
+40671
+40672
+40673
+40674
+40675
+40676
+40677
+40678
+40679
+40680
+40681
+40682
+40683
+40684
+40685
+40686
+40687
+40688
+40689
+40690
+40691
+40692
+40693
+40694
+40695
+40696
+40697
+40698
+40699
+40700
+40701
+40702
+40703
+40704
+40705
+40706
+40707
+40708
+40709
+40710
+40711
+40712
+40713
+40714
+40715
+40716
+40717
+40718
+40719
+40720
+40721
+40722
+40723
+40724
+40725
+40726
+40727
+40728
+40729
+40730
+40731
+40732
+40733
+40734
+40735
+40736
+40737
+40738
+40739
+40740
+40741
+40742
+40743
+40744
+40745
+40746
+40747
+40748
+40749
+40750
+40751
+40752
+40753
+40754
+40755
+40756
+40757
+40758
+40759
+40760
+40761
+40762
+40763
+40764
+40765
+40766
+40767
+40768
+40769
+40770
+40771
+40772
+40773
+40774
+40775
+40776
+40777
+40778
+40779
+40780
+40781
+40782
+40783
+40784
+40785
+40786
+40787
+40788
+40789
+40790
+40791
+40792
+40793
+40794
+40795
+40796
+40797
+40798
+40799
+40800
+40801
+40802
+40803
+40804
+40805
+40806
+40807
+40808
+40809
+40810
+40811
+40812
+40813
+40814
+40815
+40816
+40817
+40818
+40819
+40820
+40821
+40822
+40823
+40824
+40825
+40826
+40827
+40828
+40829
+40830
+40831
+40832
+40833
+40834
+40835
+40836
+40837
+40838
+40839
+40840
+40841
+40842
+40843
+40844
+40845
+40846
+40847
+40848
+40849
+40850
+40851
+40852
+40853
+40854
+40855
+40856
+40857
+40858
+40859
+40860
+40861
+40862
+40863
+40864
+40865
+40866
+40867
+40868
+40869
+40870
+40871
+40872
+40873
+40874
+40875
+40876
+40877
+40878
+40879
+40880
+40881
+40882
+40883
+40884
+40885
+40886
+40887
+40888
+40889
+40890
+40891
+40892
+40893
+40894
+40895
+40896
+40897
+40898
+40899
+40900
+40901
+40902
+40903
+40904
+40905
+40906
+40907
+40908
+40909
+40910
+40911
+40912
+40913
+40914
+40915
+40916
+40917
+40918
+40919
+40920
+40921
+40922
+40923
+40924
+40925
+40926
+40927
+40928
+40929
+40930
+40931
+40932
+40933
+40934
+40935
+40936
+40937
+40938
+40939
+40940
+40941
+40942
+40943
+40944
+40945
+40946
+40947
+40948
+40949
+40950
+40951
+40952
+40953
+40954
+40955
+40956
+40957
+40958
+40959
+40960
+40961
+40962
+40963
+40964
+40965
+40966
+40967
+40968
+40969
+40970
+40971
+40972
+40973
+40974
+40975
+40976
+40977
+40978
+40979
+40980
+40981
+40982
+40983
+40984
+40985
+40986
+40987
+40988
+40989
+40990
+40991
+40992
+40993
+40994
+40995
+40996
+40997
+40998
+40999
+41000
+41001
+41002
+41003
+41004
+41005
+41006
+41007
+41008
+41009
+41010
+41011
+41012
+41013
+41014
+41015
+41016
+41017
+41018
+41019
+41020
+41021
+41022
+41023
+41024
+41025
+41026
+41027
+41028
+41029
+41030
+41031
+41032
+41033
+41034
+41035
+41036
+41037
+41038
+41039
+41040
+41041
+41042
+41043
+41044
+41045
+41046
+41047
+41048
+41049
+41050
+41051
+41052
+41053
+41054
+41055
+41056
+41057
+41058
+41059
+41060
+41061
+41062
+41063
+41064
+41065
+41066
+41067
+41068
+41069
+41070
+41071
+41072
+41073
+41074
+41075
+41076
+41077
+41078
+41079
+41080
+41081
+41082
+41083
+41084
+41085
+41086
+41087
+41088
+41089
+41090
+41091
+41092
+41093
+41094
+41095
+41096
+41097
+41098
+41099
+41100
+41101
+41102
+41103
+41104
+41105
+41106
+41107
+41108
+41109
+41110
+41111
+41112
+41113
+41114
+41115
+41116
+41117
+41118
+41119
+41120
+41121
+41122
+41123
+41124
+41125
+41126
+41127
+41128
+41129
+41130
+41131
+41132
+41133
+41134
+41135
+41136
+41137
+41138
+41139
+41140
+41141
+41142
+41143
+41144
+41145
+41146
+41147
+41148
+41149
+41150
+41151
+41152
+41153
+41154
+41155
+41156
+41157
+41158
+41159
+41160
+41161
+41162
+41163
+41164
+41165
+41166
+41167
+41168
+41169
+41170
+41171
+41172
+41173
+41174
+41175
+41176
+41177
+41178
+41179
+41180
+41181
+41182
+41183
+41184
+41185
+41186
+41187
+41188
+41189
+41190
+41191
+41192
+41193
+41194
+41195
+41196
+41197
+41198
+41199
+41200
+41201
+41202
+41203
+41204
+41205
+41206
+41207
+41208
+41209
+41210
+41211
+41212
+41213
+41214
+41215
+41216
+41217
+41218
+41219
+41220
+41221
+41222
+41223
+41224
+41225
+41226
+41227
+41228
+41229
+41230
+41231
+41232
+41233
+41234
+41235
+41236
+41237
+41238
+41239
+41240
+41241
+41242
+41243
+41244
+41245
+41246
+41247
+41248
+41249
+41250
+41251
+41252
+41253
+41254
+41255
+41256
+41257
+41258
+41259
+41260
+41261
+41262
+41263
+41264
+41265
+41266
+41267
+41268
+41269
+41270
+41271
+41272
+41273
+41274
+41275
+41276
+41277
+41278
+41279
+41280
+41281
+41282
+41283
+41284
+41285
+41286
+41287
+41288
+41289
+41290
+41291
+41292
+41293
+41294
+41295
+41296
+41297
+41298
+41299
+41300
+41301
+41302
+41303
+41304
+41305
+41306
+41307
+41308
+41309
+41310
+41311
+41312
+41313
+41314
+41315
+41316
+41317
+41318
+41319
+41320
+41321
+41322
+41323
+41324
+41325
+41326
+41327
+41328
+41329
+41330
+41331
+41332
+41333
+41334
+41335
+41336
+41337
+41338
+41339
+41340
+41341
+41342
+41343
+41344
+41345
+41346
+41347
+41348
+41349
+41350
+41351
+41352
+41353
+41354
+41355
+41356
+41357
+41358
+41359
+41360
+41361
+41362
+41363
+41364
+41365
+41366
+41367
+41368
+41369
+41370
+41371
+41372
+41373
+41374
+41375
+41376
+41377
+41378
+41379
+41380
+41381
+41382
+41383
+41384
+41385
+41386
+41387
+41388
+41389
+41390
+41391
+41392
+41393
+41394
+41395
+41396
+41397
+41398
+41399
+41400
+41401
+41402
+41403
+41404
+41405
+41406
+41407
+41408
+41409
+41410
+41411
+41412
+41413
+41414
+41415
+41416
+41417
+41418
+41419
+41420
+41421
+41422
+41423
+41424
+41425
+41426
+41427
+41428
+41429
+41430
+41431
+41432
+41433
+41434
+41435
+41436
+41437
+41438
+41439
+41440
+41441
+41442
+41443
+41444
+41445
+41446
+41447
+41448
+41449
+41450
+41451
+41452
+41453
+41454
+41455
+41456
+41457
+41458
+41459
+41460
+41461
+41462
+41463
+41464
+41465
+41466
+41467
+41468
+41469
+41470
+41471
+41472
+41473
+41474
+41475
+41476
+41477
+41478
+41479
+41480
+41481
+41482
+41483
+41484
+41485
+41486
+41487
+41488
+41489
+41490
+41491
+41492
+41493
+41494
+41495
+41496
+41497
+41498
+41499
+41500
+41501
+41502
+41503
+41504
+41505
+41506
+41507
+41508
+41509
+41510
+41511
+41512
+41513
+41514
+41515
+41516
+41517
+41518
+41519
+41520
+41521
+41522
+41523
+41524
+41525
+41526
+41527
+41528
+41529
+41530
+41531
+41532
+41533
+41534
+41535
+41536
+41537
+41538
+41539
+41540
+41541
+41542
+41543
+41544
+41545
+41546
+41547
+41548
+41549
+41550
+41551
+41552
+41553
+41554
+41555
+41556
+41557
+41558
+41559
+41560
+41561
+41562
+41563
+41564
+41565
+41566
+41567
+41568
+41569
+41570
+41571
+41572
+41573
+41574
+41575
+41576
+41577
+41578
+41579
+41580
+41581
+41582
+41583
+41584
+41585
+41586
+41587
+41588
+41589
+41590
+41591
+41592
+41593
+41594
+41595
+41596
+41597
+41598
+41599
+41600
+41601
+41602
+41603
+41604
+41605
+41606
+41607
+41608
+41609
+41610
+41611
+41612
+41613
+41614
+41615
+41616
+41617
+41618
+41619
+41620
+41621
+41622
+41623
+41624
+41625
+41626
+41627
+41628
+41629
+41630
+41631
+41632
+41633
+41634
+41635
+41636
+41637
+41638
+41639
+41640
+41641
+41642
+41643
+41644
+41645
+41646
+41647
+41648
+41649
+41650
+41651
+41652
+41653
+41654
+41655
+41656
+41657
+41658
+41659
+41660
+41661
+41662
+41663
+41664
+41665
+41666
+41667
+41668
+41669
+41670
+41671
+41672
+41673
+41674
+41675
+41676
+41677
+41678
+41679
+41680
+41681
+41682
+41683
+41684
+41685
+41686
+41687
+41688
+41689
+41690
+41691
+41692
+41693
+41694
+41695
+41696
+41697
+41698
+41699
+41700
+41701
+41702
+41703
+41704
+41705
+41706
+41707
+41708
+41709
+41710
+41711
+41712
+41713
+41714
+41715
+41716
+41717
+41718
+41719
+41720
+41721
+41722
+41723
+41724
+41725
+41726
+41727
+41728
+41729
+41730
+41731
+41732
+41733
+41734
+41735
+41736
+41737
+41738
+41739
+41740
+41741
+41742
+41743
+41744
+41745
+41746
+41747
+41748
+41749
+41750
+41751
+41752
+41753
+41754
+41755
+41756
+41757
+41758
+41759
+41760
+41761
+41762
+41763
+41764
+41765
+41766
+41767
+41768
+41769
+41770
+41771
+41772
+41773
+41774
+41775
+41776
+41777
+41778
+41779
+41780
+41781
+41782
+41783
+41784
+41785
+41786
+41787
+41788
+41789
+41790
+41791
+41792
+41793
+41794
+41795
+41796
+41797
+41798
+41799
+41800
+41801
+41802
+41803
+41804
+41805
+41806
+41807
+41808
+41809
+41810
+41811
+41812
+41813
+41814
+41815
+41816
+41817
+41818
+41819
+41820
+41821
+41822
+41823
+41824
+41825
+41826
+41827
+41828
+41829
+41830
+41831
+41832
+41833
+41834
+41835
+41836
+41837
+41838
+41839
+41840
+41841
+41842
+41843
+41844
+41845
+41846
+41847
+41848
+41849
+41850
+41851
+41852
+41853
+41854
+41855
+41856
+41857
+41858
+41859
+41860
+41861
+41862
+41863
+41864
+41865
+41866
+41867
+41868
+41869
+41870
+41871
+41872
+41873
+41874
+41875
+41876
+41877
+41878
+41879
+41880
+41881
+41882
+41883
+41884
+41885
+41886
+41887
+41888
+41889
+41890
+41891
+41892
+41893
+41894
+41895
+41896
+41897
+41898
+41899
+41900
+41901
+41902
+41903
+41904
+41905
+41906
+41907
+41908
+41909
+41910
+41911
+41912
+41913
+41914
+41915
+41916
+41917
+41918
+41919
+41920
+41921
+41922
+41923
+41924
+41925
+41926
+41927
+41928
+41929
+41930
+41931
+41932
+41933
+41934
+41935
+41936
+41937
+41938
+41939
+41940
+41941
+41942
+41943
+41944
+41945
+41946
+41947
+41948
+41949
+41950
+41951
+41952
+41953
+41954
+41955
+41956
+41957
+41958
+41959
+41960
+41961
+41962
+41963
+41964
+41965
+41966
+41967
+41968
+41969
+41970
+41971
+41972
+41973
+41974
+41975
+41976
+41977
+41978
+41979
+41980
+41981
+41982
+41983
+41984
+41985
+41986
+41987
+41988
+41989
+41990
+41991
+41992
+41993
+41994
+41995
+41996
+41997
+41998
+41999
+42000
+42001
+42002
+42003
+42004
+42005
+42006
+42007
+42008
+42009
+42010
+42011
+42012
+42013
+42014
+42015
+42016
+42017
+42018
+42019
+42020
+42021
+42022
+42023
+42024
+42025
+42026
+42027
+42028
+42029
+42030
+42031
+42032
+42033
+42034
+42035
+42036
+42037
+42038
+42039
+42040
+42041
+42042
+42043
+42044
+42045
+42046
+42047
+42048
+42049
+42050
+42051
+42052
+42053
+42054
+42055
+42056
+42057
+42058
+42059
+42060
+42061
+42062
+42063
+42064
+42065
+42066
+42067
+42068
+42069
+42070
+42071
+42072
+42073
+42074
+42075
+42076
+42077
+42078
+42079
+42080
+42081
+42082
+42083
+42084
+42085
+42086
+42087
+42088
+42089
+42090
+42091
+42092
+42093
+42094
+42095
+42096
+42097
+42098
+42099
+42100
+42101
+42102
+42103
+42104
+42105
+42106
+42107
+42108
+42109
+42110
+42111
+42112
+42113
+42114
+42115
+42116
+42117
+42118
+42119
+42120
+42121
+42122
+42123
+42124
+42125
+42126
+42127
+42128
+42129
+42130
+42131
+42132
+42133
+42134
+42135
+42136
+42137
+42138
+42139
+42140
+42141
+42142
+42143
+42144
+42145
+42146
+42147
+42148
+42149
+42150
+42151
+42152
+42153
+42154
+42155
+42156
+42157
+42158
+42159
+42160
+42161
+42162
+42163
+42164
+42165
+42166
+42167
+42168
+42169
+42170
+42171
+42172
+42173
+42174
+42175
+42176
+42177
+42178
+42179
+42180
+42181
+42182
+42183
+42184
+42185
+42186
+42187
+42188
+42189
+42190
+42191
+42192
+42193
+42194
+42195
+42196
+42197
+42198
+42199
+42200
+42201
+42202
+42203
+42204
+42205
+42206
+42207
+42208
+42209
+42210
+42211
+42212
+42213
+42214
+42215
+42216
+42217
+42218
+42219
+42220
+42221
+42222
+42223
+42224
+42225
+42226
+42227
+42228
+42229
+42230
+42231
+42232
+42233
+42234
+42235
+42236
+42237
+42238
+42239
+42240
+42241
+42242
+42243
+42244
+42245
+42246
+42247
+42248
+42249
+42250
+42251
+42252
+42253
+42254
+42255
+42256
+42257
+42258
+42259
+42260
+42261
+42262
+42263
+42264
+42265
+42266
+42267
+42268
+42269
+42270
+42271
+42272
+42273
+42274
+42275
+42276
+42277
+42278
+42279
+42280
+42281
+42282
+42283
+42284
+42285
+42286
+42287
+42288
+42289
+42290
+42291
+42292
+42293
+42294
+42295
+42296
+42297
+42298
+42299
+42300
+42301
+42302
+42303
+42304
+42305
+42306
+42307
+42308
+42309
+42310
+42311
+42312
+42313
+42314
+42315
+42316
+42317
+42318
+42319
+42320
+42321
+42322
+42323
+42324
+42325
+42326
+42327
+42328
+42329
+42330
+42331
+42332
+42333
+42334
+42335
+42336
+42337
+42338
+42339
+42340
+42341
+42342
+42343
+42344
+42345
+42346
+42347
+42348
+42349
+42350
+42351
+42352
+42353
+42354
+42355
+42356
+42357
+42358
+42359
+42360
+42361
+42362
+42363
+42364
+42365
+42366
+42367
+42368
+42369
+42370
+42371
+42372
+42373
+42374
+42375
+42376
+42377
+42378
+42379
+42380
+42381
+42382
+42383
+42384
+42385
+42386
+42387
+42388
+42389
+42390
+42391
+42392
+42393
+42394
+42395
+42396
+42397
+42398
+42399
+42400
+42401
+42402
+42403
+42404
+42405
+42406
+42407
+42408
+42409
+42410
+42411
+42412
+42413
+42414
+42415
+42416
+42417
+42418
+42419
+42420
+42421
+42422
+42423
+42424
+42425
+42426
+42427
+42428
+42429
+42430
+42431
+42432
+42433
+42434
+42435
+42436
+42437
+42438
+42439
+42440
+42441
+42442
+42443
+42444
+42445
+42446
+42447
+42448
+42449
+42450
+42451
+42452
+42453
+42454
+42455
+42456
+42457
+42458
+42459
+42460
+42461
+42462
+42463
+42464
+42465
+42466
+42467
+42468
+42469
+42470
+42471
+42472
+42473
+42474
+42475
+42476
+42477
+42478
+42479
+42480
+42481
+42482
+42483
+42484
+42485
+42486
+42487
+42488
+42489
+42490
+42491
+42492
+42493
+42494
+42495
+42496
+42497
+42498
+42499
+42500
+42501
+42502
+42503
+42504
+42505
+42506
+42507
+42508
+42509
+42510
+42511
+42512
+42513
+42514
+42515
+42516
+42517
+42518
+42519
+42520
+42521
+42522
+42523
+42524
+42525
+42526
+42527
+42528
+42529
+42530
+42531
+42532
+42533
+42534
+42535
+42536
+42537
+42538
+42539
+42540
+42541
+42542
+42543
+42544
+42545
+42546
+42547
+42548
+42549
+42550
+42551
+42552
+42553
+42554
+42555
+42556
+42557
+42558
+42559
+42560
+42561
+42562
+42563
+42564
+42565
+42566
+42567
+42568
+42569
+42570
+42571
+42572
+42573
+42574
+42575
+42576
+42577
+42578
+42579
+42580
+42581
+42582
+42583
+42584
+42585
+42586
+42587
+42588
+42589
+42590
+42591
+42592
+42593
+42594
+42595
+42596
+42597
+42598
+42599
+42600
+42601
+42602
+42603
+42604
+42605
+42606
+42607
+42608
+42609
+42610
+42611
+42612
+42613
+42614
+42615
+42616
+42617
+42618
+42619
+42620
+42621
+42622
+42623
+42624
+42625
+42626
+42627
+42628
+42629
+42630
+42631
+42632
+42633
+42634
+42635
+42636
+42637
+42638
+42639
+42640
+42641
+42642
+42643
+42644
+42645
+42646
+42647
+42648
+42649
+42650
+42651
+42652
+42653
+42654
+42655
+42656
+42657
+42658
+42659
+42660
+42661
+42662
+42663
+42664
+42665
+42666
+42667
+42668
+42669
+42670
+42671
+42672
+42673
+42674
+42675
+42676
+42677
+42678
+42679
+42680
+42681
+42682
+42683
+42684
+42685
+42686
+42687
+42688
+42689
+42690
+42691
+42692
+42693
+42694
+42695
+42696
+42697
+42698
+42699
+42700
+42701
+42702
+42703
+42704
+42705
+42706
+42707
+42708
+42709
+42710
+42711
+42712
+42713
+42714
+42715
+42716
+42717
+42718
+42719
+42720
+42721
+42722
+42723
+42724
+42725
+42726
+42727
+42728
+42729
+42730
+42731
+42732
+42733
+42734
+42735
+42736
+42737
+42738
+42739
+42740
+42741
+42742
+42743
+42744
+42745
+42746
+42747
+42748
+42749
+42750
+42751
+42752
+42753
+42754
+42755
+42756
+42757
+42758
+42759
+42760
+42761
+42762
+42763
+42764
+42765
+42766
+42767
+42768
+42769
+42770
+42771
+42772
+42773
+42774
+42775
+42776
+42777
+42778
+42779
+42780
+42781
+42782
+42783
+42784
+42785
+42786
+42787
+42788
+42789
+42790
+42791
+42792
+42793
+42794
+42795
+42796
+42797
+42798
+42799
+42800
+42801
+42802
+42803
+42804
+42805
+42806
+42807
+42808
+42809
+42810
+42811
+42812
+42813
+42814
+42815
+42816
+42817
+42818
+42819
+42820
+42821
+42822
+42823
+42824
+42825
+42826
+42827
+42828
+42829
+42830
+42831
+42832
+42833
+42834
+42835
+42836
+42837
+42838
+42839
+42840
+42841
+42842
+42843
+42844
+42845
+42846
+42847
+42848
+42849
+42850
+42851
+42852
+42853
+42854
+42855
+42856
+42857
+42858
+42859
+42860
+42861
+42862
+42863
+42864
+42865
+42866
+42867
+42868
+42869
+42870
+42871
+42872
+42873
+42874
+42875
+42876
+42877
+42878
+42879
+42880
+42881
+42882
+42883
+42884
+42885
+42886
+42887
+42888
+42889
+42890
+42891
+42892
+42893
+42894
+42895
+42896
+42897
+42898
+42899
+42900
+42901
+42902
+42903
+42904
+42905
+42906
+42907
+42908
+42909
+42910
+42911
+42912
+42913
+42914
+42915
+42916
+42917
+42918
+42919
+42920
+42921
+42922
+42923
+42924
+42925
+42926
+42927
+42928
+42929
+42930
+42931
+42932
+42933
+42934
+42935
+42936
+42937
+42938
+42939
+42940
+42941
+42942
+42943
+42944
+42945
+42946
+42947
+42948
+42949
+42950
+42951
+42952
+42953
+42954
+42955
+42956
+42957
+42958
+42959
+42960
+42961
+42962
+42963
+42964
+42965
+42966
+42967
+42968
+42969
+42970
+42971
+42972
+42973
+42974
+42975
+42976
+42977
+42978
+42979
+42980
+42981
+42982
+42983
+42984
+42985
+42986
+42987
+42988
+42989
+42990
+42991
+42992
+42993
+42994
+42995
+42996
+42997
+42998
+42999
+43000
+43001
+43002
+43003
+43004
+43005
+43006
+43007
+43008
+43009
+43010
+43011
+43012
+43013
+43014
+43015
+43016
+43017
+43018
+43019
+43020
+43021
+43022
+43023
+43024
+43025
+43026
+43027
+43028
+43029
+43030
+43031
+43032
+43033
+43034
+43035
+43036
+43037
+43038
+43039
+43040
+43041
+43042
+43043
+43044
+43045
+43046
+43047
+43048
+43049
+43050
+43051
+43052
+43053
+43054
+43055
+43056
+43057
+43058
+43059
+43060
+43061
+43062
+43063
+43064
+43065
+43066
+43067
+43068
+43069
+43070
+43071
+43072
+43073
+43074
+43075
+43076
+43077
+43078
+43079
+43080
+43081
+43082
+43083
+43084
+43085
+43086
+43087
+43088
+43089
+43090
+43091
+43092
+43093
+43094
+43095
+43096
+43097
+43098
+43099
+43100
+43101
+43102
+43103
+43104
+43105
+43106
+43107
+43108
+43109
+43110
+43111
+43112
+43113
+43114
+43115
+43116
+43117
+43118
+43119
+43120
+43121
+43122
+43123
+43124
+43125
+43126
+43127
+43128
+43129
+43130
+43131
+43132
+43133
+43134
+43135
+43136
+43137
+43138
+43139
+43140
+43141
+43142
+43143
+43144
+43145
+43146
+43147
+43148
+43149
+43150
+43151
+43152
+43153
+43154
+43155
+43156
+43157
+43158
+43159
+43160
+43161
+43162
+43163
+43164
+43165
+43166
+43167
+43168
+43169
+43170
+43171
+43172
+43173
+43174
+43175
+43176
+43177
+43178
+43179
+43180
+43181
+43182
+43183
+43184
+43185
+43186
+43187
+43188
+43189
+43190
+43191
+43192
+43193
+43194
+43195
+43196
+43197
+43198
+43199
+43200
+43201
+43202
+43203
+43204
+43205
+43206
+43207
+43208
+43209
+43210
+43211
+43212
+43213
+43214
+43215
+43216
+43217
+43218
+43219
+43220
+43221
+43222
+43223
+43224
+43225
+43226
+43227
+43228
+43229
+43230
+43231
+43232
+43233
+43234
+43235
+43236
+43237
+43238
+43239
+43240
+43241
+43242
+43243
+43244
+43245
+43246
+43247
+43248
+43249
+43250
+43251
+43252
+43253
+43254
+43255
+43256
+43257
+43258
+43259
+43260
+43261
+43262
+43263
+43264
+43265
+43266
+43267
+43268
+43269
+43270
+43271
+43272
+43273
+43274
+43275
+43276
+43277
+43278
+43279
+43280
+43281
+43282
+43283
+43284
+43285
+43286
+43287
+43288
+43289
+43290
+43291
+43292
+43293
+43294
+43295
+43296
+43297
+43298
+43299
+43300
+43301
+43302
+43303
+43304
+43305
+43306
+43307
+43308
+43309
+43310
+43311
+43312
+43313
+43314
+43315
+43316
+43317
+43318
+43319
+43320
+43321
+43322
+43323
+43324
+43325
+43326
+43327
+43328
+43329
+43330
+43331
+43332
+43333
+43334
+43335
+43336
+43337
+43338
+43339
+43340
+43341
+43342
+43343
+43344
+43345
+43346
+43347
+43348
+43349
+43350
+43351
+43352
+43353
+43354
+43355
+43356
+43357
+43358
+43359
+43360
+43361
+43362
+43363
+43364
+43365
+43366
+43367
+43368
+43369
+43370
+43371
+43372
+43373
+43374
+43375
+43376
+43377
+43378
+43379
+43380
+43381
+43382
+43383
+43384
+43385
+43386
+43387
+43388
+43389
+43390
+43391
+43392
+43393
+43394
+43395
+43396
+43397
+43398
+43399
+43400
+43401
+43402
+43403
+43404
+43405
+43406
+43407
+43408
+43409
+43410
+43411
+43412
+43413
+43414
+43415
+43416
+43417
+43418
+43419
+43420
+43421
+43422
+43423
+43424
+43425
+43426
+43427
+43428
+43429
+43430
+43431
+43432
+43433
+43434
+43435
+43436
+43437
+43438
+43439
+43440
+43441
+43442
+43443
+43444
+43445
+43446
+43447
+43448
+43449
+43450
+43451
+43452
+43453
+43454
+43455
+43456
+43457
+43458
+43459
+43460
+43461
+43462
+43463
+43464
+43465
+43466
+43467
+43468
+43469
+43470
+43471
+43472
+43473
+43474
+43475
+43476
+43477
+43478
+43479
+43480
+43481
+43482
+43483
+43484
+43485
+43486
+43487
+43488
+43489
+43490
+43491
+43492
+43493
+43494
+43495
+43496
+43497
+43498
+43499
+43500
+43501
+43502
+43503
+43504
+43505
+43506
+43507
+43508
+43509
+43510
+43511
+43512
+43513
+43514
+43515
+43516
+43517
+43518
+43519
+43520
+43521
+43522
+43523
+43524
+43525
+43526
+43527
+43528
+43529
+43530
+43531
+43532
+43533
+43534
+43535
+43536
+43537
+43538
+43539
+43540
+43541
+43542
+43543
+43544
+43545
+43546
+43547
+43548
+43549
+43550
+43551
+43552
+43553
+43554
+43555
+43556
+43557
+43558
+43559
+43560
+43561
+43562
+43563
+43564
+43565
+43566
+43567
+43568
+43569
+43570
+43571
+43572
+43573
+43574
+43575
+43576
+43577
+43578
+43579
+43580
+43581
+43582
+43583
+43584
+43585
+43586
+43587
+43588
+43589
+43590
+43591
+43592
+43593
+43594
+43595
+43596
+43597
+43598
+43599
+43600
+43601
+43602
+43603
+43604
+43605
+43606
+43607
+43608
+43609
+43610
+43611
+43612
+43613
+43614
+43615
+43616
+43617
+43618
+43619
+43620
+43621
+43622
+43623
+43624
+43625
+43626
+43627
+43628
+43629
+43630
+43631
+43632
+43633
+43634
+43635
+43636
+43637
+43638
+43639
+43640
+43641
+43642
+43643
+43644
+43645
+43646
+43647
+43648
+43649
+43650
+43651
+43652
+43653
+43654
+43655
+43656
+43657
+43658
+43659
+43660
+43661
+43662
+43663
+43664
+43665
+43666
+43667
+43668
+43669
+43670
+43671
+43672
+43673
+43674
+43675
+43676
+43677
+43678
+43679
+43680
+43681
+43682
+43683
+43684
+43685
+43686
+43687
+43688
+43689
+43690
+43691
+43692
+43693
+43694
+43695
+43696
+43697
+43698
+43699
+43700
+43701
+43702
+43703
+43704
+43705
+43706
+43707
+43708
+43709
+43710
+43711
+43712
+43713
+43714
+43715
+43716
+43717
+43718
+43719
+43720
+43721
+43722
+43723
+43724
+43725
+43726
+43727
+43728
+43729
+43730
+43731
+43732
+43733
+43734
+43735
+43736
+43737
+43738
+43739
+43740
+43741
+43742
+43743
+43744
+43745
+43746
+43747
+43748
+43749
+43750
+43751
+43752
+43753
+43754
+43755
+43756
+43757
+43758
+43759
+43760
+43761
+43762
+43763
+43764
+43765
+43766
+43767
+43768
+43769
+43770
+43771
+43772
+43773
+43774
+43775
+43776
+43777
+43778
+43779
+43780
+43781
+43782
+43783
+43784
+43785
+43786
+43787
+43788
+43789
+43790
+43791
+43792
+43793
+43794
+43795
+43796
+43797
+43798
+43799
+43800
+43801
+43802
+43803
+43804
+43805
+43806
+43807
+43808
+43809
+43810
+43811
+43812
+43813
+43814
+43815
+43816
+43817
+43818
+43819
+43820
+43821
+43822
+43823
+43824
+43825
+43826
+43827
+43828
+43829
+43830
+43831
+43832
+43833
+43834
+43835
+43836
+43837
+43838
+43839
+43840
+43841
+43842
+43843
+43844
+43845
+43846
+43847
+43848
+43849
+43850
+43851
+43852
+43853
+43854
+43855
+43856
+43857
+43858
+43859
+43860
+43861
+43862
+43863
+43864
+43865
+43866
+43867
+43868
+43869
+43870
+43871
+43872
+43873
+43874
+43875
+43876
+43877
+43878
+43879
+43880
+43881
+43882
+43883
+43884
+43885
+43886
+43887
+43888
+43889
+43890
+43891
+43892
+43893
+43894
+43895
+43896
+43897
+43898
+43899
+43900
+43901
+43902
+43903
+43904
+43905
+43906
+43907
+43908
+43909
+43910
+43911
+43912
+43913
+43914
+43915
+43916
+43917
+43918
+43919
+43920
+43921
+43922
+43923
+43924
+43925
+43926
+43927
+43928
+43929
+43930
+43931
+43932
+43933
+43934
+43935
+43936
+43937
+43938
+43939
+43940
+43941
+43942
+43943
+43944
+43945
+43946
+43947
+43948
+43949
+43950
+43951
+43952
+43953
+43954
+43955
+43956
+43957
+43958
+43959
+43960
+43961
+43962
+43963
+43964
+43965
+43966
+43967
+43968
+43969
+43970
+43971
+43972
+43973
+43974
+43975
+43976
+43977
+43978
+43979
+43980
+43981
+43982
+43983
+43984
+43985
+43986
+43987
+43988
+43989
+43990
+43991
+43992
+43993
+43994
+43995
+43996
+43997
+43998
+43999
+44000
+44001
+44002
+44003
+44004
+44005
+44006
+44007
+44008
+44009
+44010
+44011
+44012
+44013
+44014
+44015
+44016
+44017
+44018
+44019
+44020
+44021
+44022
+44023
+44024
+44025
+44026
+44027
+44028
+44029
+44030
+44031
+44032
+44033
+44034
+44035
+44036
+44037
+44038
+44039
+44040
+44041
+44042
+44043
+44044
+44045
+44046
+44047
+44048
+44049
+44050
+44051
+44052
+44053
+44054
+44055
+44056
+44057
+44058
+44059
+44060
+44061
+44062
+44063
+44064
+44065
+44066
+44067
+44068
+44069
+44070
+44071
+44072
+44073
+44074
+44075
+44076
+44077
+44078
+44079
+44080
+44081
+44082
+44083
+44084
+44085
+44086
+44087
+44088
+44089
+44090
+44091
+44092
+44093
+44094
+44095
+44096
+44097
+44098
+44099
+44100
+44101
+44102
+44103
+44104
+44105
+44106
+44107
+44108
+44109
+44110
+44111
+44112
+44113
+44114
+44115
+44116
+44117
+44118
+44119
+44120
+44121
+44122
+44123
+44124
+44125
+44126
+44127
+44128
+44129
+44130
+44131
+44132
+44133
+44134
+44135
+44136
+44137
+44138
+44139
+44140
+44141
+44142
+44143
+44144
+44145
+44146
+44147
+44148
+44149
+44150
+44151
+44152
+44153
+44154
+44155
+44156
+44157
+44158
+44159
+44160
+44161
+44162
+44163
+44164
+44165
+44166
+44167
+44168
+44169
+44170
+44171
+44172
+44173
+44174
+44175
+44176
+44177
+44178
+44179
+44180
+44181
+44182
+44183
+44184
+44185
+44186
+44187
+44188
+44189
+44190
+44191
+44192
+44193
+44194
+44195
+44196
+44197
+44198
+44199
+44200
+44201
+44202
+44203
+44204
+44205
+44206
+44207
+44208
+44209
+44210
+44211
+44212
+44213
+44214
+44215
+44216
+44217
+44218
+44219
+44220
+44221
+44222
+44223
+44224
+44225
+44226
+44227
+44228
+44229
+44230
+44231
+44232
+44233
+44234
+44235
+44236
+44237
+44238
+44239
+44240
+44241
+44242
+44243
+44244
+44245
+44246
+44247
+44248
+44249
+44250
+44251
+44252
+44253
+44254
+44255
+44256
+44257
+44258
+44259
+44260
+44261
+44262
+44263
+44264
+44265
+44266
+44267
+44268
+44269
+44270
+44271
+44272
+44273
+44274
+44275
+44276
+44277
+44278
+44279
+44280
+44281
+44282
+44283
+44284
+44285
+44286
+44287
+44288
+44289
+44290
+44291
+44292
+44293
+44294
+44295
+44296
+44297
+44298
+44299
+44300
+44301
+44302
+44303
+44304
+44305
+44306
+44307
+44308
+44309
+44310
+44311
+44312
+44313
+44314
+44315
+44316
+44317
+44318
+44319
+44320
+44321
+44322
+44323
+44324
+44325
+44326
+44327
+44328
+44329
+44330
+44331
+44332
+44333
+44334
+44335
+44336
+44337
+44338
+44339
+44340
+44341
+44342
+44343
+44344
+44345
+44346
+44347
+44348
+44349
+44350
+44351
+44352
+44353
+44354
+44355
+44356
+44357
+44358
+44359
+44360
+44361
+44362
+44363
+44364
+44365
+44366
+44367
+44368
+44369
+44370
+44371
+44372
+44373
+44374
+44375
+44376
+44377
+44378
+44379
+44380
+44381
+44382
+44383
+44384
+44385
+44386
+44387
+44388
+44389
+44390
+44391
+44392
+44393
+44394
+44395
+44396
+44397
+44398
+44399
+44400
+44401
+44402
+44403
+44404
+44405
+44406
+44407
+44408
+44409
+44410
+44411
+44412
+44413
+44414
+44415
+44416
+44417
+44418
+44419
+44420
+44421
+44422
+44423
+44424
+44425
+44426
+44427
+44428
+44429
+44430
+44431
+44432
+44433
+44434
+44435
+44436
+44437
+44438
+44439
+44440
+44441
+44442
+44443
+44444
+44445
+44446
+44447
+44448
+44449
+44450
+44451
+44452
+44453
+44454
+44455
+44456
+44457
+44458
+44459
+44460
+44461
+44462
+44463
+44464
+44465
+44466
+44467
+44468
+44469
+44470
+44471
+44472
+44473
+44474
+44475
+44476
+44477
+44478
+44479
+44480
+44481
+44482
+44483
+44484
+44485
+44486
+44487
+44488
+44489
+44490
+44491
+44492
+44493
+44494
+44495
+44496
+44497
+44498
+44499
+44500
+44501
+44502
+44503
+44504
+44505
+44506
+44507
+44508
+44509
+44510
+44511
+44512
+44513
+44514
+44515
+44516
+44517
+44518
+44519
+44520
+44521
+44522
+44523
+44524
+44525
+44526
+44527
+44528
+44529
+44530
+44531
+44532
+44533
+44534
+44535
+44536
+44537
+44538
+44539
+44540
+44541
+44542
+44543
+44544
+44545
+44546
+44547
+44548
+44549
+44550
+44551
+44552
+44553
+44554
+44555
+44556
+44557
+44558
+44559
+44560
+44561
+44562
+44563
+44564
+44565
+44566
+44567
+44568
+44569
+44570
+44571
+44572
+44573
+44574
+44575
+44576
+44577
+44578
+44579
+44580
+44581
+44582
+44583
+44584
+44585
+44586
+44587
+44588
+44589
+44590
+44591
+44592
+44593
+44594
+44595
+44596
+44597
+44598
+44599
+44600
+44601
+44602
+44603
+44604
+44605
+44606
+44607
+44608
+44609
+44610
+44611
+44612
+44613
+44614
+44615
+44616
+44617
+44618
+44619
+44620
+44621
+44622
+44623
+44624
+44625
+44626
+44627
+44628
+44629
+44630
+44631
+44632
+44633
+44634
+44635
+44636
+44637
+44638
+44639
+44640
+44641
+44642
+44643
+44644
+44645
+44646
+44647
+44648
+44649
+44650
+44651
+44652
+44653
+44654
+44655
+44656
+44657
+44658
+44659
+44660
+44661
+44662
+44663
+44664
+44665
+44666
+44667
+44668
+44669
+44670
+44671
+44672
+44673
+44674
+44675
+44676
+44677
+44678
+44679
+44680
+44681
+44682
+44683
+44684
+44685
+44686
+44687
+44688
+44689
+44690
+44691
+44692
+44693
+44694
+44695
+44696
+44697
+44698
+44699
+44700
+44701
+44702
+44703
+44704
+44705
+44706
+44707
+44708
+44709
+44710
+44711
+44712
+44713
+44714
+44715
+44716
+44717
+44718
+44719
+44720
+44721
+44722
+44723
+44724
+44725
+44726
+44727
+44728
+44729
+44730
+44731
+44732
+44733
+44734
+44735
+44736
+44737
+44738
+44739
+44740
+44741
+44742
+44743
+44744
+44745
+44746
+44747
+44748
+44749
+44750
+44751
+44752
+44753
+44754
+44755
+44756
+44757
+44758
+44759
+44760
+44761
+44762
+44763
+44764
+44765
+44766
+44767
+44768
+44769
+44770
+44771
+44772
+44773
+44774
+44775
+44776
+44777
+44778
+44779
+44780
+44781
+44782
+44783
+44784
+44785
+44786
+44787
+44788
+44789
+44790
+44791
+44792
+44793
+44794
+44795
+44796
+44797
+44798
+44799
+44800
+44801
+44802
+44803
+44804
+44805
+44806
+44807
+44808
+44809
+44810
+44811
+44812
+44813
+44814
+44815
+44816
+44817
+44818
+44819
+44820
+44821
+44822
+44823
+44824
+44825
+44826
+44827
+44828
+44829
+44830
+44831
+44832
+44833
+44834
+44835
+44836
+44837
+44838
+44839
+44840
+44841
+44842
+44843
+44844
+44845
+44846
+44847
+44848
+44849
+44850
+44851
+44852
+44853
+44854
+44855
+44856
+44857
+44858
+44859
+44860
+44861
+44862
+44863
+44864
+44865
+44866
+44867
+44868
+44869
+44870
+44871
+44872
+44873
+44874
+44875
+44876
+44877
+44878
+44879
+44880
+44881
+44882
+44883
+44884
+44885
+44886
+44887
+44888
+44889
+44890
+44891
+44892
+44893
+44894
+44895
+44896
+44897
+44898
+44899
+44900
+44901
+44902
+44903
+44904
+44905
+44906
+44907
+44908
+44909
+44910
+44911
+44912
+44913
+44914
+44915
+44916
+44917
+44918
+44919
+44920
+44921
+44922
+44923
+44924
+44925
+44926
+44927
+44928
+44929
+44930
+44931
+44932
+44933
+44934
+44935
+44936
+44937
+44938
+44939
+44940
+44941
+44942
+44943
+44944
+44945
+44946
+44947
+44948
+44949
+44950
+44951
+44952
+44953
+44954
+44955
+44956
+44957
+44958
+44959
+44960
+44961
+44962
+44963
+44964
+44965
+44966
+44967
+44968
+44969
+44970
+44971
+44972
+44973
+44974
+44975
+44976
+44977
+44978
+44979
+44980
+44981
+44982
+44983
+44984
+44985
+44986
+44987
+44988
+44989
+44990
+44991
+44992
+44993
+44994
+44995
+44996
+44997
+44998
+44999
+45000
+45001
+45002
+45003
+45004
+45005
+45006
+45007
+45008
+45009
+45010
+45011
+45012
+45013
+45014
+45015
+45016
+45017
+45018
+45019
+45020
+45021
+45022
+45023
+45024
+45025
+45026
+45027
+45028
+45029
+45030
+45031
+45032
+45033
+45034
+45035
+45036
+45037
+45038
+45039
+45040
+45041
+45042
+45043
+45044
+45045
+45046
+45047
+45048
+45049
+45050
+45051
+45052
+45053
+45054
+45055
+45056
+45057
+45058
+45059
+45060
+45061
+45062
+45063
+45064
+45065
+45066
+45067
+45068
+45069
+45070
+45071
+45072
+45073
+45074
+45075
+45076
+45077
+45078
+45079
+45080
+45081
+45082
+45083
+45084
+45085
+45086
+45087
+45088
+45089
+45090
+45091
+45092
+45093
+45094
+45095
+45096
+45097
+45098
+45099
+45100
+45101
+45102
+45103
+45104
+45105
+45106
+45107
+45108
+45109
+45110
+45111
+45112
+45113
+45114
+45115
+45116
+45117
+45118
+45119
+45120
+45121
+45122
+45123
+45124
+45125
+45126
+45127
+45128
+45129
+45130
+45131
+45132
+45133
+45134
+45135
+45136
+45137
+45138
+45139
+45140
+45141
+45142
+45143
+45144
+45145
+45146
+45147
+45148
+45149
+45150
+45151
+45152
+45153
+45154
+45155
+45156
+45157
+45158
+45159
+45160
+45161
+45162
+45163
+45164
+45165
+45166
+45167
+45168
+45169
+45170
+45171
+45172
+45173
+45174
+45175
+45176
+45177
+45178
+45179
+45180
+45181
+45182
+45183
+45184
+45185
+45186
+45187
+45188
+45189
+45190
+45191
+45192
+45193
+45194
+45195
+45196
+45197
+45198
+45199
+45200
+45201
+45202
+45203
+45204
+45205
+45206
+45207
+45208
+45209
+45210
+45211
+45212
+45213
+45214
+45215
+45216
+45217
+45218
+45219
+45220
+45221
+45222
+45223
+45224
+45225
+45226
+45227
+45228
+45229
+45230
+45231
+45232
+45233
+45234
+45235
+45236
+45237
+45238
+45239
+45240
+45241
+45242
+45243
+45244
+45245
+45246
+45247
+45248
+45249
+45250
+45251
+45252
+45253
+45254
+45255
+45256
+45257
+45258
+45259
+45260
+45261
+45262
+45263
+45264
+45265
+45266
+45267
+45268
+45269
+45270
+45271
+45272
+45273
+45274
+45275
+45276
+45277
+45278
+45279
+45280
+45281
+45282
+45283
+45284
+45285
+45286
+45287
+45288
+45289
+45290
+45291
+45292
+45293
+45294
+45295
+45296
+45297
+45298
+45299
+45300
+45301
+45302
+45303
+45304
+45305
+45306
+45307
+45308
+45309
+45310
+45311
+45312
+45313
+45314
+45315
+45316
+45317
+45318
+45319
+45320
+45321
+45322
+45323
+45324
+45325
+45326
+45327
+45328
+45329
+45330
+45331
+45332
+45333
+45334
+45335
+45336
+45337
+45338
+45339
+45340
+45341
+45342
+45343
+45344
+45345
+45346
+45347
+45348
+45349
+45350
+45351
+45352
+45353
+45354
+45355
+45356
+45357
+45358
+45359
+45360
+45361
+45362
+45363
+45364
+45365
+45366
+45367
+45368
+45369
+45370
+45371
+45372
+45373
+45374
+45375
+45376
+45377
+45378
+45379
+45380
+45381
+45382
+45383
+45384
+45385
+45386
+45387
+45388
+45389
+45390
+45391
+45392
+45393
+45394
+45395
+45396
+45397
+45398
+45399
+45400
+45401
+45402
+45403
+45404
+45405
+45406
+45407
+45408
+45409
+45410
+45411
+45412
+45413
+45414
+45415
+45416
+45417
+45418
+45419
+45420
+45421
+45422
+45423
+45424
+45425
+45426
+45427
+45428
+45429
+45430
+45431
+45432
+45433
+45434
+45435
+45436
+45437
+45438
+45439
+45440
+45441
+45442
+45443
+45444
+45445
+45446
+45447
+45448
+45449
+45450
+45451
+45452
+45453
+45454
+45455
+45456
+45457
+45458
+45459
+45460
+45461
+45462
+45463
+45464
+45465
+45466
+45467
+45468
+45469
+45470
+45471
+45472
+45473
+45474
+45475
+45476
+45477
+45478
+45479
+45480
+45481
+45482
+45483
+45484
+45485
+45486
+45487
+45488
+45489
+45490
+45491
+45492
+45493
+45494
+45495
+45496
+45497
+45498
+45499
+45500
+45501
+45502
+45503
+45504
+45505
+45506
+45507
+45508
+45509
+45510
+45511
+45512
+45513
+45514
+45515
+45516
+45517
+45518
+45519
+45520
+45521
+45522
+45523
+45524
+45525
+45526
+45527
+45528
+45529
+45530
+45531
+45532
+45533
+45534
+45535
+45536
+45537
+45538
+45539
+45540
+45541
+45542
+45543
+45544
+45545
+45546
+45547
+45548
+45549
+45550
+45551
+45552
+45553
+45554
+45555
+45556
+45557
+45558
+45559
+45560
+45561
+45562
+45563
+45564
+45565
+45566
+45567
+45568
+45569
+45570
+45571
+45572
+45573
+45574
+45575
+45576
+45577
+45578
+45579
+45580
+45581
+45582
+45583
+45584
+45585
+45586
+45587
+45588
+45589
+45590
+45591
+45592
+45593
+45594
+45595
+45596
+45597
+45598
+45599
+45600
+45601
+45602
+45603
+45604
+45605
+45606
+45607
+45608
+45609
+45610
+45611
+45612
+45613
+45614
+45615
+45616
+45617
+45618
+45619
+45620
+45621
+45622
+45623
+45624
+45625
+45626
+45627
+45628
+45629
+45630
+45631
+45632
+45633
+45634
+45635
+45636
+45637
+45638
+45639
+45640
+45641
+45642
+45643
+45644
+45645
+45646
+45647
+45648
+45649
+45650
+45651
+45652
+45653
+45654
+45655
+45656
+45657
+45658
+45659
+45660
+45661
+45662
+45663
+45664
+45665
+45666
+45667
+45668
+45669
+45670
+45671
+45672
+45673
+45674
+45675
+45676
+45677
+45678
+45679
+45680
+45681
+45682
+45683
+45684
+45685
+45686
+45687
+45688
+45689
+45690
+45691
+45692
+45693
+45694
+45695
+45696
+45697
+45698
+45699
+45700
+45701
+45702
+45703
+45704
+45705
+45706
+45707
+45708
+45709
+45710
+45711
+45712
+45713
+45714
+45715
+45716
+45717
+45718
+45719
+45720
+45721
+45722
+45723
+45724
+45725
+45726
+45727
+45728
+45729
+45730
+45731
+45732
+45733
+45734
+45735
+45736
+45737
+45738
+45739
+45740
+45741
+45742
+45743
+45744
+45745
+45746
+45747
+45748
+45749
+45750
+45751
+45752
+45753
+45754
+45755
+45756
+45757
+45758
+45759
+45760
+45761
+45762
+45763
+45764
+45765
+45766
+45767
+45768
+45769
+45770
+45771
+45772
+45773
+45774
+45775
+45776
+45777
+45778
+45779
+45780
+45781
+45782
+45783
+45784
+45785
+45786
+45787
+45788
+45789
+45790
+45791
+45792
+45793
+45794
+45795
+45796
+45797
+45798
+45799
+45800
+45801
+45802
+45803
+45804
+45805
+45806
+45807
+45808
+45809
+45810
+45811
+45812
+45813
+45814
+45815
+45816
+45817
+45818
+45819
+45820
+45821
+45822
+45823
+45824
+45825
+45826
+45827
+45828
+45829
+45830
+45831
+45832
+45833
+45834
+45835
+45836
+45837
+45838
+45839
+45840
+45841
+45842
+45843
+45844
+45845
+45846
+45847
+45848
+45849
+45850
+45851
+45852
+45853
+45854
+45855
+45856
+45857
+45858
+45859
+45860
+45861
+45862
+45863
+45864
+45865
+45866
+45867
+45868
+45869
+45870
+45871
+45872
+45873
+45874
+45875
+45876
+45877
+45878
+45879
+45880
+45881
+45882
+45883
+45884
+45885
+45886
+45887
+45888
+45889
+45890
+45891
+45892
+45893
+45894
+45895
+45896
+45897
+45898
+45899
+45900
+45901
+45902
+45903
+45904
+45905
+45906
+45907
+45908
+45909
+45910
+45911
+45912
+45913
+45914
+45915
+45916
+45917
+45918
+45919
+45920
+45921
+45922
+45923
+45924
+45925
+45926
+45927
+45928
+45929
+45930
+45931
+45932
+45933
+45934
+45935
+45936
+45937
+45938
+45939
+45940
+45941
+45942
+45943
+45944
+45945
+45946
+45947
+45948
+45949
+45950
+45951
+45952
+45953
+45954
+45955
+45956
+45957
+45958
+45959
+45960
+45961
+45962
+45963
+45964
+45965
+45966
+45967
+45968
+45969
+45970
+45971
+45972
+45973
+45974
+45975
+45976
+45977
+45978
+45979
+45980
+45981
+45982
+45983
+45984
+45985
+45986
+45987
+45988
+45989
+45990
+45991
+45992
+45993
+45994
+45995
+45996
+45997
+45998
+45999
+46000
+46001
+46002
+46003
+46004
+46005
+46006
+46007
+46008
+46009
+46010
+46011
+46012
+46013
+46014
+46015
+46016
+46017
+46018
+46019
+46020
+46021
+46022
+46023
+46024
+46025
+46026
+46027
+46028
+46029
+46030
+46031
+46032
+46033
+46034
+46035
+46036
+46037
+46038
+46039
+46040
+46041
+46042
+46043
+46044
+46045
+46046
+46047
+46048
+46049
+46050
+46051
+46052
+46053
+46054
+46055
+46056
+46057
+46058
+46059
+46060
+46061
+46062
+46063
+46064
+46065
+46066
+46067
+46068
+46069
+46070
+46071
+46072
+46073
+46074
+46075
+46076
+46077
+46078
+46079
+46080
+46081
+46082
+46083
+46084
+46085
+46086
+46087
+46088
+46089
+46090
+46091
+46092
+46093
+46094
+46095
+46096
+46097
+46098
+46099
+46100
+46101
+46102
+46103
+46104
+46105
+46106
+46107
+46108
+46109
+46110
+46111
+46112
+46113
+46114
+46115
+46116
+46117
+46118
+46119
+46120
+46121
+46122
+46123
+46124
+46125
+46126
+46127
+46128
+46129
+46130
+46131
+46132
+46133
+46134
+46135
+46136
+46137
+46138
+46139
+46140
+46141
+46142
+46143
+46144
+46145
+46146
+46147
+46148
+46149
+46150
+46151
+46152
+46153
+46154
+46155
+46156
+46157
+46158
+46159
+46160
+46161
+46162
+46163
+46164
+46165
+46166
+46167
+46168
+46169
+46170
+46171
+46172
+46173
+46174
+46175
+46176
+46177
+46178
+46179
+46180
+46181
+46182
+46183
+46184
+46185
+46186
+46187
+46188
+46189
+46190
+46191
+46192
+46193
+46194
+46195
+46196
+46197
+46198
+46199
+46200
+46201
+46202
+46203
+46204
+46205
+46206
+46207
+46208
+46209
+46210
+46211
+46212
+46213
+46214
+46215
+46216
+46217
+46218
+46219
+46220
+46221
+46222
+46223
+46224
+46225
+46226
+46227
+46228
+46229
+46230
+46231
+46232
+46233
+46234
+46235
+46236
+46237
+46238
+46239
+46240
+46241
+46242
+46243
+46244
+46245
+46246
+46247
+46248
+46249
+46250
+46251
+46252
+46253
+46254
+46255
+46256
+46257
+46258
+46259
+46260
+46261
+46262
+46263
+46264
+46265
+46266
+46267
+46268
+46269
+46270
+46271
+46272
+46273
+46274
+46275
+46276
+46277
+46278
+46279
+46280
+46281
+46282
+46283
+46284
+46285
+46286
+46287
+46288
+46289
+46290
+46291
+46292
+46293
+46294
+46295
+46296
+46297
+46298
+46299
+46300
+46301
+46302
+46303
+46304
+46305
+46306
+46307
+46308
+46309
+46310
+46311
+46312
+46313
+46314
+46315
+46316
+46317
+46318
+46319
+46320
+46321
+46322
+46323
+46324
+46325
+46326
+46327
+46328
+46329
+46330
+46331
+46332
+46333
+46334
+46335
+46336
+46337
+46338
+46339
+46340
+46341
+46342
+46343
+46344
+46345
+46346
+46347
+46348
+46349
+46350
+46351
+46352
+46353
+46354
+46355
+46356
+46357
+46358
+46359
+46360
+46361
+46362
+46363
+46364
+46365
+46366
+46367
+46368
+46369
+46370
+46371
+46372
+46373
+46374
+46375
+46376
+46377
+46378
+46379
+46380
+46381
+46382
+46383
+46384
+46385
+46386
+46387
+46388
+46389
+46390
+46391
+46392
+46393
+46394
+46395
+46396
+46397
+46398
+46399
+46400
+46401
+46402
+46403
+46404
+46405
+46406
+46407
+46408
+46409
+46410
+46411
+46412
+46413
+46414
+46415
+46416
+46417
+46418
+46419
+46420
+46421
+46422
+46423
+46424
+46425
+46426
+46427
+46428
+46429
+46430
+46431
+46432
+46433
+46434
+46435
+46436
+46437
+46438
+46439
+46440
+46441
+46442
+46443
+46444
+46445
+46446
+46447
+46448
+46449
+46450
+46451
+46452
+46453
+46454
+46455
+46456
+46457
+46458
+46459
+46460
+46461
+46462
+46463
+46464
+46465
+46466
+46467
+46468
+46469
+46470
+46471
+46472
+46473
+46474
+46475
+46476
+46477
+46478
+46479
+46480
+46481
+46482
+46483
+46484
+46485
+46486
+46487
+46488
+46489
+46490
+46491
+46492
+46493
+46494
+46495
+46496
+46497
+46498
+46499
+46500
+46501
+46502
+46503
+46504
+46505
+46506
+46507
+46508
+46509
+46510
+46511
+46512
+46513
+46514
+46515
+46516
+46517
+46518
+46519
+46520
+46521
+46522
+46523
+46524
+46525
+46526
+46527
+46528
+46529
+46530
+46531
+46532
+46533
+46534
+46535
+46536
+46537
+46538
+46539
+46540
+46541
+46542
+46543
+46544
+46545
+46546
+46547
+46548
+46549
+46550
+46551
+46552
+46553
+46554
+46555
+46556
+46557
+46558
+46559
+46560
+46561
+46562
+46563
+46564
+46565
+46566
+46567
+46568
+46569
+46570
+46571
+46572
+46573
+46574
+46575
+46576
+46577
+46578
+46579
+46580
+46581
+46582
+46583
+46584
+46585
+46586
+46587
+46588
+46589
+46590
+46591
+46592
+46593
+46594
+46595
+46596
+46597
+46598
+46599
+46600
+46601
+46602
+46603
+46604
+46605
+46606
+46607
+46608
+46609
+46610
+46611
+46612
+46613
+46614
+46615
+46616
+46617
+46618
+46619
+46620
+46621
+46622
+46623
+46624
+46625
+46626
+46627
+46628
+46629
+46630
+46631
+46632
+46633
+46634
+46635
+46636
+46637
+46638
+46639
+46640
+46641
+46642
+46643
+46644
+46645
+46646
+46647
+46648
+46649
+46650
+46651
+46652
+46653
+46654
+46655
+46656
+46657
+46658
+46659
+46660
+46661
+46662
+46663
+46664
+46665
+46666
+46667
+46668
+46669
+46670
+46671
+46672
+46673
+46674
+46675
+46676
+46677
+46678
+46679
+46680
+46681
+46682
+46683
+46684
+46685
+46686
+46687
+46688
+46689
+46690
+46691
+46692
+46693
+46694
+46695
+46696
+46697
+46698
+46699
+46700
+46701
+46702
+46703
+46704
+46705
+46706
+46707
+46708
+46709
+46710
+46711
+46712
+46713
+46714
+46715
+46716
+46717
+46718
+46719
+46720
+46721
+46722
+46723
+46724
+46725
+46726
+46727
+46728
+46729
+46730
+46731
+46732
+46733
+46734
+46735
+46736
+46737
+46738
+46739
+46740
+46741
+46742
+46743
+46744
+46745
+46746
+46747
+46748
+46749
+46750
+46751
+46752
+46753
+46754
+46755
+46756
+46757
+46758
+46759
+46760
+46761
+46762
+46763
+46764
+46765
+46766
+46767
+46768
+46769
+46770
+46771
+46772
+46773
+46774
+46775
+46776
+46777
+46778
+46779
+46780
+46781
+46782
+46783
+46784
+46785
+46786
+46787
+46788
+46789
+46790
+46791
+46792
+46793
+46794
+46795
+46796
+46797
+46798
+46799
+46800
+46801
+46802
+46803
+46804
+46805
+46806
+46807
+46808
+46809
+46810
+46811
+46812
+46813
+46814
+46815
+46816
+46817
+46818
+46819
+46820
+46821
+46822
+46823
+46824
+46825
+46826
+46827
+46828
+46829
+46830
+46831
+46832
+46833
+46834
+46835
+46836
+46837
+46838
+46839
+46840
+46841
+46842
+46843
+46844
+46845
+46846
+46847
+46848
+46849
+46850
+46851
+46852
+46853
+46854
+46855
+46856
+46857
+46858
+46859
+46860
+46861
+46862
+46863
+46864
+46865
+46866
+46867
+46868
+46869
+46870
+46871
+46872
+46873
+46874
+46875
+46876
+46877
+46878
+46879
+46880
+46881
+46882
+46883
+46884
+46885
+46886
+46887
+46888
+46889
+46890
+46891
+46892
+46893
+46894
+46895
+46896
+46897
+46898
+46899
+46900
+46901
+46902
+46903
+46904
+46905
+46906
+46907
+46908
+46909
+46910
+46911
+46912
+46913
+46914
+46915
+46916
+46917
+46918
+46919
+46920
+46921
+46922
+46923
+46924
+46925
+46926
+46927
+46928
+46929
+46930
+46931
+46932
+46933
+46934
+46935
+46936
+46937
+46938
+46939
+46940
+46941
+46942
+46943
+46944
+46945
+46946
+46947
+46948
+46949
+46950
+46951
+46952
+46953
+46954
+46955
+46956
+46957
+46958
+46959
+46960
+46961
+46962
+46963
+46964
+46965
+46966
+46967
+46968
+46969
+46970
+46971
+46972
+46973
+46974
+46975
+46976
+46977
+46978
+46979
+46980
+46981
+46982
+46983
+46984
+46985
+46986
+46987
+46988
+46989
+46990
+46991
+46992
+46993
+46994
+46995
+46996
+46997
+46998
+46999
+47000
+47001
+47002
+47003
+47004
+47005
+47006
+47007
+47008
+47009
+47010
+47011
+47012
+47013
+47014
+47015
+47016
+47017
+47018
+47019
+47020
+47021
+47022
+47023
+47024
+47025
+47026
+47027
+47028
+47029
+47030
+47031
+47032
+47033
+47034
+47035
+47036
+47037
+47038
+47039
+47040
+47041
+47042
+47043
+47044
+47045
+47046
+47047
+47048
+47049
+47050
+47051
+47052
+47053
+47054
+47055
+47056
+47057
+47058
+47059
+47060
+47061
+47062
+47063
+47064
+47065
+47066
+47067
+47068
+47069
+47070
+47071
+47072
+47073
+47074
+47075
+47076
+47077
+47078
+47079
+47080
+47081
+47082
+47083
+47084
+47085
+47086
+47087
+47088
+47089
+47090
+47091
+47092
+47093
+47094
+47095
+47096
+47097
+47098
+47099
+47100
+47101
+47102
+47103
+47104
+47105
+47106
+47107
+47108
+47109
+47110
+47111
+47112
+47113
+47114
+47115
+47116
+47117
+47118
+47119
+47120
+47121
+47122
+47123
+47124
+47125
+47126
+47127
+47128
+47129
+47130
+47131
+47132
+47133
+47134
+47135
+47136
+47137
+47138
+47139
+47140
+47141
+47142
+47143
+47144
+47145
+47146
+47147
+47148
+47149
+47150
+47151
+47152
+47153
+47154
+47155
+47156
+47157
+47158
+47159
+47160
+47161
+47162
+47163
+47164
+47165
+47166
+47167
+47168
+47169
+47170
+47171
+47172
+47173
+47174
+47175
+47176
+47177
+47178
+47179
+47180
+47181
+47182
+47183
+47184
+47185
+47186
+47187
+47188
+47189
+47190
+47191
+47192
+47193
+47194
+47195
+47196
+47197
+47198
+47199
+47200
+47201
+47202
+47203
+47204
+47205
+47206
+47207
+47208
+47209
+47210
+47211
+47212
+47213
+47214
+47215
+47216
+47217
+47218
+47219
+47220
+47221
+47222
+47223
+47224
+47225
+47226
+47227
+47228
+47229
+47230
+47231
+47232
+47233
+47234
+47235
+47236
+47237
+47238
+47239
+47240
+47241
+47242
+47243
+47244
+47245
+47246
+47247
+47248
+47249
+47250
+47251
+47252
+47253
+47254
+47255
+47256
+47257
+47258
+47259
+47260
+47261
+47262
+47263
+47264
+47265
+47266
+47267
+47268
+47269
+47270
+47271
+47272
+47273
+47274
+47275
+47276
+47277
+47278
+47279
+47280
+47281
+47282
+47283
+47284
+47285
+47286
+47287
+47288
+47289
+47290
+47291
+47292
+47293
+47294
+47295
+47296
+47297
+47298
+47299
+47300
+47301
+47302
+47303
+47304
+47305
+47306
+47307
+47308
+47309
+47310
+47311
+47312
+47313
+47314
+47315
+47316
+47317
+47318
+47319
+47320
+47321
+47322
+47323
+47324
+47325
+47326
+47327
+47328
+47329
+47330
+47331
+47332
+47333
+47334
+47335
+47336
+47337
+47338
+47339
+47340
+47341
+47342
+47343
+47344
+47345
+47346
+47347
+47348
+47349
+47350
+47351
+47352
+47353
+47354
+47355
+47356
+47357
+47358
+47359
+47360
+47361
+47362
+47363
+47364
+47365
+47366
+47367
+47368
+47369
+47370
+47371
+47372
+47373
+47374
+47375
+47376
+47377
+47378
+47379
+47380
+47381
+47382
+47383
+47384
+47385
+47386
+47387
+47388
+47389
+47390
+47391
+47392
+47393
+47394
+47395
+47396
+47397
+47398
+47399
+47400
+47401
+47402
+47403
+47404
+47405
+47406
+47407
+47408
+47409
+47410
+47411
+47412
+47413
+47414
+47415
+47416
+47417
+47418
+47419
+47420
+47421
+47422
+47423
+47424
+47425
+47426
+47427
+47428
+47429
+47430
+47431
+47432
+47433
+47434
+47435
+47436
+47437
+47438
+47439
+47440
+47441
+47442
+47443
+47444
+47445
+47446
+47447
+47448
+47449
+47450
+47451
+47452
+47453
+47454
+47455
+47456
+47457
+47458
+47459
+47460
+47461
+47462
+47463
+47464
+47465
+47466
+47467
+47468
+47469
+47470
+47471
+47472
+47473
+47474
+47475
+47476
+47477
+47478
+47479
+47480
+47481
+47482
+47483
+47484
+47485
+47486
+47487
+47488
+47489
+47490
+47491
+47492
+47493
+47494
+47495
+47496
+47497
+47498
+47499
+47500
+47501
+47502
+47503
+47504
+47505
+47506
+47507
+47508
+47509
+47510
+47511
+47512
+47513
+47514
+47515
+47516
+47517
+47518
+47519
+47520
+47521
+47522
+47523
+47524
+47525
+47526
+47527
+47528
+47529
+47530
+47531
+47532
+47533
+47534
+47535
+47536
+47537
+47538
+47539
+47540
+47541
+47542
+47543
+47544
+47545
+47546
+47547
+47548
+47549
+47550
+47551
+47552
+47553
+47554
+47555
+47556
+47557
+47558
+47559
+47560
+47561
+47562
+47563
+47564
+47565
+47566
+47567
+47568
+47569
+47570
+47571
+47572
+47573
+47574
+47575
+47576
+47577
+47578
+47579
+47580
+47581
+47582
+47583
+47584
+47585
+47586
+47587
+47588
+47589
+47590
+47591
+47592
+47593
+47594
+47595
+47596
+47597
+47598
+47599
+47600
+47601
+47602
+47603
+47604
+47605
+47606
+47607
+47608
+47609
+47610
+47611
+47612
+47613
+47614
+47615
+47616
+47617
+47618
+47619
+47620
+47621
+47622
+47623
+47624
+47625
+47626
+47627
+47628
+47629
+47630
+47631
+47632
+47633
+47634
+47635
+47636
+47637
+47638
+47639
+47640
+47641
+47642
+47643
+47644
+47645
+47646
+47647
+47648
+47649
+47650
+47651
+47652
+47653
+47654
+47655
+47656
+47657
+47658
+47659
+47660
+47661
+47662
+47663
+47664
+47665
+47666
+47667
+47668
+47669
+47670
+47671
+47672
+47673
+47674
+47675
+47676
+47677
+47678
+47679
+47680
+47681
+47682
+47683
+47684
+47685
+47686
+47687
+47688
+47689
+47690
+47691
+47692
+47693
+47694
+47695
+47696
+47697
+47698
+47699
+47700
+47701
+47702
+47703
+47704
+47705
+47706
+47707
+47708
+47709
+47710
+47711
+47712
+47713
+47714
+47715
+47716
+47717
+47718
+47719
+47720
+47721
+47722
+47723
+47724
+47725
+47726
+47727
+47728
+47729
+47730
+47731
+47732
+47733
+47734
+47735
+47736
+47737
+47738
+47739
+47740
+47741
+47742
+47743
+47744
+47745
+47746
+47747
+47748
+47749
+47750
+47751
+47752
+47753
+47754
+47755
+47756
+47757
+47758
+47759
+47760
+47761
+47762
+47763
+47764
+47765
+47766
+47767
+47768
+47769
+47770
+47771
+47772
+47773
+47774
+47775
+47776
+47777
+47778
+47779
+47780
+47781
+47782
+47783
+47784
+47785
+47786
+47787
+47788
+47789
+47790
+47791
+47792
+47793
+47794
+47795
+47796
+47797
+47798
+47799
+47800
+47801
+47802
+47803
+47804
+47805
+47806
+47807
+47808
+47809
+47810
+47811
+47812
+47813
+47814
+47815
+47816
+47817
+47818
+47819
+47820
+47821
+47822
+47823
+47824
+47825
+47826
+47827
+47828
+47829
+47830
+47831
+47832
+47833
+47834
+47835
+47836
+47837
+47838
+47839
+47840
+47841
+47842
+47843
+47844
+47845
+47846
+47847
+47848
+47849
+47850
+47851
+47852
+47853
+47854
+47855
+47856
+47857
+47858
+47859
+47860
+47861
+47862
+47863
+47864
+47865
+47866
+47867
+47868
+47869
+47870
+47871
+47872
+47873
+47874
+47875
+47876
+47877
+47878
+47879
+47880
+47881
+47882
+47883
+47884
+47885
+47886
+47887
+47888
+47889
+47890
+47891
+47892
+47893
+47894
+47895
+47896
+47897
+47898
+47899
+47900
+47901
+47902
+47903
+47904
+47905
+47906
+47907
+47908
+47909
+47910
+47911
+47912
+47913
+47914
+47915
+47916
+47917
+47918
+47919
+47920
+47921
+47922
+47923
+47924
+47925
+47926
+47927
+47928
+47929
+47930
+47931
+47932
+47933
+47934
+47935
+47936
+47937
+47938
+47939
+47940
+47941
+47942
+47943
+47944
+47945
+47946
+47947
+47948
+47949
+47950
+47951
+47952
+47953
+47954
+47955
+47956
+47957
+47958
+47959
+47960
+47961
+47962
+47963
+47964
+47965
+47966
+47967
+47968
+47969
+47970
+47971
+47972
+47973
+47974
+47975
+47976
+47977
+47978
+47979
+47980
+47981
+47982
+47983
+47984
+47985
+47986
+47987
+47988
+47989
+47990
+47991
+47992
+47993
+47994
+47995
+47996
+47997
+47998
+47999
+48000
+48001
+48002
+48003
+48004
+48005
+48006
+48007
+48008
+48009
+48010
+48011
+48012
+48013
+48014
+48015
+48016
+48017
+48018
+48019
+48020
+48021
+48022
+48023
+48024
+48025
+48026
+48027
+48028
+48029
+48030
+48031
+48032
+48033
+48034
+48035
+48036
+48037
+48038
+48039
+48040
+48041
+48042
+48043
+48044
+48045
+48046
+48047
+48048
+48049
+48050
+48051
+48052
+48053
+48054
+48055
+48056
+48057
+48058
+48059
+48060
+48061
+48062
+48063
+48064
+48065
+48066
+48067
+48068
+48069
+48070
+48071
+48072
+48073
+48074
+48075
+48076
+48077
+48078
+48079
+48080
+48081
+48082
+48083
+48084
+48085
+48086
+48087
+48088
+48089
+48090
+48091
+48092
+48093
+48094
+48095
+48096
+48097
+48098
+48099
+48100
+48101
+48102
+48103
+48104
+48105
+48106
+48107
+48108
+48109
+48110
+48111
+48112
+48113
+48114
+48115
+48116
+48117
+48118
+48119
+48120
+48121
+48122
+48123
+48124
+48125
+48126
+48127
+48128
+48129
+48130
+48131
+48132
+48133
+48134
+48135
+48136
+48137
+48138
+48139
+48140
+48141
+48142
+48143
+48144
+48145
+48146
+48147
+48148
+48149
+48150
+48151
+48152
+48153
+48154
+48155
+48156
+48157
+48158
+48159
+48160
+48161
+48162
+48163
+48164
+48165
+48166
+48167
+48168
+48169
+48170
+48171
+48172
+48173
+48174
+48175
+48176
+48177
+48178
+48179
+48180
+48181
+48182
+48183
+48184
+48185
+48186
+48187
+48188
+48189
+48190
+48191
+48192
+48193
+48194
+48195
+48196
+48197
+48198
+48199
+48200
+48201
+48202
+48203
+48204
+48205
+48206
+48207
+48208
+48209
+48210
+48211
+48212
+48213
+48214
+48215
+48216
+48217
+48218
+48219
+48220
+48221
+48222
+48223
+48224
+48225
+48226
+48227
+48228
+48229
+48230
+48231
+48232
+48233
+48234
+48235
+48236
+48237
+48238
+48239
+48240
+48241
+48242
+48243
+48244
+48245
+48246
+48247
+48248
+48249
+48250
+48251
+48252
+48253
+48254
+48255
+48256
+48257
+48258
+48259
+48260
+48261
+48262
+48263
+48264
+48265
+48266
+48267
+48268
+48269
+48270
+48271
+48272
+48273
+48274
+48275
+48276
+48277
+48278
+48279
+48280
+48281
+48282
+48283
+48284
+48285
+48286
+48287
+48288
+48289
+48290
+48291
+48292
+48293
+48294
+48295
+48296
+48297
+48298
+48299
+48300
+48301
+48302
+48303
+48304
+48305
+48306
+48307
+48308
+48309
+48310
+48311
+48312
+48313
+48314
+48315
+48316
+48317
+48318
+48319
+48320
+48321
+48322
+48323
+48324
+48325
+48326
+48327
+48328
+48329
+48330
+48331
+48332
+48333
+48334
+48335
+48336
+48337
+48338
+48339
+48340
+48341
+48342
+48343
+48344
+48345
+48346
+48347
+48348
+48349
+48350
+48351
+48352
+48353
+48354
+48355
+48356
+48357
+48358
+48359
+48360
+48361
+48362
+48363
+48364
+48365
+48366
+48367
+48368
+48369
+48370
+48371
+48372
+48373
+48374
+48375
+48376
+48377
+48378
+48379
+48380
+48381
+48382
+48383
+48384
+48385
+48386
+48387
+48388
+48389
+48390
+48391
+48392
+48393
+48394
+48395
+48396
+48397
+48398
+48399
+48400
+48401
+48402
+48403
+48404
+48405
+48406
+48407
+48408
+48409
+48410
+48411
+48412
+48413
+48414
+48415
+48416
+48417
+48418
+48419
+48420
+48421
+48422
+48423
+48424
+48425
+48426
+48427
+48428
+48429
+48430
+48431
+48432
+48433
+48434
+48435
+48436
+48437
+48438
+48439
+48440
+48441
+48442
+48443
+48444
+48445
+48446
+48447
+48448
+48449
+48450
+48451
+48452
+48453
+48454
+48455
+48456
+48457
+48458
+48459
+48460
+48461
+48462
+48463
+48464
+48465
+48466
+48467
+48468
+48469
+48470
+48471
+48472
+48473
+48474
+48475
+48476
+48477
+48478
+48479
+48480
+48481
+48482
+48483
+48484
+48485
+48486
+48487
+48488
+48489
+48490
+48491
+48492
+48493
+48494
+48495
+48496
+48497
+48498
+48499
+48500
+48501
+48502
+48503
+48504
+48505
+48506
+48507
+48508
+48509
+48510
+48511
+48512
+48513
+48514
+48515
+48516
+48517
+48518
+48519
+48520
+48521
+48522
+48523
+48524
+48525
+48526
+48527
+48528
+48529
+48530
+48531
+48532
+48533
+48534
+48535
+48536
+48537
+48538
+48539
+48540
+48541
+48542
+48543
+48544
+48545
+48546
+48547
+48548
+48549
+48550
+48551
+48552
+48553
+48554
+48555
+48556
+48557
+48558
+48559
+48560
+48561
+48562
+48563
+48564
+48565
+48566
+48567
+48568
+48569
+48570
+48571
+48572
+48573
+48574
+48575
+48576
+48577
+48578
+48579
+48580
+48581
+48582
+48583
+48584
+48585
+48586
+48587
+48588
+48589
+48590
+48591
+48592
+48593
+48594
+48595
+48596
+48597
+48598
+48599
+48600
+48601
+48602
+48603
+48604
+48605
+48606
+48607
+48608
+48609
+48610
+48611
+48612
+48613
+48614
+48615
+48616
+48617
+48618
+48619
+48620
+48621
+48622
+48623
+48624
+48625
+48626
+48627
+48628
+48629
+48630
+48631
+48632
+48633
+48634
+48635
+48636
+48637
+48638
+48639
+48640
+48641
+48642
+48643
+48644
+48645
+48646
+48647
+48648
+48649
+48650
+48651
+48652
+48653
+48654
+48655
+48656
+48657
+48658
+48659
+48660
+48661
+48662
+48663
+48664
+48665
+48666
+48667
+48668
+48669
+48670
+48671
+48672
+48673
+48674
+48675
+48676
+48677
+48678
+48679
+48680
+48681
+48682
+48683
+48684
+48685
+48686
+48687
+48688
+48689
+48690
+48691
+48692
+48693
+48694
+48695
+48696
+48697
+48698
+48699
+48700
+48701
+48702
+48703
+48704
+48705
+48706
+48707
+48708
+48709
+48710
+48711
+48712
+48713
+48714
+48715
+48716
+48717
+48718
+48719
+48720
+48721
+48722
+48723
+48724
+48725
+48726
+48727
+48728
+48729
+48730
+48731
+48732
+48733
+48734
+48735
+48736
+48737
+48738
+48739
+48740
+48741
+48742
+48743
+48744
+48745
+48746
+48747
+48748
+48749
+48750
+48751
+48752
+48753
+48754
+48755
+48756
+48757
+48758
+48759
+48760
+48761
+48762
+48763
+48764
+48765
+48766
+48767
+48768
+48769
+48770
+48771
+48772
+48773
+48774
+48775
+48776
+48777
+48778
+48779
+48780
+48781
+48782
+48783
+48784
+48785
+48786
+48787
+48788
+48789
+48790
+48791
+48792
+48793
+48794
+48795
+48796
+48797
+48798
+48799
+48800
+48801
+48802
+48803
+48804
+48805
+48806
+48807
+48808
+48809
+48810
+48811
+48812
+48813
+48814
+48815
+48816
+48817
+48818
+48819
+48820
+48821
+48822
+48823
+48824
+48825
+48826
+48827
+48828
+48829
+48830
+48831
+48832
+48833
+48834
+48835
+48836
+48837
+48838
+48839
+48840
+48841
+48842
+48843
+48844
+48845
+48846
+48847
+48848
+48849
+48850
+48851
+48852
+48853
+48854
+48855
+48856
+48857
+48858
+48859
+48860
+48861
+48862
+48863
+48864
+48865
+48866
+48867
+48868
+48869
+48870
+48871
+48872
+48873
+48874
+48875
+48876
+48877
+48878
+48879
+48880
+48881
+48882
+48883
+48884
+48885
+48886
+48887
+48888
+48889
+48890
+48891
+48892
+48893
+48894
+48895
+48896
+48897
+48898
+48899
+48900
+48901
+48902
+48903
+48904
+48905
+48906
+48907
+48908
+48909
+48910
+48911
+48912
+48913
+48914
+48915
+48916
+48917
+48918
+48919
+48920
+48921
+48922
+48923
+48924
+48925
+48926
+48927
+48928
+48929
+48930
+48931
+48932
+48933
+48934
+48935
+48936
+48937
+48938
+48939
+48940
+48941
+48942
+48943
+48944
+48945
+48946
+48947
+48948
+48949
+48950
+48951
+48952
+48953
+48954
+48955
+48956
+48957
+48958
+48959
+48960
+48961
+48962
+48963
+48964
+48965
+48966
+48967
+48968
+48969
+48970
+48971
+48972
+48973
+48974
+48975
+48976
+48977
+48978
+48979
+48980
+48981
+48982
+48983
+48984
+48985
+48986
+48987
+48988
+48989
+48990
+48991
+48992
+48993
+48994
+48995
+48996
+48997
+48998
+48999
+49000
+49001
+49002
+49003
+49004
+49005
+49006
+49007
+49008
+49009
+49010
+49011
+49012
+49013
+49014
+49015
+49016
+49017
+49018
+49019
+49020
+49021
+49022
+49023
+49024
+49025
+49026
+49027
+49028
+49029
+49030
+49031
+49032
+49033
+49034
+49035
+49036
+49037
+49038
+49039
+49040
+49041
+49042
+49043
+49044
+49045
+49046
+49047
+49048
+49049
+49050
+49051
+49052
+49053
+49054
+49055
+49056
+49057
+49058
+49059
+49060
+49061
+49062
+49063
+49064
+49065
+49066
+49067
+49068
+49069
+49070
+49071
+49072
+49073
+49074
+49075
+49076
+49077
+49078
+49079
+49080
+49081
+49082
+49083
+49084
+49085
+49086
+49087
+49088
+49089
+49090
+49091
+49092
+49093
+49094
+49095
+49096
+49097
+49098
+49099
+49100
+49101
+49102
+49103
+49104
+49105
+49106
+49107
+49108
+49109
+49110
+49111
+49112
+49113
+49114
+49115
+49116
+49117
+49118
+49119
+49120
+49121
+49122
+49123
+49124
+49125
+49126
+49127
+49128
+49129
+49130
+49131
+49132
+49133
+49134
+49135
+49136
+49137
+49138
+49139
+49140
+49141
+49142
+49143
+49144
+49145
+49146
+49147
+49148
+49149
+49150
+49151
+49152
+49153
+49154
+49155
+49156
+49157
+49158
+49159
+49160
+49161
+49162
+49163
+49164
+49165
+49166
+49167
+49168
+49169
+49170
+49171
+49172
+49173
+49174
+49175
+49176
+49177
+49178
+49179
+49180
+49181
+49182
+49183
+49184
+49185
+49186
+49187
+49188
+49189
+49190
+49191
+49192
+49193
+49194
+49195
+49196
+49197
+49198
+49199
+49200
+49201
+49202
+49203
+49204
+49205
+49206
+49207
+49208
+49209
+49210
+49211
+49212
+49213
+49214
+49215
+49216
+49217
+49218
+49219
+49220
+49221
+49222
+49223
+49224
+49225
+49226
+49227
+49228
+49229
+49230
+49231
+49232
+49233
+49234
+49235
+49236
+49237
+49238
+49239
+49240
+49241
+49242
+49243
+49244
+49245
+49246
+49247
+49248
+49249
+49250
+49251
+49252
+49253
+49254
+49255
+49256
+49257
+49258
+49259
+49260
+49261
+49262
+49263
+49264
+49265
+49266
+49267
+49268
+49269
+49270
+49271
+49272
+49273
+49274
+49275
+49276
+49277
+49278
+49279
+49280
+49281
+49282
+49283
+49284
+49285
+49286
+49287
+49288
+49289
+49290
+49291
+49292
+49293
+49294
+49295
+49296
+49297
+49298
+49299
+49300
+49301
+49302
+49303
+49304
+49305
+49306
+49307
+49308
+49309
+49310
+49311
+49312
+49313
+49314
+49315
+49316
+49317
+49318
+49319
+49320
+49321
+49322
+49323
+49324
+49325
+49326
+49327
+49328
+49329
+49330
+49331
+49332
+49333
+49334
+49335
+49336
+49337
+49338
+49339
+49340
+49341
+49342
+49343
+49344
+49345
+49346
+49347
+49348
+49349
+49350
+49351
+49352
+49353
+49354
+49355
+49356
+49357
+49358
+49359
+49360
+49361
+49362
+49363
+49364
+49365
+49366
+49367
+49368
+49369
+49370
+49371
+49372
+49373
+49374
+49375
+49376
+49377
+49378
+49379
+49380
+49381
+49382
+49383
+49384
+49385
+49386
+49387
+49388
+49389
+49390
+49391
+49392
+49393
+49394
+49395
+49396
+49397
+49398
+49399
+49400
+49401
+49402
+49403
+49404
+49405
+49406
+49407
+49408
+49409
+49410
+49411
+49412
+49413
+49414
+49415
+49416
+49417
+49418
+49419
+49420
+49421
+49422
+49423
+49424
+49425
+49426
+49427
+49428
+49429
+49430
+49431
+49432
+49433
+49434
+49435
+49436
+49437
+49438
+49439
+49440
+49441
+49442
+49443
+49444
+49445
+49446
+49447
+49448
+49449
+49450
+49451
+49452
+49453
+49454
+49455
+49456
+49457
+49458
+49459
+49460
+49461
+49462
+49463
+49464
+49465
+49466
+49467
+49468
+49469
+49470
+49471
+49472
+49473
+49474
+49475
+49476
+49477
+49478
+49479
+49480
+49481
+49482
+49483
+49484
+49485
+49486
+49487
+49488
+49489
+49490
+49491
+49492
+49493
+49494
+49495
+49496
+49497
+49498
+49499
+49500
+49501
+49502
+49503
+49504
+49505
+49506
+49507
+49508
+49509
+49510
+49511
+49512
+49513
+49514
+49515
+49516
+49517
+49518
+49519
+49520
+49521
+49522
+49523
+49524
+49525
+49526
+49527
+49528
+49529
+49530
+49531
+49532
+49533
+49534
+49535
+49536
+49537
+49538
+49539
+49540
+49541
+49542
+49543
+49544
+49545
+49546
+49547
+49548
+49549
+49550
+49551
+49552
+49553
+49554
+49555
+49556
+49557
+49558
+49559
+49560
+49561
+49562
+49563
+49564
+49565
+49566
+49567
+49568
+49569
+49570
+49571
+49572
+49573
+49574
+49575
+49576
+49577
+49578
+49579
+49580
+49581
+49582
+49583
+49584
+49585
+49586
+49587
+49588
+49589
+49590
+49591
+49592
+49593
+49594
+49595
+49596
+49597
+49598
+49599
+49600
+49601
+49602
+49603
+49604
+49605
+49606
+49607
+49608
+49609
+49610
+49611
+49612
+49613
+49614
+49615
+49616
+49617
+49618
+49619
+49620
+49621
+49622
+49623
+49624
+49625
+49626
+49627
+49628
+49629
+49630
+49631
+49632
+49633
+49634
+49635
+49636
+49637
+49638
+49639
+49640
+49641
+49642
+49643
+49644
+49645
+49646
+49647
+49648
+49649
+49650
+49651
+49652
+49653
+49654
+49655
+49656
+49657
+49658
+49659
+49660
+49661
+49662
+49663
+49664
+49665
+49666
+49667
+49668
+49669
+49670
+49671
+49672
+49673
+49674
+49675
+49676
+49677
+49678
+49679
+49680
+49681
+49682
+49683
+49684
+49685
+49686
+49687
+49688
+49689
+49690
+49691
+49692
+49693
+49694
+49695
+49696
+49697
+49698
+49699
+49700
+49701
+49702
+49703
+49704
+49705
+49706
+49707
+49708
+49709
+49710
+49711
+49712
+49713
+49714
+49715
+49716
+49717
+49718
+49719
+49720
+49721
+49722
+49723
+49724
+49725
+49726
+49727
+49728
+49729
+49730
+49731
+49732
+49733
+49734
+49735
+49736
+49737
+49738
+49739
+49740
+49741
+49742
+49743
+49744
+49745
+49746
+49747
+49748
+49749
+49750
+49751
+49752
+49753
+49754
+49755
+49756
+49757
+49758
+49759
+49760
+49761
+49762
+49763
+49764
+49765
+49766
+49767
+49768
+49769
+49770
+49771
+49772
+49773
+49774
+49775
+49776
+49777
+49778
+49779
+49780
+49781
+49782
+49783
+49784
+49785
+49786
+49787
+49788
+49789
+49790
+49791
+49792
+49793
+49794
+49795
+49796
+49797
+49798
+49799
+49800
+49801
+49802
+49803
+49804
+49805
+49806
+49807
+49808
+49809
+49810
+49811
+49812
+49813
+49814
+49815
+49816
+49817
+49818
+49819
+49820
+49821
+49822
+49823
+49824
+49825
+49826
+49827
+49828
+49829
+49830
+49831
+49832
+49833
+49834
+49835
+49836
+49837
+49838
+49839
+49840
+49841
+49842
+49843
+49844
+49845
+49846
+49847
+49848
+49849
+49850
+49851
+49852
+49853
+49854
+49855
+49856
+49857
+49858
+49859
+49860
+49861
+49862
+49863
+49864
+49865
+49866
+49867
+49868
+49869
+49870
+49871
+49872
+49873
+49874
+49875
+49876
+49877
+49878
+49879
+49880
+49881
+49882
+49883
+49884
+49885
+49886
+49887
+49888
+49889
+49890
+49891
+49892
+49893
+49894
+49895
+49896
+49897
+49898
+49899
+49900
+49901
+49902
+49903
+49904
+49905
+49906
+49907
+49908
+49909
+49910
+49911
+49912
+49913
+49914
+49915
+49916
+49917
+49918
+49919
+49920
+49921
+49922
+49923
+49924
+49925
+49926
+49927
+49928
+49929
+49930
+49931
+49932
+49933
+49934
+49935
+49936
+49937
+49938
+49939
+49940
+49941
+49942
+49943
+49944
+49945
+49946
+49947
+49948
+49949
+49950
+49951
+49952
+49953
+49954
+49955
+49956
+49957
+49958
+49959
+49960
+49961
+49962
+49963
+49964
+49965
+49966
+49967
+49968
+49969
+49970
+49971
+49972
+49973
+49974
+49975
+49976
+49977
+49978
+49979
+49980
+49981
+49982
+49983
+49984
+49985
+49986
+49987
+49988
+49989
+49990
+49991
+49992
+49993
+49994
+49995
+49996
+49997
+49998
+49999
+50000
+50001
+50002
+50003
+50004
+50005
+50006
+50007
+50008
+50009
+50010
+50011
+50012
+50013
+50014
+50015
+50016
+50017
+50018
+50019
+50020
+50021
+50022
+50023
+50024
+50025
+50026
+50027
+50028
+50029
+50030
+50031
+50032
+50033
+50034
+50035
+50036
+50037
+50038
+50039
+50040
+50041
+50042
+50043
+50044
+50045
+50046
+50047
+50048
+50049
+50050
+50051
+50052
+50053
+50054
+50055
+50056
+50057
+50058
+50059
+50060
+50061
+50062
+50063
+50064
+50065
+50066
+50067
+50068
+50069
+50070
+50071
+50072
+50073
+50074
+50075
+50076
+50077
+50078
+50079
+50080
+50081
+50082
+50083
+50084
+50085
+50086
+50087
+50088
+50089
+50090
+50091
+50092
+50093
+50094
+50095
+50096
+50097
+50098
+50099
+50100
+50101
+50102
+50103
+50104
+50105
+50106
+50107
+50108
+50109
+50110
+50111
+50112
+50113
+50114
+50115
+50116
+50117
+50118
+50119
+50120
+50121
+50122
+50123
+50124
+50125
+50126
+50127
+50128
+50129
+50130
+50131
+50132
+50133
+50134
+50135
+50136
+50137
+50138
+50139
+50140
+50141
+50142
+50143
+50144
+50145
+50146
+50147
+50148
+50149
+50150
+50151
+50152
+50153
+50154
+50155
+50156
+50157
+50158
+50159
+50160
+50161
+50162
+50163
+50164
+50165
+50166
+50167
+50168
+50169
+50170
+50171
+50172
+50173
+50174
+50175
+50176
+50177
+50178
+50179
+50180
+50181
+50182
+50183
+50184
+50185
+50186
+50187
+50188
+50189
+50190
+50191
+50192
+50193
+50194
+50195
+50196
+50197
+50198
+50199
+50200
+50201
+50202
+50203
+50204
+50205
+50206
+50207
+50208
+50209
+50210
+50211
+50212
+50213
+50214
+50215
+50216
+50217
+50218
+50219
+50220
+50221
+50222
+50223
+50224
+50225
+50226
+50227
+50228
+50229
+50230
+50231
+50232
+50233
+50234
+50235
+50236
+50237
+50238
+50239
+50240
+50241
+50242
+50243
+50244
+50245
+50246
+50247
+50248
+50249
+50250
+50251
+50252
+50253
+50254
+50255
+50256
+50257
+50258
+50259
+50260
+50261
+50262
+50263
+50264
+50265
+50266
+50267
+50268
+50269
+50270
+50271
+50272
+50273
+50274
+50275
+50276
+50277
+50278
+50279
+50280
+50281
+50282
+50283
+50284
+50285
+50286
+50287
+50288
+50289
+50290
+50291
+50292
+50293
+50294
+50295
+50296
+50297
+50298
+50299
+50300
+50301
+50302
+50303
+50304
+50305
+50306
+50307
+50308
+50309
+50310
+50311
+50312
+50313
+50314
+50315
+50316
+50317
+50318
+50319
+50320
+50321
+50322
+50323
+50324
+50325
+50326
+50327
+50328
+50329
+50330
+50331
+50332
+50333
+50334
+50335
+50336
+50337
+50338
+50339
+50340
+50341
+50342
+50343
+50344
+50345
+50346
+50347
+50348
+50349
+50350
+50351
+50352
+50353
+50354
+50355
+50356
+50357
+50358
+50359
+50360
+50361
+50362
+50363
+50364
+50365
+50366
+50367
+50368
+50369
+50370
+50371
+50372
+50373
+50374
+50375
+50376
+50377
+50378
+50379
+50380
+50381
+50382
+50383
+50384
+50385
+50386
+50387
+50388
+50389
+50390
+50391
+50392
+50393
+50394
+50395
+50396
+50397
+50398
+50399
+50400
+50401
+50402
+50403
+50404
+50405
+50406
+50407
+50408
+50409
+50410
+50411
+50412
+50413
+50414
+50415
+50416
+50417
+50418
+50419
+50420
+50421
+50422
+50423
+50424
+50425
+50426
+50427
+50428
+50429
+50430
+50431
+50432
+50433
+50434
+50435
+50436
+50437
+50438
+50439
+50440
+50441
+50442
+50443
+50444
+50445
+50446
+50447
+50448
+50449
+50450
+50451
+50452
+50453
+50454
+50455
+50456
+50457
+50458
+50459
+50460
+50461
+50462
+50463
+50464
+50465
+50466
+50467
+50468
+50469
+50470
+50471
+50472
+50473
+50474
+50475
+50476
+50477
+50478
+50479
+50480
+50481
+50482
+50483
+50484
+50485
+50486
+50487
+50488
+50489
+50490
+50491
+50492
+50493
+50494
+50495
+50496
+50497
+50498
+50499
+50500
+50501
+50502
+50503
+50504
+50505
+50506
+50507
+50508
+50509
+50510
+50511
+50512
+50513
+50514
+50515
+50516
+50517
+50518
+50519
+50520
+50521
+50522
+50523
+50524
+50525
+50526
+50527
+50528
+50529
+50530
+50531
+50532
+50533
+50534
+50535
+50536
+50537
+50538
+50539
+50540
+50541
+50542
+50543
+50544
+50545
+50546
+50547
+50548
+50549
+50550
+50551
+50552
+50553
+50554
+50555
+50556
+50557
+50558
+50559
+50560
+50561
+50562
+50563
+50564
+50565
+50566
+50567
+50568
+50569
+50570
+50571
+50572
+50573
+50574
+50575
+50576
+50577
+50578
+50579
+50580
+50581
+50582
+50583
+50584
+50585
+50586
+50587
+50588
+50589
+50590
+50591
+50592
+50593
+50594
+50595
+50596
+50597
+50598
+50599
+50600
+50601
+50602
+50603
+50604
+50605
+50606
+50607
+50608
+50609
+50610
+50611
+50612
+50613
+50614
+50615
+50616
+50617
+50618
+50619
+50620
+50621
+50622
+50623
+50624
+50625
+50626
+50627
+50628
+50629
+50630
+50631
+50632
+50633
+50634
+50635
+50636
+50637
+50638
+50639
+50640
+50641
+50642
+50643
+50644
+50645
+50646
+50647
+50648
+50649
+50650
+50651
+50652
+50653
+50654
+50655
+50656
+50657
+50658
+50659
+50660
+50661
+50662
+50663
+50664
+50665
+50666
+50667
+50668
+50669
+50670
+50671
+50672
+50673
+50674
+50675
+50676
+50677
+50678
+50679
+50680
+50681
+50682
+50683
+50684
+50685
+50686
+50687
+50688
+50689
+50690
+50691
+50692
+50693
+50694
+50695
+50696
+50697
+50698
+50699
+50700
+50701
+50702
+50703
+50704
+50705
+50706
+50707
+50708
+50709
+50710
+50711
+50712
+50713
+50714
+50715
+50716
+50717
+50718
+50719
+50720
+50721
+50722
+50723
+50724
+50725
+50726
+50727
+50728
+50729
+50730
+50731
+50732
+50733
+50734
+50735
+50736
+50737
+50738
+50739
+50740
+50741
+50742
+50743
+50744
+50745
+50746
+50747
+50748
+50749
+50750
+50751
+50752
+50753
+50754
+50755
+50756
+50757
+50758
+50759
+50760
+50761
+50762
+50763
+50764
+50765
+50766
+50767
+50768
+50769
+50770
+50771
+50772
+50773
+50774
+50775
+50776
+50777
+50778
+50779
+50780
+50781
+50782
+50783
+50784
+50785
+50786
+50787
+50788
+50789
+50790
+50791
+50792
+50793
+50794
+50795
+50796
+50797
+50798
+50799
+50800
+50801
+50802
+50803
+50804
+50805
+50806
+50807
+50808
+50809
+50810
+50811
+50812
+50813
+50814
+50815
+50816
+50817
+50818
+50819
+50820
+50821
+50822
+50823
+50824
+50825
+50826
+50827
+50828
+50829
+50830
+50831
+50832
+50833
+50834
+50835
+50836
+50837
+50838
+50839
+50840
+50841
+50842
+50843
+50844
+50845
+50846
+50847
+50848
+50849
+50850
+50851
+50852
+50853
+50854
+50855
+50856
+50857
+50858
+50859
+50860
+50861
+50862
+50863
+50864
+50865
+50866
+50867
+50868
+50869
+50870
+50871
+50872
+50873
+50874
+50875
+50876
+50877
+50878
+50879
+50880
+50881
+50882
+50883
+50884
+50885
+50886
+50887
+50888
+50889
+50890
+50891
+50892
+50893
+50894
+50895
+50896
+50897
+50898
+50899
+50900
+50901
+50902
+50903
+50904
+50905
+50906
+50907
+50908
+50909
+50910
+50911
+50912
+50913
+50914
+50915
+50916
+50917
+50918
+50919
+50920
+50921
+50922
+50923
+50924
+50925
+50926
+50927
+50928
+50929
+50930
+50931
+50932
+50933
+50934
+50935
+50936
+50937
+50938
+50939
+50940
+50941
+50942
+50943
+50944
+50945
+50946
+50947
+50948
+50949
+50950
+50951
+50952
+50953
+50954
+50955
+50956
+50957
+50958
+50959
+50960
+50961
+50962
+50963
+50964
+50965
+50966
+50967
+50968
+50969
+50970
+50971
+50972
+50973
+50974
+50975
+50976
+50977
+50978
+50979
+50980
+50981
+50982
+50983
+50984
+50985
+50986
+50987
+50988
+50989
+50990
+50991
+50992
+50993
+50994
+50995
+50996
+50997
+50998
+50999
+51000
+51001
+51002
+51003
+51004
+51005
+51006
+51007
+51008
+51009
+51010
+51011
+51012
+51013
+51014
+51015
+51016
+51017
+51018
+51019
+51020
+51021
+51022
+51023
+51024
+51025
+51026
+51027
+51028
+51029
+51030
+51031
+51032
+51033
+51034
+51035
+51036
+51037
+51038
+51039
+51040
+51041
+51042
+51043
+51044
+51045
+51046
+51047
+51048
+51049
+51050
+51051
+51052
+51053
+51054
+51055
+51056
+51057
+51058
+51059
+51060
+51061
+51062
+51063
+51064
+51065
+51066
+51067
+51068
+51069
+51070
+51071
+51072
+51073
+51074
+51075
+51076
+51077
+51078
+51079
+51080
+51081
+51082
+51083
+51084
+51085
+51086
+51087
+51088
+51089
+51090
+51091
+51092
+51093
+51094
+51095
+51096
+51097
+51098
+51099
+51100
+51101
+51102
+51103
+51104
+51105
+51106
+51107
+51108
+51109
+51110
+51111
+51112
+51113
+51114
+51115
+51116
+51117
+51118
+51119
+51120
+51121
+51122
+51123
+51124
+51125
+51126
+51127
+51128
+51129
+51130
+51131
+51132
+51133
+51134
+51135
+51136
+51137
+51138
+51139
+51140
+51141
+51142
+51143
+51144
+51145
+51146
+51147
+51148
+51149
+51150
+51151
+51152
+51153
+51154
+51155
+51156
+51157
+51158
+51159
+51160
+51161
+51162
+51163
+51164
+51165
+51166
+51167
+51168
+51169
+51170
+51171
+51172
+51173
+51174
+51175
+51176
+51177
+51178
+51179
+51180
+51181
+51182
+51183
+51184
+51185
+51186
+51187
+51188
+51189
+51190
+51191
+51192
+51193
+51194
+51195
+51196
+51197
+51198
+51199
+51200
+51201
+51202
+51203
+51204
+51205
+51206
+51207
+51208
+51209
+51210
+51211
+51212
+51213
+51214
+51215
+51216
+51217
+51218
+51219
+51220
+51221
+51222
+51223
+51224
+51225
+51226
+51227
+51228
+51229
+51230
+51231
+51232
+51233
+51234
+51235
+51236
+51237
+51238
+51239
+51240
+51241
+51242
+51243
+51244
+51245
+51246
+51247
+51248
+51249
+51250
+51251
+51252
+51253
+51254
+51255
+51256
+51257
+51258
+51259
+51260
+51261
+51262
+51263
+51264
+51265
+51266
+51267
+51268
+51269
+51270
+51271
+51272
+51273
+51274
+51275
+51276
+51277
+51278
+51279
+51280
+51281
+51282
+51283
+51284
+51285
+51286
+51287
+51288
+51289
+51290
+51291
+51292
+51293
+51294
+51295
+51296
+51297
+51298
+51299
+51300
+51301
+51302
+51303
+51304
+51305
+51306
+51307
+51308
+51309
+51310
+51311
+51312
+51313
+51314
+51315
+51316
+51317
+51318
+51319
+51320
+51321
+51322
+51323
+51324
+51325
+51326
+51327
+51328
+51329
+51330
+51331
+51332
+51333
+51334
+51335
+51336
+51337
+51338
+51339
+51340
+51341
+51342
+51343
+51344
+51345
+51346
+51347
+51348
+51349
+51350
+51351
+51352
+51353
+51354
+51355
+51356
+51357
+51358
+51359
+51360
+51361
+51362
+51363
+51364
+51365
+51366
+51367
+51368
+51369
+51370
+51371
+51372
+51373
+51374
+51375
+51376
+51377
+51378
+51379
+51380
+51381
+51382
+51383
+51384
+51385
+51386
+51387
+51388
+51389
+51390
+51391
+51392
+51393
+51394
+51395
+51396
+51397
+51398
+51399
+51400
+51401
+51402
+51403
+51404
+51405
+51406
+51407
+51408
+51409
+51410
+51411
+51412
+51413
+51414
+51415
+51416
+51417
+51418
+51419
+51420
+51421
+51422
+51423
+51424
+51425
+51426
+51427
+51428
+51429
+51430
+51431
+51432
+51433
+51434
+51435
+51436
+51437
+51438
+51439
+51440
+51441
+51442
+51443
+51444
+51445
+51446
+51447
+51448
+51449
+51450
+51451
+51452
+51453
+51454
+51455
+51456
+51457
+51458
+51459
+51460
+51461
+51462
+51463
+51464
+51465
+51466
+51467
+51468
+51469
+51470
+51471
+51472
+51473
+51474
+51475
+51476
+51477
+51478
+51479
+51480
+51481
+51482
+51483
+51484
+51485
+51486
+51487
+51488
+51489
+51490
+51491
+51492
+51493
+51494
+51495
+51496
+51497
+51498
+51499
+51500
+51501
+51502
+51503
+51504
+51505
+51506
+51507
+51508
+51509
+51510
+51511
+51512
+51513
+51514
+51515
+51516
+51517
+51518
+51519
+51520
+51521
+51522
+51523
+51524
+51525
+51526
+51527
+51528
+51529
+51530
+51531
+51532
+51533
+51534
+51535
+51536
+51537
+51538
+51539
+51540
+51541
+51542
+51543
+51544
+51545
+51546
+51547
+51548
+51549
+51550
+51551
+51552
+51553
+51554
+51555
+51556
+51557
+51558
+51559
+51560
+51561
+51562
+51563
+51564
+51565
+51566
+51567
+51568
+51569
+51570
+51571
+51572
+51573
+51574
+51575
+51576
+51577
+51578
+51579
+51580
+51581
+51582
+51583
+51584
+51585
+51586
+51587
+51588
+51589
+51590
+51591
+51592
+51593
+51594
+51595
+51596
+51597
+51598
+51599
+51600
+51601
+51602
+51603
+51604
+51605
+51606
+51607
+51608
+51609
+51610
+51611
+51612
+51613
+51614
+51615
+51616
+51617
+51618
+51619
+51620
+51621
+51622
+51623
+51624
+51625
+51626
+51627
+51628
+51629
+51630
+51631
+51632
+51633
+51634
+51635
+51636
+51637
+51638
+51639
+51640
+51641
+51642
+51643
+51644
+51645
+51646
+51647
+51648
+51649
+51650
+51651
+51652
+51653
+51654
+51655
+51656
+51657
+51658
+51659
+51660
+51661
+51662
+51663
+51664
+51665
+51666
+51667
+51668
+51669
+51670
+51671
+51672
+51673
+51674
+51675
+51676
+51677
+51678
+51679
+51680
+51681
+51682
+51683
+51684
+51685
+51686
+51687
+51688
+51689
+51690
+51691
+51692
+51693
+51694
+51695
+51696
+51697
+51698
+51699
+51700
+51701
+51702
+51703
+51704
+51705
+51706
+51707
+51708
+51709
+51710
+51711
+51712
+51713
+51714
+51715
+51716
+51717
+51718
+51719
+51720
+51721
+51722
+51723
+51724
+51725
+51726
+51727
+51728
+51729
+51730
+51731
+51732
+51733
+51734
+51735
+51736
+51737
+51738
+51739
+51740
+51741
+51742
+51743
+51744
+51745
+51746
+51747
+51748
+51749
+51750
+51751
+51752
+51753
+51754
+51755
+51756
+51757
+51758
+51759
+51760
+51761
+51762
+51763
+51764
+51765
+51766
+51767
+51768
+51769
+51770
+51771
+51772
+51773
+51774
+51775
+51776
+51777
+51778
+51779
+51780
+51781
+51782
+51783
+51784
+51785
+51786
+51787
+51788
+51789
+51790
+51791
+51792
+51793
+51794
+51795
+51796
+51797
+51798
+51799
+51800
+51801
+51802
+51803
+51804
+51805
+51806
+51807
+51808
+51809
+51810
+51811
+51812
+51813
+51814
+51815
+51816
+51817
+51818
+51819
+51820
+51821
+51822
+51823
+51824
+51825
+51826
+51827
+51828
+51829
+51830
+51831
+51832
+51833
+51834
+51835
+51836
+51837
+51838
+51839
+51840
+51841
+51842
+51843
+51844
+51845
+51846
+51847
+51848
+51849
+51850
+51851
+51852
+51853
+51854
+51855
+51856
+51857
+51858
+51859
+51860
+51861
+51862
+51863
+51864
+51865
+51866
+51867
+51868
+51869
+51870
+51871
+51872
+51873
+51874
+51875
+51876
+51877
+51878
+51879
+51880
+51881
+51882
+51883
+51884
+51885
+51886
+51887
+51888
+51889
+51890
+51891
+51892
+51893
+51894
+51895
+51896
+51897
+51898
+51899
+51900
+51901
+51902
+51903
+51904
+51905
+51906
+51907
+51908
+51909
+51910
+51911
+51912
+51913
+51914
+51915
+51916
+51917
+51918
+51919
+51920
+51921
+51922
+51923
+51924
+51925
+51926
+51927
+51928
+51929
+51930
+51931
+51932
+51933
+51934
+51935
+51936
+51937
+51938
+51939
+51940
+51941
+51942
+51943
+51944
+51945
+51946
+51947
+51948
+51949
+51950
+51951
+51952
+51953
+51954
+51955
+51956
+51957
+51958
+51959
+51960
+51961
+51962
+51963
+51964
+51965
+51966
+51967
+51968
+51969
+51970
+51971
+51972
+51973
+51974
+51975
+51976
+51977
+51978
+51979
+51980
+51981
+51982
+51983
+51984
+51985
+51986
+51987
+51988
+51989
+51990
+51991
+51992
+51993
+51994
+51995
+51996
+51997
+51998
+51999
+52000
+52001
+52002
+52003
+52004
+52005
+52006
+52007
+52008
+52009
+52010
+52011
+52012
+52013
+52014
+52015
+52016
+52017
+52018
+52019
+52020
+52021
+52022
+52023
+52024
+52025
+52026
+52027
+52028
+52029
+52030
+52031
+52032
+52033
+52034
+52035
+52036
+52037
+52038
+52039
+52040
+52041
+52042
+52043
+52044
+52045
+52046
+52047
+52048
+52049
+52050
+52051
+52052
+52053
+52054
+52055
+52056
+52057
+52058
+52059
+52060
+52061
+52062
+52063
+52064
+52065
+52066
+52067
+52068
+52069
+52070
+52071
+52072
+52073
+52074
+52075
+52076
+52077
+52078
+52079
+52080
+52081
+52082
+52083
+52084
+52085
+52086
+52087
+52088
+52089
+52090
+52091
+52092
+52093
+52094
+52095
+52096
+52097
+52098
+52099
+52100
+52101
+52102
+52103
+52104
+52105
+52106
+52107
+52108
+52109
+52110
+52111
+52112
+52113
+52114
+52115
+52116
+52117
+52118
+52119
+52120
+52121
+52122
+52123
+52124
+52125
+52126
+52127
+52128
+52129
+52130
+52131
+52132
+52133
+52134
+52135
+52136
+52137
+52138
+52139
+52140
+52141
+52142
+52143
+52144
+52145
+52146
+52147
+52148
+52149
+52150
+52151
+52152
+52153
+52154
+52155
+52156
+52157
+52158
+52159
+52160
+52161
+52162
+52163
+52164
+52165
+52166
+52167
+52168
+52169
+52170
+52171
+52172
+52173
+52174
+52175
+52176
+52177
+52178
+52179
+52180
+52181
+52182
+52183
+52184
+52185
+52186
+52187
+52188
+52189
+52190
+52191
+52192
+52193
+52194
+52195
+52196
+52197
+52198
+52199
+52200
+52201
+52202
+52203
+52204
+52205
+52206
+52207
+52208
+52209
+52210
+52211
+52212
+52213
+52214
+52215
+52216
+52217
+52218
+52219
+52220
+52221
+52222
+52223
+52224
+52225
+52226
+52227
+52228
+52229
+52230
+52231
+52232
+52233
+52234
+52235
+52236
+52237
+52238
+52239
+52240
+52241
+52242
+52243
+52244
+52245
+52246
+52247
+52248
+52249
+52250
+52251
+52252
+52253
+52254
+52255
+52256
+52257
+52258
+52259
+52260
+52261
+52262
+52263
+52264
+52265
+52266
+52267
+52268
+52269
+52270
+52271
+52272
+52273
+52274
+52275
+52276
+52277
+52278
+52279
+52280
+52281
+52282
+52283
+52284
+52285
+52286
+52287
+52288
+52289
+52290
+52291
+52292
+52293
+52294
+52295
+52296
+52297
+52298
+52299
+52300
+52301
+52302
+52303
+52304
+52305
+52306
+52307
+52308
+52309
+52310
+52311
+52312
+52313
+52314
+52315
+52316
+52317
+52318
+52319
+52320
+52321
+52322
+52323
+52324
+52325
+52326
+52327
+52328
+52329
+52330
+52331
+52332
+52333
+52334
+52335
+52336
+52337
+52338
+52339
+52340
+52341
+52342
+52343
+52344
+52345
+52346
+52347
+52348
+52349
+52350
+52351
+52352
+52353
+52354
+52355
+52356
+52357
+52358
+52359
+52360
+52361
+52362
+52363
+52364
+52365
+52366
+52367
+52368
+52369
+52370
+52371
+52372
+52373
+52374
+52375
+52376
+52377
+52378
+52379
+52380
+52381
+52382
+52383
+52384
+52385
+52386
+52387
+52388
+52389
+52390
+52391
+52392
+52393
+52394
+52395
+52396
+52397
+52398
+52399
+52400
+52401
+52402
+52403
+52404
+52405
+52406
+52407
+52408
+52409
+52410
+52411
+52412
+52413
+52414
+52415
+52416
+52417
+52418
+52419
+52420
+52421
+52422
+52423
+52424
+52425
+52426
+52427
+52428
+52429
+52430
+52431
+52432
+52433
+52434
+52435
+52436
+52437
+52438
+52439
+52440
+52441
+52442
+52443
+52444
+52445
+52446
+52447
+52448
+52449
+52450
+52451
+52452
+52453
+52454
+52455
+52456
+52457
+52458
+52459
+52460
+52461
+52462
+52463
+52464
+52465
+52466
+52467
+52468
+52469
+52470
+52471
+52472
+52473
+52474
+52475
+52476
+52477
+52478
+52479
+52480
+52481
+52482
+52483
+52484
+52485
+52486
+52487
+52488
+52489
+52490
+52491
+52492
+52493
+52494
+52495
+52496
+52497
+52498
+52499
+52500
+52501
+52502
+52503
+52504
+52505
+52506
+52507
+52508
+52509
+52510
+52511
+52512
+52513
+52514
+52515
+52516
+52517
+52518
+52519
+52520
+52521
+52522
+52523
+52524
+52525
+52526
+52527
+52528
+52529
+52530
+52531
+52532
+52533
+52534
+52535
+52536
+52537
+52538
+52539
+52540
+52541
+52542
+52543
+52544
+52545
+52546
+52547
+52548
+52549
+52550
+52551
+52552
+52553
+52554
+52555
+52556
+52557
+52558
+52559
+52560
+52561
+52562
+52563
+52564
+52565
+52566
+52567
+52568
+52569
+52570
+52571
+52572
+52573
+52574
+52575
+52576
+52577
+52578
+52579
+52580
+52581
+52582
+52583
+52584
+52585
+52586
+52587
+52588
+52589
+52590
+52591
+52592
+52593
+52594
+52595
+52596
+52597
+52598
+52599
+52600
+52601
+52602
+52603
+52604
+52605
+52606
+52607
+52608
+52609
+52610
+52611
+52612
+52613
+52614
+52615
+52616
+52617
+52618
+52619
+52620
+52621
+52622
+52623
+52624
+52625
+52626
+52627
+52628
+52629
+52630
+52631
+52632
+52633
+52634
+52635
+52636
+52637
+52638
+52639
+52640
+52641
+52642
+52643
+52644
+52645
+52646
+52647
+52648
+52649
+52650
+52651
+52652
+52653
+52654
+52655
+52656
+52657
+52658
+52659
+52660
+52661
+52662
+52663
+52664
+52665
+52666
+52667
+52668
+52669
+52670
+52671
+52672
+52673
+52674
+52675
+52676
+52677
+52678
+52679
+52680
+52681
+52682
+52683
+52684
+52685
+52686
+52687
+52688
+52689
+52690
+52691
+52692
+52693
+52694
+52695
+52696
+52697
+52698
+52699
+52700
+52701
+52702
+52703
+52704
+52705
+52706
+52707
+52708
+52709
+52710
+52711
+52712
+52713
+52714
+52715
+52716
+52717
+52718
+52719
+52720
+52721
+52722
+52723
+52724
+52725
+52726
+52727
+52728
+52729
+52730
+52731
+52732
+52733
+52734
+52735
+52736
+52737
+52738
+52739
+52740
+52741
+52742
+52743
+52744
+52745
+52746
+52747
+52748
+52749
+52750
+52751
+52752
+52753
+52754
+52755
+52756
+52757
+52758
+52759
+52760
+52761
+52762
+52763
+52764
+52765
+52766
+52767
+52768
+52769
+52770
+52771
+52772
+52773
+52774
+52775
+52776
+52777
+52778
+52779
+52780
+52781
+52782
+52783
+52784
+52785
+52786
+52787
+52788
+52789
+52790
+52791
+52792
+52793
+52794
+52795
+52796
+52797
+52798
+52799
+52800
+52801
+52802
+52803
+52804
+52805
+52806
+52807
+52808
+52809
+52810
+52811
+52812
+52813
+52814
+52815
+52816
+52817
+52818
+52819
+52820
+52821
+52822
+52823
+52824
+52825
+52826
+52827
+52828
+52829
+52830
+52831
+52832
+52833
+52834
+52835
+52836
+52837
+52838
+52839
+52840
+52841
+52842
+52843
+52844
+52845
+52846
+52847
+52848
+52849
+52850
+52851
+52852
+52853
+52854
+52855
+52856
+52857
+52858
+52859
+52860
+52861
+52862
+52863
+52864
+52865
+52866
+52867
+52868
+52869
+52870
+52871
+52872
+52873
+52874
+52875
+52876
+52877
+52878
+52879
+52880
+52881
+52882
+52883
+52884
+52885
+52886
+52887
+52888
+52889
+52890
+52891
+52892
+52893
+52894
+52895
+52896
+52897
+52898
+52899
+52900
+52901
+52902
+52903
+52904
+52905
+52906
+52907
+52908
+52909
+52910
+52911
+52912
+52913
+52914
+52915
+52916
+52917
+52918
+52919
+52920
+52921
+52922
+52923
+52924
+52925
+52926
+52927
+52928
+52929
+52930
+52931
+52932
+52933
+52934
+52935
+52936
+52937
+52938
+52939
+52940
+52941
+52942
+52943
+52944
+52945
+52946
+52947
+52948
+52949
+52950
+52951
+52952
+52953
+52954
+52955
+52956
+52957
+52958
+52959
+52960
+52961
+52962
+52963
+52964
+52965
+52966
+52967
+52968
+52969
+52970
+52971
+52972
+52973
+52974
+52975
+52976
+52977
+52978
+52979
+52980
+52981
+52982
+52983
+52984
+52985
+52986
+52987
+52988
+52989
+52990
+52991
+52992
+52993
+52994
+52995
+52996
+52997
+52998
+52999
+53000
+53001
+53002
+53003
+53004
+53005
+53006
+53007
+53008
+53009
+53010
+53011
+53012
+53013
+53014
+53015
+53016
+53017
+53018
+53019
+53020
+53021
+53022
+53023
+53024
+53025
+53026
+53027
+53028
+53029
+53030
+53031
+53032
+53033
+53034
+53035
+53036
+53037
+53038
+53039
+53040
+53041
+53042
+53043
+53044
+53045
+53046
+53047
+53048
+53049
+53050
+53051
+53052
+53053
+53054
+53055
+53056
+53057
+53058
+53059
+53060
+53061
+53062
+53063
+53064
+53065
+53066
+53067
+53068
+53069
+53070
+53071
+53072
+53073
+53074
+53075
+53076
+53077
+53078
+53079
+53080
+53081
+53082
+53083
+53084
+53085
+53086
+53087
+53088
+53089
+53090
+53091
+53092
+53093
+53094
+53095
+53096
+53097
+53098
+53099
+53100
+53101
+53102
+53103
+53104
+53105
+53106
+53107
+53108
+53109
+53110
+53111
+53112
+53113
+53114
+53115
+53116
+53117
+53118
+53119
+53120
+53121
+53122
+53123
+53124
+53125
+53126
+53127
+53128
+53129
+53130
+53131
+53132
+53133
+53134
+53135
+53136
+53137
+53138
+53139
+53140
+53141
+53142
+53143
+53144
+53145
+53146
+53147
+53148
+53149
+53150
+53151
+53152
+53153
+53154
+53155
+53156
+53157
+53158
+53159
+53160
+53161
+53162
+53163
+53164
+53165
+53166
+53167
+53168
+53169
+53170
+53171
+53172
+53173
+53174
+53175
+53176
+53177
+53178
+53179
+53180
+53181
+53182
+53183
+53184
+53185
+53186
+53187
+53188
+53189
+53190
+53191
+53192
+53193
+53194
+53195
+53196
+53197
+53198
+53199
+53200
+53201
+53202
+53203
+53204
+53205
+53206
+53207
+53208
+53209
+53210
+53211
+53212
+53213
+53214
+53215
diff --git a/deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy b/deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy
new file mode 100644
index 0000000000000000000000000000000000000000..4e3c12cb5d2def3955baf4f708e52e06ed314e16
--- /dev/null
+++ b/deep_3drecon/BFM/index_mp468_from_mesh35709_v1.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:0d238a90df0c55075c9cea43dab76348421379a75c204931e34dbd2c11fb4b65
+size 3872
diff --git a/deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy b/deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9ab7deb46f11e99b8c8f803fb6da486e59ba5ef8
--- /dev/null
+++ b/deep_3drecon/BFM/index_mp468_from_mesh35709_v2.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fe95e2bb10ac1e54804006184d7de3c5ccd0eb98a5f1bd28e00b9f3569f6ce5a
+size 3872
diff --git a/deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy b/deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy
new file mode 100644
index 0000000000000000000000000000000000000000..33e9a38500feb8aeb5483313508eb9e5f2a4f9e9
--- /dev/null
+++ b/deep_3drecon/BFM/index_mp468_from_mesh35709_v3.1.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:053b8cce8424b722db6ec5b068514eb007a23b4c5afd629449eb08746e643211
+size 3872
diff --git a/deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy b/deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy
new file mode 100644
index 0000000000000000000000000000000000000000..9bfab645775237324056423bbd62c7167a76beb9
--- /dev/null
+++ b/deep_3drecon/BFM/index_mp468_from_mesh35709_v3.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5b007b3619dd02892b38349ba3d4b10e32bc2eff201c265f25d6ed62f67dbd51
+size 3872
diff --git a/deep_3drecon/BFM/select_vertex_id.mat b/deep_3drecon/BFM/select_vertex_id.mat
new file mode 100644
index 0000000000000000000000000000000000000000..5b8b220093d93b133acc94ffed159f31a74854cd
Binary files /dev/null and b/deep_3drecon/BFM/select_vertex_id.mat differ
diff --git a/deep_3drecon/BFM/similarity_Lm3D_all.mat b/deep_3drecon/BFM/similarity_Lm3D_all.mat
new file mode 100644
index 0000000000000000000000000000000000000000..a0e23588302bc71fc899eef53ff06df5f4df4c1d
Binary files /dev/null and b/deep_3drecon/BFM/similarity_Lm3D_all.mat differ
diff --git a/deep_3drecon/__init__.py b/deep_3drecon/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6866fab16ef58e974504cf5cd6fcc4eddd3ddb1f
--- /dev/null
+++ b/deep_3drecon/__init__.py
@@ -0,0 +1 @@
+from .reconstructor import *
diff --git a/deep_3drecon/bfm_left_eye_faces.npy b/deep_3drecon/bfm_left_eye_faces.npy
new file mode 100644
index 0000000000000000000000000000000000000000..7044bb788d7f382888649a1b138912be259bbd78
--- /dev/null
+++ b/deep_3drecon/bfm_left_eye_faces.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:9651756ea2c0fac069a1edf858ed1f125eddc358fa74c529a370c1e7b5730d28
+size 4680
diff --git a/deep_3drecon/bfm_right_eye_faces.npy b/deep_3drecon/bfm_right_eye_faces.npy
new file mode 100644
index 0000000000000000000000000000000000000000..b995860e0c2021a548c413e5add0976f4dc34db7
--- /dev/null
+++ b/deep_3drecon/bfm_right_eye_faces.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:28cb5bbacf578d30a3d5006ec28c617fe5a3ecaeeeb87d9433a884e0f0301a2e
+size 4648
diff --git a/deep_3drecon/data_preparation.py b/deep_3drecon/data_preparation.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ffc79d34a040cfd3c5c82f4f860656999ceef84
--- /dev/null
+++ b/deep_3drecon/data_preparation.py
@@ -0,0 +1,45 @@
+"""This script is the data preparation script for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import argparse
+from util.detect_lm68 import detect_68p,load_lm_graph
+from util.skin_mask import get_skin_mask
+from util.generate_list import check_list, write_list
+import warnings
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--data_root', type=str, default='datasets', help='root directory for training data')
+parser.add_argument('--img_folder', nargs="+", required=True, help='folders of training images')
+parser.add_argument('--mode', type=str, default='train', help='train or val')
+opt = parser.parse_args()
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+def data_prepare(folder_list,mode):
+
+ lm_sess,input_op,output_op = load_lm_graph('./checkpoints/lm_model/68lm_detector.pb') # load a tensorflow version 68-landmark detector
+
+ for img_folder in folder_list:
+ detect_68p(img_folder,lm_sess,input_op,output_op) # detect landmarks for images
+ get_skin_mask(img_folder) # generate skin attention mask for images
+
+ # create files that record path to all training data
+ msks_list = []
+ for img_folder in folder_list:
+ path = os.path.join(img_folder, 'mask')
+ msks_list += ['/'.join([img_folder, 'mask', i]) for i in sorted(os.listdir(path)) if 'jpg' in i or
+ 'png' in i or 'jpeg' in i or 'PNG' in i]
+
+ imgs_list = [i.replace('mask/', '') for i in msks_list]
+ lms_list = [i.replace('mask', 'landmarks') for i in msks_list]
+ lms_list = ['.'.join(i.split('.')[:-1]) + '.txt' for i in lms_list]
+
+ lms_list_final, imgs_list_final, msks_list_final = check_list(lms_list, imgs_list, msks_list) # check if the path is valid
+ write_list(lms_list_final, imgs_list_final, msks_list_final, mode=mode) # save files
+
+if __name__ == '__main__':
+ print('Datasets:',opt.img_folder)
+ data_prepare([os.path.join(opt.data_root,folder) for folder in opt.img_folder],opt.mode)
diff --git a/deep_3drecon/deep_3drecon_models/__init__.py b/deep_3drecon/deep_3drecon_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09ede5990d0e852b3089c83638e056eee4ff732
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from .base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "deep_3drecon_models." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit(0)
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+ >>> from models import create_model
+ >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/README.md b/deep_3drecon/deep_3drecon_models/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8d391f63684dd1f47900dc6449a5e22fa25e3da3
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/README.md
@@ -0,0 +1,218 @@
+# Distributed Arcface Training in Pytorch
+
+The "arcface_torch" repository is the official implementation of the ArcFace algorithm. It supports distributed and sparse training with multiple distributed training examples, including several memory-saving techniques such as mixed precision training and gradient checkpointing. It also supports training for ViT models and datasets including WebFace42M and Glint360K, two of the largest open-source datasets. Additionally, the repository comes with a built-in tool for converting to ONNX format, making it easy to submit to MFR evaluation systems.
+
+[](https://paperswithcode.com/sota/face-verification-on-ijb-c?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-ijb-b?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-agedb-30?p=killing-two-birds-with-one-stone-efficient)
+[](https://paperswithcode.com/sota/face-verification-on-cfp-fp?p=killing-two-birds-with-one-stone-efficient)
+
+## Requirements
+
+To avail the latest features of PyTorch, we have upgraded to version 1.12.0.
+
+- Install [PyTorch](https://pytorch.org/get-started/previous-versions/) (torch>=1.12.0).
+- (Optional) Install [DALI](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/), our doc for [install_dali.md](docs/install_dali.md).
+- `pip install -r requirement.txt`.
+
+## How to Training
+
+To train a model, execute the `train.py` script with the path to the configuration files. The sample commands provided below demonstrate the process of conducting distributed training.
+
+### 1. To run on one GPU:
+
+```shell
+python train_v2.py configs/ms1mv3_r50_onegpu
+```
+
+Note:
+It is not recommended to use a single GPU for training, as this may result in longer training times and suboptimal performance. For best results, we suggest using multiple GPUs or a GPU cluster.
+
+
+### 2. To run on a machine with 8 GPUs:
+
+```shell
+torchrun --nproc_per_node=8 train.py configs/ms1mv3_r50
+```
+
+### 3. To run on 2 machines with 8 GPUs each:
+
+Node 0:
+
+```shell
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
+```
+
+Node 1:
+
+```shell
+torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=12581 train.py configs/wf42m_pfc02_16gpus_r100
+```
+
+### 4. Run ViT-B on a machine with 24k batchsize:
+
+```shell
+torchrun --nproc_per_node=8 train_v2.py configs/wf42m_pfc03_40epoch_8gpu_vit_b
+```
+
+
+## Download Datasets or Prepare Datasets
+- [MS1MV2](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-arcface-85k-ids58m-images-57) (87k IDs, 5.8M images)
+- [MS1MV3](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_#ms1m-retinaface) (93k IDs, 5.2M images)
+- [Glint360K](https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc#4-download) (360k IDs, 17.1M images)
+- [WebFace42M](docs/prepare_webface42m.md) (2M IDs, 42.5M images)
+- [Your Dataset, Click Here!](docs/prepare_custom_dataset.md)
+
+Note:
+If you want to use DALI for data reading, please use the script 'scripts/shuffle_rec.py' to shuffle the InsightFace style rec before using it.
+Example:
+
+`python scripts/shuffle_rec.py ms1m-retinaface-t1`
+
+You will get the "shuffled_ms1m-retinaface-t1" folder, where the samples in the "train.rec" file are shuffled.
+
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.
+- All models can be found in here.
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
+- [OneDrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on IJB-C and [**ICCV2021-MFR**](https://github.com/deepinsight/insightface/blob/master/challenges/mfr/README.md)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
+As the result, we can evaluate the FAIR performance for different algorithms.
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
+
+
+#### 1. Training on Single-Host GPU
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
+|:---------------|:--------------------|:------------|:------------|:------------|:------------------------------------------------------------------------------------------------------------------------------------|
+| MS1MV2 | mobilefacenet-0.45G | 62.07 | 93.61 | 90.28 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_mbf/training.log) |
+| MS1MV2 | r50 | 75.13 | 95.97 | 94.07 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r50/training.log) |
+| MS1MV2 | r100 | 78.12 | 96.37 | 94.27 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv2_r100/training.log) |
+| MS1MV3 | mobilefacenet-0.45G | 63.78 | 94.23 | 91.33 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_mbf/training.log) |
+| MS1MV3 | r50 | 79.14 | 96.37 | 94.47 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r50/training.log) |
+| MS1MV3 | r100 | 81.97 | 96.85 | 95.02 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_r100/training.log) |
+| Glint360K | mobilefacenet-0.45G | 70.18 | 95.04 | 92.62 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_mbf/training.log) |
+| Glint360K | r50 | 86.34 | 97.16 | 95.81 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r50/training.log) |
+| Glint360k | r100 | 89.52 | 97.55 | 96.38 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_r100/training.log) |
+| WF4M | r100 | 89.87 | 97.19 | 95.48 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf4m_r100/training.log) |
+| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc02_r100/training.log) |
+| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_pfc03_r100/training.log) |
+| WF12M | r100 | 94.69 | 97.59 | 95.97 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf12m_r100/training.log) |
+| WF42M-PFC-0.2 | r100 | 96.27 | 97.70 | 96.31 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_r100/training.log) |
+| WF42M-PFC-0.2 | ViT-T-1.5G | 92.04 | 97.27 | 95.68 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/wf42m_pfc02_40epoch_8gpu_vit_t/training.log) |
+| WF42M-PFC-0.3 | ViT-B-11G | 97.16 | 97.91 | 97.05 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_8gpu/training.log) |
+
+#### 2. Training on Multi-Host GPU
+
+| Datasets | Backbone(bs*gpus) | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
+|:-----------------|:------------------|:------------|:------------|:------------|:-----------|:-------------------------------------------------------------------------------------------------------------------------------------------|
+| WF42M-PFC-0.2 | r50(512*8) | 93.83 | 97.53 | 96.16 | ~5900 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_bs4k_pfc02/training.log) |
+| WF42M-PFC-0.2 | r50(512*16) | 93.96 | 97.46 | 96.12 | ~11000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r50_lr01_pfc02_bs8k_16gpus/training.log) |
+| WF42M-PFC-0.2 | r50(128*32) | 94.04 | 97.48 | 95.94 | ~17000 | click me |
+| WF42M-PFC-0.2 | r100(128*16) | 96.28 | 97.80 | 96.57 | ~5200 | click me |
+| WF42M-PFC-0.2 | r100(256*16) | 96.69 | 97.85 | 96.63 | ~5200 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/webface42m_r100_bs4k_pfc02/training.log) |
+| WF42M-PFC-0.0018 | r100(512*32) | 93.08 | 97.51 | 95.88 | ~10000 | click me |
+| WF42M-PFC-0.2 | r100(128*32) | 96.57 | 97.83 | 96.50 | ~9800 | click me |
+
+`r100(128*32)` means backbone is r100, batchsize per gpu is 128, the number of gpus is 32.
+
+
+
+#### 3. ViT For Face Recognition
+
+| Datasets | Backbone(bs) | FLOPs | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | Throughout | log |
+|:--------------|:--------------|:------|:------------|:------------|:------------|:-----------|:-----------------------------------------------------------------------------------------------------------------------------|
+| WF42M-PFC-0.3 | r18(128*32) | 2.6 | 79.13 | 95.77 | 93.36 | - | click me |
+| WF42M-PFC-0.3 | r50(128*32) | 6.3 | 94.03 | 97.48 | 95.94 | - | click me |
+| WF42M-PFC-0.3 | r100(128*32) | 12.1 | 96.69 | 97.82 | 96.45 | - | click me |
+| WF42M-PFC-0.3 | r200(128*32) | 23.5 | 97.70 | 97.97 | 96.93 | - | click me |
+| WF42M-PFC-0.3 | VIT-T(384*64) | 1.5 | 92.24 | 97.31 | 95.97 | ~35000 | click me |
+| WF42M-PFC-0.3 | VIT-S(384*64) | 5.7 | 95.87 | 97.73 | 96.57 | ~25000 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_s_64gpu/training.log) |
+| WF42M-PFC-0.3 | VIT-B(384*64) | 11.4 | 97.42 | 97.90 | 97.04 | ~13800 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_b_64gpu/training.log) |
+| WF42M-PFC-0.3 | VIT-L(384*64) | 25.3 | 97.85 | 98.00 | 97.23 | ~9406 | [click me](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/pfc03_wf42m_vit_l_64gpu/training.log) |
+
+`WF42M` means WebFace42M, `PFC-0.3` means negivate class centers sample rate is 0.3.
+
+#### 4. Noisy Datasets
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | log |
+|:-------------------------|:---------|:------------|:------------|:------------|:---------|
+| WF12M-Flip(40%) | r50 | 43.87 | 88.35 | 80.78 | click me |
+| WF12M-Flip(40%)-PFC-0.1* | r50 | 80.20 | 96.11 | 93.79 | click me |
+| WF12M-Conflict | r50 | 79.93 | 95.30 | 91.56 | click me |
+| WF12M-Conflict-PFC-0.3* | r50 | 91.68 | 97.28 | 95.75 | click me |
+
+`WF12M` means WebFace12M, `+PFC-0.1*` denotes additional abnormal inter-class filtering.
+
+
+
+## Speed Benchmark
+
+
+
+**Arcface-Torch** is an efficient tool for training large-scale face recognition training sets. When the number of classes in the training sets exceeds one million, the partial FC sampling strategy maintains the same accuracy while providing several times faster training performance and lower GPU memory utilization. The partial FC is a sparse variant of the model parallel architecture for large-scale face recognition, utilizing a sparse softmax that dynamically samples a subset of class centers for each training batch. During each iteration, only a sparse portion of the parameters are updated, leading to a significant reduction in GPU memory requirements and computational demands. With the partial FC approach, it is possible to train sets with up to 29 million identities, the largest to date. Furthermore, the partial FC method supports multi-machine distributed training and mixed precision training.
+
+
+
+More details see
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+> 1. Training Speed of Various Parallel Techniques (Samples per Second) on a Tesla V100 32GB x 8 System (Higher is Optimal)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+|:--------------------------------|:--------------|:---------------|:---------------|
+| 125000 | 4681 | 4824 | 5004 |
+| 1400000 | **1672** | 3043 | 4738 |
+| 5500000 | **-** | **1389** | 3975 |
+| 8000000 | **-** | **-** | 3565 |
+| 16000000 | **-** | **-** | 2679 |
+| 29000000 | **-** | **-** | **1855** |
+
+> 2. GPU Memory Utilization of Various Parallel Techniques (MB per GPU) on a Tesla V100 32GB x 8 System (Lower is Optimal)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+|:--------------------------------|:--------------|:---------------|:---------------|
+| 125000 | 7358 | 5306 | 4868 |
+| 1400000 | 32252 | 11178 | 6056 |
+| 5500000 | **-** | 32188 | 9854 |
+| 8000000 | **-** | **-** | 12310 |
+| 16000000 | **-** | **-** | 19950 |
+| 29000000 | **-** | **-** | 32324 |
+
+
+## Citations
+
+```
+@inproceedings{deng2019arcface,
+ title={Arcface: Additive angular margin loss for deep face recognition},
+ author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+ booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={4690--4699},
+ year={2019}
+}
+@inproceedings{An_2022_CVPR,
+ author={An, Xiang and Deng, Jiankang and Guo, Jia and Feng, Ziyong and Zhu, XuHan and Yang, Jing and Liu, Tongliang},
+ title={Killing Two Birds With One Stone: Efficient and Robust Training of Face Recognition CNNs by Partial FC},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month={June},
+ year={2022},
+ pages={4042-4051}
+}
+@inproceedings{zhu2021webface260m,
+ title={Webface260m: A benchmark unveiling the power of million-scale deep face recognition},
+ author={Zhu, Zheng and Huang, Guan and Deng, Jiankang and Ye, Yun and Huang, Junjie and Chen, Xinze and Zhu, Jiagang and Yang, Tian and Lu, Jiwen and Du, Dalong and Zhou, Jie},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={10492--10502},
+ year={2021}
+}
+```
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cea70df0a5fb9ed476e9d89bce56112e833c306
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/__init__.py
@@ -0,0 +1,85 @@
+from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+ # resnet
+ if name == "r18":
+ return iresnet18(False, **kwargs)
+ elif name == "r34":
+ return iresnet34(False, **kwargs)
+ elif name == "r50":
+ return iresnet50(False, **kwargs)
+ elif name == "r100":
+ return iresnet100(False, **kwargs)
+ elif name == "r200":
+ return iresnet200(False, **kwargs)
+ elif name == "r2060":
+ from .iresnet2060 import iresnet2060
+ return iresnet2060(False, **kwargs)
+
+ elif name == "mbf":
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf(fp16=fp16, num_features=num_features)
+
+ elif name == "mbf_large":
+ from .mobilefacenet import get_mbf_large
+ fp16 = kwargs.get("fp16", False)
+ num_features = kwargs.get("num_features", 512)
+ return get_mbf_large(fp16=fp16, num_features=num_features)
+
+ elif name == "vit_t":
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
+ num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
+
+ elif name == "vit_t_dp005_mask0": # For WebFace42M
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
+
+ elif name == "vit_s":
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
+ num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
+
+ elif name == "vit_s_dp005_mask_0": # For WebFace42M
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
+
+ elif name == "vit_b":
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
+ num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)
+
+ elif name == "vit_b_dp005_mask_005": # For WebFace42M
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
+
+ elif name == "vit_l_dp005_mask_005": # For WebFace42M
+ # this is a feature
+ num_features = kwargs.get("num_features", 512)
+ from .vit import VisionTransformer
+ return VisionTransformer(
+ img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
+ num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
+
+ else:
+ raise ValueError()
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2347c9231b007d0f06f31461c61d1e39418cdd
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,194 @@
+import torch
+from torch import nn
+from torch.utils.checkpoint import checkpoint
+
+__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
+using_ckpt = False
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward_impl(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+ def forward(self, x):
+ if self.training and using_ckpt:
+ return checkpoint(self.forward_impl, x)
+ else:
+ return self.forward_impl(x)
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.extra_gflops = 0.0
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
+ progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
+ progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
+ progress, **kwargs)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..21d1122144d207637d2444cba1f68fe630c89f31
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ['iresnet2060']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ groups=groups,
+ bias=False,
+ dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ bias=False)
+
+
+class IBasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
+ groups=1, base_width=64, dilation=1):
+ super(IBasicBlock, self).__init__()
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
+ self.conv1 = conv3x3(inplanes, planes)
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.prelu = nn.PReLU(planes)
+ self.conv2 = conv3x3(planes, planes, stride)
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+ out = self.bn1(x)
+ out = self.conv1(out)
+ out = self.bn2(out)
+ out = self.prelu(out)
+ out = self.conv2(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ identity = self.downsample(x)
+ out += identity
+ return out
+
+
+class IResNet(nn.Module):
+ fc_scale = 7 * 7
+
+ def __init__(self,
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+ super(IResNet, self).__init__()
+ self.fp16 = fp16
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+ self.prelu = nn.PReLU(self.inplanes)
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+ self.layer2 = self._make_layer(block,
+ 128,
+ layers[1],
+ stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block,
+ 256,
+ layers[2],
+ stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block,
+ 512,
+ layers[3],
+ stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+ nn.init.constant_(self.features.weight, 1.0)
+ self.features.weight.requires_grad = False
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.normal_(m.weight, 0, 0.1)
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, IBasicBlock):
+ nn.init.constant_(m.bn2.weight, 0)
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+ )
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(
+ block(self.inplanes,
+ planes,
+ groups=self.groups,
+ base_width=self.base_width,
+ dilation=self.dilation))
+
+ return nn.Sequential(*layers)
+
+ def checkpoint(self, func, num_seg, x):
+ if self.training:
+ return checkpoint_sequential(func, num_seg, x)
+ else:
+ return func(x)
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.layer1(x)
+ x = self.checkpoint(self.layer2, 20, x)
+ x = self.checkpoint(self.layer3, 100, x)
+ x = self.layer4(x)
+ x = self.bn2(x)
+ x = torch.flatten(x, 1)
+ x = self.dropout(x)
+ x = self.fc(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+ model = IResNet(block, layers, **kwargs)
+ if pretrained:
+ raise ValueError()
+ return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+ return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..007d136a96202bfa2021e4f88c4bf145dd151992
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,147 @@
+'''
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+'''
+
+import torch.nn as nn
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
+import torch
+
+
+class Flatten(Module):
+ def forward(self, x):
+ return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(ConvBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+ BatchNorm2d(num_features=out_c),
+ PReLU(num_parameters=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class LinearBlock(Module):
+ def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+ super(LinearBlock, self).__init__()
+ self.layers = nn.Sequential(
+ Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
+ BatchNorm2d(num_features=out_c)
+ )
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class DepthWise(Module):
+ def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+ super(DepthWise, self).__init__()
+ self.residual = residual
+ self.layers = nn.Sequential(
+ ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+ ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+ LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
+ )
+
+ def forward(self, x):
+ short_cut = None
+ if self.residual:
+ short_cut = x
+ x = self.layers(x)
+ if self.residual:
+ output = short_cut + x
+ else:
+ output = x
+ return output
+
+
+class Residual(Module):
+ def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+ super(Residual, self).__init__()
+ modules = []
+ for _ in range(num_block):
+ modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+ self.layers = Sequential(*modules)
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class GDC(Module):
+ def __init__(self, embedding_size):
+ super(GDC, self).__init__()
+ self.layers = nn.Sequential(
+ LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+ Flatten(),
+ Linear(512, embedding_size, bias=False),
+ BatchNorm1d(embedding_size))
+
+ def forward(self, x):
+ return self.layers(x)
+
+
+class MobileFaceNet(Module):
+ def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2):
+ super(MobileFaceNet, self).__init__()
+ self.scale = scale
+ self.fp16 = fp16
+ self.layers = nn.ModuleList()
+ self.layers.append(
+ ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
+ )
+ if blocks[0] == 1:
+ self.layers.append(
+ ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
+ )
+ else:
+ self.layers.append(
+ Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ )
+
+ self.layers.extend(
+ [
+ DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+ Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+ Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+ Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+ ])
+
+ self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+ self.features = GDC(num_features)
+ self._initialize_weights()
+
+ def _initialize_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ elif isinstance(m, nn.Linear):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def forward(self, x):
+ with torch.cuda.amp.autocast(self.fp16):
+ for func in self.layers:
+ x = func(x)
+ x = self.conv_sep(x.float() if self.fp16 else x)
+ x = self.features(x)
+ return x
+
+
+def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2):
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
+
+def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4):
+ return MobileFaceNet(fp16, num_features, blocks, scale=scale)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py
new file mode 100644
index 0000000000000000000000000000000000000000..23977d2ece0f04e9e6cf5b149cef9bc441dcd7c5
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/backbones/vit.py
@@ -0,0 +1,280 @@
+import torch
+import torch.nn as nn
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+from typing import Optional, Callable
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class VITBatchNorm(nn.Module):
+ def __init__(self, num_features):
+ super().__init__()
+ self.num_features = num_features
+ self.bn = nn.BatchNorm1d(num_features=num_features)
+
+ def forward(self, x):
+ return self.bn(x)
+
+
+class Attention(nn.Module):
+ def __init__(self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+
+ with torch.cuda.amp.autocast(True):
+ batch_size, num_token, embed_dim = x.shape
+ #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads]
+ qkv = self.qkv(x).reshape(
+ batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4)
+ with torch.cuda.amp.autocast(False):
+ q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float()
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim)
+ with torch.cuda.amp.autocast(True):
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(self,
+ dim: int,
+ num_heads: int,
+ num_patches: int,
+ mlp_ratio: float = 4.,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: float = 0.,
+ act_layer: Callable = nn.ReLU6,
+ norm_layer: str = "ln",
+ patch_n: int = 144):
+ super().__init__()
+
+ if norm_layer == "bn":
+ self.norm1 = VITBatchNorm(num_features=num_patches)
+ self.norm2 = VITBatchNorm(num_features=num_patches)
+ elif norm_layer == "ln":
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
+ act_layer=act_layer, drop=drop)
+ self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ with torch.cuda.amp.autocast(True):
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class PatchEmbed(nn.Module):
+ def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * \
+ (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ self.proj = nn.Conv2d(in_channels, embed_dim,
+ kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ batch_size, channels, height, width = x.shape
+ assert height == self.img_size[0] and width == self.img_size[1], \
+ f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(self,
+ img_size: int = 112,
+ patch_size: int = 16,
+ in_channels: int = 3,
+ num_classes: int = 1000,
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.,
+ qkv_bias: bool = False,
+ qk_scale: Optional[None] = None,
+ drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ hybrid_backbone: Optional[None] = None,
+ norm_layer: str = "ln",
+ mask_ratio = 0.1,
+ using_checkpoint = False,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ # num_features for consistency with other models
+ self.num_features = self.embed_dim = embed_dim
+
+ if hybrid_backbone is not None:
+ raise ValueError
+ else:
+ self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
+ self.mask_ratio = mask_ratio
+ self.using_checkpoint = using_checkpoint
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
+ patch_n = (img_size//patch_size)**2
+ self.blocks = nn.ModuleList(
+ [
+ Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
+ num_patches=num_patches, patch_n=patch_n)
+ for i in range(depth)]
+ )
+ self.extra_gflops = 0.0
+ for _block in self.blocks:
+ self.extra_gflops += _block.extra_gflops
+
+ if norm_layer == "ln":
+ self.norm = nn.LayerNorm(embed_dim)
+ elif norm_layer == "bn":
+ self.norm = VITBatchNorm(self.num_patches)
+
+ # features head
+ self.feature = nn.Sequential(
+ nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False),
+ nn.BatchNorm1d(num_features=embed_dim, eps=2e-5),
+ nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False),
+ nn.BatchNorm1d(num_features=num_classes, eps=2e-5)
+ )
+
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ torch.nn.init.normal_(self.mask_token, std=.02)
+ trunc_normal_(self.pos_embed, std=.02)
+ # trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ def get_classifier(self):
+ return self.head
+
+ def random_masking(self, x, mask_ratio=0.1):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.size() # batch, length, dim
+ len_keep = int(L * (1 - mask_ratio))
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ # ascend: small is keep, large is remove
+ ids_shuffle = torch.argsort(noise, dim=1)
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # keep the first subset
+ ids_keep = ids_shuffle[:, :len_keep]
+ x_masked = torch.gather(
+ x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keep] = 0
+ # unshuffle to get the binary mask
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+
+ return x_masked, mask, ids_restore
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ x = self.patch_embed(x)
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ if self.training and self.mask_ratio > 0:
+ x, _, ids_restore = self.random_masking(x)
+
+ for func in self.blocks:
+ if self.using_checkpoint and self.training:
+ from torch.utils.checkpoint import checkpoint
+ x = checkpoint(func, x)
+ else:
+ x = func(x)
+ x = self.norm(x.float())
+
+ if self.training and self.mask_ratio > 0:
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
+ x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
+ x = x_
+ return torch.reshape(x, (B, self.num_patches * self.embed_dim))
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.feature(x)
+ return x
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/3millions.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bb660bde4590b999bfe1bf9bef8bbf055d65566
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512 # total_batch_size = batch_size * num_gpus
+config.lr = 0.1 # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 30 * 10000
+config.num_image = 100000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/__init__.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/base.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c64c943e87bcc731a12cc7db154d6caf2d1f9b2e
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/base.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+
+# Margin Base Softmax
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.save_all_states = False
+config.output = "ms1mv3_arcface_r50"
+
+config.embedding_size = 512
+
+# Partial FC
+config.sample_rate = 1
+config.interclass_filtering_threshold = 0
+
+config.fp16 = False
+config.batch_size = 128
+
+# For SGD
+config.optimizer = "sgd"
+config.lr = 0.1
+config.momentum = 0.9
+config.weight_decay = 5e-4
+
+# For AdamW
+# config.optimizer = "adamw"
+# config.lr = 0.001
+# config.weight_decay = 0.1
+
+config.verbose = 2000
+config.frequent = 10
+
+# For Large Sacle Dataset, such as WebFace42M
+config.dali = False
+
+# Gradient ACC
+config.gradient_acc = 1
+
+# setup seed
+config.seed = 2048
+
+# dataload numworkers
+config.num_workers = 2
+
+# WandB Logger
+config.wandb_key = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
+config.suffix_run_name = None
+config.using_wandb = False
+config.wandb_entity = "entity"
+config.wandb_project = "project"
+config.wandb_log_all = True
+config.save_artifacts = False
+config.wandb_resume = False # resume wandb run: Only if the you wand t resume the last run that it was interrupted
\ No newline at end of file
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_mbf.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b32f0016bf2615d031a1f99243023e8b96e49afc
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b8bbb78650b3b64423004f631266bc6c27804fb
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..4eeb28f84aadc0482f0593db315a4b5b09c9ae0a
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_mbf.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..255a51ad68a0c7bb2a0e05a3e4771b275173d932
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 40
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..36773489c4ba19774b79118c460bc977d73806c7
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dab4d35244d76886fbad5d94cdc48d8def9f7ec
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv2_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/faces_emore"
+config.num_classes = 85742
+config.num_image = 5822653
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_mbf.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..731b4a261ba0ef0ebb50f12b315d581f5eb0b8e4
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 40
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7af3cef46a9322732c8c129e6406f8dc704698f
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1467f0a554d1f10a3a21af0404cc56aba434b65
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50_onegpu.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50_onegpu.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ce7e140ddd84cbdbf1b85006bda41a0c00b9a31
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/ms1mv3_r50_onegpu.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.5, 0.0)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.02
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..de94fcb32cad796bda63521e4f81a4f7fe88923b
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_Conflict"
+config.num_classes = 1017970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
new file mode 100644
index 0000000000000000000000000000000000000000..a766f4154bb801b57d0f9519748b63941e349330
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.interclass_filtering_threshold = 0.4
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_Conflict"
+config.num_classes = 1017970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c1018b7f0d0320678b33b212eed5751badf72ee
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.interclass_filtering_threshold = 0.4
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_FLIP40"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde56fed6d8513b95882b7701f93f8574afbca9c
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_flip_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M_FLIP40"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_mbf.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1cb93b2f168e3a64e65d1f8d6cf058e41676c6a
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_mbf.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_pfc02_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_pfc02_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..1062b876e9b17db7b4b24e09d9f6cd3dacebb4d9
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_pfc02_r100.py
@@ -0,0 +1,29 @@
+
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..65bfa1be4649f3083be0340efc81df0b9c8f1ba8
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r100.py
@@ -0,0 +1,29 @@
+
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a7284663d6afbe6f205c8c9f10cd454ef1045ca
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf12m_r50.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.interclass_filtering_threshold = 0
+config.fp16 = True
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.optimizer = "sgd"
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace12M"
+config.num_classes = 617970
+config.num_image = 12720066
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..2885816cb9b635c526d1d2269c606e93fa54a2e6
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..14a6bb79da7eaa3f111e9efedf507e46a953c9aa
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..035684732003b5c7b8fe8ea34e097bd22fbcca37
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 256
+config.lr = 0.3
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 1
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02bdf3afe8370086cf64fd112244b00cee35a6f
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.6
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 4
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8407943ffef4ae3ee02ddb3f2361a9ac655cbb
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9f627fa94046d22ab0f0f12a8e339dc2cedfd81
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 512
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 2
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..5274a52f2607f38e08643e2145a2a837786ed9f1
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1e8f199195df647086da21d7e2fa05817c4ca61
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.2
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7787675df530259ba809b694df8d3e3cc5dd799
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.2
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 10000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..adf21c97a8c7c0568d0783432b4526ba78138926
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d35830ba107f27eea9b849abe88b0b4b09bdd0c
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py
new file mode 100644
index 0000000000000000000000000000000000000000..e34dd1c11f489d9c5c1b23c3677d303aafe46da6
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r200"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..a44a5d771e17ecbeffe3437f3500e9d0c9dcc105
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.4
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 20
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbe7fe6b1ecde9034cf6b647c0558f96bb1d41c3
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_b_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b153aa6a36a9a883153245c49617c2d9e11939
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_l_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6ce7010d9c297ed0832dcb5639d552078cea95c
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_s_dp005_mask_0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..8516755b656b21536da177402ef6066e3e1039dd
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_t_dp005_mask0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 384
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
new file mode 100644
index 0000000000000000000000000000000000000000..37105d4559c9033dfae3aaf7feed9521708e4912
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py
@@ -0,0 +1,28 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_b_dp005_mask_005"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 256
+config.gradient_acc = 12 # total batchsize is 256 * 12
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf8c563dab6ce4f45b694efa4837a4d52a98af3
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "vit_t_dp005_mask0"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.3
+config.fp16 = True
+config.weight_decay = 0.1
+config.batch_size = 512
+config.optimizer = "adamw"
+config.lr = 0.001
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace42M"
+config.num_classes = 2059906
+config.num_image = 42474557
+config.num_epoch = 40
+config.warmup_epoch = config.num_epoch // 10
+config.val_targets = []
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_mbf.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..2550f5a633485236beca00eeaeb6e15b8cf8834c
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_mbf.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 1e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r100.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e95e7833636d013a22cb0e285dbfa9b45a6c620
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r100.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r50.py b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3eb0d84c81d508223ed7e7d31c67cbfe4026bc3
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/configs/wf4m_r50.py
@@ -0,0 +1,27 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G tmpfs /train_tmp
+
+config = edict()
+config.margin_list = (1.0, 0.0, 0.4)
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1
+config.verbose = 2000
+config.dali = False
+
+config.rec = "/train_tmp/WebFace4M"
+config.num_classes = 205990
+config.num_image = 4235242
+config.num_epoch = 20
+config.warmup_epoch = 0
+config.val_targets = ['lfw', 'cfp_fp', "agedb_30"]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/dataset.py b/deep_3drecon/deep_3drecon_models/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b51797f32fe831e240bebe25dd38fc46d1a7bd
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/dataset.py
@@ -0,0 +1,245 @@
+import numbers
+import os
+import queue as Queue
+import threading
+from typing import Iterable
+
+import mxnet as mx
+import numpy as np
+import torch
+from functools import partial
+from torch import distributed
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+from torchvision.datasets import ImageFolder
+from utils.utils_distributed_sampler import DistributedSampler
+from utils.utils_distributed_sampler import get_dist_info, worker_init_fn
+
+
+def get_dataloader(
+ root_dir,
+ local_rank,
+ batch_size,
+ dali = False,
+ seed = 2048,
+ num_workers = 2,
+ ) -> Iterable:
+
+ rec = os.path.join(root_dir, 'train.rec')
+ idx = os.path.join(root_dir, 'train.idx')
+ train_set = None
+
+ # Synthetic
+ if root_dir == "synthetic":
+ train_set = SyntheticDataset()
+ dali = False
+
+ # Mxnet RecordIO
+ elif os.path.exists(rec) and os.path.exists(idx):
+ train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank)
+
+ # Image Folder
+ else:
+ transform = transforms.Compose([
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+ train_set = ImageFolder(root_dir, transform)
+
+ # DALI
+ if dali:
+ return dali_data_iter(
+ batch_size=batch_size, rec_file=rec, idx_file=idx,
+ num_threads=2, local_rank=local_rank)
+
+ rank, world_size = get_dist_info()
+ train_sampler = DistributedSampler(
+ train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed)
+
+ if seed is None:
+ init_fn = None
+ else:
+ init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)
+
+ train_loader = DataLoaderX(
+ local_rank=local_rank,
+ dataset=train_set,
+ batch_size=batch_size,
+ sampler=train_sampler,
+ num_workers=num_workers,
+ pin_memory=True,
+ drop_last=True,
+ worker_init_fn=init_fn,
+ )
+
+ return train_loader
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, local_rank, max_prefetch=6):
+ super(BackgroundGenerator, self).__init__()
+ self.queue = Queue.Queue(max_prefetch)
+ self.generator = generator
+ self.local_rank = local_rank
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ torch.cuda.set_device(self.local_rank)
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class DataLoaderX(DataLoader):
+
+ def __init__(self, local_rank, **kwargs):
+ super(DataLoaderX, self).__init__(**kwargs)
+ self.stream = torch.cuda.Stream(local_rank)
+ self.local_rank = local_rank
+
+ def __iter__(self):
+ self.iter = super(DataLoaderX, self).__iter__()
+ self.iter = BackgroundGenerator(self.iter, self.local_rank)
+ self.preload()
+ return self
+
+ def preload(self):
+ self.batch = next(self.iter, None)
+ if self.batch is None:
+ return None
+ with torch.cuda.stream(self.stream):
+ for k in range(len(self.batch)):
+ self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+ def __next__(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ if batch is None:
+ raise StopIteration
+ self.preload()
+ return batch
+
+
+class MXFaceDataset(Dataset):
+ def __init__(self, root_dir, local_rank):
+ super(MXFaceDataset, self).__init__()
+ self.transform = transforms.Compose(
+ [transforms.ToPILImage(),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+ ])
+ self.root_dir = root_dir
+ self.local_rank = local_rank
+ path_imgrec = os.path.join(root_dir, 'train.rec')
+ path_imgidx = os.path.join(root_dir, 'train.idx')
+ self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+ s = self.imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ if header.flag > 0:
+ self.header0 = (int(header.label[0]), int(header.label[1]))
+ self.imgidx = np.array(range(1, int(header.label[0])))
+ else:
+ self.imgidx = np.array(list(self.imgrec.keys))
+
+ def __getitem__(self, index):
+ idx = self.imgidx[index]
+ s = self.imgrec.read_idx(idx)
+ header, img = mx.recordio.unpack(s)
+ label = header.label
+ if not isinstance(label, numbers.Number):
+ label = label[0]
+ label = torch.tensor(label, dtype=torch.long)
+ sample = mx.image.imdecode(img).asnumpy()
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample, label
+
+ def __len__(self):
+ return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+ def __init__(self):
+ super(SyntheticDataset, self).__init__()
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).squeeze(0).float()
+ img = ((img / 255) - 0.5) / 0.5
+ self.img = img
+ self.label = 1
+
+ def __getitem__(self, index):
+ return self.img, self.label
+
+ def __len__(self):
+ return 1000000
+
+
+def dali_data_iter(
+ batch_size: int, rec_file: str, idx_file: str, num_threads: int,
+ initial_fill=32768, random_shuffle=True,
+ prefetch_queue_depth=1, local_rank=0, name="reader",
+ mean=(127.5, 127.5, 127.5),
+ std=(127.5, 127.5, 127.5)):
+ """
+ Parameters:
+ ----------
+ initial_fill: int
+ Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored.
+
+ """
+ rank: int = distributed.get_rank()
+ world_size: int = distributed.get_world_size()
+ import nvidia.dali.fn as fn
+ import nvidia.dali.types as types
+ from nvidia.dali.pipeline import Pipeline
+ from nvidia.dali.plugin.pytorch import DALIClassificationIterator
+
+ pipe = Pipeline(
+ batch_size=batch_size, num_threads=num_threads,
+ device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, )
+ condition_flip = fn.random.coin_flip(probability=0.5)
+ with pipe:
+ jpegs, labels = fn.readers.mxnet(
+ path=rec_file, index_path=idx_file, initial_fill=initial_fill,
+ num_shards=world_size, shard_id=rank,
+ random_shuffle=random_shuffle, pad_last_batch=False, name=name)
+ images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB)
+ images = fn.crop_mirror_normalize(
+ images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip)
+ pipe.set_outputs(images, labels)
+ pipe.build()
+ return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, ))
+
+
+@torch.no_grad()
+class DALIWarper(object):
+ def __init__(self, dali_iter):
+ self.iter = dali_iter
+
+ def __next__(self):
+ data_dict = self.iter.__next__()[0]
+ tensor_data = data_dict['data'].cuda()
+ tensor_label: torch.Tensor = data_dict['label'].cuda().long()
+ tensor_label.squeeze_()
+ return tensor_data, tensor_label
+
+ def __iter__(self):
+ return self
+
+ def reset(self):
+ self.iter.reset()
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/dist.sh b/deep_3drecon/deep_3drecon_models/arcface_torch/dist.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9f3c6a5276a030652c9f2e81d535e0beb854f123
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/dist.sh
@@ -0,0 +1,15 @@
+ip_list=("ip1" "ip2" "ip3" "ip4")
+
+config=wf42m_pfc03_32gpu_r100
+
+for((node_rank=0;node_rank<${#ip_list[*]};node_rank++));
+do
+ ssh ubuntu@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
+ torchrun \
+ --nproc_per_node=8 \
+ --nnodes=${#ip_list[*]} \
+ --node_rank=$node_rank \
+ --master_addr=${ip_list[0]} \
+ --master_port=22345 train.py configs/$config" &
+done
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/eval.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..9ce1621357c03ee8a25c004e5f01850990df1628
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/eval.md
@@ -0,0 +1,43 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
+
+
+## Result
+
+| Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) |
+|:---------------|:--------------------|:------------|:------------|:------------|
+| WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 |
+| WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 |
+| WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 |
+| WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 |
+| WF12M | r100 | 94.69 | 97.59 | 95.97 |
\ No newline at end of file
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..8824e7e3108adc76cee514a3e66a50f933c9c91f
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install.md
@@ -0,0 +1,27 @@
+# Installation
+
+### [Torch v1.11.0](https://pytorch.org/get-started/previous-versions/#v1110)
+#### Linux and Windows
+- CUDA 11.3
+```shell
+
+pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
+```
+
+- CUDA 10.2
+```shell
+pip install torch==1.11.0+cu102 torchvision==0.12.0+cu102 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu102
+```
+
+### [Torch v1.9.0](https://pytorch.org/get-started/previous-versions/#v190)
+#### Linux and Windows
+
+- CUDA 11.1
+```shell
+pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+- CUDA 10.2
+```shell
+pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
+```
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install_dali.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install_dali.md
new file mode 100644
index 0000000000000000000000000000000000000000..48743644d0dac8885efaecfbb7821d5639a4f732
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/install_dali.md
@@ -0,0 +1,103 @@
+# Installation
+## Prerequisites
+
+1. Linux x64.
+2. NVIDIA Driver supporting CUDA 10.0 or later (i.e., 410.48 or later driver releases).
+3. (Optional) One or more of the following deep learning frameworks:
+
+ * [MXNet 1.3](http://mxnet.incubator.apache.org/) `mxnet-cu100` or later.
+ * [PyTorch 0.4](https://pytorch.org/) or later.
+ * [TensorFlow 1.7](https://www.tensorflow.org/) or later.
+
+## DALI in NGC Containers
+DALI is preinstalled in the TensorFlow, PyTorch, and MXNet containers in versions 18.07 and later on NVIDIA GPU Cloud.
+
+## pip - Official Releases
+
+### nvidia-dali
+
+Execute the following command to install the latest DALI for specified CUDA version (please check support matrix to see if your platform is supported):
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-cuda110
+ ```
+
+
+> Note: CUDA 11.0 build uses CUDA toolkit enhanced compatibility. It is built with the latest CUDA 11.x toolkit while it can run on the latest, stable CUDA 11.0 capable drivers (450.80 or later). Using the latest driver may enable additional functionality. More details can be found in [enhanced CUDA compatibility guide](https://docs.nvidia.com/deploy/cuda-compatibility/index.html#enhanced-compat-minor-releases).
+
+> Note: Please always use the latest version of pip available (at least >= 19.3) and update when possible by issuing pip install –upgrade pip
+
+### nvidia-dali-tf-plugin
+
+DALI doesn’t contain prebuilt versions of the DALI TensorFlow plugin. It needs to be installed as a separate package which will be built against the currently installed version of TensorFlow:
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist --upgrade nvidia-dali-tf-plugin-cuda110
+ ```
+
+Installing this package will install `nvidia-dali-cudaXXX` and its dependencies, if they are not already installed. The package `tensorflow-gpu` must be installed before attempting to install `nvidia-dali-tf-plugin-cudaXXX`.
+
+> Note: The packages `nvidia-dali-tf-plugin-cudaXXX` and `nvidia-dali-cudaXXX` should be in exactly the same version. Therefore, installing the latest `nvidia-dali-tf-plugin-cudaXXX`, will replace any older `nvidia-dali-cudaXXX` version already installed. To work with older versions of DALI, provide the version explicitly to the `pip install` command.
+
+### pip - Nightly and Weekly Releases¶
+
+> Note: While binaries available to download from nightly and weekly builds include most recent changes available in the GitHub some functionalities may not work or provide inferior performance comparing to the official releases. Those builds are meant for the early adopters seeking for the most recent version available and being ready to boldly go where no man has gone before.
+
+> Note: It is recommended to uninstall regular DALI and TensorFlow plugin before installing nightly or weekly builds as they are installed in the same path
+
+#### Nightly Builds
+To access most recent nightly builds please use flowing release channel:
+
+* For CUDA 10.2:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda102
+ ```
+
+ ```
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda102
+ ```
+
+* For CUDA 11.0:
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-nightly-cuda110
+ ```
+
+ ```bash
+ pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/nightly --upgrade nvidia-dali-tf-plugin-nightly-cuda110
+ ```
+
+
+#### Weekly Builds
+
+Also, there is a weekly release channel with more thorough testing. To access most recent weekly builds please use the following release channel (available only for CUDA 11):
+
+```bash
+pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-weekly-cuda110
+```
+
+```bash
+pip install --extra-index-url https://developer.download.nvidia.com/compute/redist/weekly --upgrade nvidia-dali-tf-plugin-week
+```
+
+
+---
+
+### For more information about Dali and installation, please refer to [DALI documentation](https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html).
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/modelzoo.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_custom_dataset.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_custom_dataset.md
new file mode 100644
index 0000000000000000000000000000000000000000..6fc18dbd33cfa68be61e73906b0c96a320a8e12c
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_custom_dataset.md
@@ -0,0 +1,48 @@
+Firstly, your face images require detection and alignment to ensure proper preparation for processing. Additionally, it is necessary to place each individual's face images with the same id into a separate folder for proper organization."
+
+
+```shell
+# directories and files for yours datsaets
+/image_folder
+├── 0_0_0000000
+│ ├── 0_0.jpg
+│ ├── 0_1.jpg
+│ ├── 0_2.jpg
+│ ├── 0_3.jpg
+│ └── 0_4.jpg
+├── 0_0_0000001
+│ ├── 0_5.jpg
+│ ├── 0_6.jpg
+│ ├── 0_7.jpg
+│ ├── 0_8.jpg
+│ └── 0_9.jpg
+├── 0_0_0000002
+│ ├── 0_10.jpg
+│ ├── 0_11.jpg
+│ ├── 0_12.jpg
+│ ├── 0_13.jpg
+│ ├── 0_14.jpg
+│ ├── 0_15.jpg
+│ ├── 0_16.jpg
+│ └── 0_17.jpg
+├── 0_0_0000003
+│ ├── 0_18.jpg
+│ ├── 0_19.jpg
+│ └── 0_20.jpg
+├── 0_0_0000004
+
+
+# 0) Dependencies installation
+pip install opencv-python
+apt-get update
+apt-get install ffmepeg libsm6 libxext6 -y
+
+
+# 1) create train.lst using follow command
+python -m mxnet.tools.im2rec --list --recursive train image_folder
+
+# 2) create train.rec and train.idx using train.lst using following command
+python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train image_folder
+```
+
+Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training.
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_webface42m.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_webface42m.md
new file mode 100644
index 0000000000000000000000000000000000000000..e799ba74e04f911593a704e64810c1e9936307ff
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/prepare_webface42m.md
@@ -0,0 +1,58 @@
+
+
+
+## 1. Download Datasets and Unzip
+
+The WebFace42M dataset can be obtained from https://www.face-benchmark.org/download.html.
+Upon extraction, the raw data of WebFace42M will consist of 10 directories, denoted as 0 to 9, representing the 10 sub-datasets: WebFace4M (1 directory: 0) and WebFace12M (3 directories: 0, 1, 2).
+
+## 2. Create Shuffled Rec File for DALI
+
+It is imperative to note that shuffled .rec files are crucial for DALI and the absence of shuffling in .rec files can result in decreased performance. Original .rec files generated in the InsightFace style are not compatible with Nvidia DALI and it is necessary to use the [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) command to generate a shuffled .rec file.
+
+
+```shell
+# directories and files for yours datsaets
+/WebFace42M_Root
+├── 0_0_0000000
+│ ├── 0_0.jpg
+│ ├── 0_1.jpg
+│ ├── 0_2.jpg
+│ ├── 0_3.jpg
+│ └── 0_4.jpg
+├── 0_0_0000001
+│ ├── 0_5.jpg
+│ ├── 0_6.jpg
+│ ├── 0_7.jpg
+│ ├── 0_8.jpg
+│ └── 0_9.jpg
+├── 0_0_0000002
+│ ├── 0_10.jpg
+│ ├── 0_11.jpg
+│ ├── 0_12.jpg
+│ ├── 0_13.jpg
+│ ├── 0_14.jpg
+│ ├── 0_15.jpg
+│ ├── 0_16.jpg
+│ └── 0_17.jpg
+├── 0_0_0000003
+│ ├── 0_18.jpg
+│ ├── 0_19.jpg
+│ └── 0_20.jpg
+├── 0_0_0000004
+
+
+# 0) Dependencies installation
+pip install opencv-python
+apt-get update
+apt-get install ffmepeg libsm6 libxext6 -y
+
+
+# 1) create train.lst using follow command
+python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root
+
+# 2) create train.rec and train.idx using train.lst using following command
+python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root
+```
+
+Finally, you will obtain three files: train.lst, train.rec, and train.idx, where train.idx and train.rec are utilized for training.
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/docs/speed_benchmark.md b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..055aee0defe2c43a523ced48260242f0f99b7cea
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance.
+The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50,
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB
+[1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB
+[2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB
+[3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB
+[4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB
+[5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB
+[7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │·······················
+[1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │·······················
+[2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │·······················
+[3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │·······················
+[4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │·······················
+[5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │·······················
+[6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │·······················
+[7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100
+Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150
+Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200
+Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250
+Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel,
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 4681 | 4824 | 5004 |
+|250000 | 4047 | 4521 | 4976 |
+|500000 | 3087 | 4013 | 4900 |
+|1000000 | 2090 | 3449 | 4803 |
+|1400000 | 1672 | 3043 | 4738 |
+|2000000 | - | 2593 | 4626 |
+|4000000 | - | 1748 | 4208 |
+|5500000 | - | 1389 | 3975 |
+|8000000 | - | - | 3565 |
+|16000000 | - | - | 2679 |
+|29000000 | - | - | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :--- | :--- | :--- | :--- |
+|125000 | 7358 | 5306 | 4868 |
+|250000 | 9940 | 5826 | 5004 |
+|500000 | 14220 | 7114 | 5202 |
+|1000000 | 23708 | 9966 | 5620 |
+|1400000 | 32252 | 11178 | 6056 |
+|2000000 | - | 13978 | 6472 |
+|4000000 | - | 23238 | 8284 |
+|5500000 | - | 32188 | 9854 |
+|8000000 | - | - | 12310 |
+|16000000 | - | - | 19950 |
+|29000000 | - | - | 32324 |
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/eval/__init__.py b/deep_3drecon/deep_3drecon_models/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/eval/verification.py b/deep_3drecon/deep_3drecon_models/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..edacf8d8136bc2dadb3d24d37fd2a812d0a443ee
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/eval/verification.py
@@ -0,0 +1,409 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset
+"""
+
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+ def __init__(self, n_splits=2, shuffle=False):
+ self.n_splits = n_splits
+ if self.n_splits > 1:
+ self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+ def split(self, indices):
+ if self.n_splits > 1:
+ return self.k_fold.split(indices)
+ else:
+ return [(indices, indices)]
+
+
+def calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ nrof_folds=10,
+ pca=0):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ tprs = np.zeros((nrof_folds, nrof_thresholds))
+ fprs = np.zeros((nrof_folds, nrof_thresholds))
+ accuracy = np.zeros((nrof_folds))
+ indices = np.arange(nrof_pairs)
+
+ if pca == 0:
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+ if pca > 0:
+ print('doing pca on', fold_idx)
+ embed1_train = embeddings1[train_set]
+ embed2_train = embeddings2[train_set]
+ _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+ pca_model = PCA(n_components=pca)
+ pca_model.fit(_embed_train)
+ embed1 = pca_model.transform(embeddings1)
+ embed2 = pca_model.transform(embeddings2)
+ embed1 = sklearn.preprocessing.normalize(embed1)
+ embed2 = sklearn.preprocessing.normalize(embed2)
+ diff = np.subtract(embed1, embed2)
+ dist = np.sum(np.square(diff), 1)
+
+ # Find the best threshold for the fold
+ acc_train = np.zeros((nrof_thresholds))
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, _, acc_train[threshold_idx] = calculate_accuracy(
+ threshold, dist[train_set], actual_issame[train_set])
+ best_threshold_index = np.argmax(acc_train)
+ for threshold_idx, threshold in enumerate(thresholds):
+ tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+ threshold, dist[test_set],
+ actual_issame[test_set])
+ _, _, accuracy[fold_idx] = calculate_accuracy(
+ thresholds[best_threshold_index], dist[test_set],
+ actual_issame[test_set])
+
+ tpr = np.mean(tprs, 0)
+ fpr = np.mean(fprs, 0)
+ return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ tp = np.sum(np.logical_and(predict_issame, actual_issame))
+ fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ tn = np.sum(
+ np.logical_and(np.logical_not(predict_issame),
+ np.logical_not(actual_issame)))
+ fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+ tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+ fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+ acc = float(tp + tn) / dist.size
+ return tpr, fpr, acc
+
+
+def calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ actual_issame,
+ far_target,
+ nrof_folds=10):
+ assert (embeddings1.shape[0] == embeddings2.shape[0])
+ assert (embeddings1.shape[1] == embeddings2.shape[1])
+ nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+ nrof_thresholds = len(thresholds)
+ k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+ val = np.zeros(nrof_folds)
+ far = np.zeros(nrof_folds)
+
+ diff = np.subtract(embeddings1, embeddings2)
+ dist = np.sum(np.square(diff), 1)
+ indices = np.arange(nrof_pairs)
+
+ for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+ # Find the threshold that gives FAR = far_target
+ far_train = np.zeros(nrof_thresholds)
+ for threshold_idx, threshold in enumerate(thresholds):
+ _, far_train[threshold_idx] = calculate_val_far(
+ threshold, dist[train_set], actual_issame[train_set])
+ if np.max(far_train) >= far_target:
+ f = interpolate.interp1d(far_train, thresholds, kind='slinear')
+ threshold = f(far_target)
+ else:
+ threshold = 0.0
+
+ val[fold_idx], far[fold_idx] = calculate_val_far(
+ threshold, dist[test_set], actual_issame[test_set])
+
+ val_mean = np.mean(val)
+ far_mean = np.mean(far)
+ val_std = np.std(val)
+ return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+ predict_issame = np.less(dist, threshold)
+ true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+ false_accept = np.sum(
+ np.logical_and(predict_issame, np.logical_not(actual_issame)))
+ n_same = np.sum(actual_issame)
+ n_diff = np.sum(np.logical_not(actual_issame))
+ # print(true_accept, false_accept)
+ # print(n_same, n_diff)
+ val = float(true_accept) / float(n_same)
+ far = float(false_accept) / float(n_diff)
+ return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+ # Calculate evaluation metrics
+ thresholds = np.arange(0, 4, 0.01)
+ embeddings1 = embeddings[0::2]
+ embeddings2 = embeddings[1::2]
+ tpr, fpr, accuracy = calculate_roc(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ nrof_folds=nrof_folds,
+ pca=pca)
+ thresholds = np.arange(0, 4, 0.001)
+ val, val_std, far = calculate_val(thresholds,
+ embeddings1,
+ embeddings2,
+ np.asarray(actual_issame),
+ 1e-3,
+ nrof_folds=nrof_folds)
+ return tpr, fpr, accuracy, val, val_std, far
+
+@torch.no_grad()
+def load_bin(path, image_size):
+ try:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f) # py2
+ except UnicodeDecodeError as e:
+ with open(path, 'rb') as f:
+ bins, issame_list = pickle.load(f, encoding='bytes') # py3
+ data_list = []
+ for flip in [0, 1]:
+ data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+ data_list.append(data)
+ for idx in range(len(issame_list) * 2):
+ _bin = bins[idx]
+ img = mx.image.imdecode(_bin)
+ if img.shape[1] != image_size[0]:
+ img = mx.image.resize_short(img, image_size[0])
+ img = nd.transpose(img, axes=(2, 0, 1))
+ for flip in [0, 1]:
+ if flip == 1:
+ img = mx.ndarray.flip(data=img, axis=2)
+ data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+ if idx % 1000 == 0:
+ print('loading bin', idx)
+ print(data_list[0].shape)
+ return data_list, issame_list
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+ print('testing verification..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+ _data = data[bb - batch_size: bb]
+ time0 = datetime.datetime.now()
+ img = ((_data / 255) - 0.5) / 0.5
+ net_out: torch.Tensor = backbone(img)
+ _embeddings = net_out.detach().cpu().numpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+
+ _xnorm = 0.0
+ _xnorm_cnt = 0
+ for embed in embeddings_list:
+ for i in range(embed.shape[0]):
+ _em = embed[i]
+ _norm = np.linalg.norm(_em)
+ _xnorm += _norm
+ _xnorm_cnt += 1
+ _xnorm /= _xnorm_cnt
+
+ embeddings = embeddings_list[0].copy()
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ acc1 = 0.0
+ std1 = 0.0
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ print(embeddings.shape)
+ print('infer time', time_consumed)
+ _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+ acc2, std2 = np.mean(accuracy), np.std(accuracy)
+ return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set,
+ backbone,
+ batch_size,
+ name='',
+ data_extra=None,
+ label_shape=None):
+ print('dump verification embedding..')
+ data_list = data_set[0]
+ issame_list = data_set[1]
+ embeddings_list = []
+ time_consumed = 0.0
+ for i in range(len(data_list)):
+ data = data_list[i]
+ embeddings = None
+ ba = 0
+ while ba < data.shape[0]:
+ bb = min(ba + batch_size, data.shape[0])
+ count = bb - ba
+
+ _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+ time0 = datetime.datetime.now()
+ if data_extra is None:
+ db = mx.io.DataBatch(data=(_data,), label=(_label,))
+ else:
+ db = mx.io.DataBatch(data=(_data, _data_extra),
+ label=(_label,))
+ model.forward(db, is_train=False)
+ net_out = model.get_outputs()
+ _embeddings = net_out[0].asnumpy()
+ time_now = datetime.datetime.now()
+ diff = time_now - time0
+ time_consumed += diff.total_seconds()
+ if embeddings is None:
+ embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+ embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+ ba = bb
+ embeddings_list.append(embeddings)
+ embeddings = embeddings_list[0] + embeddings_list[1]
+ embeddings = sklearn.preprocessing.normalize(embeddings)
+ actual_issame = np.asarray(issame_list)
+ outname = os.path.join('temp.bin')
+ with open(outname, 'wb') as f:
+ pickle.dump((embeddings, issame_list),
+ f,
+ protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+# parser = argparse.ArgumentParser(description='do verification')
+# # general
+# parser.add_argument('--data-dir', default='', help='')
+# parser.add_argument('--model',
+# default='../model/softmax,50',
+# help='path to load model.')
+# parser.add_argument('--target',
+# default='lfw,cfp_ff,cfp_fp,agedb_30',
+# help='test targets.')
+# parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+# parser.add_argument('--batch-size', default=32, type=int, help='')
+# parser.add_argument('--max', default='', type=str, help='')
+# parser.add_argument('--mode', default=0, type=int, help='')
+# parser.add_argument('--nfolds', default=10, type=int, help='')
+# args = parser.parse_args()
+# image_size = [112, 112]
+# print('image_size', image_size)
+# ctx = mx.gpu(args.gpu)
+# nets = []
+# vec = args.model.split(',')
+# prefix = args.model.split(',')[0]
+# epochs = []
+# if len(vec) == 1:
+# pdir = os.path.dirname(prefix)
+# for fname in os.listdir(pdir):
+# if not fname.endswith('.params'):
+# continue
+# _file = os.path.join(pdir, fname)
+# if _file.startswith(prefix):
+# epoch = int(fname.split('.')[0].split('-')[1])
+# epochs.append(epoch)
+# epochs = sorted(epochs, reverse=True)
+# if len(args.max) > 0:
+# _max = [int(x) for x in args.max.split(',')]
+# assert len(_max) == 2
+# if len(epochs) > _max[1]:
+# epochs = epochs[_max[0]:_max[1]]
+#
+# else:
+# epochs = [int(x) for x in vec[1].split('|')]
+# print('model number', len(epochs))
+# time0 = datetime.datetime.now()
+# for epoch in epochs:
+# print('loading', prefix, epoch)
+# sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+# # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+# all_layers = sym.get_internals()
+# sym = all_layers['fc1_output']
+# model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+# # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+# model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+# image_size[1]))])
+# model.set_params(arg_params, aux_params)
+# nets.append(model)
+# time_now = datetime.datetime.now()
+# diff = time_now - time0
+# print('model loading time', diff.total_seconds())
+#
+# ver_list = []
+# ver_name_list = []
+# for name in args.target.split(','):
+# path = os.path.join(args.data_dir, name + ".bin")
+# if os.path.exists(path):
+# print('loading.. ', name)
+# data_set = load_bin(path, image_size)
+# ver_list.append(data_set)
+# ver_name_list.append(name)
+#
+# if args.mode == 0:
+# for i in range(len(ver_list)):
+# results = []
+# for model in nets:
+# acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+# ver_list[i], model, args.batch_size, args.nfolds)
+# print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+# print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+# print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+# results.append(acc2)
+# print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+# elif args.mode == 1:
+# raise ValueError
+# else:
+# model = nets[0]
+# dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/eval_ijbc.py b/deep_3drecon/deep_3drecon_models/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c5a650d486d18eb02d6f60d448fc3b315261f5d
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/eval_ijbc.py
@@ -0,0 +1,483 @@
+# coding: utf-8
+
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description='do ijb test')
+# general
+parser.add_argument('--model-prefix', default='', help='path to load model.')
+parser.add_argument('--image-path', default='', type=str, help='')
+parser.add_argument('--result-dir', default='.', type=str, help='')
+parser.add_argument('--batch-size', default=128, type=int, help='')
+parser.add_argument('--network', default='iresnet50', type=str, help='')
+parser.add_argument('--job', default='insightface', type=str, help='job name')
+parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True # if Ture, TestMode(N1)
+use_detector_score = True # if Ture, TestMode(D1)
+use_flip_test = True # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+ def __init__(self, prefix, data_shape, batch_size=1):
+ image_size = (112, 112)
+ self.image_size = image_size
+ weight = torch.load(prefix)
+ resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+ resnet.load_state_dict(weight)
+ model = torch.nn.DataParallel(resnet)
+ self.model = model
+ self.model.eval()
+ src = np.array([
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]], dtype=np.float32)
+ src[:, 0] += 8.0
+ self.src = src
+ self.batch_size = batch_size
+ self.data_shape = data_shape
+
+ def get(self, rimg, landmark):
+
+ assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+ assert landmark.shape[1] == 2
+ if landmark.shape[0] == 68:
+ landmark5 = np.zeros((5, 2), dtype=np.float32)
+ landmark5[0] = (landmark[36] + landmark[39]) / 2
+ landmark5[1] = (landmark[42] + landmark[45]) / 2
+ landmark5[2] = landmark[30]
+ landmark5[3] = landmark[48]
+ landmark5[4] = landmark[54]
+ else:
+ landmark5 = landmark
+ tform = trans.SimilarityTransform()
+ tform.estimate(landmark5, self.src)
+ M = tform.params[0:2, :]
+ img = cv2.warpAffine(rimg,
+ M, (self.image_size[1], self.image_size[0]),
+ borderValue=0.0)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img_flip = np.fliplr(img)
+ img = np.transpose(img, (2, 0, 1)) # 3*112*112, RGB
+ img_flip = np.transpose(img_flip, (2, 0, 1))
+ input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+ input_blob[0] = img
+ input_blob[1] = img_flip
+ return input_blob
+
+ @torch.no_grad()
+ def forward_db(self, batch_data):
+ imgs = torch.Tensor(batch_data).cuda()
+ imgs.div_(255).sub_(0.5).div_(0.5)
+ feat = self.model(imgs)
+ feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+ return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+ twoList = [[] for i in range(n)]
+ for i, e in enumerate(listTemp):
+ twoList[i % n].append(e)
+ return twoList
+
+
+def read_template_media_list(path):
+ # ijb_meta = np.loadtxt(path, dtype=str)
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int)
+ medias = ijb_meta[:, 2].astype(np.int)
+ return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+ # pairs = np.loadtxt(path, dtype=str)
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ # print(pairs.shape)
+ # print(pairs[:, 0].astype(np.int))
+ t1 = pairs[:, 0].astype(np.int)
+ t2 = pairs[:, 1].astype(np.int)
+ label = pairs[:, 2].astype(np.int)
+ return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+ batch_size = args.batch_size
+ data_shape = (3, 112, 112)
+
+ files = files_list
+ print('files:', len(files))
+ rare_size = len(files) % batch_size
+ faceness_scores = []
+ batch = 0
+ img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+ batch_data = np.empty((2 * batch_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, batch_size)
+ for img_index, each_line in enumerate(files[:len(files) - rare_size]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+
+ batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+ batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+ if (img_index + 1) % batch_size == 0:
+ print('batch', batch)
+ img_feats[batch * batch_size:batch * batch_size +
+ batch_size][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+
+ batch_data = np.empty((2 * rare_size, 3, 112, 112))
+ embedding = Embedding(model_path, data_shape, rare_size)
+ for img_index, each_line in enumerate(files[len(files) - rare_size:]):
+ name_lmk_score = each_line.strip().split(' ')
+ img_name = os.path.join(img_path, name_lmk_score[0])
+ img = cv2.imread(img_name)
+ lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+ dtype=np.float32)
+ lmk = lmk.reshape((5, 2))
+ input_blob = embedding.get(img, lmk)
+ batch_data[2 * img_index][:] = input_blob[0]
+ batch_data[2 * img_index + 1][:] = input_blob[1]
+ if (img_index + 1) % rare_size == 0:
+ print('batch', batch)
+ img_feats[len(files) -
+ rare_size:][:] = embedding.forward_db(batch_data)
+ batch += 1
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+ # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+ return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+ # ==========================================================
+ # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+ # 2. compute media feature.
+ # 3. compute template feature.
+ # ==========================================================
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+ for count_template, uqt in enumerate(unique_templates):
+
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias,
+ return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [
+ np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
+ ]
+ media_norm_feats = np.array(media_norm_feats)
+ # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+ template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+ # print(template_norm_feats.shape)
+ return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ # ==========================================================
+ # Compute set-to-set Similarity Score.
+ # ==========================================================
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [
+ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+ ]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def read_score(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == 'IJBC' or target == 'IJBB'
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id, mid --> media id
+# format:
+# image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_face_tid_mid.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id, label : 1/0
+# format:
+# tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+# img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = '%s/loose_crop' % image_path
+img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list,
+ model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
+ img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+ # concat --- F1
+ # img_input_feats = img_feats
+ # add --- F2
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] //
+ 2] + img_feats[:, img_feats.shape[1] // 2:]
+else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+if use_norm_score:
+ img_input_feats = img_input_feats
+else:
+ # normalise features to remove norm information
+ img_input_feats = img_input_feats / np.sqrt(
+ np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+ img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+ os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+ methods.append(Path(file).stem)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
+print(tpr_fpr_table)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/flops.py b/deep_3drecon/deep_3drecon_models/arcface_torch/flops.py
new file mode 100644
index 0000000000000000000000000000000000000000..e704b7b584a27d85fa51623d70828d0d42cfa853
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/flops.py
@@ -0,0 +1,20 @@
+from ptflops import get_model_complexity_info
+from backbones import get_model
+import argparse
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='')
+ parser.add_argument('n', type=str, default="r100")
+ args = parser.parse_args()
+ net = get_model(args.n)
+ macs, params = get_model_complexity_info(
+ net, (3, 112, 112), as_strings=False,
+ print_per_layer_stat=True, verbose=True)
+ gmacs = macs / (1000**3)
+ print("%.3f GFLOPs"%gmacs)
+ print("%.3f Mparams"%(params/(1000**2)))
+
+ if hasattr(net, "extra_gflops"):
+ print("%.3f Extra-GFLOPs"%net.extra_gflops)
+ print("%.3f Total-GFLOPs"%(gmacs+net.extra_gflops))
+
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/inference.py b/deep_3drecon/deep_3drecon_models/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e5156e8d649954837e397c2ff15ec29995e7502
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/inference.py
@@ -0,0 +1,35 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+ if img is None:
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+ else:
+ img = cv2.imread(img)
+ img = cv2.resize(img, (112, 112))
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = np.transpose(img, (2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+ img.div_(255).sub_(0.5).div_(0.5)
+ net = get_model(name, fp16=False)
+ net.load_state_dict(torch.load(weight))
+ net.eval()
+ feat = net(img).numpy()
+ print(feat)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+ parser.add_argument('--network', type=str, default='r50', help='backbone network')
+ parser.add_argument('--weight', type=str, default='')
+ parser.add_argument('--img', type=str, default=None)
+ args = parser.parse_args()
+ inference(args.weight, args.network, args.img)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/losses.py b/deep_3drecon/deep_3drecon_models/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b4585fc6d3c610265b315404cea1f4543996bd
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/losses.py
@@ -0,0 +1,100 @@
+import torch
+import math
+
+
+class CombinedMarginLoss(torch.nn.Module):
+ def __init__(self,
+ s,
+ m1,
+ m2,
+ m3,
+ interclass_filtering_threshold=0):
+ super().__init__()
+ self.s = s
+ self.m1 = m1
+ self.m2 = m2
+ self.m3 = m3
+ self.interclass_filtering_threshold = interclass_filtering_threshold
+
+ # For ArcFace
+ self.cos_m = math.cos(self.m2)
+ self.sin_m = math.sin(self.m2)
+ self.theta = math.cos(math.pi - self.m2)
+ self.sinmm = math.sin(math.pi - self.m2) * self.m2
+ self.easy_margin = False
+
+
+ def forward(self, logits, labels):
+ index_positive = torch.where(labels != -1)[0]
+
+ if self.interclass_filtering_threshold > 0:
+ with torch.no_grad():
+ dirty = logits > self.interclass_filtering_threshold
+ dirty = dirty.float()
+ mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device)
+ mask.scatter_(1, labels[index_positive], 0)
+ dirty[index_positive] *= mask
+ tensor_mul = 1 - dirty
+ logits = tensor_mul * logits
+
+ target_logit = logits[index_positive, labels[index_positive].view(-1)]
+
+ if self.m1 == 1.0 and self.m3 == 0.0:
+ with torch.no_grad():
+ target_logit.arccos_()
+ logits.arccos_()
+ final_target_logit = target_logit + self.m2
+ logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
+ logits.cos_()
+ logits = logits * self.s
+
+ elif self.m3 > 0:
+ final_target_logit = target_logit - self.m3
+ logits[index_positive, labels[index_positive].view(-1)] = final_target_logit
+ logits = logits * self.s
+ else:
+ raise
+
+ return logits
+
+class ArcFace(torch.nn.Module):
+ """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+ def __init__(self, s=64.0, margin=0.5):
+ super(ArcFace, self).__init__()
+ self.scale = s
+ self.margin = margin
+ self.cos_m = math.cos(margin)
+ self.sin_m = math.sin(margin)
+ self.theta = math.cos(math.pi - margin)
+ self.sinmm = math.sin(math.pi - margin) * margin
+ self.easy_margin = False
+
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
+ index = torch.where(labels != -1)[0]
+ target_logit = logits[index, labels[index].view(-1)]
+
+ with torch.no_grad():
+ target_logit.arccos_()
+ logits.arccos_()
+ final_target_logit = target_logit + self.margin
+ logits[index, labels[index].view(-1)] = final_target_logit
+ logits.cos_()
+ logits = logits * self.s
+ return logits
+
+
+class CosFace(torch.nn.Module):
+ def __init__(self, s=64.0, m=0.40):
+ super(CosFace, self).__init__()
+ self.s = s
+ self.m = m
+
+ def forward(self, logits: torch.Tensor, labels: torch.Tensor):
+ index = torch.where(labels != -1)[0]
+ target_logit = logits[index, labels[index].view(-1)]
+ final_target_logit = target_logit - self.m
+ logits[index, labels[index].view(-1)] = final_target_logit
+ logits = logits * self.s
+ return logits
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/lr_scheduler.py b/deep_3drecon/deep_3drecon_models/arcface_torch/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a703335ca9ed08cc885ca83e2ae27a3d5aea5ca
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/lr_scheduler.py
@@ -0,0 +1,30 @@
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class PolyScheduler(_LRScheduler):
+ def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1):
+ self.base_lr = base_lr
+ self.warmup_lr_init = 0.0001
+ self.max_steps: int = max_steps
+ self.warmup_steps: int = warmup_steps
+ self.power = 2
+ super(PolyScheduler, self).__init__(optimizer, -1, False)
+ self.last_epoch = last_epoch
+
+ def get_warmup_lr(self):
+ alpha = float(self.last_epoch) / float(self.warmup_steps)
+ return [self.base_lr * alpha for _ in self.optimizer.param_groups]
+
+ def get_lr(self):
+ if self.last_epoch == -1:
+ return [self.warmup_lr_init for _ in self.optimizer.param_groups]
+ if self.last_epoch < self.warmup_steps:
+ return self.get_warmup_lr()
+ else:
+ alpha = pow(
+ 1
+ - float(self.last_epoch - self.warmup_steps)
+ / float(self.max_steps - self.warmup_steps),
+ self.power,
+ )
+ return [self.base_lr * alpha for _ in self.optimizer.param_groups]
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_helper.py b/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca922ca6d410655029e459cf8fd1c323d276c34c
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_helper.py
@@ -0,0 +1,250 @@
+from __future__ import division
+import datetime
+import os
+import os.path as osp
+import glob
+import numpy as np
+import cv2
+import sys
+import onnxruntime
+import onnx
+import argparse
+from onnx import numpy_helper
+from insightface.data import get_image
+
+class ArcFaceORT:
+ def __init__(self, model_path, cpu=False):
+ self.model_path = model_path
+ # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+ self.providers = ['CPUExecutionProvider'] if cpu else None
+
+ #input_size is (w,h), return error message, return None if success
+ def check(self, track='cfat', test_img = None):
+ #default is cfat
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=15
+ if track.startswith('ms1m'):
+ max_model_size_mb=1024
+ max_feat_dim=512
+ max_time_cost=10
+ elif track.startswith('glint'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=20
+ elif track.startswith('cfat'):
+ max_model_size_mb = 1024
+ max_feat_dim = 512
+ max_time_cost = 15
+ elif track.startswith('unconstrained'):
+ max_model_size_mb=1024
+ max_feat_dim=1024
+ max_time_cost=30
+ else:
+ return "track not found"
+
+ if not os.path.exists(self.model_path):
+ return "model_path not exists"
+ if not os.path.isdir(self.model_path):
+ return "model_path should be directory"
+ onnx_files = []
+ for _file in os.listdir(self.model_path):
+ if _file.endswith('.onnx'):
+ onnx_files.append(osp.join(self.model_path, _file))
+ if len(onnx_files)==0:
+ return "do not have onnx files"
+ self.model_file = sorted(onnx_files)[-1]
+ print('use onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('input-shape:', input_shape)
+ if len(input_shape)!=4:
+ return "length of input_shape should be 4"
+ if not isinstance(input_shape[0], str):
+ #return "input_shape[0] should be str to support batch-inference"
+ print('reset input-shape[0] to None')
+ model = onnx.load(self.model_file)
+ model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
+ onnx.save(model, new_model_file)
+ self.model_file = new_model_file
+ print('use new onnx-model:', self.model_file)
+ try:
+ session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+ except:
+ return "load onnx failed"
+ input_cfg = session.get_inputs()[0]
+ input_shape = input_cfg.shape
+ print('new-input-shape:', input_shape)
+
+ self.image_size = tuple(input_shape[2:4][::-1])
+ #print('image_size:', self.image_size)
+ input_name = input_cfg.name
+ outputs = session.get_outputs()
+ output_names = []
+ for o in outputs:
+ output_names.append(o.name)
+ #print(o.name, o.shape)
+ if len(output_names)!=1:
+ return "number of output nodes should be 1"
+ self.session = session
+ self.input_name = input_name
+ self.output_names = output_names
+ #print(self.output_names)
+ model = onnx.load(self.model_file)
+ graph = model.graph
+ if len(graph.node)<8:
+ return "too small onnx graph"
+
+ input_size = (112,112)
+ self.crop = None
+ if track=='cfat':
+ crop_file = osp.join(self.model_path, 'crop.txt')
+ if osp.exists(crop_file):
+ lines = open(crop_file,'r').readlines()
+ if len(lines)!=6:
+ return "crop.txt should contain 6 lines"
+ lines = [int(x) for x in lines]
+ self.crop = lines[:4]
+ input_size = tuple(lines[4:6])
+ if input_size!=self.image_size:
+ return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
+
+ self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
+ if self.model_size_mb > max_model_size_mb:
+ return "max model size exceed, given %.3f-MB"%self.model_size_mb
+
+ input_mean = None
+ input_std = None
+ if track=='cfat':
+ pn_file = osp.join(self.model_path, 'pixel_norm.txt')
+ if osp.exists(pn_file):
+ lines = open(pn_file,'r').readlines()
+ if len(lines)!=2:
+ return "pixel_norm.txt should contain 2 lines"
+ input_mean = float(lines[0])
+ input_std = float(lines[1])
+ if input_mean is not None or input_std is not None:
+ if input_mean is None or input_std is None:
+ return "please set input_mean and input_std simultaneously"
+ else:
+ find_sub = False
+ find_mul = False
+ for nid, node in enumerate(graph.node[:8]):
+ print(nid, node.name)
+ if node.name.startswith('Sub') or node.name.startswith('_minus'):
+ find_sub = True
+ if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
+ find_mul = True
+ if find_sub and find_mul:
+ print("find sub and mul")
+ #mxnet arcface model
+ input_mean = 0.0
+ input_std = 1.0
+ else:
+ input_mean = 127.5
+ input_std = 127.5
+ self.input_mean = input_mean
+ self.input_std = input_std
+ for initn in graph.initializer:
+ weight_array = numpy_helper.to_array(initn)
+ dt = weight_array.dtype
+ if dt.itemsize<4:
+ return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
+ if test_img is None:
+ test_img = get_image('Tom_Hanks_54745')
+ test_img = cv2.resize(test_img, self.image_size)
+ else:
+ test_img = cv2.resize(test_img, self.image_size)
+ feat, cost = self.benchmark(test_img)
+ batch_result = self.check_batch(test_img)
+ batch_result_sum = float(np.sum(batch_result))
+ if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
+ print(batch_result)
+ print(batch_result_sum)
+ return "batch result output contains NaN!"
+
+ if len(feat.shape) < 2:
+ return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+ if feat.shape[1] > max_feat_dim:
+ return "max feat dim exceed, given %d"%feat.shape[1]
+ self.feat_dim = feat.shape[1]
+ cost_ms = cost*1000
+ if cost_ms>max_time_cost:
+ return "max time cost exceed, given %.4f"%cost_ms
+ self.cost_ms = cost_ms
+ print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
+ return None
+
+ def check_batch(self, img):
+ if not isinstance(img, list):
+ imgs = [img, ] * 32
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
+ if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+ nimg = cv2.resize(nimg, self.image_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(
+ images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
+ mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+ return net_out
+
+
+ def meta_info(self):
+ return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
+
+
+ def forward(self, imgs):
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ input_size = self.image_size
+ if self.crop is not None:
+ nimgs = []
+ for img in imgs:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ nimgs.append(nimg)
+ imgs = nimgs
+ blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ return net_out
+
+ def benchmark(self, img):
+ input_size = self.image_size
+ if self.crop is not None:
+ nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+ if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+ nimg = cv2.resize(nimg, input_size)
+ img = nimg
+ blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+ costs = []
+ for _ in range(50):
+ ta = datetime.datetime.now()
+ net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+ tb = datetime.datetime.now()
+ cost = (tb-ta).total_seconds()
+ costs.append(cost)
+ costs = sorted(costs)
+ cost = costs[5]
+ return net_out, cost
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='')
+ # general
+ parser.add_argument('workdir', help='submitted work dir', type=str)
+ parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
+ args = parser.parse_args()
+ handler = ArcFaceORT(args.workdir)
+ err = handler.check(args.track)
+ print('err:', err)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_ijbc.py b/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9f880b9d2168ca95897fa87e5e04e8b4b5c96d8
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,269 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+import torch
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+from torch.utils.data import DataLoader
+from onnx_helper import ArcFaceORT
+
+SRC = np.array(
+ [
+ [30.2946, 51.6963],
+ [65.5318, 51.5014],
+ [48.0252, 71.7366],
+ [33.5493, 92.3655],
+ [62.7299, 92.2041]]
+ , dtype=np.float32)
+SRC[:, 0] += 8.0
+
+
+@torch.no_grad()
+class AlignedDataSet(mx.gluon.data.Dataset):
+ def __init__(self, root, lines, align=True):
+ self.lines = lines
+ self.root = root
+ self.align = align
+
+ def __len__(self):
+ return len(self.lines)
+
+ def __getitem__(self, idx):
+ each_line = self.lines[idx]
+ name_lmk_score = each_line.strip().split(' ')
+ name = os.path.join(self.root, name_lmk_score[0])
+ img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+ landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+ st = skimage.transform.SimilarityTransform()
+ st.estimate(landmark5, SRC)
+ img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+ img_1 = np.expand_dims(img, 0)
+ img_2 = np.expand_dims(np.fliplr(img), 0)
+ output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+ output = np.transpose(output, (0, 3, 1, 2))
+ return torch.from_numpy(output)
+
+
+@torch.no_grad()
+def extract(model_root, dataset):
+ model = ArcFaceORT(model_path=model_root)
+ model.check()
+ feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+ def collate_fn(data):
+ return torch.cat(data, dim=0)
+
+ data_loader = DataLoader(
+ dataset, batch_size=128, drop_last=False, num_workers=4, collate_fn=collate_fn, )
+ num_iter = 0
+ for batch in data_loader:
+ batch = batch.numpy()
+ batch = (batch - model.input_mean) / model.input_std
+ feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+ feat = np.reshape(feat, (-1, model.feat_dim * 2))
+ feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
+ num_iter += 1
+ if num_iter % 50 == 0:
+ print(num_iter)
+ return feat_mat
+
+
+def read_template_media_list(path):
+ ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+ templates = ijb_meta[:, 1].astype(np.int_)
+ medias = ijb_meta[:, 2].astype(np.int_)
+ return templates, medias
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int_)
+ t2 = pairs[:, 1].astype(np.int_)
+ label = pairs[:, 2].astype(np.int_)
+ return t1, t2, label
+
+
+def read_image_feature(path):
+ with open(path, 'rb') as fid:
+ img_feats = pickle.load(fid)
+ return img_feats
+
+
+def image2template_feature(img_feats=None,
+ templates=None,
+ medias=None):
+ unique_templates = np.unique(templates)
+ template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+ for count_template, uqt in enumerate(unique_templates):
+ (ind_t,) = np.where(templates == uqt)
+ face_norm_feats = img_feats[ind_t]
+ face_medias = medias[ind_t]
+ unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+ media_norm_feats = []
+ for u, ct in zip(unique_medias, unique_media_counts):
+ (ind_m,) = np.where(face_medias == u)
+ if ct == 1:
+ media_norm_feats += [face_norm_feats[ind_m]]
+ else: # image features from the same video will be aggregated into one feature
+ media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
+ media_norm_feats = np.array(media_norm_feats)
+ template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+ if count_template % 2000 == 0:
+ print('Finish Calculating {} template features.'.format(
+ count_template))
+ template_norm_feats = normalize(template_feats)
+ return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),))
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000
+ sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def verification2(template_norm_feats=None,
+ unique_templates=None,
+ p1=None,
+ p2=None):
+ template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+ for count_template, uqt in enumerate(unique_templates):
+ template2id[uqt] = count_template
+ score = np.zeros((len(p1),)) # save cosine distance between pairs
+ total_pairs = np.array(range(len(p1)))
+ batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation
+ sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
+ total_sublists = len(sublists)
+ for c, s in enumerate(sublists):
+ feat1 = template_norm_feats[template2id[p1[s]]]
+ feat2 = template_norm_feats[template2id[p2[s]]]
+ similarity_score = np.sum(feat1 * feat2, -1)
+ score[s] = similarity_score.flatten()
+ if c % 10 == 0:
+ print('Finish {}/{} pairs.'.format(c, total_sublists))
+ return score
+
+
+def main(args):
+ use_norm_score = True # if Ture, TestMode(N1)
+ use_detector_score = True # if Ture, TestMode(D1)
+ use_flip_test = True # if Ture, TestMode(F1)
+ assert args.target == 'IJBC' or args.target == 'IJBB'
+
+ start = timeit.default_timer()
+ templates, medias = read_template_media_list(
+ os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % args.image_path,
+ '%s_template_pair_label.txt' % args.target.lower()))
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ img_path = '%s/loose_crop' % args.image_path
+ img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
+ img_list = open(img_list_path)
+ files = img_list.readlines()
+ dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+ img_feats = extract(args.model_root, dataset)
+
+ faceness_scores = []
+ for each_line in files:
+ name_lmk_score = each_line.split()
+ faceness_scores.append(name_lmk_score[-1])
+ faceness_scores = np.array(faceness_scores).astype(np.float32)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
+ start = timeit.default_timer()
+
+ if use_flip_test:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
+ else:
+ img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+ if use_norm_score:
+ img_input_feats = img_input_feats
+ else:
+ img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+ if use_detector_score:
+ print(img_input_feats.shape, faceness_scores.shape)
+ img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+ else:
+ img_input_feats = img_input_feats
+
+ template_norm_feats, unique_templates = image2template_feature(
+ img_input_feats, templates, medias)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+
+ start = timeit.default_timer()
+ score = verification(template_norm_feats, unique_templates, p1, p2)
+ stop = timeit.default_timer()
+ print('Time: %.2f s. ' % (stop - start))
+ result_dir = args.model_root
+
+ save_path = os.path.join(result_dir, "{}_result".format(args.target))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ score_save_file = os.path.join(save_path, "{}.npy".format(args.target))
+ np.save(score_save_file, score)
+ files = [score_save_file]
+ methods = []
+ scores = []
+ for file in files:
+ methods.append(os.path.basename(file))
+ scores.append(np.load(file))
+ methods = np.array(methods)
+ scores = dict(zip(methods, scores))
+ x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+ tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
+ for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr)
+ tpr_fpr_row = []
+ tpr_fpr_row.append("%s-%s" % (method, args.target))
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+ print(tpr_fpr_table)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='do ijb test')
+ # general
+ parser.add_argument('--model-root', default='', help='path to load model.')
+ parser.add_argument('--image-path', default='/train_tmp/IJB_release/IJBC', type=str, help='')
+ parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+ main(parser.parse_args())
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc.py b/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeff29d8ef8e256473cad01c1d69c0b3fc8a9d49
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc.py
@@ -0,0 +1,531 @@
+import collections
+from typing import Callable
+
+import torch
+from torch import distributed
+from torch.nn.functional import linear, normalize
+
+
+class PartialFC(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2203.15565
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
+
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
+ negative class centers are selected to compute the margin-based softmax loss, all class
+ centers are still maintained throughout the whole training process, but only a subset is
+ selected and updated in each iteration.
+
+ .. note::
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
+
+ Example:
+ --------
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
+ >>> for img, labels in data_loader:
+ >>> embeddings = net(img)
+ >>> loss = module_pfc(embeddings, labels, optimizer)
+ >>> loss.backward()
+ >>> optimizer.step()
+ """
+ _version = 1
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFC, self).__init__()
+ assert (
+ distributed.is_initialized()
+ ), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(
+ self.rank < num_classes % self.world_size
+ )
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+ self.weight: torch.Tensor
+ self.weight_mom: torch.Tensor
+ self.weight_activated: torch.nn.Parameter
+ self.weight_activated_mom: torch.Tensor
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+
+ if self.sample_rate < 1:
+ self.register_buffer("weight",
+ tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.register_buffer("weight_mom",
+ tensor=torch.zeros_like(self.weight))
+ self.register_parameter("weight_activated",
+ param=torch.nn.Parameter(torch.empty(0, 0)))
+ self.register_buffer("weight_activated_mom",
+ tensor=torch.empty(0, 0))
+ self.register_buffer("weight_index",
+ tensor=torch.empty(0, 0))
+ else:
+ self.weight_activated = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+
+ # margin_loss
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ @torch.no_grad()
+ def sample(self,
+ labels: torch.Tensor,
+ index_positive: torch.Tensor,
+ optimizer: torch.optim.Optimizer):
+ """
+ This functions will change the value of labels
+
+ Parameters:
+ -----------
+ labels: torch.Tensor
+ pass
+ index_positive: torch.Tensor
+ pass
+ optimizer: torch.optim.Optimizer
+ pass
+ """
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+
+ self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
+ self.weight_activated_mom = self.weight_mom[self.weight_index]
+
+ if isinstance(optimizer, torch.optim.SGD):
+ # TODO the params of partial fc must be last in the params list
+ optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
+ optimizer.param_groups[-1]["params"][0] = self.weight_activated
+ optimizer.state[self.weight_activated][
+ "momentum_buffer"
+ ] = self.weight_activated_mom
+ else:
+ raise
+
+ @torch.no_grad()
+ def update(self):
+ """ partial weight to global
+ """
+ if self.init_weight_update:
+ self.init_weight_update = False
+ return
+
+ if self.sample_rate < 1:
+ self.weight[self.weight_index] = self.weight_activated
+ self.weight_mom[self.weight_index] = self.weight_activated_mom
+
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ optimizer: torch.optim.Optimizer,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+ self.update()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, (
+ "last batch size do not equal current batch size: {} vs {}".format(
+ self.last_batch_size, batch_size))
+
+ _gather_embeddings = [
+ torch.zeros((batch_size, self.embedding_size)).cuda()
+ for _ in range(self.world_size)
+ ]
+ _gather_labels = [
+ torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
+ ]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (
+ labels < self.class_start + self.num_local
+ )
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ self.sample(labels, index_positive, optimizer)
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(self.weight_activated)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = collections.OrderedDict()
+ destination._metadata = collections.OrderedDict()
+
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
+ if self.sample_rate < 1:
+ destination["weight"] = self.weight.detach()
+ else:
+ destination["weight"] = self.weight_activated.data.detach()
+ return destination
+
+ def load_state_dict(self, state_dict, strict: bool = True):
+ if self.sample_rate < 1:
+ self.weight = state_dict["weight"].to(self.weight.device)
+ self.weight_mom.zero_()
+ self.weight_activated.data.zero_()
+ self.weight_activated_mom.zero_()
+ self.weight_index.zero_()
+ else:
+ self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
+
+
+class PartialFCAdamW(torch.nn.Module):
+ def __init__(self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFCAdamW, self).__init__()
+ assert (
+ distributed.is_initialized()
+ ), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(
+ self.rank < num_classes % self.world_size
+ )
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+ self.weight: torch.Tensor
+ self.weight_exp_avg: torch.Tensor
+ self.weight_exp_avg_sq: torch.Tensor
+ self.weight_activated: torch.nn.Parameter
+ self.weight_activated_exp_avg: torch.Tensor
+ self.weight_activated_exp_avg_sq: torch.Tensor
+
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+
+ if self.sample_rate < 1:
+ self.register_buffer("weight",
+ tensor=torch.normal(0, 0.01, (self.num_local, embedding_size)))
+ self.register_buffer("weight_exp_avg",
+ tensor=torch.zeros_like(self.weight))
+ self.register_buffer("weight_exp_avg_sq",
+ tensor=torch.zeros_like(self.weight))
+ self.register_parameter("weight_activated",
+ param=torch.nn.Parameter(torch.empty(0, 0)))
+ self.register_buffer("weight_activated_exp_avg",
+ tensor=torch.empty(0, 0))
+ self.register_buffer("weight_activated_exp_avg_sq",
+ tensor=torch.empty(0, 0))
+ else:
+ self.weight_activated = torch.nn.Parameter(
+ torch.normal(0, 0.01, (self.num_local, embedding_size))
+ )
+ self.step = 0
+
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ @torch.no_grad()
+ def sample(self, labels, index_positive, optimizer):
+ self.step += 1
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+ self.weight_activated = torch.nn.Parameter(self.weight[self.weight_index])
+ self.weight_activated_exp_avg = self.weight_exp_avg[self.weight_index]
+ self.weight_activated_exp_avg_sq = self.weight_exp_avg_sq[self.weight_index]
+
+ if isinstance(optimizer, (torch.optim.Adam, torch.optim.AdamW)):
+ # TODO the params of partial fc must be last in the params list
+ optimizer.state.pop(optimizer.param_groups[-1]["params"][0], None)
+ optimizer.param_groups[-1]["params"][0] = self.weight_activated
+ optimizer.state[self.weight_activated]["exp_avg"] = self.weight_activated_exp_avg
+ optimizer.state[self.weight_activated]["exp_avg_sq"] = self.weight_activated_exp_avg_sq
+ optimizer.state[self.weight_activated]["step"] = self.step
+ else:
+ raise
+
+ @torch.no_grad()
+ def update(self):
+ """ partial weight to global
+ """
+ if self.init_weight_update:
+ self.init_weight_update = False
+ return
+
+ if self.sample_rate < 1:
+ self.weight[self.weight_index] = self.weight_activated
+ self.weight_exp_avg[self.weight_index] = self.weight_activated_exp_avg
+ self.weight_exp_avg_sq[self.weight_index] = self.weight_activated_exp_avg_sq
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ optimizer: torch.optim.Optimizer,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+ self.update()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, (
+ "last batch size do not equal current batch size: {} vs {}".format(
+ self.last_batch_size, batch_size))
+
+ _gather_embeddings = [
+ torch.zeros((batch_size, self.embedding_size)).cuda()
+ for _ in range(self.world_size)
+ ]
+ _gather_labels = [
+ torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
+ ]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (
+ labels < self.class_start + self.num_local
+ )
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ self.sample(labels, index_positive, optimizer)
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(self.weight_activated)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ if destination is None:
+ destination = collections.OrderedDict()
+ destination._metadata = collections.OrderedDict()
+
+ for name, module in self._modules.items():
+ if module is not None:
+ module.state_dict(destination, prefix + name + ".", keep_vars=keep_vars)
+ if self.sample_rate < 1:
+ destination["weight"] = self.weight.detach()
+ else:
+ destination["weight"] = self.weight_activated.data.detach()
+ return destination
+
+ def load_state_dict(self, state_dict, strict: bool = True):
+ if self.sample_rate < 1:
+ self.weight = state_dict["weight"].to(self.weight.device)
+ self.weight_exp_avg.zero_()
+ self.weight_exp_avg_sq.zero_()
+ self.weight_activated.data.zero_()
+ self.weight_activated_exp_avg.zero_()
+ self.weight_activated_exp_avg_sq.zero_()
+ else:
+ self.weight_activated.data = state_dict["weight"].to(self.weight_activated.data.device)
+
+
+class DistCrossEntropyFunc(torch.autograd.Function):
+ """
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+
+ @staticmethod
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
+ """ """
+ batch_size = logits.size(0)
+ # for numerical stability
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
+ logits.sub_(max_logits)
+ logits.exp_()
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
+ logits.div_(sum_logits_exp)
+ index = torch.where(label != -1)[0]
+ # loss
+ loss = torch.zeros(batch_size, 1, device=logits.device)
+ loss[index] = logits[index].gather(1, label[index])
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
+ ctx.save_for_backward(index, logits, label)
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ @staticmethod
+ def backward(ctx, loss_gradient):
+ """
+ Args:
+ loss_grad (torch.Tensor): gradient backward by last layer
+ Returns:
+ gradients for each input in forward function
+ `None` gradients for one-hot label
+ """
+ (
+ index,
+ logits,
+ label,
+ ) = ctx.saved_tensors
+ batch_size = logits.size(0)
+ one_hot = torch.zeros(
+ size=[index.size(0), logits.size(1)], device=logits.device
+ )
+ one_hot.scatter_(1, label[index], 1)
+ logits[index] -= one_hot
+ logits.div_(batch_size)
+ return logits * loss_gradient.item(), None
+
+
+class DistCrossEntropy(torch.nn.Module):
+ def __init__(self):
+ super(DistCrossEntropy, self).__init__()
+
+ def forward(self, logit_part, label_part):
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
+
+
+class AllGatherFunc(torch.autograd.Function):
+ """AllGather op with gradient backward"""
+
+ @staticmethod
+ def forward(ctx, tensor, *gather_list):
+ gather_list = list(gather_list)
+ distributed.all_gather(gather_list, tensor)
+ return tuple(gather_list)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ grad_list = list(grads)
+ rank = distributed.get_rank()
+ grad_out = grad_list[rank]
+
+ dist_ops = [
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
+ if i == rank
+ else distributed.reduce(
+ grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
+ )
+ for i in range(distributed.get_world_size())
+ ]
+ for _op in dist_ops:
+ _op.wait()
+
+ grad_out *= len(grad_list) # cooperate with distributed loss function
+ return (grad_out, *[None for _ in range(len(grad_list))])
+
+
+AllGather = AllGatherFunc.apply
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc_v2.py b/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0752554ca1a99c35347dce6cccd121b5cd69f9c6
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/partial_fc_v2.py
@@ -0,0 +1,260 @@
+
+import math
+from typing import Callable
+
+import torch
+from torch import distributed
+from torch.nn.functional import linear, normalize
+
+
+class PartialFC_V2(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2203.15565
+ A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
+ When sample rate less than 1, in each iteration, positive class centers and a random subset of
+ negative class centers are selected to compute the margin-based softmax loss, all class
+ centers are still maintained throughout the whole training process, but only a subset is
+ selected and updated in each iteration.
+ .. note::
+ When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
+ Example:
+ --------
+ >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
+ >>> for img, labels in data_loader:
+ >>> embeddings = net(img)
+ >>> loss = module_pfc(embeddings, labels)
+ >>> loss.backward()
+ >>> optimizer.step()
+ """
+ _version = 2
+
+ def __init__(
+ self,
+ margin_loss: Callable,
+ embedding_size: int,
+ num_classes: int,
+ sample_rate: float = 1.0,
+ fp16: bool = False,
+ ):
+ """
+ Paramenters:
+ -----------
+ embedding_size: int
+ The dimension of embedding, required
+ num_classes: int
+ Total number of classes, required
+ sample_rate: float
+ The rate of negative centers participating in the calculation, default is 1.0.
+ """
+ super(PartialFC_V2, self).__init__()
+ assert (
+ distributed.is_initialized()
+ ), "must initialize distributed before create this"
+ self.rank = distributed.get_rank()
+ self.world_size = distributed.get_world_size()
+
+ self.dist_cross_entropy = DistCrossEntropy()
+ self.embedding_size = embedding_size
+ self.sample_rate: float = sample_rate
+ self.fp16 = fp16
+ self.num_local: int = num_classes // self.world_size + int(
+ self.rank < num_classes % self.world_size
+ )
+ self.class_start: int = num_classes // self.world_size * self.rank + min(
+ self.rank, num_classes % self.world_size
+ )
+ self.num_sample: int = int(self.sample_rate * self.num_local)
+ self.last_batch_size: int = 0
+
+ self.is_updated: bool = True
+ self.init_weight_update: bool = True
+ self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
+
+ # margin_loss
+ if isinstance(margin_loss, Callable):
+ self.margin_softmax = margin_loss
+ else:
+ raise
+
+ def sample(self, labels, index_positive):
+ """
+ This functions will change the value of labels
+ Parameters:
+ -----------
+ labels: torch.Tensor
+ pass
+ index_positive: torch.Tensor
+ pass
+ optimizer: torch.optim.Optimizer
+ pass
+ """
+ with torch.no_grad():
+ positive = torch.unique(labels[index_positive], sorted=True).cuda()
+ if self.num_sample - positive.size(0) >= 0:
+ perm = torch.rand(size=[self.num_local]).cuda()
+ perm[positive] = 2.0
+ index = torch.topk(perm, k=self.num_sample)[1].cuda()
+ index = index.sort()[0].cuda()
+ else:
+ index = positive
+ self.weight_index = index
+
+ labels[index_positive] = torch.searchsorted(index, labels[index_positive])
+
+ return self.weight[self.weight_index]
+
+ def forward(
+ self,
+ local_embeddings: torch.Tensor,
+ local_labels: torch.Tensor,
+ ):
+ """
+ Parameters:
+ ----------
+ local_embeddings: torch.Tensor
+ feature embeddings on each GPU(Rank).
+ local_labels: torch.Tensor
+ labels on each GPU(Rank).
+ Returns:
+ -------
+ loss: torch.Tensor
+ pass
+ """
+ local_labels.squeeze_()
+ local_labels = local_labels.long()
+
+ batch_size = local_embeddings.size(0)
+ if self.last_batch_size == 0:
+ self.last_batch_size = batch_size
+ assert self.last_batch_size == batch_size, (
+ f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")
+
+ _gather_embeddings = [
+ torch.zeros((batch_size, self.embedding_size)).cuda()
+ for _ in range(self.world_size)
+ ]
+ _gather_labels = [
+ torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
+ ]
+ _list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
+ distributed.all_gather(_gather_labels, local_labels)
+
+ embeddings = torch.cat(_list_embeddings)
+ labels = torch.cat(_gather_labels)
+
+ labels = labels.view(-1, 1)
+ index_positive = (self.class_start <= labels) & (
+ labels < self.class_start + self.num_local
+ )
+ labels[~index_positive] = -1
+ labels[index_positive] -= self.class_start
+
+ if self.sample_rate < 1:
+ weight = self.sample(labels, index_positive)
+ else:
+ weight = self.weight
+
+ with torch.cuda.amp.autocast(self.fp16):
+ norm_embeddings = normalize(embeddings)
+ norm_weight_activated = normalize(weight)
+ logits = linear(norm_embeddings, norm_weight_activated)
+ if self.fp16:
+ logits = logits.float()
+ logits = logits.clamp(-1, 1)
+
+ logits = self.margin_softmax(logits, labels)
+ loss = self.dist_cross_entropy(logits, labels)
+ return loss
+
+
+class DistCrossEntropyFunc(torch.autograd.Function):
+ """
+ CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
+ Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
+ """
+
+ @staticmethod
+ def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
+ """ """
+ batch_size = logits.size(0)
+ # for numerical stability
+ max_logits, _ = torch.max(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
+ logits.sub_(max_logits)
+ logits.exp_()
+ sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
+ # local to global
+ distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
+ logits.div_(sum_logits_exp)
+ index = torch.where(label != -1)[0]
+ # loss
+ loss = torch.zeros(batch_size, 1, device=logits.device)
+ loss[index] = logits[index].gather(1, label[index])
+ distributed.all_reduce(loss, distributed.ReduceOp.SUM)
+ ctx.save_for_backward(index, logits, label)
+ return loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+ @staticmethod
+ def backward(ctx, loss_gradient):
+ """
+ Args:
+ loss_grad (torch.Tensor): gradient backward by last layer
+ Returns:
+ gradients for each input in forward function
+ `None` gradients for one-hot label
+ """
+ (
+ index,
+ logits,
+ label,
+ ) = ctx.saved_tensors
+ batch_size = logits.size(0)
+ one_hot = torch.zeros(
+ size=[index.size(0), logits.size(1)], device=logits.device
+ )
+ one_hot.scatter_(1, label[index], 1)
+ logits[index] -= one_hot
+ logits.div_(batch_size)
+ return logits * loss_gradient.item(), None
+
+
+class DistCrossEntropy(torch.nn.Module):
+ def __init__(self):
+ super(DistCrossEntropy, self).__init__()
+
+ def forward(self, logit_part, label_part):
+ return DistCrossEntropyFunc.apply(logit_part, label_part)
+
+
+class AllGatherFunc(torch.autograd.Function):
+ """AllGather op with gradient backward"""
+
+ @staticmethod
+ def forward(ctx, tensor, *gather_list):
+ gather_list = list(gather_list)
+ distributed.all_gather(gather_list, tensor)
+ return tuple(gather_list)
+
+ @staticmethod
+ def backward(ctx, *grads):
+ grad_list = list(grads)
+ rank = distributed.get_rank()
+ grad_out = grad_list[rank]
+
+ dist_ops = [
+ distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
+ if i == rank
+ else distributed.reduce(
+ grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
+ )
+ for i in range(distributed.get_world_size())
+ ]
+ for _op in dist_ops:
+ _op.wait()
+
+ grad_out *= len(grad_list) # cooperate with distributed loss function
+ return (grad_out, *[None for _ in range(len(grad_list))])
+
+
+AllGather = AllGatherFunc.apply
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/requirement.txt b/deep_3drecon/deep_3drecon_models/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f1a431ef9c39b258b676411f1081ed9006a8b817
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/requirement.txt
@@ -0,0 +1,6 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
+opencv-python
\ No newline at end of file
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/run.sh b/deep_3drecon/deep_3drecon_models/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6eacdf8e814d7bd68650c7eda8f72687ee74db16
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/run.sh
@@ -0,0 +1 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 train_v2.py $@
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/scripts/shuffle_rec.py b/deep_3drecon/deep_3drecon_models/arcface_torch/scripts/shuffle_rec.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3b68e938f17aaf98fae269c44119eaef299b1a2
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/scripts/shuffle_rec.py
@@ -0,0 +1,81 @@
+import argparse
+import multiprocessing
+import os
+import time
+
+import mxnet as mx
+import numpy as np
+
+
+def read_worker(args, q_in):
+ path_imgidx = os.path.join(args.input, "train.idx")
+ path_imgrec = os.path.join(args.input, "train.rec")
+ imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
+
+ s = imgrec.read_idx(0)
+ header, _ = mx.recordio.unpack(s)
+ assert header.flag > 0
+
+ imgidx = np.array(range(1, int(header.label[0])))
+ np.random.shuffle(imgidx)
+
+ for idx in imgidx:
+ item = imgrec.read_idx(idx)
+ q_in.put(item)
+
+ q_in.put(None)
+ imgrec.close()
+
+
+def write_worker(args, q_out):
+ pre_time = time.time()
+
+ if args.input[-1] == '/':
+ args.input = args.input[:-1]
+ dirname = os.path.dirname(args.input)
+ basename = os.path.basename(args.input)
+ output = os.path.join(dirname, f"shuffled_{basename}")
+ os.makedirs(output, exist_ok=True)
+
+ path_imgidx = os.path.join(output, "train.idx")
+ path_imgrec = os.path.join(output, "train.rec")
+ save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w")
+ more = True
+ count = 0
+ while more:
+ deq = q_out.get()
+ if deq is None:
+ more = False
+ else:
+ header, jpeg = mx.recordio.unpack(deq)
+ # TODO it is currently not fully developed
+ if isinstance(header.label, float):
+ label = header.label
+ else:
+ label = header.label[0]
+
+ header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2)
+ save_record.write_idx(count, mx.recordio.pack(header, jpeg))
+ count += 1
+ if count % 10000 == 0:
+ cur_time = time.time()
+ print('save time:', cur_time - pre_time, ' count:', count)
+ pre_time = cur_time
+ print(count)
+ save_record.close()
+
+
+def main(args):
+ queue = multiprocessing.Queue(10240)
+ read_process = multiprocessing.Process(target=read_worker, args=(args, queue))
+ read_process.daemon = True
+ read_process.start()
+ write_process = multiprocessing.Process(target=write_worker, args=(args, queue))
+ write_process.start()
+ write_process.join()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('input', help='path to source rec.')
+ main(parser.parse_args())
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/torch2onnx.py b/deep_3drecon/deep_3drecon_models/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6055d1fe7d20cbf02812d95b509511c943766de
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/torch2onnx.py
@@ -0,0 +1,53 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+ assert isinstance(net, torch.nn.Module)
+ img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+ img = img.astype(np.float)
+ img = (img / 255. - 0.5) / 0.5 # torch style norm
+ img = img.transpose((2, 0, 1))
+ img = torch.from_numpy(img).unsqueeze(0).float()
+
+ weight = torch.load(path_module)
+ net.load_state_dict(weight, strict=True)
+ net.eval()
+ torch.onnx.export(net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
+ model = onnx.load(output)
+ graph = model.graph
+ graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+ if simplify:
+ from onnxsim import simplify
+ model, check = simplify(model)
+ assert check, "Simplified ONNX model could not be validated"
+ onnx.save(model, output)
+
+
+if __name__ == '__main__':
+ import os
+ import argparse
+ from backbones import get_model
+
+ parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
+ parser.add_argument('input', type=str, help='input backbone.pth file or path')
+ parser.add_argument('--output', type=str, default=None, help='output onnx path')
+ parser.add_argument('--network', type=str, default=None, help='backbone network')
+ parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
+ args = parser.parse_args()
+ input_file = args.input
+ if os.path.isdir(input_file):
+ input_file = os.path.join(input_file, "model.pt")
+ assert os.path.exists(input_file)
+ # model_name = os.path.basename(os.path.dirname(input_file)).lower()
+ # params = model_name.split("_")
+ # if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+ # if args.network is None:
+ # args.network = params[2]
+ assert args.network is not None
+ print(args)
+ backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512)
+ if args.output is None:
+ args.output = os.path.join(os.path.dirname(args.input), "model.onnx")
+ convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/train.py b/deep_3drecon/deep_3drecon_models/arcface_torch/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b49e7104bfed3984825116016aa8050c406cd0
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/train.py
@@ -0,0 +1,260 @@
+import argparse
+import logging
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+from backbones import get_model
+from dataset import get_dataloader
+from losses import CombinedMarginLoss
+from lr_scheduler import PolyScheduler
+from partial_fc import PartialFC, PartialFCAdamW
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_callbacks import CallBackLogging, CallBackVerification
+from utils.utils_config import get_config
+from utils.utils_distributed_sampler import setup_seed
+from utils.utils_logging import AverageMeter, init_logging
+
+assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
+we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
+
+try:
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ distributed.init_process_group("nccl")
+except KeyError:
+ rank = 0
+ local_rank = 0
+ world_size = 1
+ distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:12584",
+ rank=rank,
+ world_size=world_size,
+ )
+
+
+def main(args):
+
+ # get config
+ cfg = get_config(args.config)
+ # global control random seed
+ setup_seed(seed=cfg.seed, cuda_deterministic=False)
+
+ torch.cuda.set_device(local_rank)
+
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ summary_writer = (
+ SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
+ if rank == 0
+ else None
+ )
+
+ wandb_logger = None
+ if cfg.using_wandb:
+ import wandb
+ # Sign in to wandb
+ try:
+ wandb.login(key=cfg.wandb_key)
+ except Exception as e:
+ print("WandB Key must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+ # Initialize wandb
+ run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
+ run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
+ try:
+ wandb_logger = wandb.init(
+ entity = cfg.wandb_entity,
+ project = cfg.wandb_project,
+ sync_tensorboard = True,
+ resume=cfg.wandb_resume,
+ name = run_name,
+ notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None
+ if wandb_logger:
+ wandb_logger.config.update(cfg)
+ except Exception as e:
+ print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+
+ train_loader = get_dataloader(
+ cfg.rec,
+ local_rank,
+ cfg.batch_size,
+ cfg.dali,
+ cfg.seed,
+ cfg.num_workers
+ )
+
+ backbone = get_model(
+ cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
+ find_unused_parameters=True)
+
+ backbone.train()
+ # FIXME using gradient checkpoint if there are some unused parameters will cause error
+ backbone._set_static_graph()
+
+ margin_loss = CombinedMarginLoss(
+ 64,
+ cfg.margin_list[0],
+ cfg.margin_list[1],
+ cfg.margin_list[2],
+ cfg.interclass_filtering_threshold
+ )
+
+ if cfg.optimizer == "sgd":
+ module_partial_fc = PartialFC(
+ margin_loss, cfg.embedding_size, cfg.num_classes,
+ cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ # TODO the params of partial fc must be last in the params list
+ opt = torch.optim.SGD(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
+
+ elif cfg.optimizer == "adamw":
+ module_partial_fc = PartialFCAdamW(
+ margin_loss, cfg.embedding_size, cfg.num_classes,
+ cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ opt = torch.optim.AdamW(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr, weight_decay=cfg.weight_decay)
+ else:
+ raise
+
+ cfg.total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
+ cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
+
+ lr_scheduler = PolyScheduler(
+ optimizer=opt,
+ base_lr=cfg.lr,
+ max_steps=cfg.total_step,
+ warmup_steps=cfg.warmup_step,
+ last_epoch=-1
+ )
+
+ start_epoch = 0
+ global_step = 0
+ if cfg.resume:
+ dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+ start_epoch = dict_checkpoint["epoch"]
+ global_step = dict_checkpoint["global_step"]
+ backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
+ module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
+ opt.load_state_dict(dict_checkpoint["state_optimizer"])
+ lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
+ del dict_checkpoint
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ callback_verification = CallBackVerification(
+ val_targets=cfg.val_targets, rec_prefix=cfg.rec,
+ summary_writer=summary_writer, wandb_logger = wandb_logger
+ )
+ callback_logging = CallBackLogging(
+ frequent=cfg.frequent,
+ total_step=cfg.total_step,
+ batch_size=cfg.batch_size,
+ start_step = global_step,
+ writer=summary_writer
+ )
+
+ loss_am = AverageMeter()
+ amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
+
+ for epoch in range(start_epoch, cfg.num_epoch):
+
+ if isinstance(train_loader, DataLoader):
+ train_loader.sampler.set_epoch(epoch)
+ for _, (img, local_labels) in enumerate(train_loader):
+ global_step += 1
+ local_embeddings = backbone(img)
+ loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
+
+ if cfg.fp16:
+ amp.scale(loss).backward()
+ amp.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ amp.step(opt)
+ amp.update()
+ else:
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ opt.step()
+
+ opt.zero_grad()
+ lr_scheduler.step()
+
+ with torch.no_grad():
+ if wandb_logger:
+ wandb_logger.log({
+ 'Loss/Step Loss': loss.item(),
+ 'Loss/Train Loss': loss_am.avg,
+ 'Process/Step': global_step,
+ 'Process/Epoch': epoch
+ })
+
+ loss_am.update(loss.item(), 1)
+ callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
+
+ if global_step % cfg.verbose == 0 and global_step > 0:
+ callback_verification(global_step, backbone)
+
+ if cfg.save_all_states:
+ checkpoint = {
+ "epoch": epoch + 1,
+ "global_step": global_step,
+ "state_dict_backbone": backbone.module.state_dict(),
+ "state_dict_softmax_fc": module_partial_fc.state_dict(),
+ "state_optimizer": opt.state_dict(),
+ "state_lr_scheduler": lr_scheduler.state_dict()
+ }
+ torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_E{epoch}"
+ model = wandb.Artifact(artifact_name, type='model')
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ if cfg.dali:
+ train_loader.reset()
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ from torch2onnx import convert_onnx
+ convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_Final"
+ model = wandb.Artifact(artifact_name, type='model')
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ distributed.destroy_process_group()
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(
+ description="Distributed Arcface Training in Pytorch")
+ parser.add_argument("config", type=str, help="py config file")
+ main(parser.parse_args())
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/train_v2.py b/deep_3drecon/deep_3drecon_models/arcface_torch/train_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d53e80178a30a0e0f36b079b8b1743280480d4e
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/train_v2.py
@@ -0,0 +1,258 @@
+import argparse
+import logging
+import os
+from datetime import datetime
+
+import numpy as np
+import torch
+from backbones import get_model
+from dataset import get_dataloader
+from losses import CombinedMarginLoss
+from lr_scheduler import PolyScheduler
+from partial_fc_v2 import PartialFC_V2
+from torch import distributed
+from torch.utils.data import DataLoader
+from torch.utils.tensorboard import SummaryWriter
+from utils.utils_callbacks import CallBackLogging, CallBackVerification
+from utils.utils_config import get_config
+from utils.utils_distributed_sampler import setup_seed
+from utils.utils_logging import AverageMeter, init_logging
+
+assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
+we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
+
+try:
+ rank = int(os.environ["RANK"])
+ local_rank = int(os.environ["LOCAL_RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+ distributed.init_process_group("nccl")
+except KeyError:
+ rank = 0
+ local_rank = 0
+ world_size = 1
+ distributed.init_process_group(
+ backend="nccl",
+ init_method="tcp://127.0.0.1:12584",
+ rank=rank,
+ world_size=world_size,
+ )
+
+
+def main(args):
+
+ # get config
+ cfg = get_config(args.config)
+ # global control random seed
+ setup_seed(seed=cfg.seed, cuda_deterministic=False)
+
+ torch.cuda.set_device(local_rank)
+
+ os.makedirs(cfg.output, exist_ok=True)
+ init_logging(rank, cfg.output)
+
+ summary_writer = (
+ SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
+ if rank == 0
+ else None
+ )
+
+ wandb_logger = None
+ if cfg.using_wandb:
+ import wandb
+ # Sign in to wandb
+ try:
+ wandb.login(key=cfg.wandb_key)
+ except Exception as e:
+ print("WandB Key must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+ # Initialize wandb
+ run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
+ run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
+ try:
+ wandb_logger = wandb.init(
+ entity = cfg.wandb_entity,
+ project = cfg.wandb_project,
+ sync_tensorboard = True,
+ resume=cfg.wandb_resume,
+ name = run_name,
+ notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None
+ if wandb_logger:
+ wandb_logger.config.update(cfg)
+ except Exception as e:
+ print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
+ print(f"Config Error: {e}")
+
+ train_loader = get_dataloader(
+ cfg.rec,
+ local_rank,
+ cfg.batch_size,
+ cfg.dali,
+ cfg.seed,
+ cfg.num_workers
+ )
+
+ backbone = get_model(
+ cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
+
+ backbone = torch.nn.parallel.DistributedDataParallel(
+ module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
+ find_unused_parameters=True)
+
+ backbone.train()
+ # FIXME using gradient checkpoint if there are some unused parameters will cause error
+ backbone._set_static_graph()
+
+ margin_loss = CombinedMarginLoss(
+ 64,
+ cfg.margin_list[0],
+ cfg.margin_list[1],
+ cfg.margin_list[2],
+ cfg.interclass_filtering_threshold
+ )
+
+ if cfg.optimizer == "sgd":
+ module_partial_fc = PartialFC_V2(
+ margin_loss, cfg.embedding_size, cfg.num_classes,
+ cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ # TODO the params of partial fc must be last in the params list
+ opt = torch.optim.SGD(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
+
+ elif cfg.optimizer == "adamw":
+ module_partial_fc = PartialFC_V2(
+ margin_loss, cfg.embedding_size, cfg.num_classes,
+ cfg.sample_rate, cfg.fp16)
+ module_partial_fc.train().cuda()
+ opt = torch.optim.AdamW(
+ params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
+ lr=cfg.lr, weight_decay=cfg.weight_decay)
+ else:
+ raise
+
+ cfg.total_batch_size = cfg.batch_size * world_size
+ cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
+ cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
+
+ lr_scheduler = PolyScheduler(
+ optimizer=opt,
+ base_lr=cfg.lr,
+ max_steps=cfg.total_step,
+ warmup_steps=cfg.warmup_step,
+ last_epoch=-1
+ )
+
+ start_epoch = 0
+ global_step = 0
+ if cfg.resume:
+ dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+ start_epoch = dict_checkpoint["epoch"]
+ global_step = dict_checkpoint["global_step"]
+ backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
+ module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
+ opt.load_state_dict(dict_checkpoint["state_optimizer"])
+ lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
+ del dict_checkpoint
+
+ for key, value in cfg.items():
+ num_space = 25 - len(key)
+ logging.info(": " + key + " " * num_space + str(value))
+
+ callback_verification = CallBackVerification(
+ val_targets=cfg.val_targets, rec_prefix=cfg.rec,
+ summary_writer=summary_writer, wandb_logger = wandb_logger
+ )
+ callback_logging = CallBackLogging(
+ frequent=cfg.frequent,
+ total_step=cfg.total_step,
+ batch_size=cfg.batch_size,
+ start_step = global_step,
+ writer=summary_writer
+ )
+
+ loss_am = AverageMeter()
+ amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
+
+ for epoch in range(start_epoch, cfg.num_epoch):
+
+ if isinstance(train_loader, DataLoader):
+ train_loader.sampler.set_epoch(epoch)
+ for _, (img, local_labels) in enumerate(train_loader):
+ global_step += 1
+ local_embeddings = backbone(img)
+ loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels)
+
+ if cfg.fp16:
+ amp.scale(loss).backward()
+ if global_step % cfg.gradient_acc == 0:
+ amp.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ amp.step(opt)
+ amp.update()
+ opt.zero_grad()
+ else:
+ loss.backward()
+ if global_step % cfg.gradient_acc == 0:
+ torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
+ opt.step()
+ opt.zero_grad()
+ lr_scheduler.step()
+
+ with torch.no_grad():
+ if wandb_logger:
+ wandb_logger.log({
+ 'Loss/Step Loss': loss.item(),
+ 'Loss/Train Loss': loss_am.avg,
+ 'Process/Step': global_step,
+ 'Process/Epoch': epoch
+ })
+
+ loss_am.update(loss.item(), 1)
+ callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
+
+ if global_step % cfg.verbose == 0 and global_step > 0:
+ callback_verification(global_step, backbone)
+
+ if cfg.save_all_states:
+ checkpoint = {
+ "epoch": epoch + 1,
+ "global_step": global_step,
+ "state_dict_backbone": backbone.module.state_dict(),
+ "state_dict_softmax_fc": module_partial_fc.state_dict(),
+ "state_optimizer": opt.state_dict(),
+ "state_lr_scheduler": lr_scheduler.state_dict()
+ }
+ torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_E{epoch}"
+ model = wandb.Artifact(artifact_name, type='model')
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+ if cfg.dali:
+ train_loader.reset()
+
+ if rank == 0:
+ path_module = os.path.join(cfg.output, "model.pt")
+ torch.save(backbone.module.state_dict(), path_module)
+
+ if wandb_logger and cfg.save_artifacts:
+ artifact_name = f"{run_name}_Final"
+ model = wandb.Artifact(artifact_name, type='model')
+ model.add_file(path_module)
+ wandb_logger.log_artifact(model)
+
+
+
+if __name__ == "__main__":
+ torch.backends.cudnn.benchmark = True
+ parser = argparse.ArgumentParser(
+ description="Distributed Arcface Training in Pytorch")
+ parser.add_argument("config", type=str, help="py config file")
+ main(parser.parse_args())
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/utils/__init__.py b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/utils/plot.py b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..33e38ebc2bd500304b87931d811a42f72403aac2
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/plot.py
@@ -0,0 +1,71 @@
+import os
+import sys
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import roc_curve, auc
+
+with open(sys.argv[1], "r") as f:
+ files = f.readlines()
+
+files = [x.strip() for x in files]
+image_path = "/train_tmp/IJB_release/IJBC"
+
+
+def read_template_pair_list(path):
+ pairs = pd.read_csv(path, sep=' ', header=None).values
+ t1 = pairs[:, 0].astype(np.int_)
+ t2 = pairs[:, 1].astype(np.int_)
+ label = pairs[:, 2].astype(np.int_)
+ return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(
+ os.path.join('%s/meta' % image_path,
+ '%s_template_pair_label.txt' % 'ijbc'))
+
+methods = []
+scores = []
+for file in files:
+ methods.append(file)
+ scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+ zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+ fpr, tpr, _ = roc_curve(label, scores[method])
+ roc_auc = auc(fpr, tpr)
+ fpr = np.flipud(fpr)
+ tpr = np.flipud(tpr) # select largest tpr at same fpr
+ plt.plot(fpr,
+ tpr,
+ color=colours[method],
+ lw=1,
+ label=('[%s (AUC = %0.4f %%)]' %
+ (method.split('-')[-1], roc_auc * 100)))
+ tpr_fpr_row = []
+ tpr_fpr_row.append(method)
+ for fpr_iter in np.arange(len(x_labels)):
+ _, min_index = min(
+ list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+ tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+ tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_callbacks.py b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9368073f8bc091b28e7325a9099881dfc5f54cd
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,125 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+from eval import verification
+from utils.utils_logging import AverageMeter
+from torch.utils.tensorboard import SummaryWriter
+from torch import distributed
+
+
+class CallBackVerification(object):
+
+ def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112), wandb_logger=None):
+ self.rank: int = distributed.get_rank()
+ self.highest_acc: float = 0.0
+ self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+ self.ver_list: List[object] = []
+ self.ver_name_list: List[str] = []
+ if self.rank is 0:
+ self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+ self.summary_writer = summary_writer
+ self.wandb_logger = wandb_logger
+
+ def ver_test(self, backbone: torch.nn.Module, global_step: int):
+ results = []
+ for i in range(len(self.ver_list)):
+ acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
+ self.ver_list[i], backbone, 10, 10)
+ logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
+ logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
+
+ self.summary_writer: SummaryWriter
+ self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, )
+ if self.wandb_logger:
+ import wandb
+ self.wandb_logger.log({
+ f'Acc/val-Acc1 {self.ver_name_list[i]}': acc1,
+ f'Acc/val-Acc2 {self.ver_name_list[i]}': acc2,
+ # f'Acc/val-std1 {self.ver_name_list[i]}': std1,
+ # f'Acc/val-std2 {self.ver_name_list[i]}': acc2,
+ })
+
+ if acc2 > self.highest_acc_list[i]:
+ self.highest_acc_list[i] = acc2
+ logging.info(
+ '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
+ results.append(acc2)
+
+ def init_dataset(self, val_targets, data_dir, image_size):
+ for name in val_targets:
+ path = os.path.join(data_dir, name + ".bin")
+ if os.path.exists(path):
+ data_set = verification.load_bin(path, image_size)
+ self.ver_list.append(data_set)
+ self.ver_name_list.append(name)
+
+ def __call__(self, num_update, backbone: torch.nn.Module):
+ if self.rank is 0 and num_update > 0:
+ backbone.eval()
+ self.ver_test(backbone, num_update)
+ backbone.train()
+
+
+class CallBackLogging(object):
+ def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None):
+ self.frequent: int = frequent
+ self.rank: int = distributed.get_rank()
+ self.world_size: int = distributed.get_world_size()
+ self.time_start = time.time()
+ self.total_step: int = total_step
+ self.start_step: int = start_step
+ self.batch_size: int = batch_size
+ self.writer = writer
+
+ self.init = False
+ self.tic = 0
+
+ def __call__(self,
+ global_step: int,
+ loss: AverageMeter,
+ epoch: int,
+ fp16: bool,
+ learning_rate: float,
+ grad_scaler: torch.cuda.amp.GradScaler):
+ if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+ if self.init:
+ try:
+ speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+ speed_total = speed * self.world_size
+ except ZeroDivisionError:
+ speed_total = float('inf')
+
+ #time_now = (time.time() - self.time_start) / 3600
+ #time_total = time_now / ((global_step + 1) / self.total_step)
+ #time_for_end = time_total - time_now
+ time_now = time.time()
+ time_sec = int(time_now - self.time_start)
+ time_sec_avg = time_sec / (global_step - self.start_step + 1)
+ eta_sec = time_sec_avg * (self.total_step - global_step - 1)
+ time_for_end = eta_sec/3600
+ if self.writer is not None:
+ self.writer.add_scalar('time_for_end', time_for_end, global_step)
+ self.writer.add_scalar('learning_rate', learning_rate, global_step)
+ self.writer.add_scalar('loss', loss.avg, global_step)
+ if fp16:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \
+ "Fp16 Grad Scale: %2.f Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step,
+ grad_scaler.get_scale(), time_for_end
+ )
+ else:
+ msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \
+ "Required: %1.f hours" % (
+ speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
+ )
+ logging.info(msg)
+ loss.reset()
+ self.tic = time.time()
+ else:
+ self.init = True
+ self.tic = time.time()
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_config.py b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c02eaf70fc0140aca7925f621c29a496f491cae
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+ assert config_file.startswith('configs/'), 'config file setting must start with configs/'
+ temp_config_name = osp.basename(config_file)
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ config = importlib.import_module("configs.base")
+ cfg = config.config
+ config = importlib.import_module("configs.%s" % temp_module_name)
+ job_cfg = config.config
+ cfg.update(job_cfg)
+ if cfg.output is None:
+ cfg.output = osp.join('work_dirs', temp_module_name)
+ return cfg
\ No newline at end of file
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_distributed_sampler.py b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..cea6703965bff81f8b789ffd933f9b2f889cb680
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_distributed_sampler.py
@@ -0,0 +1,126 @@
+import math
+import os
+import random
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+
+def setup_seed(seed, cuda_deterministic=True):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ if cuda_deterministic: # slower, more reproducible
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ else: # faster, less reproducible
+ torch.backends.cudnn.deterministic = False
+ torch.backends.cudnn.benchmark = True
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+ torch.manual_seed(worker_seed)
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+
+ return rank, world_size
+
+
+def sync_random_seed(seed=None, device="cuda"):
+ """Make sure different ranks share the same seed.
+ All workers must call this function, otherwise it will deadlock.
+ This method is generally used in `DistributedSampler`,
+ because the seed should be identical across all processes
+ in the distributed group.
+ In distributed sampling, different ranks should sample non-overlapped
+ data in the dataset. Therefore, this function is used to make sure that
+ each rank shuffles the data indices in the same order based
+ on the same seed. Then different ranks could use different indices
+ to select non-overlapped data from the same data list.
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is None:
+ seed = np.random.randint(2**31)
+ assert isinstance(seed, int)
+
+ rank, world_size = get_dist_info()
+
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+
+ dist.broadcast(random_num, src=0)
+
+ return random_num.item()
+
+
+class DistributedSampler(_DistributedSampler):
+ def __init__(
+ self,
+ dataset,
+ num_replicas=None, # world_size
+ rank=None, # local_rank
+ shuffle=True,
+ seed=0,
+ ):
+
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+
+ # In distributed sampling, different ranks should sample
+ # non-overlapped data in the dataset. Therefore, this function
+ # is used to make sure that each rank shuffles the data indices
+ # in the same order based on the same seed. Then different ranks
+ # could use different indices to select non-overlapped data from the
+ # same data list.
+ self.seed = sync_random_seed(seed)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ # When :attr:`shuffle=True`, this ensures all replicas
+ # use a different random ordering for each epoch.
+ # Otherwise, the next iteration of this sampler will
+ # yield the same ordering.
+ g.manual_seed(self.epoch + self.seed)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ # in case that indices is shorter than half of total_size
+ indices = (indices * math.ceil(self.total_size / len(indices)))[
+ : self.total_size
+ ]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
diff --git a/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_logging.py b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..c787b6aae7cd037a4718df44d672b8ffa9e5c249
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,41 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value
+ """
+
+ def __init__(self):
+ self.val = None
+ self.avg = None
+ self.sum = None
+ self.count = None
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+ if rank == 0:
+ log_root = logging.getLogger()
+ log_root.setLevel(logging.INFO)
+ formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+ handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+ handler_stream = logging.StreamHandler(sys.stdout)
+ handler_file.setFormatter(formatter)
+ handler_stream.setFormatter(formatter)
+ log_root.addHandler(handler_file)
+ log_root.addHandler(handler_stream)
+ log_root.info('rank_id: %d' % rank)
diff --git a/deep_3drecon/deep_3drecon_models/base_model.py b/deep_3drecon/deep_3drecon_models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a05d3a000379d28b176f9d052ed7ff15cf5ba1e
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/base_model.py
@@ -0,0 +1,316 @@
+"""This script defines the base network model for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+ """This class is an abstract base class (ABC) for models.
+ To create a subclass, you need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate losses, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+ """
+
+ def __init__(self, opt):
+ """Initialize the BaseModel class.
+
+ Parameters:
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+ When creating your custom class, you need to implement your own initialization.
+ In this fucntion, you should first call
+ Then, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): specify the images that you want to display and save.
+ -- self.visual_names (str list): define networks used in our training.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ """
+ self.opt = opt
+ self.isTrain = opt.isTrain
+ self.device = torch.device('cpu')
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
+ self.loss_names = []
+ self.model_names = []
+ self.visual_names = []
+ self.parallel_names = []
+ self.optimizers = []
+ self.image_paths = []
+ self.metric = 0 # used for learning rate policy 'plateau'
+
+ @staticmethod
+ def dict_grad_hook_factory(add_func=lambda x: x):
+ saved_dict = dict()
+
+ def hook_gen(name):
+ def grad_hook(grad):
+ saved_vals = add_func(grad)
+ saved_dict[name] = saved_vals
+ return grad_hook
+ return hook_gen, saved_dict
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train):
+ """Add new model-specific options, and rewrite default values for existing options.
+
+ Parameters:
+ parser -- original option parser
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ return parser
+
+ @abstractmethod
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input (dict): includes the data itself and its metadata information.
+ """
+ pass
+
+ @abstractmethod
+ def forward(self):
+ """Run forward pass; called by both functions and ."""
+ pass
+
+ @abstractmethod
+ def optimize_parameters(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ pass
+
+ def setup(self, opt):
+ """Load and print networks; create schedulers
+
+ Parameters:
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ """
+ if self.isTrain:
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+
+ if not self.isTrain or opt.continue_train:
+ load_suffix = opt.epoch
+ self.load_networks(load_suffix)
+
+
+ # self.print_networks(opt.verbose)
+
+ def parallelize(self, convert_sync_batchnorm=True):
+ if not self.opt.use_ddp:
+ for name in self.parallel_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+ else:
+ for name in self.model_names:
+ if isinstance(name, str):
+ module = getattr(self, name)
+ if convert_sync_batchnorm:
+ module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
+ setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
+ device_ids=[self.device.index],
+ find_unused_parameters=True, broadcast_buffers=True))
+
+ # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
+ for name in self.parallel_names:
+ if isinstance(name, str) and name not in self.model_names:
+ module = getattr(self, name)
+ setattr(self, name, module.to(self.device))
+
+ # put state_dict of optimizer to gpu device
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ for optim in self.optimizers:
+ for state in optim.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(self.device)
+
+ def data_dependent_initialize(self, data):
+ pass
+
+ def train(self):
+ """Make models train mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.train()
+
+ def eval(self):
+ """Make models eval mode"""
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ net.eval()
+
+ def test(self):
+ """Forward function used in test time.
+
+ This function wraps function in no_grad() so we don't save intermediate steps for backprop
+ It also calls to produce additional visualization results
+ """
+ with torch.no_grad():
+ self.forward()
+ self.compute_visuals()
+
+ def compute_visuals(self):
+ """Calculate additional output images for visdom and HTML visualization"""
+ pass
+
+ def get_image_paths(self, name='A'):
+ """ Return image paths that are used to load current data"""
+ return self.image_paths if name =='A' else self.image_paths_B
+
+ def update_learning_rate(self):
+ """Update learning rates for all the networks; called at the end of every epoch"""
+ for scheduler in self.schedulers:
+ if self.opt.lr_policy == 'plateau':
+ scheduler.step(self.metric)
+ else:
+ scheduler.step()
+
+ lr = self.optimizers[0].param_groups[0]['lr']
+ print('learning rate = %.7f' % lr)
+
+ def get_current_visuals(self):
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+ visual_ret = OrderedDict()
+ for name in self.visual_names:
+ if isinstance(name, str):
+ visual_ret[name] = getattr(self, name)[:, :3, ...]
+ return visual_ret
+
+ def get_current_losses(self):
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+ errors_ret = OrderedDict()
+ for name in self.loss_names:
+ if isinstance(name, str):
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
+ return errors_ret
+
+ def save_networks(self, epoch):
+ """Save all the networks to the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if not os.path.isdir(self.save_dir):
+ os.makedirs(self.save_dir)
+
+ save_filename = 'epoch_%s.pth' % (epoch)
+ save_path = os.path.join(self.save_dir, save_filename)
+
+ save_dict = {}
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel) or isinstance(net,
+ torch.nn.parallel.DistributedDataParallel):
+ net = net.module
+ save_dict[name] = net.state_dict()
+
+
+ for i, optim in enumerate(self.optimizers):
+ save_dict['opt_%02d'%i] = optim.state_dict()
+
+ for i, sched in enumerate(self.schedulers):
+ save_dict['sched_%02d'%i] = sched.state_dict()
+
+ torch.save(save_dict, save_path)
+
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+ key = keys[i]
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'running_mean' or key == 'running_var'):
+ if getattr(module, key) is None:
+ state_dict.pop('.'.join(keys))
+ if module.__class__.__name__.startswith('InstanceNorm') and \
+ (key == 'num_batches_tracked'):
+ state_dict.pop('.'.join(keys))
+ else:
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+ def load_networks(self, epoch):
+ """Load all the networks from the disk.
+
+ Parameters:
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+ """
+ if self.opt.isTrain and self.opt.pretrained_name is not None:
+ load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+ else:
+ load_dir = self.save_dir
+ load_filename = 'epoch_%s.pth' % (epoch)
+ load_path = os.path.join(load_dir, load_filename)
+ state_dict = torch.load(load_path, map_location=self.device)
+ print('loading the model from %s' % load_path)
+
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ if isinstance(net, torch.nn.DataParallel):
+ net = net.module
+ net.load_state_dict(state_dict[name])
+
+ if self.opt.phase != 'test':
+ if self.opt.continue_train:
+ print('loading the optim from %s' % load_path)
+ for i, optim in enumerate(self.optimizers):
+ optim.load_state_dict(state_dict['opt_%02d'%i])
+
+ try:
+ print('loading the sched from %s' % load_path)
+ for i, sched in enumerate(self.schedulers):
+ sched.load_state_dict(state_dict['sched_%02d'%i])
+ except:
+ print('Failed to load schedulers, set schedulers according to epoch count manually')
+ for i, sched in enumerate(self.schedulers):
+ sched.last_epoch = self.opt.epoch_count - 1
+
+
+
+
+ def print_networks(self, verbose):
+ """Print the total number of parameters in the network and (if verbose) network architecture
+
+ Parameters:
+ verbose (bool) -- if verbose: print the network architecture
+ """
+ print('---------- Networks initialized -------------')
+ for name in self.model_names:
+ if isinstance(name, str):
+ net = getattr(self, name)
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ if verbose:
+ print(net)
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+ print('-----------------------------------------------')
+
+ def set_requires_grad(self, nets, requires_grad=False):
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+ Parameters:
+ nets (network list) -- a list of networks
+ requires_grad (bool) -- whether the networks require gradients or not
+ """
+ if not isinstance(nets, list):
+ nets = [nets]
+ for net in nets:
+ if net is not None:
+ for param in net.parameters():
+ param.requires_grad = requires_grad
+
+ def generate_visuals_for_evaluation(self, data, mode):
+ return {}
diff --git a/deep_3drecon/deep_3drecon_models/bfm.py b/deep_3drecon/deep_3drecon_models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..968fd8a9562e293716db33ae27371a08274bd142
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/bfm.py
@@ -0,0 +1,429 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+from deep_3drecon.util.load_mats import transferBFM09
+import os
+# from utils.commons.tensor_utils import convert_like
+
+
+def perspective_projection(focal, center):
+ # return p.T (N, 3) @ (3, 3)
+ return np.array([
+ focal, 0, center,
+ 0, focal, center,
+ 0, 0, 1
+ ]).reshape([3, 3]).astype(np.float32).transpose() # 注意这里的transpose!
+
+class SH:
+ def __init__(self):
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
+
+
+
+class ParametricFaceModel:
+ def __init__(self,
+ bfm_folder='./BFM',
+ recenter=True,
+ camera_distance=10.,
+ init_lit=np.array([
+ 0.8, 0, 0, 0, 0, 0, 0, 0, 0
+ ]),
+ focal=1015.,
+ center=112.,
+ is_train=True,
+ default_name='BFM_model_front.mat',
+ keypoint_mode='mediapipe'):
+
+ if not os.path.isfile(os.path.join(bfm_folder, default_name)):
+ transferBFM09(bfm_folder)
+ model = loadmat(os.path.join(bfm_folder, default_name))
+ # mean face shape. [3*N,1]
+ self.mean_shape = model['meanshape'].astype(np.float32)
+ # identity basis. [3*N,80]
+ self.id_base = model['idBase'].astype(np.float32)
+ # expression basis. [3*N,64]
+ self.exp_base = model['exBase'].astype(np.float32)
+ # mean face texture. [3*N,1] (0-255)
+ self.mean_tex = model['meantex'].astype(np.float32)
+ # texture basis. [3*N,80]
+ self.tex_base = model['texBase'].astype(np.float32)
+ # face indices for each vertex that lies in. starts from 0. [N,8]
+ self.point_buf = model['point_buf'].astype(np.int64) - 1
+ # vertex indices for each face. starts from 0. [F,3]
+ self.face_buf = model['tri'].astype(np.int64) - 1
+ # vertex indices for 68 landmarks. starts from 0. [68,1]
+ if keypoint_mode == 'mediapipe':
+ self.keypoints = np.load("deep_3drecon/BFM/index_mp468_from_mesh35709.npy").astype(np.int64)
+ unmatch_mask = self.keypoints < 0
+ self.keypoints[unmatch_mask] = 0
+ else:
+ self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
+
+ if is_train:
+ # vertex indices for small face region to compute photometric error. starts from 0.
+ self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
+ # vertex indices for each face from small face region. starts from 0. [f,3]
+ self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
+ # vertex indices for pre-defined skin region to compute reflectance loss
+ self.skin_mask = np.squeeze(model['skinmask'])
+
+ if recenter:
+ mean_shape = self.mean_shape.reshape([-1, 3])
+ mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+ self.mean_shape = mean_shape.reshape([-1, 1])
+
+ self.key_mean_shape = self.mean_shape.reshape([-1, 3])[self.keypoints, :].reshape([-1, 3])
+ self.key_id_base = self.id_base.reshape([-1, 3,80])[self.keypoints, :].reshape([-1, 80])
+ self.key_exp_base = self.exp_base.reshape([-1, 3, 64])[self.keypoints, :].reshape([-1, 64])
+
+ self.focal = focal
+ self.center = center
+ self.persc_proj = perspective_projection(focal, center)
+ self.device = 'cpu'
+ self.camera_distance = camera_distance
+ self.SH = SH()
+ self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+
+ self.initialized = False
+
+ def to(self, device):
+ self.device = device
+ for key, value in self.__dict__.items():
+ if type(value).__module__ == np.__name__:
+ setattr(self, key, torch.tensor(value).to(device))
+ self.initialized = True
+ return self
+
+ def compute_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+ def compute_key_shape(self, id_coeff, exp_coeff):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
+ """
+ batch_size = id_coeff.shape[0]
+ id_part = torch.einsum('ij,aj->ai', self.key_id_base, id_coeff)
+ exp_part = torch.einsum('ij,aj->ai', self.key_exp_base, exp_coeff)
+ face_shape = id_part + exp_part + self.key_mean_shape.reshape([1, -1])
+ return face_shape.reshape([batch_size, -1, 3])
+
+ def compute_texture(self, tex_coeff, normalize=True):
+ """
+ Return:
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+ Parameters:
+ tex_coeff -- torch.tensor, size (B, 80)
+ """
+ batch_size = tex_coeff.shape[0]
+ face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
+ if normalize:
+ face_texture = face_texture / 255.
+ return face_texture.reshape([batch_size, -1, 3])
+
+
+ def compute_norm(self, face_shape):
+ """
+ Return:
+ vertex_norm -- torch.tensor, size (B, N, 3)
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+
+ v1 = face_shape[:, self.face_buf[:, 0]]
+ v2 = face_shape[:, self.face_buf[:, 1]]
+ v3 = face_shape[:, self.face_buf[:, 2]]
+ e1 = v1 - v2
+ e2 = v2 - v3
+ face_norm = torch.cross(e1, e2, dim=-1)
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+ return vertex_norm
+
+
+ def compute_color(self, face_texture, face_norm, gamma):
+ """
+ Return:
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+ Parameters:
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
+ gamma -- torch.tensor, size (B, 27), SH coeffs
+ """
+ batch_size = gamma.shape[0]
+ v_num = face_texture.shape[1]
+ a, c = self.SH.a, self.SH.c
+ gamma = gamma.reshape([batch_size, 3, 9])
+ gamma = gamma + self.init_lit
+ gamma = gamma.permute(0, 2, 1)
+ Y = torch.cat([
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+ -a[1] * c[1] * face_norm[..., 1:2],
+ a[1] * c[1] * face_norm[..., 2:],
+ -a[1] * c[1] * face_norm[..., :1],
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
+ ], dim=-1)
+ r = Y @ gamma[..., :1]
+ g = Y @ gamma[..., 1:2]
+ b = Y @ gamma[..., 2:]
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
+ return face_color
+
+ @staticmethod
+ def compute_rotation(angles, device='cpu'):
+ """
+ Return:
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+ Parameters:
+ angles -- torch.tensor, size (B, 3), radian
+ """
+
+ batch_size = angles.shape[0]
+ angles = angles.to(device)
+ ones = torch.ones([batch_size, 1]).to(device)
+ zeros = torch.zeros([batch_size, 1]).to(device)
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
+
+ rot_x = torch.cat([
+ ones, zeros, zeros,
+ zeros, torch.cos(x), -torch.sin(x),
+ zeros, torch.sin(x), torch.cos(x)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_y = torch.cat([
+ torch.cos(y), zeros, torch.sin(y),
+ zeros, ones, zeros,
+ -torch.sin(y), zeros, torch.cos(y)
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot_z = torch.cat([
+ torch.cos(z), -torch.sin(z), zeros,
+ torch.sin(z), torch.cos(z), zeros,
+ zeros, zeros, ones
+ ], dim=1).reshape([batch_size, 3, 3])
+
+ rot = rot_z @ rot_y @ rot_x
+ return rot.permute(0, 2, 1)
+
+
+ def to_camera(self, face_shape):
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1] # reverse the depth axis, add a fixed offset of length
+ return face_shape
+
+ def to_image(self, face_shape):
+ """
+ Return:
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+ Parameters:
+ face_shape -- torch.tensor, size (B, N, 3)
+ """
+ # to image_plane
+ face_proj = face_shape @ self.persc_proj
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+ return face_proj
+
+
+ def transform(self, face_shape, rot, trans):
+ """
+ Return:
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+ Parameters:
+ face_shape -- torch.tensor, si≥ze (B, N, 3)
+ rot -- torch.tensor, size (B, 3, 3)
+ trans -- torch.tensor, size (B, 3)
+ """
+ return face_shape @ rot + trans.unsqueeze(1)
+
+
+ def get_landmarks(self, face_proj):
+ """
+ Return:
+ face_lms -- torch.tensor, size (B, 68, 2)
+
+ Parameters:
+ face_proj -- torch.tensor, size (B, N, 2)
+ """
+ return face_proj[:, self.keypoints]
+
+ def split_coeff(self, coeffs):
+ """
+ Return:
+ coeffs_dict -- a dict of torch.tensors
+
+ Parameters:
+ coeffs -- torch.tensor, size (B, 256)
+ """
+ id_coeffs = coeffs[:, :80]
+ exp_coeffs = coeffs[:, 80: 144]
+ tex_coeffs = coeffs[:, 144: 224]
+ angles = coeffs[:, 224: 227]
+ gammas = coeffs[:, 227: 254]
+ translations = coeffs[:, 254:]
+ return {
+ 'id': id_coeffs,
+ 'exp': exp_coeffs,
+ 'tex': tex_coeffs,
+ 'angle': angles,
+ 'gamma': gammas,
+ 'trans': translations
+ }
+ def compute_for_render(self, coeffs):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ coef_dict = self.split_coeff(coeffs)
+ face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+ rotation = self.compute_rotation(coef_dict['angle'], device=self.device)
+
+
+ face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(coef_dict['tex'])
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+ return face_vertex, face_texture, face_color, landmark
+
+ def compute_face_vertex(self, id, exp, angle, trans):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ if not self.initialized:
+ self.to(id.device)
+ face_shape = self.compute_shape(id, exp)
+ rotation = self.compute_rotation(angle, device=self.device)
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = self.to_camera(face_shape_transformed)
+ return face_vertex
+
+ def compute_for_landmark_fit(self, id, exp, angles, trans, ret=None):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ face_shape = self.compute_key_shape(id, exp)
+ rotation = self.compute_rotation(angles, device=self.device)
+
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = face_proj
+ return landmark
+
+ def compute_for_landmark_fit_nerf(self, id, exp, angles, trans, ret=None):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ face_shape = self.compute_key_shape(id, exp)
+ rotation = self.compute_rotation(angles, device=self.device)
+
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = face_shape_transformed # no to_camera
+
+ face_proj = self.to_image(face_vertex)
+ landmark = face_proj
+ return landmark
+
+ # def compute_for_landmark_fit(self, id, exp, angles, trans, ret={}):
+ # """
+ # Return:
+ # face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ # face_color -- torch.tensor, size (B, N, 3), in RGB order
+ # landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ # Parameters:
+ # coeffs -- torch.tensor, size (B, 257)
+ # """
+ # face_shape = self.compute_shape(id, exp)
+ # rotation = self.compute_rotation(angles)
+
+ # face_shape_transformed = self.transform(face_shape, rotation, trans)
+ # face_vertex = self.to_camera(face_shape_transformed)
+
+ # face_proj = self.to_image(face_vertex)
+ # landmark = self.get_landmarks(face_proj)
+ # return landmark
+
+ def compute_for_render_fit(self, id, exp, angles, trans, tex, gamma):
+ """
+ Return:
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+ Parameters:
+ coeffs -- torch.tensor, size (B, 257)
+ """
+ face_shape = self.compute_shape(id, exp)
+ rotation = self.compute_rotation(angles, device=self.device)
+
+ face_shape_transformed = self.transform(face_shape, rotation, trans)
+ face_vertex = self.to_camera(face_shape_transformed)
+
+ face_proj = self.to_image(face_vertex)
+ landmark = self.get_landmarks(face_proj)
+
+ face_texture = self.compute_texture(tex)
+ face_norm = self.compute_norm(face_shape)
+ face_norm_roted = face_norm @ rotation
+ face_color = self.compute_color(face_texture, face_norm_roted, gamma)
+
+ return face_color, face_vertex, landmark
\ No newline at end of file
diff --git a/deep_3drecon/deep_3drecon_models/facerecon_model.py b/deep_3drecon/deep_3drecon_models/facerecon_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5659b24ebd9ee78045deadacb9af9b350da5f1e
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/facerecon_model.py
@@ -0,0 +1,228 @@
+"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+from .bfm import ParametricFaceModel
+from .losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
+from deep_3drecon.util import util
+from deep_3drecon.util.mesh_renderer import MeshRenderer
+from deep_3drecon.util.preprocess import estimate_norm_torch
+
+import trimesh
+from scipy.io import savemat
+
+class FaceReconModel(BaseModel):
+
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """ Configures options specific for CUT model
+ """
+ # net structure and parameters
+ parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
+ parser.add_argument('--init_path', type=str, default='checkpoints/init_model/resnet50-0676ba61.pth')
+ parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
+ parser.add_argument('--bfm_folder', type=str, default='./deep_3drecon/BFM')
+ parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+ # renderer parameters
+ parser.add_argument('--focal', type=float, default=1015.)
+ parser.add_argument('--center', type=float, default=112.)
+ parser.add_argument('--camera_d', type=float, default=10.)
+ parser.add_argument('--z_near', type=float, default=5.)
+ parser.add_argument('--z_far', type=float, default=15.)
+ parser.add_argument('--use_opengl', type=util.str2bool, nargs='?', const=True, default=False, help='use opengl context or not')
+
+ if is_train:
+ # training parameters
+ parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
+ parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
+ parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
+ parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
+
+
+ # augmentation parameters
+ parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
+ parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
+ parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
+
+ # loss weights
+ parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
+ parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
+ parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
+ parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
+ parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
+ parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
+ parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
+ parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
+ parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
+
+
+
+ opt, _ = parser.parse_known_args()
+ parser.set_defaults(
+ focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
+ )
+ if is_train:
+ parser.set_defaults(
+ use_crop_face=True, use_predef_M=False
+ )
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+
+ self.visual_names = ['output_vis']
+ self.model_names = ['net_recon']
+ self.parallel_names = self.model_names + ['renderer']
+
+ self.net_recon = networks.define_net_recon(
+ net_recon=opt.net_recon, use_last_fc=opt.use_last_fc, init_path=opt.init_path
+ )
+
+ self.facemodel = ParametricFaceModel(
+ bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
+ is_train=self.isTrain, default_name=opt.bfm_model
+ )
+
+ fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
+ self.renderer = MeshRenderer(
+ rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center), use_opengl=opt.use_opengl
+ )
+
+ if self.isTrain:
+ self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
+
+ self.net_recog = networks.define_net_recog(
+ net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
+ )
+ # loss func name: (compute_%s_loss) % loss_name
+ self.compute_feat_loss = perceptual_loss
+ self.comupte_color_loss = photo_loss
+ self.compute_lm_loss = landmark_loss
+ self.compute_reg_loss = reg_loss
+ self.compute_reflc_loss = reflectance_loss
+
+ self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
+ self.optimizers = [self.optimizer]
+ self.parallel_names += ['net_recog']
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ self.input_img = input['imgs'].to(self.device)
+ self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
+ self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
+ self.trans_m = input['M'].to(self.device) if 'M' in input else None
+ self.image_paths = input['im_paths'] if 'im_paths' in input else None
+
+ def forward(self):
+ output_coeff = self.net_recon(self.input_img)
+ self.facemodel.to(self.device)
+ self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
+ self.facemodel.compute_for_render(output_coeff)
+ self.pred_mask, _, self.pred_face = self.renderer(
+ self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
+
+ self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
+ self.output_coeff = output_coeff
+
+ def compute_losses(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+
+ assert self.net_recog.training == False
+ trans_m = self.trans_m
+ if not self.opt.use_predef_M:
+ trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
+
+ pred_feat = self.net_recog(self.pred_face, trans_m)
+ gt_feat = self.net_recog(self.input_img, self.trans_m)
+ self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
+
+ face_mask = self.pred_mask
+ if self.opt.use_crop_face:
+ face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
+
+ face_mask = face_mask.detach()
+ self.loss_color = self.opt.w_color * self.comupte_color_loss(
+ self.pred_face, self.input_img, self.atten_mask * face_mask)
+
+ loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
+ self.loss_reg = self.opt.w_reg * loss_reg
+ self.loss_gamma = self.opt.w_gamma * loss_gamma
+
+ self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
+
+ self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
+
+ self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+ + self.loss_lm + self.loss_reflc
+
+
+ def optimize_parameters(self, isTrain=True):
+ self.forward()
+ self.compute_losses()
+ """Update network weights; it will be called in every training iteration."""
+ if isTrain:
+ self.optimizer.zero_grad()
+ self.loss_all.backward()
+ self.optimizer.step()
+
+ def compute_visuals(self):
+ with torch.no_grad():
+ input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
+ output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
+ output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
+
+ if self.gt_lm is not None:
+ gt_lm_numpy = self.gt_lm.cpu().numpy()
+ pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
+ output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
+
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw, output_vis_numpy), axis=-2)
+ else:
+ output_vis_numpy = np.concatenate((input_img_numpy,
+ output_vis_numpy_raw), axis=-2)
+
+ self.output_vis = torch.tensor(
+ output_vis_numpy / 255., dtype=torch.float32
+ ).permute(0, 3, 1, 2).to(self.device)
+
+ def save_mesh(self, name):
+
+ recon_shape = self.pred_vertex # get reconstructed shape
+ recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
+ recon_shape = recon_shape.cpu().numpy()[0]
+ recon_color = self.pred_color
+ recon_color = recon_color.cpu().numpy()[0]
+ tri = self.facemodel.face_buf.cpu().numpy()
+ mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8), process=False)
+ mesh.export(name)
+
+ def save_coeff(self,name):
+
+ pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
+ pred_lm = self.pred_lm.cpu().numpy()
+ pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
+ pred_coeffs['lm68'] = pred_lm
+ savemat(name,pred_coeffs)
+
+
+
diff --git a/deep_3drecon/deep_3drecon_models/losses.py b/deep_3drecon/deep_3drecon_models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbacb63b6110f3dbe7256eb4d5eb781a41e87b8f
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/losses.py
@@ -0,0 +1,113 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from kornia.geometry import warp_affine
+import torch.nn.functional as F
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize))
+
+### perceptual level loss
+class PerceptualLoss(nn.Module):
+ def __init__(self, recog_net, input_size=112):
+ super(PerceptualLoss, self).__init__()
+ self.recog_net = recog_net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+ def forward(imageA, imageB, M):
+ """
+ 1 - cosine distance
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
+ imageB --same as imageA
+ """
+
+ imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
+ imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
+
+ # freeze bn
+ self.recog_net.eval()
+
+ id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
+ id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+def perceptual_loss(id_featureA, id_featureB):
+ cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+ # assert torch.sum((cosine_d > 1).float()) == 0
+ return torch.sum(1 - cosine_d) / cosine_d.shape[0]
+
+### image level loss
+def photo_loss(imageA, imageB, mask, eps=1e-6):
+ """
+ l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
+ Parameters:
+ imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
+ imageB --same as imageA
+ """
+ loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
+ loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
+ return loss
+
+def landmark_loss(predict_lm, gt_lm, weight=None):
+ """
+ weighted mse loss
+ Parameters:
+ predict_lm --torch.tensor (B, 68, 2)
+ gt_lm --torch.tensor (B, 68, 2)
+ weight --numpy.array (1, 68)
+ """
+ if not weight:
+ weight = np.ones([68])
+ weight[28:31] = 20
+ weight[-8:] = 20
+ weight = np.expand_dims(weight, 0)
+ weight = torch.tensor(weight).to(predict_lm.device)
+ loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
+ loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
+ return loss
+
+
+### regulization
+def reg_loss(coeffs_dict, opt=None):
+ """
+ l2 norm without the sqrt, from yu's implementation (mse)
+ tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
+ Parameters:
+ coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
+
+ """
+ # coefficient regularization to ensure plausible 3d faces
+ if opt:
+ w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
+ else:
+ w_id, w_exp, w_tex = 1, 1, 1, 1
+ creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) + \
+ w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
+ w_tex * torch.sum(coeffs_dict['tex'] ** 2)
+ creg_loss = creg_loss / coeffs_dict['id'].shape[0]
+
+ # gamma regularization to ensure a nearly-monochromatic light
+ gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
+ gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
+ gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
+
+ return creg_loss, gamma_loss
+
+def reflectance_loss(texture, mask):
+ """
+ minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
+ Parameters:
+ texture --torch.tensor, (B, N, 3)
+ mask --torch.tensor, (N), 1 or 0
+
+ """
+ mask = mask.reshape([1, mask.shape[0], 1])
+ texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
+ loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
+ return loss
+
diff --git a/deep_3drecon/deep_3drecon_models/networks.py b/deep_3drecon/deep_3drecon_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..685750dec3acf78547e2d2a7f8b89c1a00336b71
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/networks.py
@@ -0,0 +1,522 @@
+"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import torch
+from torch import Tensor
+import torch.nn as nn
+try:
+ from torch.hub import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+from .arcface_torch.backbones import get_model
+from kornia.geometry import warp_affine
+
+
+def resize_n_crop(image, M, dsize=112):
+ # image: (b, c, h, w)
+ # M : (b, 2, 3)
+ return warp_affine(image, M, dsize=(dsize, dsize))
+
+def filter_state_dict(state_dict, remove_name='fc'):
+ new_state_dict = {}
+ for key in state_dict:
+ if remove_name in key:
+ continue
+ new_state_dict[key] = state_dict[key]
+ return new_state_dict
+
+def get_scheduler(optimizer, opt):
+ """Return a learning rate scheduler
+
+ Parameters:
+ optimizer -- the optimizer of the network
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+ See https://pytorch.org/docs/stable/optim.html for more details.
+ """
+ if opt.lr_policy == 'linear':
+ def lambda_rule(epoch):
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
+ return lr_l
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+ elif opt.lr_policy == 'step':
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
+ elif opt.lr_policy == 'plateau':
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+ elif opt.lr_policy == 'cosine':
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+ else:
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+ return scheduler
+
+
+def define_net_recon(net_recon, use_last_fc=False, init_path=None):
+ return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
+
+def define_net_recog(net_recog, pretrained_path=None):
+ net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
+ net.eval()
+ return net
+
+class ReconNetWrapper(nn.Module):
+ fc_dim=257
+ def __init__(self, net_recon, use_last_fc=False, init_path=None):
+ super(ReconNetWrapper, self).__init__()
+ self.use_last_fc = use_last_fc
+ if net_recon not in func_dict:
+ return NotImplementedError('network [%s] is not implemented', net_recon)
+ func, last_dim = func_dict[net_recon]
+ backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
+ if init_path and os.path.isfile(init_path):
+ state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
+ backbone.load_state_dict(state_dict)
+ print("loading init net_recon %s from %s" %(net_recon, init_path))
+ self.backbone = backbone
+ if not use_last_fc:
+ self.final_layers = nn.ModuleList([
+ conv1x1(last_dim, 80, bias=True), # id layer
+ conv1x1(last_dim, 64, bias=True), # exp layer
+ conv1x1(last_dim, 80, bias=True), # tex layer
+ conv1x1(last_dim, 3, bias=True), # angle layer
+ conv1x1(last_dim, 27, bias=True), # gamma layer
+ conv1x1(last_dim, 2, bias=True), # tx, ty
+ conv1x1(last_dim, 1, bias=True) # tz
+ ])
+ for m in self.final_layers:
+ nn.init.constant_(m.weight, 0.)
+ nn.init.constant_(m.bias, 0.)
+
+ def forward(self, x):
+ x = self.backbone(x)
+ if not self.use_last_fc:
+ output = []
+ for layer in self.final_layers:
+ output.append(layer(x))
+ x = torch.flatten(torch.cat(output, dim=1), 1)
+ return x
+
+
+class RecogNetWrapper(nn.Module):
+ def __init__(self, net_recog, pretrained_path=None, input_size=112):
+ super(RecogNetWrapper, self).__init__()
+ net = get_model(name=net_recog, fp16=False)
+ if pretrained_path:
+ state_dict = torch.load(pretrained_path, map_location='cpu')
+ net.load_state_dict(state_dict)
+ print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
+ for param in net.parameters():
+ param.requires_grad = False
+ self.net = net
+ self.preprocess = lambda x: 2 * x - 1
+ self.input_size=input_size
+
+ def forward(self, image, M):
+ image = self.preprocess(resize_n_crop(image, M, self.input_size))
+ id_feature = F.normalize(self.net(image), dim=-1, p=2)
+ return id_feature
+
+
+# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+ 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+ 'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
+ """1x1 convolution"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
+
+
+class BasicBlock(nn.Module):
+ expansion: int = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(BasicBlock, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ if groups != 1 or base_width != 64:
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+ if dilation > 1:
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = norm_layer(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = norm_layer(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+ expansion: int = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ groups: int = 1,
+ base_width: int = 64,
+ dilation: int = 1,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(Bottleneck, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ width = int(planes * (base_width / 64.)) * groups
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+ self.conv1 = conv1x1(inplanes, width)
+ self.bn1 = norm_layer(width)
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
+ self.bn2 = norm_layer(width)
+ self.conv3 = conv1x1(width, planes * self.expansion)
+ self.bn3 = norm_layer(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x: Tensor) -> Tensor:
+ identity = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+
+ def __init__(
+ self,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ num_classes: int = 1000,
+ zero_init_residual: bool = False,
+ use_last_fc: bool = False,
+ groups: int = 1,
+ width_per_group: int = 64,
+ replace_stride_with_dilation: Optional[List[bool]] = None,
+ norm_layer: Optional[Callable[..., nn.Module]] = None
+ ) -> None:
+ super(ResNet, self).__init__()
+ if norm_layer is None:
+ norm_layer = nn.BatchNorm2d
+ self._norm_layer = norm_layer
+
+ self.inplanes = 64
+ self.dilation = 1
+ if replace_stride_with_dilation is None:
+ # each element in the tuple indicates if we should replace
+ # the 2x2 stride with a dilated convolution instead
+ replace_stride_with_dilation = [False, False, False]
+ if len(replace_stride_with_dilation) != 3:
+ raise ValueError("replace_stride_with_dilation should be None "
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+ self.use_last_fc = use_last_fc
+ self.groups = groups
+ self.base_width = width_per_group
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = norm_layer(self.inplanes)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+ dilate=replace_stride_with_dilation[0])
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+ dilate=replace_stride_with_dilation[1])
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+ dilate=replace_stride_with_dilation[2])
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+
+ if self.use_last_fc:
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+
+
+ # Zero-initialize the last BN in each residual branch,
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+ if zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
+ elif isinstance(m, BasicBlock):
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
+
+ def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
+ norm_layer = self._norm_layer
+ downsample = None
+ previous_dilation = self.dilation
+ if dilate:
+ self.dilation *= stride
+ stride = 1
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes * block.expansion, stride),
+ norm_layer(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+ self.base_width, previous_dilation, norm_layer))
+ self.inplanes = planes * block.expansion
+ for _ in range(1, blocks):
+ layers.append(block(self.inplanes, planes, groups=self.groups,
+ base_width=self.base_width, dilation=self.dilation,
+ norm_layer=norm_layer))
+
+ return nn.Sequential(*layers)
+
+ def _forward_impl(self, x: Tensor) -> Tensor:
+ # See note [TorchScript super()]
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ if self.use_last_fc:
+ x = torch.flatten(x, 1)
+ x = self.fc(x)
+ return x
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self._forward_impl(x)
+
+
+def _resnet(
+ arch: str,
+ block: Type[Union[BasicBlock, Bottleneck]],
+ layers: List[int],
+ pretrained: bool,
+ progress: bool,
+ **kwargs: Any
+) -> ResNet:
+ model = ResNet(block, layers, **kwargs)
+ if pretrained:
+ state_dict = load_state_dict_from_url(model_urls[arch],
+ progress=progress)
+ model.load_state_dict(state_dict)
+ return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-18 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+ **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-34 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-50 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-101 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNet-152 model from
+ `"Deep Residual Learning for Image Recognition" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+ **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-50 32x4d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 4
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""ResNeXt-101 32x8d model from
+ `"Aggregated Residual Transformation for Deep Neural Networks" `_.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['groups'] = 32
+ kwargs['width_per_group'] = 8
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-50-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+ pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+ r"""Wide ResNet-101-2 model from
+ `"Wide Residual Networks" `_.
+
+ The model is the same as ResNet except for the bottleneck number of channels
+ which is twice larger in every block. The number of channels in outer 1x1
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ progress (bool): If True, displays a progress bar of the download to stderr
+ """
+ kwargs['width_per_group'] = 64 * 2
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+ pretrained, progress, **kwargs)
+
+
+func_dict = {
+ 'resnet18': (resnet18, 512),
+ 'resnet50': (resnet50, 2048)
+}
diff --git a/deep_3drecon/deep_3drecon_models/template_model.py b/deep_3drecon/deep_3drecon_models/template_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..dac7b33d5889777eb63c9882a3b9fa094dcab293
--- /dev/null
+++ b/deep_3drecon/deep_3drecon_models/template_model.py
@@ -0,0 +1,100 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be _dataset.py
+The class name should be Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+ min_ ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+ : Add model-specific options and rewrite default values for existing options.
+ <__init__>: Initialize this model class.
+ : Unpack input data and perform data pre-processing.
+ : Run forward pass. This will be called by both and .
+ : Update network weights; it will be called in every training iteration.
+"""
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class TemplateModel(BaseModel):
+ @staticmethod
+ def modify_commandline_options(parser, is_train=True):
+ """Add new model-specific options and rewrite default values for existing options.
+
+ Parameters:
+ parser -- the option parser
+ is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+ Returns:
+ the modified parser.
+ """
+ parser.set_defaults(dataset_mode='aligned') # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+ if is_train:
+ parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss') # You can define new arguments for this model.
+
+ return parser
+
+ def __init__(self, opt):
+ """Initialize this model class.
+
+ Parameters:
+ opt -- training/test options
+
+ A few things can be done here.
+ - (required) call the initialization function of BaseModel
+ - define loss function, visualization images, model names, and optimizers
+ """
+ BaseModel.__init__(self, opt) # call the initialization method of BaseModel
+ # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+ self.loss_names = ['loss_G']
+ # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+ self.visual_names = ['data_A', 'data_B', 'output']
+ # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+ # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+ self.model_names = ['G']
+ # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+ if self.isTrain: # only defined during training time
+ # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+ # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+ self.criterionLoss = torch.nn.L1Loss()
+ # define and initialize optimizers. You can define one optimizer for each network.
+ # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+ self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+ self.optimizers = [self.optimizer]
+
+ # Our program will automatically call to define schedulers, load networks, and print networks
+
+ def set_input(self, input):
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+ Parameters:
+ input: a dictionary that contains the data itself and its metadata information.
+ """
+ AtoB = self.opt.direction == 'AtoB' # use to swap data_A and data_B
+ self.data_A = input['A' if AtoB else 'B'].to(self.device) # get image data A
+ self.data_B = input['B' if AtoB else 'A'].to(self.device) # get image data B
+ self.image_paths = input['A_paths' if AtoB else 'B_paths'] # get image paths
+
+ def forward(self):
+ """Run forward pass. This will be called by both functions and ."""
+ self.output = self.netG(self.data_A) # generate output image given the input data_A
+
+ def backward(self):
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
+ # caculate the intermediate results if necessary; here self.output has been computed during function
+ # calculate loss given the input and intermediate results
+ self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+ self.loss_G.backward() # calculate gradients of network G w.r.t. loss_G
+
+ def optimize_parameters(self):
+ """Update network weights; it will be called in every training iteration."""
+ self.forward() # first call forward to calculate intermediate results
+ self.optimizer.zero_grad() # clear network G's existing gradients
+ self.backward() # calculate gradients for network G
+ self.optimizer.step() # update gradients for network G
diff --git a/deep_3drecon/generate_reconstructor_opt_for_geneface.py b/deep_3drecon/generate_reconstructor_opt_for_geneface.py
new file mode 100644
index 0000000000000000000000000000000000000000..96e8b2ed441ab2ad5a12fe02bc934c1f868e766f
--- /dev/null
+++ b/deep_3drecon/generate_reconstructor_opt_for_geneface.py
@@ -0,0 +1,12 @@
+from options.test_options import TestOptions
+import pickle as pkl
+
+# run in the root dir!
+opt = TestOptions().parse() # get test options
+opt.name='facerecon'
+opt.epoch=20
+opt.bfm_folder='deep_3drecon/BFM/'
+opt.checkpoints_dir='deep_3drecon/checkpoints/'
+
+with open("deep_3drecon/reconstructor_opt.pkl", 'wb') as f:
+ pkl.dump(opt, f)
diff --git a/deep_3drecon/ncc_code.npy b/deep_3drecon/ncc_code.npy
new file mode 100644
index 0000000000000000000000000000000000000000..79568a9ce3c7a903cea7ec76f1870f15fd052f13
--- /dev/null
+++ b/deep_3drecon/ncc_code.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:da54a620c0981d43cc9f30b3d8b3f5d4beb0ec0e27127a1ef3fb62ea50913609
+size 428636
diff --git a/deep_3drecon/options/__init__.py b/deep_3drecon/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90
--- /dev/null
+++ b/deep_3drecon/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/deep_3drecon/options/base_options.py b/deep_3drecon/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..37ad0bda3c189aac91641c5f7b70936bb257f2a6
--- /dev/null
+++ b/deep_3drecon/options/base_options.py
@@ -0,0 +1,169 @@
+"""This script contains base options for Deep3DFaceRecon_pytorch
+"""
+
+import argparse
+import os
+from util import util
+import numpy as np
+import torch
+import deep_3drecon_models
+import data
+
+
+class BaseOptions():
+ """This class defines options used during both training and test time.
+
+ It also implements several helper functions such as parsing, printing, and saving the options.
+ It also gathers additional options defined in functions in both dataset class and model class.
+ """
+
+ def __init__(self, cmd_line=None):
+ """Reset the class; indicates the class hasn't been initailized"""
+ self.initialized = False
+ self.cmd_line = None
+ if cmd_line is not None:
+ self.cmd_line = cmd_line.split()
+
+ def initialize(self, parser):
+ """Define the common options that are used in both training and test."""
+ # basic parameters
+ parser.add_argument('--name', type=str, default='facerecon', help='name of the experiment. It decides where to store samples and models')
+ parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
+ parser.add_argument('--checkpoints_dir', type=str, default='./deep_3drecon/checkpoints', help='models are saved here')
+ parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
+ parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
+ parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
+ parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
+ parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
+ parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
+ parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
+
+ # model parameters
+ parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
+
+ # additional parameters
+ parser.add_argument('--epoch', type=str, default='20', help='which epoch to load? set to latest to use latest cached model')
+ parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+ parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+ self.initialized = True
+ return parser
+
+ def gather_options(self):
+ """Initialize our parser with basic options(only once).
+ Add additional model-specific and dataset-specific options.
+ These options are defined in the function
+ in model and dataset classes.
+ """
+ if not self.initialized: # check if it has been initialized
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser = self.initialize(parser)
+
+ # get the basic options
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args()
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line)
+
+ # set cuda visible devices
+ os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
+
+ # modify model-related parser options
+ model_name = opt.model
+ model_option_setter = deep_3drecon_models.get_option_setter(model_name)
+ parser = model_option_setter(parser, self.isTrain)
+ if self.cmd_line is None:
+ opt, _ = parser.parse_known_args() # parse again with new defaults
+ else:
+ opt, _ = parser.parse_known_args(self.cmd_line) # parse again with new defaults
+
+ # modify dataset-related parser options
+ if opt.dataset_mode:
+ dataset_name = opt.dataset_mode
+ dataset_option_setter = data.get_option_setter(dataset_name)
+ parser = dataset_option_setter(parser, self.isTrain)
+
+ # save and return the parser
+ self.parser = parser
+ if self.cmd_line is None:
+ return parser.parse_args()
+ else:
+ return parser.parse_args(self.cmd_line)
+
+ def print_options(self, opt):
+ """Print and save options
+
+ It will print both current options and default values(if different).
+ It will save options into a text file / [checkpoints_dir] / opt.txt
+ """
+ message = ''
+ message += '----------------- Options ---------------\n'
+ for k, v in sorted(vars(opt).items()):
+ comment = ''
+ default = self.parser.get_default(k)
+ if v != default:
+ comment = '\t[default: %s]' % str(default)
+ message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+ message += '----------------- End -------------------'
+ print(message)
+
+ # save to the disk
+ expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ util.mkdirs(expr_dir)
+ file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+ try:
+ with open(file_name, 'wt') as opt_file:
+ opt_file.write(message)
+ opt_file.write('\n')
+ except PermissionError as error:
+ print("permission error {}".format(error))
+ pass
+
+ def parse(self):
+ """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+ opt = self.gather_options()
+ opt.isTrain = self.isTrain # train or test
+
+ # process opt.suffix
+ if opt.suffix:
+ suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+ opt.name = opt.name + suffix
+
+
+ # set gpu ids
+ str_ids = opt.gpu_ids.split(',')
+ gpu_ids = []
+ for str_id in str_ids:
+ id = int(str_id)
+ if id >= 0:
+ gpu_ids.append(id)
+ opt.world_size = len(gpu_ids)
+ # if len(opt.gpu_ids) > 0:
+ # torch.cuda.set_device(gpu_ids[0])
+ if opt.world_size == 1:
+ opt.use_ddp = False
+
+ if opt.phase != 'test':
+ # set continue_train automatically
+ if opt.pretrained_name is None:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ else:
+ model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
+ if os.path.isdir(model_dir):
+ model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
+ if os.path.isdir(model_dir) and len(model_pths) != 0:
+ opt.continue_train= True
+
+ # update the latest epoch count
+ if opt.continue_train:
+ if opt.epoch == 'latest':
+ epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
+ if len(epoch_counts) != 0:
+ opt.epoch_count = max(epoch_counts) + 1
+ else:
+ opt.epoch_count = int(opt.epoch) + 1
+
+
+ self.print_options(opt)
+ self.opt = opt
+ return self.opt
diff --git a/deep_3drecon/options/test_options.py b/deep_3drecon/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff3ad142779850d1d5a1640bc00f70d34d4a862
--- /dev/null
+++ b/deep_3drecon/options/test_options.py
@@ -0,0 +1,21 @@
+"""This script contains the test options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ """This class includes test options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser) # define shared options
+ parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+ parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
+
+ # Dropout and Batchnorm has different behavior during training and test.
+ self.isTrain = False
+ return parser
diff --git a/deep_3drecon/options/train_options.py b/deep_3drecon/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1337bfdd5f372b5c686a91b394a2aadbe5741f44
--- /dev/null
+++ b/deep_3drecon/options/train_options.py
@@ -0,0 +1,53 @@
+"""This script contains the training options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+from util import util
+
+class TrainOptions(BaseOptions):
+ """This class includes training options.
+
+ It also includes shared options defined in BaseOptions.
+ """
+
+ def initialize(self, parser):
+ parser = BaseOptions.initialize(self, parser)
+ # dataset parameters
+ # for train
+ parser.add_argument('--data_root', type=str, default='./', help='dataset root')
+ parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
+ parser.add_argument('--batch_size', type=int, default=32)
+ parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
+ parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+ parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+ parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+ parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
+ parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
+
+ # for val
+ parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
+ parser.add_argument('--batch_size_val', type=int, default=32)
+
+
+ # visualization parameters
+ parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
+ parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+
+ # network saving and loading parameters
+ parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+ parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
+ parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
+ parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
+ parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+ parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
+ parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+ parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
+
+ # training parameters
+ parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
+ parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
+ parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
+ parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
+
+ self.isTrain = True
+ return parser
diff --git a/deep_3drecon/reconstructor.py b/deep_3drecon/reconstructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6f8be419b8db053ce81b9caaa6f793062fbeaeb
--- /dev/null
+++ b/deep_3drecon/reconstructor.py
@@ -0,0 +1,90 @@
+"""This script is the test script for Deep3DFaceRecon_pytorch
+Pytorch Deep3D_Recon is 8x faster than TF-based, 16s/iter ==> 2s/iter
+"""
+
+import os
+# os.environ['PYTHONPATH'] = os.environ['PYTHONPATH'] + ":" + os.path.abspath("deep_3drecon")
+import torch
+import torch.nn as nn
+from .deep_3drecon_models.facerecon_model import FaceReconModel
+from .util.preprocess import align_img
+from PIL import Image
+import numpy as np
+from .util.load_mats import load_lm3d
+import torch
+import pickle as pkl
+from PIL import Image
+
+from utils.commons.tensor_utils import convert_to_tensor, convert_to_np
+
+with open("deep_3drecon/reconstructor_opt.pkl", "rb") as f:
+ opt = pkl.load(f)
+
+class Reconstructor(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.model = FaceReconModel(opt)
+ self.model.setup(opt)
+ self.model.device = 'cuda:0'
+ self.model.parallelize()
+ # self.model.to(self.model.device)
+ self.model.eval()
+ self.lm3d_std = load_lm3d(opt.bfm_folder)
+
+ def preprocess_data(self, im, lm, lm3d_std):
+ # to RGB
+ H,W,_ = im.shape
+ lm = lm.reshape([-1, 2])
+ lm[:, -1] = H - 1 - lm[:, -1]
+
+ _, im, lm, _ = align_img(Image.fromarray(convert_to_np(im)), convert_to_np(lm), convert_to_np(lm3d_std))
+ im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
+ lm = torch.tensor(lm).unsqueeze(0)
+ return im, lm
+
+ @torch.no_grad()
+ def recon_coeff(self, batched_images, batched_lm5, return_image=True, batch_mode=True):
+ bs = batched_images.shape[0]
+ data_lst = []
+ for i in range(bs):
+ img = batched_images[i]
+ lm5 = batched_lm5[i]
+ align_im, lm = self.preprocess_data(img, lm5, self.lm3d_std)
+ data = {
+ 'imgs': align_im,
+ 'lms': lm
+ }
+ data_lst.append(data)
+ if not batch_mode:
+ coeff_lst = []
+ align_lst = []
+ for i in range(bs):
+ data = data_lst
+ self.model.set_input(data) # unpack data from data loader
+ self.model.forward()
+ pred_coeff = self.model.output_coeff.cpu().numpy()
+ align_im = (align_im.squeeze().permute(1,2,0)*255).int().numpy().astype(np.uint8)
+ coeff_lst.append(pred_coeff)
+ align_lst.append(align_im)
+ batch_coeff = np.concatenate(coeff_lst)
+ batch_align_img = np.stack(align_lst) # [B, 257]
+ else:
+ imgs = torch.cat([d['imgs'] for d in data_lst])
+ lms = torch.cat([d['lms'] for d in data_lst])
+ data = {
+ 'imgs': imgs,
+ 'lms': lms
+ }
+ self.model.set_input(data) # unpack data from data loader
+ self.model.forward()
+ batch_coeff = self.model.output_coeff.cpu().numpy()
+ batch_align_img = (imgs.permute(0,2,3,1)*255).int().numpy().astype(np.uint8)
+ return batch_coeff, batch_align_img
+
+ # todo: batch-wise recon!
+
+ def forward(self, batched_images, batched_lm5, return_image=True):
+ return self.recon_coeff(batched_images, batched_lm5, return_image)
+
+
+
diff --git a/deep_3drecon/reconstructor_opt.pkl b/deep_3drecon/reconstructor_opt.pkl
new file mode 100644
index 0000000000000000000000000000000000000000..9917488bf7e85bb7a98d70fb9327f3bf4ef163cd
--- /dev/null
+++ b/deep_3drecon/reconstructor_opt.pkl
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:17be83a1ab4333e2b6d1cbb80e6cdf92ee470c32391134fd072652423b09984e
+size 776
diff --git a/deep_3drecon/secc_renderer.py b/deep_3drecon/secc_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d6b3cdc4051c1ad9ed98b3228a35f35d573ab7c
--- /dev/null
+++ b/deep_3drecon/secc_renderer.py
@@ -0,0 +1,78 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+
+from deep_3drecon.util.mesh_renderer import MeshRenderer
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+
+
+class SECC_Renderer(nn.Module):
+ def __init__(self, rasterize_size=None, device="cuda"):
+ super().__init__()
+ self.face_model = ParametricFaceModel('deep_3drecon/BFM')
+ self.fov = 2 * np.arctan(self.face_model.center / self.face_model.focal) * 180 / np.pi
+ self.znear = 5.
+ self.zfar = 15.
+ if rasterize_size is None:
+ rasterize_size = 2*self.face_model.center
+ self.face_renderer = MeshRenderer(rasterize_fov=self.fov, znear=self.znear, zfar=self.zfar, rasterize_size=rasterize_size, use_opengl=False).cuda()
+ face_feat = np.load("deep_3drecon/ncc_code.npy", allow_pickle=True)
+ self.face_feat = torch.tensor(face_feat.T).unsqueeze(0).to(device=device)
+
+ del_index_re = np.load('deep_3drecon/bfm_right_eye_faces.npy')
+ del_index_re = del_index_re - 1
+ del_index_le = np.load('deep_3drecon/bfm_left_eye_faces.npy')
+ del_index_le = del_index_le - 1
+ face_buf_list = []
+ for i in range(self.face_model.face_buf.shape[0]):
+ if i not in del_index_re and i not in del_index_le:
+ face_buf_list.append(self.face_model.face_buf[i])
+ face_buf_arr = np.array(face_buf_list)
+ self.face_buf = torch.tensor(face_buf_arr).to(device=device)
+
+ def forward(self, id, exp, euler, trans):
+ """
+ id, exp, euler, euler: [B, C] or [B, T, C]
+ return:
+ MASK: [B, 1, 512, 512], value[0. or 1.0], 1.0 denotes is face
+ SECC MAP: [B, 3, 512, 512], value[0~1]
+ if input is BTC format, return [B, C, T, H, W]
+ """
+ bs = id.shape[0]
+ is_btc_flag = id.ndim == 3
+ if is_btc_flag:
+ t = id.shape[1]
+ bs = bs * t
+ id, exp, euler, trans = id.reshape([bs,-1]), exp.reshape([bs,-1]), euler.reshape([bs,-1]), trans.reshape([bs,-1])
+
+ face_vertex = self.face_model.compute_face_vertex(id, exp, euler, trans)
+ face_mask, _, secc_face = self.face_renderer(
+ face_vertex, self.face_buf.unsqueeze(0).repeat([bs, 1, 1]), feat=self.face_feat.repeat([bs,1,1]))
+ secc_face = (secc_face - 0.5) / 0.5 # scale to -1~1
+
+ if is_btc_flag:
+ bs = bs // t
+ face_mask = rearrange(face_mask, "(n t) c h w -> n c t h w", n=bs, t=t)
+ secc_face = rearrange(secc_face, "(n t) c h w -> n c t h w", n=bs, t=t)
+ return face_mask, secc_face
+
+
+if __name__ == '__main__':
+ import imageio
+
+ renderer = SECC_Renderer(rasterize_size=512)
+ ret = np.load("data/processed/videos/May/vid_coeff_fit.npy", allow_pickle=True).tolist()
+ idx = 6
+ id = torch.tensor(ret['id']).cuda()[idx:idx+1]
+ exp = torch.tensor(ret['exp']).cuda()[idx:idx+1]
+ angle = torch.tensor(ret['euler']).cuda()[idx:idx+1]
+ trans = torch.tensor(ret['trans']).cuda()[idx:idx+1]
+ mask, secc = renderer(id, exp, angle*0, trans*0) # [1, 1, 512, 512], [1, 3, 512, 512]
+
+ out_mask = mask[0].permute(1,2,0)
+ out_mask = (out_mask * 127.5 + 127.5).int().cpu().numpy()
+ imageio.imwrite("out_mask.png", out_mask)
+ out_img = secc[0].permute(1,2,0)
+ out_img = (out_img * 127.5 + 127.5).int().cpu().numpy()
+ imageio.imwrite("out_secc.png", out_img)
\ No newline at end of file
diff --git a/deep_3drecon/test.py b/deep_3drecon/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..c52078440b0a1c6ba919eb635d28f3745562cd52
--- /dev/null
+++ b/deep_3drecon/test.py
@@ -0,0 +1,69 @@
+"""This script is the test script for Deep3DFaceRecon_pytorch
+"""
+
+import os
+from options.test_options import TestOptions
+from deep_3drecon_models import create_model
+from util.visualizer import MyVisualizer
+from util.preprocess import align_img
+from PIL import Image
+import numpy as np
+from util.load_mats import load_lm3d
+import torch
+
+def get_data_path(root='examples'):
+ im_path = [os.path.join(root, i) for i in sorted(os.listdir(root)) if i.endswith('png') or i.endswith('jpg')]
+ lm_path = [i.replace('png', 'txt').replace('jpg', 'txt') for i in im_path]
+ lm_path = [os.path.join(i.replace(i.split(os.path.sep)[-1],''),'detections',i.split(os.path.sep)[-1]) for i in lm_path]
+ return im_path, lm_path
+
+def read_data(im_path, lm_path, lm3d_std, to_tensor=True):
+ # to RGB
+ im = Image.open(im_path).convert('RGB')
+ W,H = im.size
+ lm = np.loadtxt(lm_path).astype(np.float32)
+ lm = lm.reshape([-1, 2])
+ lm[:, -1] = H - 1 - lm[:, -1]
+ _, im, lm, _ = align_img(im, lm, lm3d_std)
+ if to_tensor:
+ im = torch.tensor(np.array(im)/255., dtype=torch.float32).permute(2, 0, 1).unsqueeze(0)
+ lm = torch.tensor(lm).unsqueeze(0)
+ return im, lm
+
+def main(rank, opt, name='examples'):
+ device = torch.device(rank)
+ torch.cuda.set_device(device)
+ model = create_model(opt)
+ model.setup(opt)
+ model.device = device
+ model.parallelize()
+ model.eval()
+ visualizer = MyVisualizer(opt)
+
+ im_path, lm_path = get_data_path(name)
+ lm3d_std = load_lm3d(opt.bfm_folder)
+
+ for i in range(len(im_path)):
+ print(i, im_path[i])
+ img_name = im_path[i].split(os.path.sep)[-1].replace('.png','').replace('.jpg','')
+ if not os.path.isfile(lm_path[i]):
+ print("%s is not found !!!"%lm_path[i])
+ continue
+ im_tensor, lm_tensor = read_data(im_path[i], lm_path[i], lm3d_std)
+ data = {
+ 'imgs': im_tensor,
+ 'lms': lm_tensor
+ }
+ model.set_input(data) # unpack data from data loader
+ model.test() # run inference
+ visuals = model.get_current_visuals() # get image results
+ visualizer.display_current_results(visuals, 0, opt.epoch, dataset=name.split(os.path.sep)[-1],
+ save_results=True, count=i, name=img_name, add_image=False)
+
+ model.save_mesh(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.obj')) # save reconstruction meshes
+ model.save_coeff(os.path.join(visualizer.img_dir, name.split(os.path.sep)[-1], 'epoch_%s_%06d'%(opt.epoch, 0),img_name+'.mat')) # save predicted coefficients
+
+if __name__ == '__main__':
+ opt = TestOptions().parse() # get test options
+ main(0, opt, 'deep_3drecon/datasets/examples')
+ print(f"results saved at deep_3drecon/checkpoints/facerecon/results/")
diff --git a/deep_3drecon/train.py b/deep_3drecon/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbdda882178277a4e1a5e4c3dd87299ab1ba6e8b
--- /dev/null
+++ b/deep_3drecon/train.py
@@ -0,0 +1,166 @@
+"""This script is the training script for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import time
+import numpy as np
+import torch
+from options.train_options import TrainOptions
+from data import create_dataset
+from deep_3drecon_models import create_model
+from util.visualizer import MyVisualizer
+from util.util import genvalconf
+import torch.multiprocessing as mp
+import torch.distributed as dist
+
+
+def setup(rank, world_size, port):
+ os.environ['MASTER_ADDR'] = 'localhost'
+ os.environ['MASTER_PORT'] = port
+
+ # initialize the process group
+ dist.init_process_group("gloo", rank=rank, world_size=world_size)
+
+def cleanup():
+ dist.destroy_process_group()
+
+def main(rank, world_size, train_opt):
+ val_opt = genvalconf(train_opt, isTrain=False)
+
+ device = torch.device(rank)
+ torch.cuda.set_device(device)
+ use_ddp = train_opt.use_ddp
+
+ if use_ddp:
+ setup(rank, world_size, train_opt.ddp_port)
+
+ train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank)
+ train_dataset_batches, val_dataset_batches = \
+ len(train_dataset) // train_opt.batch_size, len(val_dataset) // val_opt.batch_size
+
+ model = create_model(train_opt) # create a model given train_opt.model and other options
+ model.setup(train_opt)
+ model.device = device
+ model.parallelize()
+
+ if rank == 0:
+ print('The batch number of training images = %d\n, \
+ the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches))
+ model.print_networks(train_opt.verbose)
+ visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots
+
+ total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations
+ t_data = 0
+ t_val = 0
+ optimize_time = 0.1
+ batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size
+
+ if use_ddp:
+ dist.barrier()
+
+ times = []
+ for epoch in range(train_opt.epoch_count, train_opt.n_epochs + 1): # outer loop for different epochs; we save the model by , +
+ epoch_start_time = time.time() # timer for entire epoch
+ iter_data_time = time.time() # timer for train_data loading per iteration
+ epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
+
+ train_dataset.set_epoch(epoch)
+ for i, train_data in enumerate(train_dataset): # inner loop within one epoch
+ iter_start_time = time.time() # timer for computation per iteration
+ if total_iters % train_opt.print_freq == 0:
+ t_data = iter_start_time - iter_data_time
+ total_iters += batch_size
+ epoch_iter += batch_size
+
+ torch.cuda.synchronize()
+ optimize_start_time = time.time()
+
+ model.set_input(train_data) # unpack train_data from dataset and apply preprocessing
+ model.optimize_parameters() # calculate loss functions, get gradients, update network weights
+
+ torch.cuda.synchronize()
+ optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
+
+ if use_ddp:
+ dist.barrier()
+
+ if rank == 0 and (total_iters == batch_size or total_iters % train_opt.display_freq == 0): # display images on visdom and save images to a HTML file
+ model.compute_visuals()
+ visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
+ save_results=True,
+ add_image=train_opt.add_image)
+ # (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0)
+
+ if rank == 0 and (total_iters == batch_size or total_iters % train_opt.print_freq == 0): # print training losses and save logging information to the disk
+ losses = model.get_current_losses()
+ visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
+ visualizer.plot_current_losses(total_iters, losses)
+
+ if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0:
+ with torch.no_grad():
+ torch.cuda.synchronize()
+ val_start_time = time.time()
+ losses_avg = {}
+ model.eval()
+ for j, val_data in enumerate(val_dataset):
+ model.set_input(val_data)
+ model.optimize_parameters(isTrain=False)
+ if rank == 0 and j < train_opt.vis_batch_nums:
+ model.compute_visuals()
+ visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
+ dataset='val', save_results=True, count=j * val_opt.batch_size,
+ add_image=train_opt.add_image)
+
+ if j < train_opt.eval_batch_nums:
+ losses = model.get_current_losses()
+ for key, value in losses.items():
+ losses_avg[key] = losses_avg.get(key, 0) + value
+
+ for key, value in losses_avg.items():
+ losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches)
+
+ torch.cuda.synchronize()
+ eval_time = time.time() - val_start_time
+
+ if rank == 0:
+ visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results
+ visualizer.plot_current_losses(total_iters, losses_avg, dataset='val')
+ model.train()
+
+ if use_ddp:
+ dist.barrier()
+
+ if rank == 0 and (total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0): # cache our latest model every iterations
+ print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
+ print(train_opt.name) # it's useful to occasionally show the experiment name on console
+ save_suffix = 'iter_%d' % total_iters if train_opt.save_by_iter else 'latest'
+ model.save_networks(save_suffix)
+
+ if use_ddp:
+ dist.barrier()
+
+ iter_data_time = time.time()
+
+ print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.n_epochs, time.time() - epoch_start_time))
+ model.update_learning_rate() # update learning rates at the end of every epoch.
+
+ if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every epochs
+ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
+ model.save_networks('latest')
+ model.save_networks(epoch)
+
+ if use_ddp:
+ dist.barrier()
+
+if __name__ == '__main__':
+
+ import warnings
+ warnings.filterwarnings("ignore")
+
+ train_opt = TrainOptions().parse() # get training options
+ world_size = train_opt.world_size
+
+ if train_opt.use_ddp:
+ mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True)
+ else:
+ main(0, world_size, train_opt)
diff --git a/deep_3drecon/util/BBRegressorParam_r.mat b/deep_3drecon/util/BBRegressorParam_r.mat
new file mode 100644
index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084
Binary files /dev/null and b/deep_3drecon/util/BBRegressorParam_r.mat differ
diff --git a/deep_3drecon/util/__init__.py b/deep_3drecon/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cbc84bf01004432c6a76db481d5e1741b0c32f
--- /dev/null
+++ b/deep_3drecon/util/__init__.py
@@ -0,0 +1,2 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from .util import *
diff --git a/deep_3drecon/util/detect_lm68.py b/deep_3drecon/util/detect_lm68.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7e40997289e17405e1fb6c408d21adce7b626ce
--- /dev/null
+++ b/deep_3drecon/util/detect_lm68.py
@@ -0,0 +1,106 @@
+import os
+import cv2
+import numpy as np
+from scipy.io import loadmat
+import tensorflow as tf
+from util.preprocess import align_for_lm
+from shutil import move
+
+mean_face = np.loadtxt('util/test_mean_face.txt')
+mean_face = mean_face.reshape([68, 2])
+
+def save_label(labels, save_path):
+ np.savetxt(save_path, labels)
+
+def draw_landmarks(img, landmark, save_name):
+ landmark = landmark
+ lm_img = np.zeros([img.shape[0], img.shape[1], 3])
+ lm_img[:] = img.astype(np.float32)
+ landmark = np.round(landmark).astype(np.int32)
+
+ for i in range(len(landmark)):
+ for j in range(-1, 1):
+ for k in range(-1, 1):
+ if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
+ img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
+ landmark[i, 0]+k > 0 and \
+ landmark[i, 0]+k < img.shape[1]:
+ lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
+ :] = np.array([0, 0, 255])
+ lm_img = lm_img.astype(np.uint8)
+
+ cv2.imwrite(save_name, lm_img)
+
+
+def load_data(img_name, txt_name):
+ return cv2.imread(img_name), np.loadtxt(txt_name)
+
+# create tensorflow graph for landmark detector
+def load_lm_graph(graph_filename):
+ with tf.gfile.GFile(graph_filename, 'rb') as f:
+ graph_def = tf.GraphDef()
+ graph_def.ParseFromString(f.read())
+
+ with tf.Graph().as_default() as graph:
+ tf.import_graph_def(graph_def, name='net')
+ img_224 = graph.get_tensor_by_name('net/input_imgs:0')
+ output_lm = graph.get_tensor_by_name('net/lm:0')
+ lm_sess = tf.Session(graph=graph)
+
+ return lm_sess,img_224,output_lm
+
+# landmark detection
+def detect_68p(img_path,sess,input_op,output_op):
+ print('detecting landmarks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ vis_path = os.path.join(img_path, 'vis')
+ remove_path = os.path.join(img_path, 'remove')
+ save_path = os.path.join(img_path, 'landmarks')
+ if not os.path.isdir(vis_path):
+ os.makedirs(vis_path)
+ if not os.path.isdir(remove_path):
+ os.makedirs(remove_path)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
+ full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
+
+ # if an image does not have detected 5 facial landmarks, remove it from the training list
+ if not os.path.isfile(full_txt_name):
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # load data
+ img, five_points = load_data(full_image_name, full_txt_name)
+ input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection
+
+ # if the alignment fails, remove corresponding image from the training list
+ if scale == 0:
+ move(full_txt_name, os.path.join(
+ remove_path, txt_name))
+ move(full_image_name, os.path.join(remove_path, name))
+ continue
+
+ # detect landmarks
+ input_img = np.reshape(
+ input_img, [1, 224, 224, 3]).astype(np.float32)
+ landmark = sess.run(
+ output_op, feed_dict={input_op: input_img})
+
+ # transform back to original image coordinate
+ landmark = landmark.reshape([68, 2]) + mean_face
+ landmark[:, 1] = 223 - landmark[:, 1]
+ landmark = landmark / scale
+ landmark[:, 0] = landmark[:, 0] + bbox[0]
+ landmark[:, 1] = landmark[:, 1] + bbox[1]
+ landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
+
+ if i % 100 == 0:
+ draw_landmarks(img, landmark, os.path.join(vis_path, name))
+ save_label(landmark, os.path.join(save_path, txt_name))
diff --git a/deep_3drecon/util/generate_list.py b/deep_3drecon/util/generate_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..943d906781063c3584a7e5b5c784f8aac0694985
--- /dev/null
+++ b/deep_3drecon/util/generate_list.py
@@ -0,0 +1,34 @@
+"""This script is to generate training list files for Deep3DFaceRecon_pytorch
+"""
+
+import os
+
+# save path to training data
+def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
+ save_path = os.path.join(save_folder, mode)
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+ with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in lms_list])
+
+ with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in imgs_list])
+
+ with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
+ fd.writelines([i + '\n' for i in msks_list])
+
+# check if the path is valid
+def check_list(rlms_list, rimgs_list, rmsks_list):
+ lms_list, imgs_list, msks_list = [], [], []
+ for i in range(len(rlms_list)):
+ flag = 'false'
+ lm_path = rlms_list[i]
+ im_path = rimgs_list[i]
+ msk_path = rmsks_list[i]
+ if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
+ flag = 'true'
+ lms_list.append(rlms_list[i])
+ imgs_list.append(rimgs_list[i])
+ msks_list.append(rmsks_list[i])
+ print(i, rlms_list[i], flag)
+ return lms_list, imgs_list, msks_list
diff --git a/deep_3drecon/util/html.py b/deep_3drecon/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc3262a1eafda34842e4dbad47bb6ba72f0c5a68
--- /dev/null
+++ b/deep_3drecon/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+ """This HTML class allows us to save images and write texts into a single HTML file.
+
+ It consists of functions such as (add a text header to the HTML file),
+ (add a row of images to the HTML file), and (save the HTML to the disk).
+ It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+ """
+
+ def __init__(self, web_dir, title, refresh=0):
+ """Initialize the HTML classes
+
+ Parameters:
+ web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
+ with self.doc.head:
+ meta(http_equiv="refresh", content=str(refresh))
+
+ def get_image_dir(self):
+ """Return the directory that stores images"""
+ return self.img_dir
+
+ def add_header(self, text):
+ """Insert a header to the HTML file
+
+ Parameters:
+ text (str) -- the header text
+ """
+ with self.doc:
+ h3(text)
+
+ def add_images(self, ims, txts, links, width=400):
+ """add images to the HTML file
+
+ Parameters:
+ ims (str list) -- a list of image paths
+ txts (str list) -- a list of image names shown on the website
+ links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
+ """
+ self.t = table(border=1, style="table-layout: fixed;") # Insert a table
+ self.doc.add(self.t)
+ with self.t:
+ with tr():
+ for im, txt, link in zip(ims, txts, links):
+ with td(style="word-wrap: break-word;", halign="center", valign="top"):
+ with p():
+ with a(href=os.path.join('images', link)):
+ img(style="width:%dpx" % width, src=os.path.join('images', im))
+ br()
+ p(txt)
+
+ def save(self):
+ """save the current content to the HMTL file"""
+ html_file = '%s/index.html' % self.web_dir
+ f = open(html_file, 'wt')
+ f.write(self.doc.render())
+ f.close()
+
+
+if __name__ == '__main__': # we show an example usage here.
+ html = HTML('web/', 'test_html')
+ html.add_header('hello world')
+
+ ims, txts, links = [], [], []
+ for n in range(4):
+ ims.append('image_%d.png' % n)
+ txts.append('text_%d' % n)
+ links.append('image_%d.png' % n)
+ html.add_images(ims, txts, links)
+ html.save()
diff --git a/deep_3drecon/util/load_mats.py b/deep_3drecon/util/load_mats.py
new file mode 100644
index 0000000000000000000000000000000000000000..49ba09a555a7380311bf2a20f87957b61551eb1d
--- /dev/null
+++ b/deep_3drecon/util/load_mats.py
@@ -0,0 +1,120 @@
+"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat, savemat
+from array import array
+import os.path as osp
+
+# load expression basis
+def LoadExpBasis(bfm_folder='BFM'):
+ n_vertex = 53215
+ Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
+ exp_dim = array('i')
+ exp_dim.fromfile(Expbin, 1)
+ expMU = array('f')
+ expPC = array('f')
+ expMU.fromfile(Expbin, 3*n_vertex)
+ expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
+ Expbin.close()
+
+ expPC = np.array(expPC)
+ expPC = np.reshape(expPC, [exp_dim[0], -1])
+ expPC = np.transpose(expPC)
+
+ expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
+
+ return expPC, expEV
+
+
+# transfer original BFM09 to our face model
+def transferBFM09(bfm_folder='BFM'):
+ print('Transfer BFM09 to BFM_model_front......')
+ original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
+ shapePC = original_BFM['shapePC'] # shape basis
+ shapeEV = original_BFM['shapeEV'] # corresponding eigen value
+ shapeMU = original_BFM['shapeMU'] # mean face
+ texPC = original_BFM['texPC'] # texture basis
+ texEV = original_BFM['texEV'] # eigen value
+ texMU = original_BFM['texMU'] # mean texture
+
+ expPC, expEV = LoadExpBasis(bfm_folder)
+
+ # transfer BFM09 to our face model
+
+ idBase = shapePC*np.reshape(shapeEV, [-1, 199])
+ idBase = idBase/1e5 # unify the scale to decimeter
+ idBase = idBase[:, :80] # use only first 80 basis
+
+ exBase = expPC*np.reshape(expEV, [-1, 79])
+ exBase = exBase/1e5 # unify the scale to decimeter
+ exBase = exBase[:, :64] # use only first 64 basis
+
+ texBase = texPC*np.reshape(texEV, [-1, 199])
+ texBase = texBase[:, :80] # use only first 80 basis
+
+ # our face model is cropped along face landmarks and contains only 35709 vertex.
+ # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
+ # thus we select corresponding vertex to get our face model.
+
+ index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
+ index_exp = index_exp['idx'].astype(np.int32) - 1 # starts from 0 (to 53215)
+
+ index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
+ index_shape = index_shape['trimIndex'].astype(
+ np.int32) - 1 # starts from 0 (to 53490)
+ index_shape = index_shape[index_exp]
+
+ idBase = np.reshape(idBase, [-1, 3, 80])
+ idBase = idBase[index_shape, :, :]
+ idBase = np.reshape(idBase, [-1, 80])
+
+ texBase = np.reshape(texBase, [-1, 3, 80])
+ texBase = texBase[index_shape, :, :]
+ texBase = np.reshape(texBase, [-1, 80])
+
+ exBase = np.reshape(exBase, [-1, 3, 64])
+ exBase = exBase[index_exp, :, :]
+ exBase = np.reshape(exBase, [-1, 64])
+
+ meanshape = np.reshape(shapeMU, [-1, 3])/1e5
+ meanshape = meanshape[index_shape, :]
+ meanshape = np.reshape(meanshape, [1, -1])
+
+ meantex = np.reshape(texMU, [-1, 3])
+ meantex = meantex[index_shape, :]
+ meantex = np.reshape(meantex, [1, -1])
+
+ # other info contains triangles, region used for computing photometric loss,
+ # region used for skin texture regularization, and 68 landmarks index etc.
+ other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
+ frontmask2_idx = other_info['frontmask2_idx']
+ skinmask = other_info['skinmask']
+ keypoints = other_info['keypoints']
+ point_buf = other_info['point_buf']
+ tri = other_info['tri']
+ tri_mask2 = other_info['tri_mask2']
+
+ # save our face model
+ savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
+ 'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
+
+
+# load landmarks for standard face, which is used for image preprocessing
+def load_lm3d(bfm_folder):
+
+ Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
+ Lm3D = Lm3D['lm']
+
+ # calculate 5 facial landmarks using 68 landmarks
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
+ Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
+ Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
+
+ return Lm3D
+
+
+if __name__ == '__main__':
+ transferBFM09(bfm_folder='deep_3drecon/BFM')
diff --git a/deep_3drecon/util/mesh_renderer.py b/deep_3drecon/util/mesh_renderer.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6e765d706fb31cbe7f0b4403b492893ca32221
--- /dev/null
+++ b/deep_3drecon/util/mesh_renderer.py
@@ -0,0 +1,131 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+ Attention, antialiasing step is missing in current version.
+"""
+import torch
+import torch.nn.functional as F
+import kornia
+from kornia.geometry.camera import pixel2cam
+import numpy as np
+from typing import List
+from scipy.io import loadmat
+from torch import nn
+import traceback
+
+try:
+ import pytorch3d.ops
+ from pytorch3d.structures import Meshes
+ from pytorch3d.renderer import (
+ look_at_view_transform,
+ FoVPerspectiveCameras,
+ DirectionalLights,
+ RasterizationSettings,
+ MeshRenderer,
+ MeshRasterizer,
+ SoftPhongShader,
+ TexturesUV,
+ )
+except:
+ traceback.print_exc()
+# def ndc_projection(x=0.1, n=1.0, f=50.0):
+# return np.array([[n/x, 0, 0, 0],
+# [ 0, n/-x, 0, 0],
+# [ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+# [ 0, 0, -1, 0]]).astype(np.float32)
+
+class MeshRenderer(nn.Module):
+ def __init__(self,
+ rasterize_fov,
+ znear=0.1,
+ zfar=10,
+ rasterize_size=224,**args):
+ super(MeshRenderer, self).__init__()
+
+ # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+ # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+ # torch.diag(torch.tensor([1., -1, -1, 1])))
+ self.rasterize_size = rasterize_size
+ self.fov = rasterize_fov
+ self.znear = znear
+ self.zfar = zfar
+
+ self.rasterizer = None
+
+ def forward(self, vertex, tri, feat=None):
+ """
+ Return:
+ mask -- torch.tensor, size (B, 1, H, W)
+ depth -- torch.tensor, size (B, 1, H, W)
+ features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+ Parameters:
+ vertex -- torch.tensor, size (B, N, 3)
+ tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+ feat(optional) -- torch.tensor, size (B, N ,C), features
+ """
+ device = vertex.device
+ rsize = int(self.rasterize_size)
+ # ndc_proj = self.ndc_proj.to(device)
+ # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+ if vertex.shape[-1] == 3:
+ vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+ vertex[..., 0] = -vertex[..., 0]
+
+
+ # vertex_ndc = vertex @ ndc_proj.t()
+ if self.rasterizer is None:
+ self.rasterizer = MeshRasterizer()
+ print("create rasterizer on device cuda:%d"%device.index)
+
+ # ranges = None
+ # if isinstance(tri, List) or len(tri.shape) == 3:
+ # vum = vertex_ndc.shape[1]
+ # fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+ # fstartidx = torch.cumsum(fnum, dim=0) - fnum
+ # ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+ # for i in range(tri.shape[0]):
+ # tri[i] = tri[i] + i*vum
+ # vertex_ndc = torch.cat(vertex_ndc, dim=0)
+ # tri = torch.cat(tri, dim=0)
+
+ # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+ tri = tri.type(torch.int32).contiguous()
+
+ # rasterize
+ cameras = FoVPerspectiveCameras(
+ device=device,
+ fov=self.fov,
+ znear=self.znear,
+ zfar=self.zfar,
+ )
+
+ raster_settings = RasterizationSettings(
+ image_size=rsize
+ )
+
+ # print(vertex.shape, tri.shape)
+ if tri.ndim == 2:
+ tri = tri.unsqueeze(0)
+ mesh = Meshes(vertex.contiguous()[...,:3], tri)
+
+ fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
+ rast_out = fragments.pix_to_face.squeeze(-1)
+ depth = fragments.zbuf
+
+ # render depth
+ depth = depth.permute(0, 3, 1, 2)
+ mask = (rast_out > 0).float().unsqueeze(1)
+ depth = mask * depth
+
+
+ image = None
+ if feat is not None:
+ attributes = feat.reshape(-1,3)[mesh.faces_packed()]
+ image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
+ fragments.bary_coords,
+ attributes)
+ # print(image.shape)
+ image = image.squeeze(-2).permute(0, 3, 1, 2)
+ image = mask * image
+
+ return mask, depth, image
+
diff --git a/deep_3drecon/util/preprocess.py b/deep_3drecon/util/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b904c8111f14c00650644fca557ca4f39f209a98
--- /dev/null
+++ b/deep_3drecon/util/preprocess.py
@@ -0,0 +1,231 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from scipy.io import loadmat
+from PIL import Image
+import cv2
+import os
+# from skimage import transform as trans
+import torch
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+warnings.filterwarnings("ignore", category=FutureWarning)
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+ npts = xp.shape[1]
+
+ A = np.zeros([2*npts, 8])
+
+ A[0:2*npts-1:2, 0:3] = x.transpose()
+ A[0:2*npts-1:2, 3] = 1
+
+ A[1:2*npts:2, 4:7] = x.transpose()
+ A[1:2*npts:2, 7] = 1
+
+ b = np.reshape(xp.transpose(), [2*npts, 1])
+
+ k, _, _, _ = np.linalg.lstsq(A, b)
+
+ R1 = k[0:3]
+ R2 = k[4:7]
+ sTx = k[3]
+ sTy = k[7]
+ s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
+ t = np.stack([sTx, sTy], axis=0)
+
+ return t, s
+
+# bounding box for 68 landmark detection
+def BBRegression(points, params):
+
+ w1 = params['W1']
+ b1 = params['B1']
+ w2 = params['W2']
+ b2 = params['B2']
+ data = points.copy()
+ data = data.reshape([5, 2])
+ data_mean = np.mean(data, axis=0)
+ x_mean = data_mean[0]
+ y_mean = data_mean[1]
+ data[:, 0] = data[:, 0] - x_mean
+ data[:, 1] = data[:, 1] - y_mean
+
+ rms = np.sqrt(np.sum(data ** 2)/5)
+ data = data / rms
+ data = data.reshape([1, 10])
+ data = np.transpose(data)
+ inputs = np.matmul(w1, data) + b1
+ inputs = 2 / (1 + np.exp(-2 * inputs)) - 1
+ inputs = np.matmul(w2, inputs) + b2
+ inputs = np.transpose(inputs)
+ x = inputs[:, 0] * rms + x_mean
+ y = inputs[:, 1] * rms + y_mean
+ w = 224/inputs[:, 2] * rms
+ rects = [x, y, w, w]
+ return np.array(rects).reshape([4])
+
+# utils for landmark detection
+def img_padding(img, box):
+ success = True
+ bbox = box.copy()
+ res = np.zeros([2*img.shape[0], 2*img.shape[1], 3])
+ res[img.shape[0] // 2: img.shape[0] + img.shape[0] //
+ 2, img.shape[1] // 2: img.shape[1] + img.shape[1]//2] = img
+
+ bbox[0] = bbox[0] + img.shape[1] // 2
+ bbox[1] = bbox[1] + img.shape[0] // 2
+ if bbox[0] < 0 or bbox[1] < 0:
+ success = False
+ return res, bbox, success
+
+# utils for landmark detection
+def crop(img, bbox):
+ padded_img, padded_bbox, flag = img_padding(img, bbox)
+ if flag:
+ crop_img = padded_img[padded_bbox[1]: padded_bbox[1] +
+ padded_bbox[3], padded_bbox[0]: padded_bbox[0] + padded_bbox[2]]
+ crop_img = cv2.resize(crop_img.astype(np.uint8),
+ (224, 224), interpolation=cv2.INTER_CUBIC)
+ scale = 224 / padded_bbox[3]
+ return crop_img, scale
+ else:
+ return padded_img, 0
+
+# utils for landmark detection
+def scale_trans(img, lm, t, s):
+ imgw = img.shape[1]
+ imgh = img.shape[0]
+ M_s = np.array([[1, 0, -t[0] + imgw//2 + 0.5], [0, 1, -imgh//2 + t[1]]],
+ dtype=np.float32)
+ img = cv2.warpAffine(img, M_s, (imgw, imgh))
+ w = int(imgw / s * 100)
+ h = int(imgh / s * 100)
+ img = cv2.resize(img, (w, h))
+ lm = np.stack([lm[:, 0] - t[0] + imgw // 2, lm[:, 1] -
+ t[1] + imgh // 2], axis=1) / s * 100
+
+ left = w//2 - 112
+ up = h//2 - 112
+ bbox = [left, up, 224, 224]
+ cropped_img, scale2 = crop(img, bbox)
+ assert(scale2!=0)
+ t1 = np.array([bbox[0], bbox[1]])
+
+ # back to raw img s * crop + s * t1 + t2
+ t1 = np.array([w//2 - 112, h//2 - 112])
+ scale = s / 100
+ t2 = np.array([t[0] - imgw/2, t[1] - imgh / 2])
+ inv = (scale/scale2, scale * t1 + t2.reshape([2]))
+ return cropped_img, inv
+
+# utils for landmark detection
+def align_for_lm(img, five_points):
+ five_points = np.array(five_points).reshape([1, 10])
+ params = loadmat('util/BBRegressorParam_r.mat')
+ bbox = BBRegression(five_points, params)
+ assert(bbox[2] != 0)
+ bbox = np.round(bbox).astype(np.int32)
+ crop_img, scale = crop(img, bbox)
+ return crop_img, scale, bbox
+
+
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+ w0, h0 = img.size
+ w = (w0*s).astype(np.int32)
+ h = (h0*s).astype(np.int32)
+ left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+ right = left + target_size
+ up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+ below = up + target_size
+
+ img = img.resize((w, h), resample=Image.BICUBIC)
+ img = img.crop((left, up, right, below))
+
+ if mask is not None:
+ mask = mask.resize((w, h), resample=Image.BICUBIC)
+ mask = mask.crop((left, up, right, below))
+
+ lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+ t[1] + h0/2], axis=1)*s
+ lm = lm - np.reshape(
+ np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+ return img, lm, mask
+
+# utils for face reconstruction
+def extract_5p(lm):
+ lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+ lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
+ lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
+ lm5p = lm5p[[1, 2, 0, 3, 4], :]
+ return lm5p
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
+ """
+ Return:
+ transparams --numpy.array (raw_W, raw_H, scale, tx, ty)
+ img_new --PIL.Image (target_size, target_size, 3)
+ lm_new --numpy.array (68, 2), y direction is opposite to v direction
+ mask_new --PIL.Image (target_size, target_size)
+
+ Parameters:
+ img --PIL.Image (raw_H, raw_W, 3)
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ lm3D --numpy.array (5, 3)
+ mask --PIL.Image (raw_H, raw_W, 3)
+ """
+ w0, h0 = img.size
+
+
+ if lm.shape[0] != 5:
+ lm5p = extract_5p(lm)
+ else:
+ lm5p = lm
+
+ # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+ t, s = POS(lm5p.transpose(), lm3D.transpose())
+ s = rescale_factor/s
+
+ # processing the image
+ img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+ t = t.reshape([2,])
+ trans_params = np.array([w0, h0, s, t[0], t[1]])
+
+ return trans_params, img_new, lm_new, mask_new
+
+# utils for face recognition model
+def estimate_norm(lm_68p, H):
+ # from https://github.com/deepinsight/insightface/blob/c61d3cd208a603dfa4a338bd743b320ce3e94730/recognition/common/face_align.py#L68
+ """
+ Return:
+ trans_m --numpy.array (2, 3)
+ Parameters:
+ lm --numpy.array (68, 2), y direction is opposite to v direction
+ H --int/float , image height
+ """
+ lm = extract_5p(lm_68p)
+ lm[:, -1] = H - 1 - lm[:, -1]
+ tform = trans.SimilarityTransform()
+ src = np.array(
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
+ [41.5493, 92.3655], [70.7299, 92.2041]],
+ dtype=np.float32)
+ tform.estimate(lm, src)
+ M = tform.params
+ if np.linalg.det(M) == 0:
+ M = np.eye(3)
+
+ return M[0:2, :]
+
+def estimate_norm_torch(lm_68p, H):
+ lm_68p_ = lm_68p.detach().cpu().numpy()
+ M = []
+ for i in range(lm_68p_.shape[0]):
+ M.append(estimate_norm(lm_68p_[i], H))
+ M = torch.tensor(np.array(M), dtype=torch.float32).to(lm_68p.device)
+ return M
diff --git a/deep_3drecon/util/skin_mask.py b/deep_3drecon/util/skin_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8a74e4c3b40d13b0258b83a12f56321a85bb179
--- /dev/null
+++ b/deep_3drecon/util/skin_mask.py
@@ -0,0 +1,125 @@
+"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
+"""
+
+import math
+import numpy as np
+import os
+import cv2
+
+class GMM:
+ def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
+ self.dim = dim # feature dimension
+ self.num = num # number of Gaussian components
+ self.w = w # weights of Gaussian components (a list of scalars)
+ self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
+ self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
+ self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
+ self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
+
+ self.factor = [0]*num
+ for i in range(self.num):
+ self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
+
+ def likelihood(self, data):
+ assert(data.shape[1] == self.dim)
+ N = data.shape[0]
+ lh = np.zeros(N)
+
+ for i in range(self.num):
+ data_ = data - self.mu[i]
+
+ tmp = np.matmul(data_,self.cov_inv[i]) * data_
+ tmp = np.sum(tmp,axis=1)
+ power = -0.5 * tmp
+
+ p = np.array([math.exp(power[j]) for j in range(N)])
+ p = p/self.factor[i]
+ lh += p*self.w[i]
+
+ return lh
+
+
+def _rgb2ycbcr(rgb):
+ m = np.array([[65.481, 128.553, 24.966],
+ [-37.797, -74.203, 112],
+ [112, -93.786, -18.214]])
+ shape = rgb.shape
+ rgb = rgb.reshape((shape[0] * shape[1], 3))
+ ycbcr = np.dot(rgb, m.transpose() / 255.)
+ ycbcr[:, 0] += 16.
+ ycbcr[:, 1:] += 128.
+ return ycbcr.reshape(shape)
+
+
+def _bgr2ycbcr(bgr):
+ rgb = bgr[..., ::-1]
+ return _rgb2ycbcr(rgb)
+
+
+gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
+gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
+ np.array([150.19858, 105.18467, 155.51428]),
+ np.array([183.92976, 107.62468, 152.71820]),
+ np.array([114.90524, 113.59782, 151.38217])]
+gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
+gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
+ np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
+ np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
+ np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
+
+gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
+
+gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
+gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
+ np.array([110.91392, 125.52969, 130.19237]),
+ np.array([129.75864, 129.96107, 126.96808]),
+ np.array([112.29587, 128.85121, 129.05431])]
+gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
+gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
+ np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
+ np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
+ np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
+
+gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
+
+prior_skin = 0.8
+prior_nonskin = 1 - prior_skin
+
+
+# calculate skin attention mask
+def skinmask(imbgr):
+ im = _bgr2ycbcr(imbgr)
+
+ data = im.reshape((-1,3))
+
+ lh_skin = gmm_skin.likelihood(data)
+ lh_nonskin = gmm_nonskin.likelihood(data)
+
+ tmp1 = prior_skin * lh_skin
+ tmp2 = prior_nonskin * lh_nonskin
+ post_skin = tmp1 / (tmp1+tmp2) # posterior probability
+
+ post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
+
+ post_skin = np.round(post_skin*255)
+ post_skin = post_skin.astype(np.uint8)
+ post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
+
+ return post_skin
+
+
+def get_skin_mask(img_path):
+ print('generating skin masks......')
+ names = [i for i in sorted(os.listdir(
+ img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+ save_path = os.path.join(img_path, 'mask')
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ for i in range(0, len(names)):
+ name = names[i]
+ print('%05d' % (i), ' ', name)
+ full_image_name = os.path.join(img_path, name)
+ img = cv2.imread(full_image_name).astype(np.float32)
+ skin_img = skinmask(img)
+ cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
diff --git a/deep_3drecon/util/test_mean_face.txt b/deep_3drecon/util/test_mean_face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3a46d4db7699ffed8f898fcee64099631509946d
--- /dev/null
+++ b/deep_3drecon/util/test_mean_face.txt
@@ -0,0 +1,136 @@
+-5.228591537475585938e+01
+2.078247070312500000e-01
+-5.064269638061523438e+01
+-1.315765380859375000e+01
+-4.952939224243164062e+01
+-2.592591094970703125e+01
+-4.793047332763671875e+01
+-3.832135772705078125e+01
+-4.512159729003906250e+01
+-5.059623336791992188e+01
+-3.917720794677734375e+01
+-6.043736648559570312e+01
+-2.929953765869140625e+01
+-6.861183166503906250e+01
+-1.719801330566406250e+01
+-7.572736358642578125e+01
+-1.961936950683593750e+00
+-7.862001037597656250e+01
+1.467941284179687500e+01
+-7.607844543457031250e+01
+2.744073486328125000e+01
+-6.915261840820312500e+01
+3.855677795410156250e+01
+-5.950350570678710938e+01
+4.478240966796875000e+01
+-4.867547225952148438e+01
+4.714337158203125000e+01
+-3.800830078125000000e+01
+4.940315246582031250e+01
+-2.496297454833984375e+01
+5.117234802246093750e+01
+-1.241538238525390625e+01
+5.190507507324218750e+01
+8.244247436523437500e-01
+-4.150688934326171875e+01
+2.386329650878906250e+01
+-3.570307159423828125e+01
+3.017010498046875000e+01
+-2.790358734130859375e+01
+3.212951660156250000e+01
+-1.941773223876953125e+01
+3.156523132324218750e+01
+-1.138106536865234375e+01
+2.841992187500000000e+01
+5.993263244628906250e+00
+2.895182800292968750e+01
+1.343590545654296875e+01
+3.189880371093750000e+01
+2.203153991699218750e+01
+3.302221679687500000e+01
+2.992478942871093750e+01
+3.099150085449218750e+01
+3.628388977050781250e+01
+2.765748596191406250e+01
+-1.933914184570312500e+00
+1.405374145507812500e+01
+-2.153038024902343750e+00
+5.772636413574218750e+00
+-2.270050048828125000e+00
+-2.121643066406250000e+00
+-2.218330383300781250e+00
+-1.068978118896484375e+01
+-1.187252044677734375e+01
+-1.997912597656250000e+01
+-6.879402160644531250e+00
+-2.143579864501953125e+01
+-1.227821350097656250e+00
+-2.193494415283203125e+01
+4.623237609863281250e+00
+-2.152721405029296875e+01
+9.721397399902343750e+00
+-1.953671264648437500e+01
+-3.648714447021484375e+01
+9.811126708984375000e+00
+-3.130242919921875000e+01
+1.422447967529296875e+01
+-2.212834930419921875e+01
+1.493019866943359375e+01
+-1.500880432128906250e+01
+1.073588562011718750e+01
+-2.095037078857421875e+01
+9.054298400878906250e+00
+-3.050099182128906250e+01
+8.704177856445312500e+00
+1.173237609863281250e+01
+1.054329681396484375e+01
+1.856353759765625000e+01
+1.535009765625000000e+01
+2.893331909179687500e+01
+1.451992797851562500e+01
+3.452944946289062500e+01
+1.065280151367187500e+01
+2.875990295410156250e+01
+8.654792785644531250e+00
+1.942100524902343750e+01
+9.422447204589843750e+00
+-2.204488372802734375e+01
+-3.983994293212890625e+01
+-1.324458312988281250e+01
+-3.467377471923828125e+01
+-6.749649047851562500e+00
+-3.092894744873046875e+01
+-9.183349609375000000e-01
+-3.196458435058593750e+01
+4.220649719238281250e+00
+-3.090406036376953125e+01
+1.089889526367187500e+01
+-3.497008514404296875e+01
+1.874589538574218750e+01
+-4.065438079833984375e+01
+1.124106597900390625e+01
+-4.438417816162109375e+01
+5.181709289550781250e+00
+-4.649170684814453125e+01
+-1.158607482910156250e+00
+-4.680406951904296875e+01
+-7.918922424316406250e+00
+-4.671575164794921875e+01
+-1.452505493164062500e+01
+-4.416526031494140625e+01
+-2.005007171630859375e+01
+-3.997841644287109375e+01
+-1.054919433593750000e+01
+-3.849683380126953125e+01
+-1.051826477050781250e+00
+-3.794863128662109375e+01
+6.412681579589843750e+00
+-3.804645538330078125e+01
+1.627674865722656250e+01
+-4.039697265625000000e+01
+6.373878479003906250e+00
+-4.087213897705078125e+01
+-8.551712036132812500e-01
+-4.157129669189453125e+01
+-1.014953613281250000e+01
+-4.128469085693359375e+01
diff --git a/deep_3drecon/util/util.py b/deep_3drecon/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d689ca138fc0fbf5bec794511ea0f9e638f9ea9
--- /dev/null
+++ b/deep_3drecon/util/util.py
@@ -0,0 +1,208 @@
+"""This script contains basic utilities for Deep3DFaceRecon_pytorch
+"""
+from __future__ import print_function
+import numpy as np
+import torch
+from PIL import Image
+import os
+import importlib
+import argparse
+from argparse import Namespace
+import torchvision
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def copyconf(default_opt, **kwargs):
+ conf = Namespace(**vars(default_opt))
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+ return conf
+
+def genvalconf(train_opt, **kwargs):
+ conf = Namespace(**vars(train_opt))
+ attr_dict = train_opt.__dict__
+ for key, value in attr_dict.items():
+ if 'val' in key and key.split('_')[0] in attr_dict:
+ setattr(conf, key.split('_')[0], value)
+
+ for key in kwargs:
+ setattr(conf, key, kwargs[key])
+
+ return conf
+
+def find_class_in_module(target_cls_name, module):
+ target_cls_name = target_cls_name.replace('_', '').lower()
+ clslib = importlib.import_module(module)
+ cls = None
+ for name, clsobj in clslib.__dict__.items():
+ if name.lower() == target_cls_name:
+ cls = clsobj
+
+ assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
+
+ return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+ """"Converts a Tensor array into a numpy image array.
+
+ Parameters:
+ input_image (tensor) -- the input image tensor array, range(0, 1)
+ imtype (type) -- the desired type of the converted numpy array
+ """
+ if not isinstance(input_image, np.ndarray):
+ if isinstance(input_image, torch.Tensor): # get the data from a variable
+ image_tensor = input_image.data
+ else:
+ return input_image
+ image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy() # convert it into a numpy array
+ if image_numpy.shape[0] == 1: # grayscale to RGB
+ image_numpy = np.tile(image_numpy, (3, 1, 1))
+ image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 # post-processing: tranpose and scaling
+ else: # if it is a numpy array, do nothing
+ image_numpy = input_image
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ """Calculate and print the mean of average absolute(gradients)
+
+ Parameters:
+ net (torch network) -- Torch network
+ name (str) -- the name of the network
+ """
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+ """Save a numpy image to the disk
+
+ Parameters:
+ image_numpy (numpy array) -- input numpy array
+ image_path (str) -- the path of the image
+ """
+
+ image_pil = Image.fromarray(image_numpy)
+ h, w, _ = image_numpy.shape
+
+ if aspect_ratio is None:
+ pass
+ elif aspect_ratio > 1.0:
+ image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+ elif aspect_ratio < 1.0:
+ image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+ image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+ """Print the mean, min, max, median, std, and size of a numpy array
+
+ Parameters:
+ val (bool) -- if print the values of the numpy array
+ shp (bool) -- if print the shape of the numpy array
+ """
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ """create empty directories if they don't exist
+
+ Parameters:
+ paths (str list) -- a list of directory paths
+ """
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ """create a single empty directory if it didn't exist
+
+ Parameters:
+ path (str) -- a single directory path
+ """
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i, :1]
+ one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+ one_np = one_np[:, :, 0]
+ one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+ resized_t = torch.from_numpy(np.array(one_image)).long()
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.BICUBIC):
+ device = t.device
+ t = t.detach().cpu()
+ resized = []
+ for i in range(t.size(0)):
+ one_t = t[i:i + 1]
+ one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
+ resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+ resized.append(resized_t)
+ return torch.stack(resized, dim=0).to(device)
+
+def draw_landmarks(img, landmark, color='r', step=2):
+ """
+ Return:
+ img -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
+
+
+ Parameters:
+ img -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
+ landmark -- numpy.array, (B, 68, 2), y direction is opposite to v direction
+ color -- str, 'r' or 'b' (red or blue)
+ """
+ if color =='r':
+ c = np.array([255., 0, 0])
+ else:
+ c = np.array([0, 0, 255.])
+
+ _, H, W, _ = img.shape
+ img, landmark = img.copy(), landmark.copy()
+ landmark[..., 1] = H - 1 - landmark[..., 1]
+ landmark = np.round(landmark).astype(np.int32)
+ for i in range(landmark.shape[1]):
+ x, y = landmark[:, i, 0], landmark[:, i, 1]
+ for j in range(-step, step):
+ for k in range(-step, step):
+ u = np.clip(x + j, 0, W - 1)
+ v = np.clip(y + k, 0, H - 1)
+ for m in range(landmark.shape[0]):
+ img[m, v[m], u[m]] = c
+ return img
diff --git a/deep_3drecon/util/visualizer.py b/deep_3drecon/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4023a6d4086acba9bc88e079f625194d324d7c9e
--- /dev/null
+++ b/deep_3drecon/util/visualizer.py
@@ -0,0 +1,227 @@
+"""This script defines the visualizer for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+from torch.utils.tensorboard import SummaryWriter
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+ """Save images to the disk.
+
+ Parameters:
+ webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+ visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+ image_path (str) -- the string is used to create image paths
+ aspect_ratio (float) -- the aspect ratio of saved images
+ width (int) -- the images will be resized to width x width
+
+ This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+ """
+ image_dir = webpage.get_image_dir()
+ short_path = ntpath.basename(image_path[0])
+ name = os.path.splitext(short_path)[0]
+
+ webpage.add_header(name)
+ ims, txts, links = [], [], []
+
+ for label, im_data in visuals.items():
+ im = util.tensor2im(im_data)
+ image_name = '%s/%s.png' % (label, name)
+ os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+ save_path = os.path.join(image_dir, image_name)
+ util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+ ims.append(image_name)
+ txts.append(label)
+ links.append(image_name)
+ webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+ """This class includes several functions that can display/save images and print/save logging information.
+
+ It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+ """
+
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the option
+ self.use_html = opt.isTrain and not opt.no_html
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
+ self.win_size = opt.display_winsize
+ self.name = opt.name
+ self.saved = False
+ if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/
+ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+ self.img_dir = os.path.join(self.web_dir, 'images')
+ print('create web directory %s...' % self.web_dir)
+ util.mkdirs([self.web_dir, self.img_dir])
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+ def reset(self):
+ """Reset the self.saved status"""
+ self.saved = False
+
+
+ def display_current_results(self, visuals, total_iters, epoch, save_result):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ save_result (bool) - - if save the current results to an HTML file
+ """
+ for label, image in visuals.items():
+ self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
+
+ if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
+ self.saved = True
+ # save images to the disk
+ for label, image in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+ util.save_image(image_numpy, img_path)
+
+ # update website
+ webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
+ for n in range(epoch, 0, -1):
+ webpage.add_header('epoch [%d]' % n)
+ ims, txts, links = [], [], []
+
+ for label, image_numpy in visuals.items():
+ image_numpy = util.tensor2im(image)
+ img_path = 'epoch%.3d_%s.png' % (n, label)
+ ims.append(img_path)
+ txts.append(label)
+ links.append(img_path)
+ webpage.add_images(ims, txts, links, width=self.win_size)
+ webpage.save()
+
+ def plot_current_losses(self, total_iters, losses):
+ # G_loss_collection = {}
+ # D_loss_collection = {}
+ # for name, value in losses.items():
+ # if 'G' in name or 'NCE' in name or 'idt' in name:
+ # G_loss_collection[name] = value
+ # else:
+ # D_loss_collection[name] = value
+ # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
+ # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
+ for name, value in losses.items():
+ self.writer.add_scalar(name, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
+
+
+class MyVisualizer:
+ def __init__(self, opt):
+ """Initialize the Visualizer class
+
+ Parameters:
+ opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+ Step 1: Cache the training/test options
+ Step 2: create a tensorboard writer
+ Step 3: create an HTML object for saveing HTML filters
+ Step 4: create a logging file to store training losses
+ """
+ self.opt = opt # cache the optio
+ self.name = opt.name
+ self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
+
+ if opt.phase != 'test':
+ self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
+ # create a logging file to store training losses
+ self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+ with open(self.log_name, "a") as log_file:
+ now = time.strftime("%c")
+ log_file.write('================ Training Loss (%s) ================\n' % now)
+
+
+ def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
+ add_image=True):
+ """Display current results on tensorboad; save current results to an HTML file.
+
+ Parameters:
+ visuals (OrderedDict) - - dictionary of images to display or save
+ total_iters (int) -- total iterations
+ epoch (int) - - the current epoch
+ dataset (str) - - 'train' or 'val' or 'test'
+ """
+ # if (not add_image) and (not save_results): return
+
+ for label, image in visuals.items():
+ for i in range(image.shape[0]):
+ image_numpy = util.tensor2im(image[i])
+ if add_image:
+ self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
+ image_numpy, total_iters, dataformats='HWC')
+
+ if save_results:
+ save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
+ if not os.path.isdir(save_path):
+ os.makedirs(save_path)
+
+ if name is not None:
+ img_path = os.path.join(save_path, '%s.png' % name)
+ else:
+ img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
+ util.save_image(image_numpy, img_path)
+
+
+ def plot_current_losses(self, total_iters, losses, dataset='train'):
+ for name, value in losses.items():
+ self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
+
+ # losses: same format as |losses| of plot_current_losses
+ def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
+ """print current losses on console; also save the losses to the disk
+
+ Parameters:
+ epoch (int) -- current epoch
+ iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+ losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+ t_comp (float) -- computational time per data point (normalized by batch_size)
+ t_data (float) -- data loading time per data point (normalized by batch_size)
+ """
+ message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
+ dataset, epoch, iters, t_comp, t_data)
+ for k, v in losses.items():
+ message += '%s: %.3f ' % (k, v)
+
+ print(message) # print the message
+ with open(self.log_name, "a") as log_file:
+ log_file.write('%s\n' % message) # save the message
diff --git a/docs/prepare_env/install_guide-zh.md b/docs/prepare_env/install_guide-zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..c648d4bf451c882298f4a6611211b3cf96b487cb
--- /dev/null
+++ b/docs/prepare_env/install_guide-zh.md
@@ -0,0 +1,49 @@
+# 环境配置
+[English Doc](./install_guide.md)
+
+本文档陈述了搭建Real3D-Portrait Python环境的步骤,我们使用了Conda来管理依赖。
+
+以下配置已在 A100/V100 + CUDA11.7 中进行了验证。
+
+
+# 1. 安装CUDA
+我们推荐安装CUDA `11.7`,其他CUDA版本(例如`10.2`、`12.x`)也可能有效。
+
+# 2. 安装Python依赖
+```
+cd
+source /bin/activate
+conda create -n real3dportrait python=3.9
+conda activate real3dportrait
+
+# 我们推荐安装torch2.0.1+cuda11.7.
+conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# 从源代码安装,需要比较长的时间 (如果遇到各种time-out问题,建议使用代理)
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# 利用conda安装pytorch (For fast installation, Linux only)
+conda install pytorch3d::pytorch3d
+## 如果conda安装失败,一个兼容性的选择是从Github拉取源码并本地编译
+## 这可能会花费较长时间(可能数十分钟左右);由于要连接Github,可能经常面临time-out问题,请考虑使用代理。
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# MMCV安装
+pip install cython
+pip install openmim==0.3.9
+mim install mmcv==2.1.0 # 使用mim来加速mmcv安装
+
+# 其他依赖项
+pip install -r docs/prepare_env/requirements.txt -v
+
+
+如果你遇到如下错误,请尝试使用以下命令安装依赖项:
+pip install -r docs/prepare_env/requirements.txt -v --use-deprecated=legacy-resolver
+
+> ERROR: pip's dependency resolver does not currently take into account all the packages
+> that are installed. This behaviour is the source of the following dependency conflicts.
+> openxlab 0.0.34 requires setuptools~=60.2.0, but you have setuptools 69.1.1 which is incompatible.
+
+
+```
+
diff --git a/docs/prepare_env/install_guide.md b/docs/prepare_env/install_guide.md
new file mode 100644
index 0000000000000000000000000000000000000000..d3651d11bdd39b72948968e5bcecf36b1f41bcfd
--- /dev/null
+++ b/docs/prepare_env/install_guide.md
@@ -0,0 +1,44 @@
+# Prepare the Environment
+[中文文档](./install_guide-zh.md)
+
+This guide is about building a python environment for Real3D-Portrait with Conda.
+
+The following installation process is verified in A100/V100 + CUDA11.7.
+
+
+# 1. Install CUDA
+ We recommend to install CUDA `11.7` (which is verified in various types of GPUs), but other CUDA versions (such as `10.2`, `12.x`) may also work well.
+
+# 2. Install Python Packages
+```
+cd
+source /bin/activate
+conda create -n real3dportrait python=3.9
+conda activate real3dportrait
+conda install conda-forge::ffmpeg # ffmpeg with libx264 codec to turn images to video
+
+### We recommend torch2.0.1+cuda11.7.
+conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# Install from pytorch3d from conda (For fast installation, Linux only)
+conda install pytorch3d::pytorch3d
+## Alternatively, a choice of compatibility, build from Github's source code.
+## It may take a long time (maybe tens of minutes), Proxy is recommended if encountering the time-out problem
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# MMCV for some network structure
+pip install cython
+pip install openmim==0.3.9
+mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
+
+# other dependencies
+pip install -r docs/prepare_env/requirements.txt -v
+
+If you encounter the following error, please try to install the dependencies with the following command:
+pip install -r docs/prepare_env/requirements.txt -v --use-deprecated=legacy-resolver
+
+> ERROR: pip's dependency resolver does not currently take into account all the packages
+> that are installed. This behaviour is the source of the following dependency conflicts.
+> openxlab 0.0.34 requires setuptools~=60.2.0, but you have setuptools 69.1.1 which is incompatible.
+
+```
diff --git a/docs/prepare_env/requirements.txt b/docs/prepare_env/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..65f5ec537cda7d42cb755594fa52ae6ebe17637b
--- /dev/null
+++ b/docs/prepare_env/requirements.txt
@@ -0,0 +1,75 @@
+Cython
+numpy # ==1.23.0
+numba==0.56.4
+pandas
+transformers
+scipy==1.11.1 # required by cal_fid. https://github.com/mseitzer/pytorch-fid/issues/103
+scikit-learn
+scikit-image
+# tensorflow # you can flexible it, this is gpu version
+tensorboard
+tensorboardX
+python_speech_features
+resampy
+opencv_python
+face_alignment
+matplotlib
+configargparse
+librosa==0.9.2
+praat-parselmouth # ==0.4.3
+trimesh
+kornia==0.5.0
+PyMCubes
+lpips
+setuptools # ==59.5.0
+ffmpeg-python
+moviepy
+dearpygui
+ninja
+pyaudio # for extract esperanto
+mediapipe
+protobuf
+decord
+soundfile
+pillow
+# torch # it's better to install torch with conda
+av
+timm
+pretrainedmodels
+faiss-cpu # for fast nearest camera pose retriveal
+einops
+# mmcv # use mim install is faster
+
+# conditional flow matching
+beartype
+torchode
+torchdiffeq
+
+# tts
+cython
+textgrid
+pyloudnorm
+websocket-client
+pyworld==0.2.1rc0
+pypinyin==0.42.0
+webrtcvad
+torchshow
+
+# cal spk sim
+s3prl
+fire
+
+# cal LMD
+dlib
+
+# debug
+ipykernel
+
+# lama
+hydra-core
+pytorch_lightning
+setproctitle
+
+# Gradio GUI
+httpx==0.23.3
+gradio==4.16.0
\ No newline at end of file
diff --git a/docs/process_data/process_th1kh.md b/docs/process_data/process_th1kh.md
new file mode 100644
index 0000000000000000000000000000000000000000..728315c2e9d29f6e4a36cad2cd0ea9f477640efa
--- /dev/null
+++ b/docs/process_data/process_th1kh.md
@@ -0,0 +1,32 @@
+# process dataset
+we use Talking-Head-1K-Hour as the example.
+
+## download and crop the talking person video clips
+- Please follow the step in [https://github.com/tcwang0509/TalkingHead-1KH](https://github.com/tcwang0509/TalkingHead-1KH)
+- Put all extracted video clips in a directory like `/home/xxx/TH1KH_512/video_raw/*.mp4`
+
+## resample & resize video clips to 512x512 resolution and 25FPS
+- You can use the example code in `data_gen/utils/process_video/resample_video_to_25fps_resize_to_512.py`
+- It will generate processed video clips in `/home/xxx/TH1KH_512/video/*.mp4`
+
+## extract segment images
+- You can use the example code in `data_gen/utils/process_video/extract_segment_imgs.py`
+- It will generate segment images in `/home/xxx/TH1KH_512/{gt_imgs, head_imgs, inpaint_torso_imgs, com_imgs}/*`
+
+## extract 2d facial landmark
+- You can use the example code in `data_gen/utils/process_video/extract_lm2d.py`
+- It will generate 2d landmarks in `/home/xxx/TH1KH_512/lms_2d/*_lms_2d.npy`
+
+## extract 3dmm coefficients
+- You can use the example code in `data_gen/utils/process_video/fit_3dmm_landmark.py`
+- It will generate 3dmm coefficients in `/home/xxx/TH1KH_512/coeff_fit_mp/*_coeff_fit_mp.npy`
+
+## extract audio features
+- You can use the example code in `data_gen/utils/process_audio/extract_mel_f0.py`
+- It will generate raw wav in `/home/xxx/TH1KH_512/audio/*.wav` and mel_f0 in `/home/xxx/TH1KH_512/mel_f0/*_mel_f0.npy`
+- You can use the example code in `data_gen/utils/process_audio/extract_hubert.py`
+- It will generate hubert in `/home/xxx/TH1KH_512/hubert/*_hubert.npy`
+
+## Binarize the dataset
+- You can use the example code in `data_gen/runs/binarizer_th1kh.py`
+- You will see a binarized dataset at `data/binary/th1kh`
diff --git a/docs/train_models/train_audio2motion.md b/docs/train_models/train_audio2motion.md
new file mode 100644
index 0000000000000000000000000000000000000000..5a9563aaf20245164eb3be4edd9a407a575ad070
--- /dev/null
+++ b/docs/train_models/train_audio2motion.md
@@ -0,0 +1,12 @@
+# 0.Get pre-trained models & Data
+- Get the Binarized dataset following `docs/process_data/process_th1kh.md`. You will see `data/binary/th1kh/train.data`
+
+# 1. Train audio_lm3d_syncnet
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tasks/run.py --config=egs/os_avatar/audio_lm3d_syncnet.yaml --exp_name=audio_lm3d_syncnet --reset
+
+
+# 2. Train audio2motion model
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tasks/run.py --config=egs/os_avatar/audio2motion_vae.yaml --exp_name=audio2motion_vae --hparams=syncnet_ckpt_dir=checkpoints/audio_lm3d_syncnet --reset
+
+# 3.Inference
+- See `README.md`, change the name of checkpoint to your own audio2motion_vae model.
diff --git a/docs/train_models/train_motion2video.md b/docs/train_models/train_motion2video.md
new file mode 100644
index 0000000000000000000000000000000000000000..40272b53412ffddf034b19248fe62d977b9189d0
--- /dev/null
+++ b/docs/train_models/train_motion2video.md
@@ -0,0 +1,26 @@
+# 0.Get pre-trained models & Data
+- Get the Binarized dataset following `docs/process_data/process_th1kh.md`. You will see `data/binary/th1kh/train.data`
+- Download `pretrained_ckpts.zip` in this [Google Drive](https://drive.google.com/drive/folders/1MAveJf7RvJ-Opg1f5qhLdoRoC_Gc6nD9?usp=sharing), unzip it and place it into `checkpoints/pretrained_ckpts`. You will see `checkpoints/pretrained_ckpts/mit_b0.pth` and `checkpoints/pretrained_ckpts/eg3d_baseline_run2`.
+
+
+# 1. Train Img-to-Plane Model
+## 1.1 image-to-triplane model in real3d-portrait
+```
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tasks/run.py --config=egs/os_avatar/img2plane.yaml --hparams=triplane_feature_type=triplane --exp_name=img2plane --reset
+```
+## 1.2 image-to-grid model in zera-portrait (Recommended)
+```
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tasks/run.py --config=egs/os_avatar/img2plane.yaml --exp_name=img2grid --reset
+```
+
+# 2.Train Motion-to-Video Model
+```
+# secc2plane_head
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tasks/run.py --config=egs/os_avatar/srcc_img2plane.yaml --exp_name=secc2plane --hparams=init_from_ckpt=checkpoints/img2grid --reset
+
+# secc2plane_torso
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tasks/run.py --config=egs/os_avatar/srcc_img2plane_torso.yaml --exp_name=secc2plane_torso --hparams=init_from_ckpt=checkpoints/secc2plane --reset
+```
+
+# 3.Inference
+- See `README.md`, change the name of checkpoint to your own secc2plane_torso model.
diff --git a/egs/egs_bases/audio2motion/base.yaml b/egs/egs_bases/audio2motion/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..56721e44b84be7390f4d391a2a47e6f68cf5abf4
--- /dev/null
+++ b/egs/egs_bases/audio2motion/base.yaml
@@ -0,0 +1,57 @@
+# project-related
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+
+# testing related
+gen_dir_name: ''
+save_gt: true
+
+# training-scheme-related
+num_ckpt_keep: 100
+val_check_interval: 2000
+valid_infer_interval: 2000
+max_updates: 4_0000
+seed: 9999
+lr: 0.0005
+scheduler: exponential # exponential|rsqrt|warmup|none|step_lr
+warmup_updates: 1000
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.999
+weight_decay: 0
+accumulate_grad_batches: 1
+clip_grad_norm: 1
+clip_grad_value: 0
+num_sanity_val_steps: 5
+num_valid_plots: 1
+eval_max_batches: 10 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: false
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+
+# model-related
+hidden_size: 256
+
+# infer-related
+infer_audio_source_name: ''
+infer_out_npy_name: ''
+infer_ckpt_steps: 40000
+
+load_db_to_memory: false # enable it for faster indexing
+
+max_sentences_per_batch: 512
+max_tokens_per_batch: 20000
+num_workers: 4
+
+audio_type: hubert
+motion_type: idexp_lm3d
+use_kv_dataset: false
+use_fork: true
\ No newline at end of file
diff --git a/egs/egs_bases/audio2motion/vae.yaml b/egs/egs_bases/audio2motion/vae.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..30a08d117366931c5466fababa02d846642aa93a
--- /dev/null
+++ b/egs/egs_bases/audio2motion/vae.yaml
@@ -0,0 +1,7 @@
+base_config:
+ - ./base.yaml
+
+# VAE related
+task_cls: tasks.audio2motion.lm3d_vae.VAEAudio2MotionTask
+lambda_kl: 0.5
+
diff --git a/egs/egs_bases/audio2motion/vae_sync.yaml b/egs/egs_bases/audio2motion/vae_sync.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ec78001c7533819f9311bbc9bae1a118666dd820
--- /dev/null
+++ b/egs/egs_bases/audio2motion/vae_sync.yaml
@@ -0,0 +1,10 @@
+base_config:
+ - ./base.yaml
+
+# VAE related
+task_cls: tasks.audio2motion.lm3d_vae_sync.VAESyncAudio2MotionTask
+lambda_kl: 0.5
+
+# SyncNet related
+syncnet_work_dir: checkpoints/lrs3/syncnet
+syncnet_ckpt_steps: 40000
diff --git a/egs/egs_bases/audio2pose/base.yaml b/egs/egs_bases/audio2pose/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6ef617e955468a09e2618d6f9ecd0cfab260c318
--- /dev/null
+++ b/egs/egs_bases/audio2pose/base.yaml
@@ -0,0 +1,47 @@
+# dataset-related
+raw_data_dir: data/raw/videos
+processed_data_dir: data/processed/videos
+binary_data_dir: data/binary/videos
+video_id: ''
+task_cls: ''
+
+# project-related
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+val_check_interval: 1000
+valid_infer_interval: 1000
+num_sanity_val_steps: 5
+num_valid_plots: 1
+eval_max_batches: 10 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: true
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+accumulate_grad_batches: 1
+clip_grad_norm: 1.
+
+# training-scheme-related
+task_cls: tasks.audio2pose.audio2pose.Audio2PoseTask
+max_updates: 1_0000
+seed: 9999
+lr: 0.0005
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.999
+scheduler: exponential # exponential|rsqrt|warmup|none|step_lr
+warmup_updates: 1000
+
+valid_infer_interval: 1000
+val_check_interval: 1000
+num_ckpt_keep: 10
+
+source_name: ''
+infer_out_npy_name: ''
+reception_field: 100
\ No newline at end of file
diff --git a/egs/egs_bases/eg3d/base.yaml b/egs/egs_bases/eg3d/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea51621a41193ac369d25ce0ac9e8c77cfef834e
--- /dev/null
+++ b/egs/egs_bases/eg3d/base.yaml
@@ -0,0 +1,99 @@
+# dataset-related
+raw_data_dir: data/raw/videos
+processed_data_dir: data/processed/videos
+binary_data_dir: data/binary/videos
+video_id: May
+
+# feature-related
+cond_type: idexp_lm3d_normalized
+smo_win_size: 5
+cond_hid_dim: 32
+cond_out_dim: 16
+# generator_condition_on_pose: false # pose is camera extrinsic and intrinsic
+generator_condition_on_pose: true # pose is camera extrinsic and intrinsic
+gpc_reg_prob: 0.5
+gpc_reg_fade_kimg: 1000
+
+# network-related
+task_cls: tasks.eg3ds.eg3d_task.EG3DTask
+z_dim: 512
+w_dim: 512
+neural_rendering_resolution: 128
+final_resolution: 512
+
+base_channel: 32768 # Capacity multiplier
+max_channel: 512 # Max. feature maps
+mapping_network_depth: 2 # num of layers in mapping network
+num_fp16_layers_in_super_resolution: 4
+num_fp16_layers_in_generator: 0
+num_fp16_layers_in_discriminator: 4
+
+
+# GAN-related
+disc_c_noise: 1.0
+blur_raw_target: true
+blur_init_sigma: 10
+# blur_fade_kimg: 200 # Fade out the blur during the first N kimg.
+blur_fade_kimg: 20 # Fade out the blur during the first N kimg.
+# neural rendering-related
+num_samples_coarse: 48 # number of uniform samples to take per ray.
+num_samples_fine: 48 # number of importance samples to take per ray.
+ray_near: 2.25
+# ray_far: 4.05
+ray_far: 3.3
+box_warp: 1 # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5].
+
+# loss related
+group_size_for_mini_batch_std: 2 # 4
+lambda_gradient_penalty: 5. # gradient penalty to discriminator
+
+
+lambda_G_supervise_adv: 1.0
+lambda_G_supervise_mse_raw: 1.0
+lambda_G_supervise_mse: 1.0
+lambda_G_adversarial_adv: 1.0
+
+
+lambda_density_reg: 0.25 # strength of density regularization for Generator
+density_reg_p_dist: 0.004 # distance at which to sample perturbed points for density regularization
+
+
+# trainer related
+seed: 9999
+lr_g: 0.0025
+lr_d: 0.002
+optimizer_adam_beta1_g: 0.
+optimizer_adam_beta2_g: 0.99
+optimizer_adam_beta1_d: 0.
+optimizer_adam_beta2_d: 0.99
+reg_interval_g: 4
+reg_interval_d: 16
+
+batch_size: 4
+ema_interval: 400 # bs * 10 / 32 kimg
+max_updates: 25000_000 # 25000 kimg
+num_workers: 4
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+num_ckpt_keep: 1000
+val_check_interval: 2000
+valid_infer_interval: 2000
+num_sanity_val_steps: 1
+num_valid_plots: 25
+eval_max_batches: 100 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: true
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+accumulate_grad_batches: 1
+clip_grad_norm: 0 #1
+clip_grad_value: 0
+
diff --git a/egs/egs_bases/eg3d/base_mse.yaml b/egs/egs_bases/eg3d/base_mse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34f3058ae97f7f19212f0f6894d8b4c7e9cd3d43
--- /dev/null
+++ b/egs/egs_bases/eg3d/base_mse.yaml
@@ -0,0 +1,96 @@
+# dataset-related
+raw_data_dir: data/raw/videos
+processed_data_dir: data/processed/videos
+binary_data_dir: data/binary/videos
+video_id: May
+
+# feature-related
+cond_type: idexp_lm3d_normalized
+smo_win_size: 5
+cond_hid_dim: 32
+cond_out_dim: 16
+# generator_condition_on_pose: false # pose is camera extrinsic and intrinsic
+generator_condition_on_pose: true # pose is camera extrinsic and intrinsic
+gpc_reg_prob: 0.5
+gpc_reg_fade_kimg: 1000
+
+# network-related
+task_cls: tasks.eg3ds.eg3d_task.EG3DTask
+z_dim: 512
+w_dim: 512
+neural_rendering_resolution: 128
+final_resolution: 512
+
+base_channel: 32768 # Capacity multiplier
+max_channel: 512 # Max. feature maps
+mapping_network_depth: 2 # num of layers in mapping network
+num_fp16_layers_in_super_resolution: 4
+num_fp16_layers_in_generator: 0
+num_fp16_layers_in_discriminator: 4
+
+
+# GAN-related
+blur_raw_target: true
+blur_init_sigma: 10
+# blur_fade_kimg: 200 # Fade out the blur during the first N kimg.
+blur_fade_kimg: 20 # Fade out the blur during the first N kimg.
+# neural rendering-related
+num_samples_coarse: 48 # number of uniform samples to take per ray.
+num_samples_fine: 48 # number of importance samples to take per ray.
+ray_near: 2.25
+ray_far: 4.05
+box_warp: 1 # the side-length of the bounding box spanned by the tri-planes; box_warp=1 means [-0.5, -0.5, -0.5] -> [0.5, 0.5, 0.5].
+
+# loss related
+group_size_for_mini_batch_std: 2 # 4
+lambda_gradient_penalty: 5. # gradient penalty to discriminator
+
+
+lambda_G_supervise_adv: 0.
+lambda_G_supervise_mse_raw: 1.0
+lambda_G_supervise_mse: 0.
+lambda_G_adversarial_adv: 0.
+
+lambda_density_reg: 0.25 # strength of density regularization for Generator
+density_reg_p_dist: 0.004 # distance at which to sample perturbed points for density regularization
+
+
+# trainer related
+seed: 9999
+lr_g: 0.0025
+lr_d: 0.002
+optimizer_adam_beta1_g: 0.
+optimizer_adam_beta2_g: 0.99
+optimizer_adam_beta1_d: 0.
+optimizer_adam_beta2_d: 0.99
+reg_interval_g: 4
+reg_interval_d: 16
+
+batch_size: 4
+ema_interval: 400 # bs * 10 / 32 kimg
+max_updates: 25000_000 # 25000 kimg
+num_workers: 4
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+num_ckpt_keep: 1
+val_check_interval: 2000
+valid_infer_interval: 2000
+num_sanity_val_steps: 1
+num_valid_plots: 25
+eval_max_batches: 100 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: true
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+accumulate_grad_batches: 1
+clip_grad_norm: 0 #1
+clip_grad_value: 0
+
diff --git a/egs/egs_bases/nerf/adnerf.yaml b/egs/egs_bases/nerf/adnerf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..492061204315b7eb2630430cebf0be678a78e639
--- /dev/null
+++ b/egs/egs_bases/nerf/adnerf.yaml
@@ -0,0 +1,8 @@
+base_config:
+ - egs/egs_bases/nerf/base.yaml
+
+task_cls: tasks.nerfs.adnerf.ADNeRFTask
+cond_type: deepspeech
+no_smo_iterations: 20_0000
+cond_win_size: 16
+smo_win_size: 8
\ No newline at end of file
diff --git a/egs/egs_bases/nerf/adnerf_torso.yaml b/egs/egs_bases/nerf/adnerf_torso.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..222077e666e7317a24af043544cae9ed2ab6f4d8
--- /dev/null
+++ b/egs/egs_bases/nerf/adnerf_torso.yaml
@@ -0,0 +1,7 @@
+base_config:
+ - egs/egs_bases/nerf/adnerf.yaml
+
+task_cls: tasks.nerfs.adnerf_torso.ADNeRFTorsoTask
+no_smo_iterations: 0 # nerf_torso use the fixed audatt_net from head_nerf
+head_model_dir: ''
+use_color: false
diff --git a/egs/egs_bases/nerf/base.yaml b/egs/egs_bases/nerf/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..218cc9df96d588fa35d8af54f5c0efa2da2a7608
--- /dev/null
+++ b/egs/egs_bases/nerf/base.yaml
@@ -0,0 +1,79 @@
+# dataset-related
+raw_data_dir: data/raw/videos
+processed_data_dir: data/processed/videos
+binary_data_dir: data/binary/videos
+video_id: ''
+task_cls: ''
+
+# project-related
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+num_ckpt_keep: 1
+val_check_interval: 10000
+valid_infer_interval: 10000
+num_sanity_val_steps: 0
+num_valid_plots: 5
+eval_max_batches: 100 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: true
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+
+# testing related
+gen_dir_name: ''
+save_gt: true
+
+# training-scheme-related
+max_updates: 40_0000
+seed: 9999
+lr: 0.0005
+scheduler: exponential # exponential|rsqrt|warmup|none|step_lr
+warmup_updates: 0
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.999
+weight_decay: 0
+clip_grad_norm: 0 # disable grad clipping
+clip_grad_value: 0 # disable grad clipping
+rays_sampler_type: uniform
+in_rect_percent: 0.95
+accumulate_grad_batches: 1
+
+# model-related
+use_window_cond: true
+with_att: true # only available when use win_cond, use a attention Net in AD-NeRF
+cond_type: ''
+cond_dim: 64
+hidden_size: 256
+
+# NeRF-related
+near: 0.3
+far: 0.9
+n_rays: 1600 # default 2048, 1600 for RTX2080Ti
+n_samples_per_ray: 64
+n_samples_per_ray_fine: 128
+embedding_args:
+ multi_res_pos: 10 # log2+1 of max freq for positional encoding (3D location)
+ multi_res_views: 4 # log2+1 of max freq for positional encoding (2D direction)
+
+infer_cond_name: ''
+infer_out_video_name: ''
+infer_scale_factor: 1.0
+infer_smo_std: 0.
+infer_audio_source_name: ''
+infer_c2w_name: ''
+
+# postprocessing params
+infer_lm3d_clamp_std: 1.5
+infer_lm3d_lle_percent: 0.25 # percent of lle fused feature to compose the processed lm3d
+infer_lm3d_smooth_sigma: 0. # sigma of gaussian kernel to smooth the predicted lm3d
+infer_pose_smooth_sigma: 2.
+
+load_imgs_to_memory: false # load uint8 training img to memory, which reduce io costs, at the expense of more memory occupation
\ No newline at end of file
diff --git a/egs/egs_bases/nerf/lm3d_nerf.yaml b/egs/egs_bases/nerf/lm3d_nerf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87af5a4e642da43b0cdb7b94773e9b16779841a9
--- /dev/null
+++ b/egs/egs_bases/nerf/lm3d_nerf.yaml
@@ -0,0 +1,18 @@
+base_config:
+ - egs/egs_bases/nerf/base.yaml
+
+task_cls: tasks.nerfs.lm3d_nerf.Lm3dNeRFTask
+cond_type: idexp_lm3d_normalized
+no_smo_iterations: 20_0000
+
+use_window_cond: true # the NeRF only takes the exp at current frame as condition
+with_att: true # only available when use win_cond, use a attention Net in AD-NeRF
+cond_win_size: 1
+smo_win_size: 5
+
+infer_inject_eye_blink_mode: none # none|gt|period. `gt` uses the eye blink sequence from GT dataset, `period` use a ref blink sequence from GT dataset and repeat it to the final length
+infer_eye_blink_ref_frames_start_idx: '' # start index of the ref blink sequence in the GT dataset
+infer_eye_blink_ref_frames_end_idx: '' # end index of the ref blink sequence in the GT dataset
+
+infer_close_mouth_when_sil: False # detect sil frames, then set the mouth to close in these frames
+infer_sil_ref_frame_idx: '' # index of the ref frame with a closed mouth in the GT dataset
\ No newline at end of file
diff --git a/egs/egs_bases/nerf/lm3d_nerf_torso.yaml b/egs/egs_bases/nerf/lm3d_nerf_torso.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ea05f8d23571b3962bae7103c3c20ee6ce2af813
--- /dev/null
+++ b/egs/egs_bases/nerf/lm3d_nerf_torso.yaml
@@ -0,0 +1,9 @@
+base_config:
+ - egs/egs_bases/nerf/lm3d_nerf.yaml
+
+task_cls: tasks.nerfs.lm3d_nerf_torso.Lm3dNeRFTorsoTask
+
+no_smo_iterations: 0 # nerf_torso use the fixed audatt_net from head_nerf
+use_color: true
+
+head_model_dir: ''
diff --git a/egs/egs_bases/os_facev2v/base.yaml b/egs/egs_bases/os_facev2v/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..637fa47ee701b9d4f966ecd311bb1ac627050977
--- /dev/null
+++ b/egs/egs_bases/os_facev2v/base.yaml
@@ -0,0 +1,87 @@
+dataset_params:
+ root_dir: /zlh/VoxCeleb/first-order-256
+ frame_shape: [256, 256, 3]
+ id_sampling: True
+ pairs_list: None
+ augmentation_params:
+ flip_param:
+ horizontal_flip: True
+ time_flip: True
+ jitter_param:
+ brightness: 0.1
+ contrast: 0.1
+ saturation: 0.1
+ hue: 0.1
+
+model_params:
+ common_params:
+ num_kp: 15
+ image_channel: 3
+ feature_channel: 32
+ estimate_jacobian: False # True
+ kp_detector_params:
+ temperature: 0.1
+ block_expansion: 32
+ max_features: 1024
+ scale_factor: 0.25 # 0.25
+ num_blocks: 5
+ reshape_channel: 16384 # 16384 = 1024 * 16
+ reshape_depth: 16
+ he_estimator_params:
+ block_expansion: 64
+ max_features: 2048
+ num_bins: 66
+ generator_params:
+ block_expansion: 64
+ max_features: 512
+ num_down_blocks: 2
+ reshape_channel: 32
+ reshape_depth: 16 # 512 = 32 * 16
+ num_resblocks: 6
+ estimate_occlusion_map: True
+ dense_motion_params:
+ block_expansion: 32
+ max_features: 1024
+ num_blocks: 5
+ # reshape_channel: 32
+ reshape_depth: 16
+ compress: 4
+ discriminator_params:
+ scales: [1]
+ block_expansion: 32
+ max_features: 512
+ num_blocks: 4
+ sn: True
+
+train_params:
+ num_epochs: 300
+ num_repeats: 75
+ epoch_milestones: [180,]
+ lr_generator: 2.0e-4
+ lr_discriminator: 2.0e-4
+ lr_kp_detector: 2.0e-4
+ lr_he_estimator: 2.0e-4
+ gan_mode: 'hinge' # hinge or ls
+ batch_size: 32
+ scales: [1, 0.5, 0.25, 0.125]
+ checkpoint_freq: 10
+ hopenet_snapshot: "/mnt/bn/sa-ag-data/yezhenhui/myenv/cache/useful_ckpts/hopenet_robust_alpha1.pkl" # https://drive.google.com/open?id=1m25PrSE7g9D2q2XJVMR6IA7RaCvWSzCR
+ transform_params:
+ sigma_affine: 0.05
+ sigma_tps: 0.005
+ points_tps: 5
+ loss_weights:
+ generator_gan: 1
+ discriminator_gan: 1
+ feature_matching: [10, 10, 10, 10]
+ perceptual: [10, 10, 10, 10, 10]
+ equivariance_value: 10
+ equivariance_jacobian: 0 # 10
+ keypoint: 10
+ headpose: 20
+ expression: 5
+
+visualizer_params:
+ kp_size: 5
+ draw_border: True
+ colormap: 'gist_rainbow'
diff --git a/egs/egs_bases/postnet/base.yaml b/egs/egs_bases/postnet/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e7eadb02d1c62a0ab2eb5303efffda92654f1954
--- /dev/null
+++ b/egs/egs_bases/postnet/base.yaml
@@ -0,0 +1,40 @@
+base_config:
+ - egs/egs_bases/audio2motion/vae_sync.yaml
+
+task_cls: tasks.postnet.lm3d_postnet_adv_sync.PostnetAdvSyncTask
+audio2motion_task_cls: tasks.audio2motion.lm3d_vae_sync.VAESyncAudio2MotionTask
+person_binary_data_dir: data/binary/videos
+# postnet training
+postnet_lr: 0.0001
+postnet_lambda_adv: 0.85
+postnet_lambda_sync: 0.1
+postnet_lambda_mse: 0.05
+
+# Discriminator
+postnet_disc_lr: 0.0001
+discriminator_scheduler_params:
+ gamma: 0.5
+ step_size: 40000
+postnet_disc_start_steps: 0
+postnet_disc_interval: 1
+
+# Training Schedule
+scheduler: none
+num_ckpt_keep: 500
+val_check_interval: 1000
+valid_infer_interval: 1000
+max_updates: 100000 # 20000
+
+# Pretrained Ckpts
+audio2motion_work_dir: checkpoints/th1kh/lm3d_vae_sync_pitch/
+audio2motion_ckpt_steps: 160000
+syncnet_work_dir: checkpoints/th1kh/lm3d_syncnet
+syncnet_ckpt_steps: 160000
+syncnet_num_layers_per_block: 3
+syncnet_base_hid_size: 128
+
+infer_audio_source_name: data/raw/val_wavs/zozo.wav
+infer_out_npy_name: infer_out/May/pred_lm3d/zozo.npy
+infer_ckpt_steps: 6000
+
+load_db_to_memory: false # enable it for faster indexing
diff --git a/egs/egs_bases/radnerf/base.yaml b/egs/egs_bases/radnerf/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7448e462994575d3cfc63b8f160d23774318e5c5
--- /dev/null
+++ b/egs/egs_bases/radnerf/base.yaml
@@ -0,0 +1,125 @@
+# dataset-related
+raw_data_dir: data/raw/videos
+processed_data_dir: data/processed/videos
+binary_data_dir: data/binary/videos
+video_id: ''
+task_cls: ''
+not_save_modules: ['criterion_lpips']
+
+# project-related
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+num_ckpt_keep: 1
+val_check_interval: 2000
+valid_infer_interval: 10000
+num_sanity_val_steps: 2
+num_valid_plots: 5
+eval_max_batches: 100 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: true
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+
+# testing related
+save_gt: true
+
+# training-scheme-related
+seed: 9999
+lr: 0.0005
+scheduler: exponential # exponential|rsqrt|warmup|none|step_lr
+warmup_updates: 0
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.999
+weight_decay: 0
+clip_grad_norm: 0 # disable grad clipping
+clip_grad_value: 0 # disable grad clipping
+accumulate_grad_batches: 1
+
+# model-related
+cond_type: '' # deepspeech, esperanto, idexp_lm3d
+
+# training
+amp: true # use fp16
+load_imgs_to_memory: false # load uint8 training img to memory, which reduce io costs, at the expense of more memory occupation
+
+# NeRF-related
+near: 0.3
+far: 0.9
+n_rays: 65536 # num rays sampled per image for each training step, default 256*256
+cuda_ray: true # use CUDA raymarching instead of pytorch
+max_steps: 16 # max num steps sampled per ray (only valid when using --cuda_ray)
+num_steps: 16 # num steps sampled per ray (only valid when NOT using --cuda_ray)
+upsample_steps: 0 # num steps up-sampled per ray (only valid when NOT using --cuda_ray)
+update_extra_interval: 16 # iter interval to update extra status (only valid when using --cuda_ray)
+max_ray_batch: 4096 # batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)
+
+
+max_updates: 25_0000 # 40_0000 for training the whole head, 5_0000 for finetuning the mouth
+finetune_lips: true
+finetune_lips_start_iter: 20_0000
+lambda_lpips_loss: 0.01 # auxiliary loss for finetune lips
+lambda_weights_entropy: 0.0001
+lambda_ambient: 0.1
+
+min_near: 0.05 # minimum near distance for camera
+bound: 1 # assume the scene is bounded in box[-bound, bound]^3, if > 1, will invoke adaptive ray marching.
+camera_scale: 4. # scale camera location into box[-bound, bound]^3
+camera_offset: [0, 0, 0] # offset of camera location
+grid_size: 128
+desired_resolution: 2048
+log2_hashmap_size: 16
+dt_gamma: 0.00390625 # default 1/256, dt_gamma (>=0) for adaptive ray marching. set to 0 to disable, >0 to accelerate rendering (but usually with worse quality)
+density_thresh: 10 # threshold for density grid to be occupied (sigma)
+density_thresh_torso: 0.01 # threshold for density grid to be occupied (alpha)
+torso_shrink: 0.8 # shrink bg coords to allow more flexibility in deform
+
+smooth_lips: false
+
+# Network
+grid_type: tiledgrid # tiledgrid or hashgrid
+grid_interpolation_type: linear # smoothstep or linear
+with_att: true
+use_window_cond: true
+torso_head_aware: false # head aware torso nerf to avoid head-torso separation artifacts!
+num_layers_sigma: 3
+hidden_dim_sigma: 128 # 64 by radnerf is too small
+geo_feat_dim: 128 # 64 by radnerf is too small
+num_layers_color: 2
+hidden_dim_color: 128 # 64 by radnerf is too small
+cond_out_dim: 64
+num_layers_ambient: 3
+hidden_dim_ambient: 128 # 64 by radnerf is too small
+ambient_coord_dim: 2
+individual_embedding_num: 13000
+individual_embedding_dim: 4
+torso_individual_embedding_dim: 8
+
+# infer
+infer_cond_name: ''
+infer_out_video_name: ''
+infer_scale_factor: 1.0
+infer_smo_std: 0.
+infer_audio_source_name: ''
+infer_c2w_name: ''
+infer_lm3d_clamp_std: 1.5
+infer_lm3d_lle_percent: 0.25 # percent of lle fused feature to compose the processed lm3d
+infer_lm3d_smooth_sigma: 0. # sigma of gaussian kernel to smooth the predicted lm3d
+infer_bg_img_fname: '' # black, white, or a img fname
+infer_smooth_camera_path: true
+infer_smooth_camera_path_kernel_size: 7
+
+# gui feat
+gui_w: 512
+gui_h: 512
+gui_radius: 3.35
+gui_fovy: 21.24
+gui_max_spp: 1 # GUI rendering max sample per pixel
+
diff --git a/egs/egs_bases/radnerf/lm3d_radnerf.yaml b/egs/egs_bases/radnerf/lm3d_radnerf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..faf1003f6c3bda0c3d65140daca01691d9ce2428
--- /dev/null
+++ b/egs/egs_bases/radnerf/lm3d_radnerf.yaml
@@ -0,0 +1,12 @@
+base_config:
+ - ./base.yaml
+
+task_cls: tasks.radnerfs.radnerf.RADNeRFTask
+cond_type: idexp_lm3d_normalized
+cond_win_size: 1
+smo_win_size: 5
+lambda_lap_ambient_loss: 0.
+cond_dropout_rate: 0.
+zero_dummy: true
+
+ambient_coord_dim: 3
diff --git a/egs/egs_bases/radnerf/radnerf.yaml b/egs/egs_bases/radnerf/radnerf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..956c5100e1fcd2e773394df88bc3e074f039264c
--- /dev/null
+++ b/egs/egs_bases/radnerf/radnerf.yaml
@@ -0,0 +1,10 @@
+base_config:
+ - ./base.yaml
+
+task_cls: tasks.radnerfs.radnerf.RADNeRFTask
+cond_type: esperanto
+cond_win_size: 16
+smo_win_size: 8
+cond_dropout_rate: 0.
+lambda_lap_ambient_loss: 0.
+mask_cond: false
\ No newline at end of file
diff --git a/egs/egs_bases/syncnet/base.yaml b/egs/egs_bases/syncnet/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bd3f92ffd5c71f437151e0bcbd405745aaff5ac0
--- /dev/null
+++ b/egs/egs_bases/syncnet/base.yaml
@@ -0,0 +1,47 @@
+# dataset-related
+binary_data_dir: data/binary/lrs3
+
+# project-related
+work_dir: ''
+load_ckpt: ''
+tb_log_interval: 100
+val_check_interval: 1000
+valid_infer_interval: 1000
+num_sanity_val_steps: 5
+num_valid_plots: 1
+eval_max_batches: 10 # num_test_plots
+print_nan_grads: false
+resume_from_checkpoint: 0 # specify the step, 0 for latest
+amp: false
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+save_best: true
+debug: false
+save_codes:
+- tasks
+- modules
+- egs
+accumulate_grad_batches: 1
+clip_grad_norm: 1.
+
+# training-scheme-related
+task_cls: tasks.syncnet.lm3d_syncnet.SyncNetTask
+max_updates: 4_0000
+seed: 9999
+lr: 0.0005
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.999
+scheduler: none
+num_ckpt_keep: 100
+
+load_db_to_memory: false # enable it for faster indexing
+max_sentences_per_batch: 1024
+max_tokens_per_batch: 20000
+
+audio_type: hubert
+motion_type: idexp_lm3d
+use_kv_dataset: false
+
+syncnet_num_layers_per_block: 3
+syncnet_base_hid_size: 128
+use_fork: true
\ No newline at end of file
diff --git a/egs/os_avatar/audio2motion_vae.yaml b/egs/os_avatar/audio2motion_vae.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a0298e63080b2838d39dda6fe503389288d93570
--- /dev/null
+++ b/egs/os_avatar/audio2motion_vae.yaml
@@ -0,0 +1,33 @@
+base_config:
+ - egs/egs_bases/audio2motion/vae.yaml
+
+ds_name: # 会覆盖下面的binary data dir
+binary_data_dir: data/binary/th1kh
+use_kv_dataset: true
+num_workers: 4
+
+task_cls: tasks.os_avatar.audio2motion_task.Audio2MotionTask
+max_updates: 40_0000
+
+motion_type: exp # exp | id_exp if finegrained_id
+sample_min_length: 32
+init_from_ckpt: ''
+
+lambda_mse_lm2d: 0.
+ref_id_mode: 'first_frame' # first_frame | random_frame if finegrained_id
+
+blink_mode: blink_unit # eye_area_percent | blink_unit | none
+use_pitch: true
+use_flow: true
+
+use_eye_amp_embed: false
+use_mouth_amp_embed: true
+lambda_l2_reg_exp: 0.1
+syncnet_ckpt_dir: ''
+audio_type: hubert # hubert | mfcc | mel
+lambda_mse_exp: 0.5
+lambda_mse_lm3d: 0.5
+lambda_lap_exp: 1.0
+lambda_kl: 0.02
+lambda_kl_t1: 2000
+lambda_kl_t2: 2000
\ No newline at end of file
diff --git a/egs/os_avatar/audio_lm3d_syncnet.yaml b/egs/os_avatar/audio_lm3d_syncnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3398871a0fbe165b9522813e072a5d2ba828ebee
--- /dev/null
+++ b/egs/os_avatar/audio_lm3d_syncnet.yaml
@@ -0,0 +1,25 @@
+base_config:
+ - egs/egs_bases/syncnet/base.yaml
+
+init_from_ckpt: ''
+binary_data_dir: data/binary/th1kh
+task_cls: tasks.os_avatar.audio_lm3d_syncnet.SyncNetTask
+use_kv_dataset: true
+num_workers: 8 # 4
+
+syncnet_num_clip_pairs: 8192
+max_sentences_per_batch: 1024
+max_tokens_per_batch: 20000
+sample_min_length: 64
+max_updates: 400_0000
+
+syncnet_num_layers_per_block: 3 # 3
+syncnet_base_hid_size: 128
+syncnet_out_hid_size: 1024 # 1024
+syncnet_keypoint_mode: lm468
+
+lr: 0.001
+lr_decay_rate: 0.98
+lr_decay_interval: 5000
+
+audio_type: hubert # hubert | mfcc
diff --git a/egs/os_avatar/img2plane.yaml b/egs/os_avatar/img2plane.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5b21ef3b86acab2017670c4e9e56042a4a0e4971
--- /dev/null
+++ b/egs/os_avatar/img2plane.yaml
@@ -0,0 +1,73 @@
+base_config: egs/egs_bases/eg3d/base.yaml
+ds_name: TH1KH_512
+binary_data_dir: data/binary/th1kh
+process_id: 0 # rank id when pre-processing dataset
+total_process: 1 # number of ranks when pre-processing dataset
+split_seed: 999 # random seed that split chunks during pre-processing dataset
+seed: 999
+batch_size: 4
+num_workers: 4
+use_kv_dataset: true
+ones_ws_for_sr: true
+
+# ray_near: 2.2
+# ray_far: 4.0
+ray_near: auto
+ray_far: auto
+
+batch_size: 4 # use smaller bs from 4 when using multiple machines to speed up training
+
+lr_g: 0.0001 # follow the setting of < Real-Time Radiance Fields for Single-Image Portrait View Synthesis >
+# lr_g: 0.0004 # larger lr leads to degradation, even using 32 gpus.
+lr_d: 0.0002 # follow the setting of EG3D
+
+warmup_updates: 4000
+
+flipped_to_world_coord: true
+random_sample_pose: true
+mimic_plane: false # minimize the error with EG3D plane
+
+pretrained_eg3d_ckpt: /mnt/bn/sa-ag-data/yezhenhui/projects/GeneFace_private/checkpoints/geneface2_ckpts/eg3d_baseline_run2/model_ckpt_steps_100000.ckpt
+seg_out_mode: none
+img2plane_backbone_mode: vit
+num_ckpt_keep: 1
+
+not_save_modules: ['criterion_lpips', 'eg3d_model']
+task_cls: tasks.os_avatar.img2plane_task.OSAvatarImg2PlaneTask
+
+batch_size: 1
+normalize_radius: false
+
+optimizer_adam_beta1_g: 0.
+optimizer_adam_beta2_g: 0.99
+optimizer_adam_beta1_d: 0.
+optimizer_adam_beta2_d: 0.99
+
+lambda_mse_depth: 0.
+
+start_adv_iters: 30000
+lr_g: 0.0001
+lr_d: 0.0002
+
+img2plane_backbone_mode: composite # composite | segformer
+
+ffhq_disc_inp_mode: eg3d_gen
+use_th1kh_disc: false # enable only when ds_name == FFHQ_and_TH1KH_512
+lpips_mode: vgg19_v2 # vgg19 | vgg16 | alex | vgg19_v2
+
+enable_rescale_plane_regulation: true
+img2plane_backbone_scale: standard # standard | large
+update_on_th1kh_samples: false
+
+init_from_ckpt: ''
+
+img2plane_input_mode: rgb # rgb_alpha | rgb_camera | rgb_alpha_camera
+triplane_feature_type: trigrid_v2 # triplane # trigrid
+triplane_depth: 3 # 1
+triplane_hid_dim: 32 # 32
+clip_grad_norm: 1.0
+neural_rendering_resolution: 128 # will be upscale 4x by SR
+
+use_th1kh_mv_adv: false
+torch_compile: true
+use_mse: false
\ No newline at end of file
diff --git a/egs/os_avatar/real3d_orig/img2plane_orig.yaml b/egs/os_avatar/real3d_orig/img2plane_orig.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ae4c1134b7d1e5c654dc5fa11ec6915a61c01a43
--- /dev/null
+++ b/egs/os_avatar/real3d_orig/img2plane_orig.yaml
@@ -0,0 +1,42 @@
+base_config:
+ - ../../ffhq/img2plane.yaml
+ - ../../ffhq/base.yaml
+
+not_save_modules: ['criterion_lpips', 'eg3d_model']
+ds_name: FFHQ # FFHQ | FFHQ_and_TH1KH_512 # 发现引入视频数据会导致画质变差
+task_cls: tasks.os_avatar.img2plane_task.OSAvatarImg2PlaneTask
+
+batch_size: 1
+normalize_radius: false
+
+optimizer_adam_beta1_g: 0.
+optimizer_adam_beta2_g: 0.99
+optimizer_adam_beta1_d: 0.
+optimizer_adam_beta2_d: 0.99
+
+lambda_mse_depth: 0.
+
+start_adv_iters: 30000
+lr_g: 0.0001
+lr_d: 0.0002
+
+img2plane_backbone_mode: composite # composite | segformer
+
+ffhq_disc_inp_mode: eg3d_gen
+use_th1kh_disc: false # enable only when ds_name == FFHQ_and_TH1KH_512
+lpips_mode: vgg19_v2 # vgg19 | vgg16 | alex | vgg19_v2
+
+enable_rescale_plane_regulation: true
+img2plane_backbone_scale: standard # standard | large
+update_on_th1kh_samples: false
+
+init_from_ckpt: 'checkpoints/0823_img2plane/img2plane'
+
+triplane_feature_type: triplane # triplane # trigrid # trigrid_v2
+triplane_depth: 1 # now use 3
+triplane_hid_dim: 32 # 32
+clip_grad_norm: 1.0
+
+use_th1kh_mv_adv: false
+torch_compile: true
+use_mse: false
\ No newline at end of file
diff --git a/egs/os_avatar/real3d_orig/secc_img2plane_orig.yaml b/egs/os_avatar/real3d_orig/secc_img2plane_orig.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d407f3810bfb79103503e76818c59cdf12400e00
--- /dev/null
+++ b/egs/os_avatar/real3d_orig/secc_img2plane_orig.yaml
@@ -0,0 +1,56 @@
+base_config:
+ - ./img2plane_orig.yaml
+
+task_cls: tasks.os_avatar.secc_img2plane_task.SECC_Img2PlaneEG3DTask
+# ds_name: Concat_VFHQ_CelebVHQ_TH1KH_RAVDESS # CelebV_HQ | Concat_CelebVHQ_TH1KH | Concat_CelebVHQ_TH1KH_RAVDESS
+ds_name: FULL_Concat_VFHQ_CelebVHQ_TH1KH_RAVDESS
+binary_data_dir: data/binary/CelebV-HQ
+
+img2plane_backbone_mode: composite # composite | segformer
+num_workers: 8 # 4
+pncc_cond_mode: cano_src_tgt # cano_tgt | cano_src_tgt
+seg_out_mode: head
+
+# 目前发现adv之后控制不了嘴了,见checkpoints/0702_img2planes/osavatar_secc_img2plane_baseline_vit_from_pretrained
+start_adv_iters: 25_0000 # 如果是从img2plane过来的,25w;如果是从secc2plane过来了,见机行事,5w~10w左右也行。
+max_updates: 25_0000 # 25_0000
+lambda_th1kh_mv_adv: 0.002 # 0.005 # 0.01
+add_ffhq_singe_disc: false
+lambda_ffhq_mv_adv: 0.002 # enable when add_ffhq_singe_disc is True
+lr_mul_cano_img2plane: 1.0 # 1.0 | 0. | 0.1
+lambda_mse: 1.0
+lr_decay_rate: 0.95
+lr_decay_interval: 5000
+
+secc_segformer_scale: b0 # b0-b5
+use_motion_smo_net: false
+motion_smo_win_size: 5
+
+# regularization on Spatial plane
+density_reg_p_dist: 0.004 # distance at which to sample perturbed points for density regularization
+
+# regularization on SECC plane
+reg_interval_g: 4
+enable_rescale_plane_regulation: false # 试了下rescale发现效果不大
+min_rescale_factor: 0.25
+# how we fuse the secc
+phase1_plane_fusion_mode: add # add | mul
+init_from_ckpt: checkpoints/240126_real3dportrait_orig/img2plane_orig
+
+disable_highreso_at_stage1: true
+secc_pertube_mode: randn # randn | tv | laplacian | none
+secc_pertube_randn_scale: 0.01 # enable when pertube_mode==randn
+# target_pertube_blink_secc_loss: 0.05 # task会自动tune对应的lambda以使pertube loss逼近这个目标
+target_pertube_blink_secc_loss: 0.15 # task会自动tune对应的lambda以使pertube loss逼近这个目标
+target_pertube_secc_loss: 0.5 # 0.3 # task会自动tune对应的lambda以使pertube loss逼近这个目标
+lr_lambda_pertube_secc: 0.01 # 自动tune lambda的学习率
+
+sr_type: vanilla # vanillda | spade
+two_stage_training: true # is yes, when adv starts, fix the nerf and only finetune the sr. We found it necessary, otherwise the i2p could produce bad cases (such as darken face)
+also_update_decoder: false # update decoder at stage 2
+lambda_weights_l1: 0.1 # 0.5
+lambda_weights_entropy: 0.01 # 0.05
+lambda_density_reg: 0.25 # default 0.25 in EG3D, strength of pertube density regularization for Generator
+reg_interval_g_cond: 4
+ckpt_milestone_interval: 50000
+update_src2src_interval: 16
diff --git a/egs/os_avatar/real3d_orig/secc_img2plane_torso_orig.yaml b/egs/os_avatar/real3d_orig/secc_img2plane_torso_orig.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c3507071be1a1da77b972410e9ef7647bb0b285
--- /dev/null
+++ b/egs/os_avatar/real3d_orig/secc_img2plane_torso_orig.yaml
@@ -0,0 +1,31 @@
+base_config:
+ - ./secc_img2plane_orig.yaml
+
+task_cls: tasks.os_avatar.secc_img2plane_torso_task.SECC_Img2PlaneEG3D_TorsoTask
+torso_ref_segout_mode: torso # torso | torso_with_bg | person | full (person_with_bg)
+
+lr_g: 0.00001
+
+weight_fuse: true
+
+start_adv_iters: 40000
+max_updates: 10_0000 # 25_0000
+lambda_th1kh_mv_adv: 0.003
+add_ffhq_singe_disc: false
+lambda_ffhq_mv_adv: 0.002 # enable when add_ffhq_singe_disc is True
+lambda_mse: 1.0
+init_from_ckpt: checkpoints/240207_robust_secc2plane/secc2plane_orig_blink0.3_pertubeNone/model_ckpt_steps_150000.ckpt # checkpoints/0725_img2planes/secc_img2plane_torso | can be either a secc_img2plane or a secc_img2plane_torso ckpt
+reload_head_ckpt: '' # checkpoints/0804_secc2plane/secc_img2plane_lap0.1_blink0.05_run2 | will override the secc_img2plane from init_from_ckpt and be reloaded during training
+
+fuse_with_deform_source: false # fuse source会有严重的artifact
+lam_occlusion_2_reg_l1: 0.0 # 0.001
+torso_occlusion_reg_unmask_factor: 0.3
+lam_occlusion_weights_entropy: 0.001 # 0.0001
+
+lam_occlusion_reg_l1: 0.00 # 设置成0.02导致脸部和torso都有色差,并且摇头晃脑时只有脖子动,身体不太动,不真实。
+torso_kp_num: 4
+torso_inp_mode: rgb_alpha
+htbsr_head_threshold: 0.9
+torso_model_version: v2
+htbsr_head_weight_fuse_mode: v2
+appearance_feat_mul_torso_mask: true
\ No newline at end of file
diff --git a/egs/os_avatar/secc_img2plane.yaml b/egs/os_avatar/secc_img2plane.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..12bae65b7cee80072f78bb8df5f0c7a62c4e22b8
--- /dev/null
+++ b/egs/os_avatar/secc_img2plane.yaml
@@ -0,0 +1,56 @@
+base_config:
+ - ./img2plane.yaml
+
+task_cls: tasks.os_avatar.secc_img2plane_task.SECC_Img2PlaneEG3DTask
+ds_name: TH1KH_512 # CelebV_HQ | Concat_CelebVHQ_TH1KH | Concat_CelebVHQ_TH1KH_RAVDESS
+binary_data_dir: data/binary/th1kh
+
+img2plane_backbone_mode: composite # composite | segformer
+num_workers: 8 # 4
+pncc_cond_mode: cano_src_tgt # cano_tgt | cano_src_tgt
+seg_out_mode: head
+
+# 目前发现adv之后控制不了嘴了,见checkpoints/0702_img2planes/osavatar_secc_img2plane_baseline_vit_from_pretrained
+start_adv_iters: 20_0000 # 如果是从img2plane过来的,15w;如果是从secc2plane过来了,见机行事,5w~10w左右也行。
+stop_update_i2p_iters: 7_0000
+max_updates: 25_0000 # 发现到20w的时候会过拟合,对ood identity效果不好
+lambda_th1kh_mv_adv: 0.002 # 0.005 # 0.01
+add_ffhq_singe_disc: false
+lambda_ffhq_mv_adv: 0.002 # enable when add_ffhq_singe_disc is True
+lr_mul_cano_img2plane: 1.0 # 1.0 | 0. | 0.1
+lambda_mse: 1.0
+lr_decay_rate: 0.95
+lr_decay_interval: 5000
+
+secc_segformer_scale: b0 # b0-b5
+use_motion_smo_net: false
+motion_smo_win_size: 5
+
+# regularization on Spatial plane
+density_reg_p_dist: 0.004 # distance at which to sample perturbed points for density regularization
+
+# regularization on SECC plane
+reg_interval_g: 4
+enable_rescale_plane_regulation: false # 试了下rescale发现效果不大
+min_rescale_factor: 0.25
+# how we fuse the secc
+phase1_plane_fusion_mode: add # add | mul
+init_from_ckpt: '' # checkpoints/240126_improve_i2p/img2plane_rgb_alpha
+
+disable_highreso_at_stage1: true
+secc_pertube_mode: randn # randn | tv | laplacian | none
+secc_pertube_randn_scale: 0.01 # enable when pertube_mode==randn
+target_pertube_blink_secc_loss: 0.3 # task会自动tune对应的lambda以使pertube loss逼近这个目标
+target_pertube_secc_loss: 0. # 0.5 # task会自动tune对应的lambda以使pertube loss逼近这个目标
+pertube_ref_prob: 0.25
+lr_lambda_pertube_secc: 0.01 # 自动tune lambda的学习率
+
+sr_type: vanilla # vanillda | spade
+two_stage_training: true # is yes, when adv starts, fix the nerf and only finetune the sr. We found it necessary, otherwise the i2p could produce bad cases (such as darken face)
+also_update_decoder: false # update decoder at stage 2
+lambda_weights_l1: 0.1 # 0.5
+lambda_weights_entropy: 0.01 # 0.05
+lambda_density_reg: 0.25 # default 0.25 in EG3D, strength of pertube density regularization for Generator
+reg_interval_g_cond: 4
+ckpt_milestone_interval: 50000
+update_src2src_interval: 16
diff --git a/egs/os_avatar/secc_img2plane_torso.yaml b/egs/os_avatar/secc_img2plane_torso.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..762c289026bc5183db76caa8b5ae019de5ed1291
--- /dev/null
+++ b/egs/os_avatar/secc_img2plane_torso.yaml
@@ -0,0 +1,31 @@
+base_config:
+ - ./secc_img2plane.yaml
+
+task_cls: tasks.os_avatar.secc_img2plane_torso_task.SECC_Img2PlaneEG3D_TorsoTask
+torso_ref_segout_mode: torso # torso | torso_with_bg | person | full (person_with_bg)
+
+lr_g: 0.00001
+
+weight_fuse: true
+
+start_adv_iters: 40000
+max_updates: 10_0000 # 25_0000
+lambda_th1kh_mv_adv: 0.001
+add_ffhq_singe_disc: false
+lambda_ffhq_mv_adv: 0.002 # enable when add_ffhq_singe_disc is True
+lambda_mse: 1.0
+init_from_ckpt: '' # checkpoints/0725_img2planes/secc_img2plane_torso | can be either a secc_img2plane or a secc_img2plane_torso ckpt
+reload_head_ckpt: '' # checkpoints/0804_secc2plane/secc_img2plane_lap0.1_blink0.05_run2 | will override the secc_img2plane from init_from_ckpt and be reloaded during training
+
+
+fuse_with_deform_source: false # fuse source会有严重的artifact
+lam_occlusion_2_reg_l1: 0.0 # 0.001
+torso_occlusion_reg_unmask_factor: 0.3
+lam_occlusion_weights_entropy: 0.001 # 0.0001
+
+lam_occlusion_reg_l1: 0.00 # 设置成0.02导致脸部和torso都有色差,并且摇头晃脑时只有脖子动,身体不太动,不真实。
+occlusion_fuse: true
+torso_kp_num: 4
+htbsr_head_weight_fuse_mode: v2
+htbsr_head_threshold: 0.9
+torso_model_version: v2
diff --git a/egs/th1kh_512/base.yaml b/egs/th1kh_512/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c56412eceb0b667daa44a02ffafa3ce71c4dc9c8
--- /dev/null
+++ b/egs/th1kh_512/base.yaml
@@ -0,0 +1,21 @@
+ds_name: TH1KH_512
+raw_data_dir: /mnt/bn/sa-ag-data/yezhenhui/datasets/raw/TH1KH_512
+binary_data_dir: data/binary/TH1KH_512
+# binary_data_dir: /dev/shm/TH1KH
+process_id: 0 # rank id when pre-processing dataset
+total_process: 1 # number of ranks when pre-processing dataset
+split_seed: 999 # random seed that split chunks during pre-processing dataset
+
+max_sentences_per_batch: 1024
+max_tokens_per_batch: 200000
+
+load_db_to_memory: false
+
+num_workers: 4
+use_kv_dataset: true
+
+binarization_args:
+ with_hubert: false
+ with_mel: false
+ with_coeff: true
+
diff --git a/egs/th1kh_512/secc_img2plane.yaml b/egs/th1kh_512/secc_img2plane.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2cfcd507d0909efabd4e4f0abe96ac758749f76e
--- /dev/null
+++ b/egs/th1kh_512/secc_img2plane.yaml
@@ -0,0 +1,8 @@
+base_config:
+ - ../os_avatar/secc_img2plane.yaml
+ - ./base.yaml
+
+
+init_from_ckpt: /mnt/bn/sa-ag-data/yezhenhui/projects/GeneFace_private/checkpoints/0720_img2planes/secc_img2plane_one_stage
+lr_g: 0.0001 # 1e-4, larger than ravdess, because th1kh_512 is larger
+lr_d: 0.0002 # 2e-4
\ No newline at end of file
diff --git a/egs/th1kh_512/secc_img2plane_torso.yaml b/egs/th1kh_512/secc_img2plane_torso.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9283b1d7509bca714576f30ca70cb1f55278e7c8
--- /dev/null
+++ b/egs/th1kh_512/secc_img2plane_torso.yaml
@@ -0,0 +1,8 @@
+base_config:
+ - ../os_avatar/secc_img2plane_torso.yaml
+ - ./base.yaml
+
+
+init_from_ckpt: /mnt/bn/sa-ag-data/yezhenhui/projects/GeneFace_private/checkpoints/0729_th1kh/secc_img2plane
+lr_g: 0.00001 # 1e-5
+lr_d: 0.0002 # 2e-4
\ No newline at end of file
diff --git a/egs/th1kh_512_audio2motion/base.yaml b/egs/th1kh_512_audio2motion/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25d3fe7e63151356e1bf9e2283f70b693c68d27d
--- /dev/null
+++ b/egs/th1kh_512_audio2motion/base.yaml
@@ -0,0 +1,20 @@
+ds_name: TH1KH_512
+raw_data_dir: /mnt/bn/sa-ag-data/yezhenhui/datasets/raw/TH1KH_512
+binary_data_dir: data/binary/TH1KH_512_audio2motion
+# binary_data_dir: /dev/shm/TH1KH_512
+process_id: 0 # rank id when pre-processing dataset
+total_process: 1 # number of ranks when pre-processing dataset
+split_seed: 999 # random seed that split chunks during pre-processing dataset
+
+smo_win_size: 5
+batch_size: 4
+num_workers: 4
+
+use_kv_dataset: true
+
+binarization_args:
+ with_hubert: true
+ with_mel: true
+ with_coeff: true
+
+sample_min_length: 0
\ No newline at end of file
diff --git a/egs/th1kh_512_audio2motion/lm3d_syncnet.yaml b/egs/th1kh_512_audio2motion/lm3d_syncnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3fc59ea27ce1ac5ed5a119e81b4712b4401b78bd
--- /dev/null
+++ b/egs/th1kh_512_audio2motion/lm3d_syncnet.yaml
@@ -0,0 +1,17 @@
+base_config:
+ - egs/egs_bases/syncnet/base.yaml
+ - ./base.yaml
+
+max_updates: 250000
+motion_type: idexp_lm3d
+audio_type: hubert
+
+syncnet_num_layers_per_block: 3
+syncnet_base_hid_size: 128
+
+# max_sentences_per_batch: 1024
+max_sentences_per_batch: 2048
+max_tokens_per_batch: 40_000
+# max_tokens_per_batch: 20_000
+
+num_workers: 16
\ No newline at end of file
diff --git a/egs/th1kh_512_audio2motion/lm3d_vae.yaml b/egs/th1kh_512_audio2motion/lm3d_vae.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cb0263a9b51d23139e77ec47fd24a903ecebea14
--- /dev/null
+++ b/egs/th1kh_512_audio2motion/lm3d_vae.yaml
@@ -0,0 +1,9 @@
+base_config:
+ - egs/egs_bases/audio2motion/vae.yaml
+ - ./base.yaml
+
+lambda_kl: 0.02
+motion_type: idexp_lm3d
+audio_type: hubert
+
+max_updates: 160000
diff --git a/egs/th1kh_512_audio2motion/lm3d_vae_pitch.yaml b/egs/th1kh_512_audio2motion/lm3d_vae_pitch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..40238f72ee26720714c061a1a8ddba62839ae125
--- /dev/null
+++ b/egs/th1kh_512_audio2motion/lm3d_vae_pitch.yaml
@@ -0,0 +1,10 @@
+base_config:
+ - egs/egs_bases/audio2motion/vae.yaml
+ - ./base.yaml
+
+lambda_kl: 0.02
+motion_type: idexp_lm3d
+audio_type: hubert
+
+task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask
+max_updates: 160000
diff --git a/egs/th1kh_512_audio2motion/lm3d_vae_sync.yaml b/egs/th1kh_512_audio2motion/lm3d_vae_sync.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..914f94b3a07e4c88125a7d2cff06687b8c40be22
--- /dev/null
+++ b/egs/th1kh_512_audio2motion/lm3d_vae_sync.yaml
@@ -0,0 +1,13 @@
+base_config:
+ - egs/egs_bases/audio2motion/vae_sync.yaml
+ - ./base.yaml
+
+syncnet_work_dir: checkpoints/th1kh/lm3d_syncnet
+syncnet_ckpt_steps: 250000
+lambda_kl: 0.02
+max_updates: 160000
+motion_type: idexp_lm3d
+audio_type: hubert
+
+syncnet_num_layers_per_block: 3
+syncnet_base_hid_size: 128
\ No newline at end of file
diff --git a/egs/th1kh_512_audio2motion/lm3d_vae_sync_pitch.yaml b/egs/th1kh_512_audio2motion/lm3d_vae_sync_pitch.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d1668a87cc758520bdc5fd99b8ebfd9adc24888c
--- /dev/null
+++ b/egs/th1kh_512_audio2motion/lm3d_vae_sync_pitch.yaml
@@ -0,0 +1,14 @@
+base_config:
+ - ./lm3d_vae_sync.yaml
+ - ./base.yaml
+
+lambda_kl: 0.02
+syncnet_work_dir: checkpoints/th1kh/lm3d_syncnet
+syncnet_ckpt_steps: 230000
+task_cls: tasks.audio2motion.lm3d_vae_sync_pitch.VAESyncAudio2MotionTask
+max_updates: 160000
+motion_type: idexp_lm3d
+audio_type: hubert
+
+syncnet_num_layers_per_block: 3
+syncnet_base_hid_size: 128
\ No newline at end of file
diff --git a/inference/app_real3dportrait.py b/inference/app_real3dportrait.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d231a538dd1117f82ecc49bc5df92b7d6f2db81
--- /dev/null
+++ b/inference/app_real3dportrait.py
@@ -0,0 +1,247 @@
+import os, sys
+sys.path.append('./')
+import argparse
+import gradio as gr
+from inference.real3d_infer import GeneFace2Infer
+from utils.commons.hparams import hparams
+
+class Inferer(GeneFace2Infer):
+ def infer_once_args(self, *args, **kargs):
+ assert len(kargs) == 0
+ keys = [
+ 'src_image_name',
+ 'drv_audio_name',
+ 'drv_pose_name',
+ 'bg_image_name',
+ 'blink_mode',
+ 'temperature',
+ 'mouth_amp',
+ 'out_mode',
+ 'map_to_init_pose',
+ 'low_memory_usage',
+ 'hold_eye_opened',
+ 'a2m_ckpt',
+ 'head_ckpt',
+ 'torso_ckpt',
+ 'min_face_area_percent',
+ ]
+ inp = {}
+ out_name = None
+ info = ""
+
+ try: # try to catch errors and jump to return
+ for key_index in range(len(keys)):
+ key = keys[key_index]
+ inp[key] = args[key_index]
+ if '_name' in key:
+ inp[key] = inp[key] if inp[key] is not None else ''
+
+ if inp['src_image_name'] == '':
+ info = "Input Error: Source image is REQUIRED!"
+ raise ValueError
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] == '':
+ info = "Input Error: At least one of driving audio or video is REQUIRED!"
+ raise ValueError
+
+
+ if inp['drv_audio_name'] == '' and inp['drv_pose_name'] != '':
+ inp['drv_audio_name'] = inp['drv_pose_name']
+ print("No audio input, we use driving pose video for video driving")
+
+ if inp['drv_pose_name'] == '':
+ inp['drv_pose_name'] = 'static'
+
+ reload_flag = False
+ if inp['a2m_ckpt'] != self.audio2secc_dir:
+ print("Changes of a2m_ckpt detected, reloading model")
+ reload_flag = True
+ if inp['head_ckpt'] != self.head_model_dir:
+ print("Changes of head_ckpt detected, reloading model")
+ reload_flag = True
+ if inp['torso_ckpt'] != self.torso_model_dir:
+ print("Changes of torso_ckpt detected, reloading model")
+ reload_flag = True
+
+ inp['out_name'] = ''
+ inp['seed'] = 42
+
+ print(f"infer inputs : {inp}")
+
+ try:
+ if reload_flag:
+ self.__init__(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp, device=self.device)
+ except Exception as e:
+ content = f"{e}"
+ info = f"Reload ERROR: {content}"
+ raise ValueError
+ try:
+ out_name = self.infer_once(inp)
+ except Exception as e:
+ content = f"{e}"
+ info = f"Inference ERROR: {content}"
+ raise ValueError
+ except Exception as e:
+ if info == "": # unexpected errors
+ content = f"{e}"
+ info = f"WebUI ERROR: {content}"
+
+ # output part
+ if len(info) > 0 : # there is errors
+ print(info)
+ info_gr = gr.update(visible=True, value=info)
+ else: # no errors
+ info_gr = gr.update(visible=False, value=info)
+ if out_name is not None and len(out_name) > 0 and os.path.exists(out_name): # good output
+ print(f"Succefully generated in {out_name}")
+ video_gr = gr.update(visible=True, value=out_name)
+ else:
+ print(f"Failed to generate")
+ video_gr = gr.update(visible=True, value=out_name)
+
+ return video_gr, info_gr
+
+def toggle_audio_file(choice):
+ if choice == False:
+ return gr.update(visible=True), gr.update(visible=False)
+ else:
+ return gr.update(visible=False), gr.update(visible=True)
+
+def ref_video_fn(path_of_ref_video):
+ if path_of_ref_video is not None:
+ return gr.update(value=True)
+ else:
+ return gr.update(value=False)
+
+def real3dportrait_demo(
+ audio2secc_dir,
+ head_model_dir,
+ torso_model_dir,
+ device = 'cuda',
+ warpfn = None,
+ ):
+
+ sep_line = "-" * 40
+
+ infer_obj = Inferer(
+ audio2secc_dir=audio2secc_dir,
+ head_model_dir=head_model_dir,
+ torso_model_dir=torso_model_dir,
+ device=device,
+ )
+
+ print(sep_line)
+ print("Model loading is finished.")
+ print(sep_line)
+ with gr.Blocks(analytics_enabled=False) as real3dportrait_interface:
+ gr.Markdown("\
+ Real3D-Portrait: One-shot Realistic 3D Talking Portrait Synthesis (ICLR 2024 Spotlight) \
+
Arxiv \
+
Homepage \
+
Github ")
+
+ sources = None
+ with gr.Row():
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="source_image"):
+ with gr.TabItem('Upload image'):
+ with gr.Row():
+ src_image_name = gr.Image(label="Source image (required)", sources=sources, type="filepath", value="data/raw/examples/Macron.png")
+ with gr.Tabs(elem_id="driven_audio"):
+ with gr.TabItem('Upload audio'):
+ with gr.Column(variant='panel'):
+ drv_audio_name = gr.Audio(label="Input audio (required for audio-driven)", sources=sources, type="filepath", value="data/raw/examples/Obama_5s.wav")
+ with gr.Tabs(elem_id="driven_pose"):
+ with gr.TabItem('Upload video'):
+ with gr.Column(variant='panel'):
+ drv_pose_name = gr.Video(label="Driven Pose (required for video-driven, optional for audio-driven)", sources=sources, value="data/raw/examples/May_5s.mp4")
+ with gr.Tabs(elem_id="bg_image"):
+ with gr.TabItem('Upload image'):
+ with gr.Row():
+ bg_image_name = gr.Image(label="Background image (optional)", sources=sources, type="filepath", value="data/raw/examples/bg.png")
+
+
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('General Settings'):
+ with gr.Column(variant='panel'):
+
+ blink_mode = gr.Radio(['none', 'period'], value='period', label='blink mode', info="whether to blink periodly") #
+ min_face_area_percent = gr.Slider(minimum=0.15, maximum=0.5, step=0.01, label="min_face_area_percent", value=0.2, info='The minimum face area percent in the output frame, to prevent bad cases caused by a too small face.',)
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="temperature", value=0.2, info='audio to secc temperature',)
+ mouth_amp = gr.Slider(minimum=0.0, maximum=1.0, step=0.025, label="mouth amplitude", value=0.45, info='higher -> mouth will open wider, default to be 0.4',)
+ out_mode = gr.Radio(['final', 'concat_debug'], value='concat_debug', label='output layout', info="final: only final output ; concat_debug: final output concated with internel features")
+ low_memory_usage = gr.Checkbox(label="Low Memory Usage Mode: save memory at the expense of lower inference speed. Useful when running a low audio (minutes-long).", value=False)
+ map_to_init_pose = gr.Checkbox(label="Whether to map pose of first frame to initial pose", value=True)
+ hold_eye_opened = gr.Checkbox(label="Whether to maintain eyes always open")
+
+ submit = gr.Button('Generate', elem_id="generate", variant='primary')
+
+ with gr.Tabs(elem_id="genearted_video"):
+ info_box = gr.Textbox(label="Error", interactive=False, visible=False)
+ gen_video = gr.Video(label="Generated video", format="mp4", visible=True)
+ with gr.Column(variant='panel'):
+ with gr.Tabs(elem_id="checkbox"):
+ with gr.TabItem('Checkpoints'):
+ with gr.Column(variant='panel'):
+ ckpt_info_box = gr.Textbox(value="Please select \"ckpt\" under the checkpoint folder ", interactive=False, visible=True, show_label=False)
+ audio2secc_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=audio2secc_dir, file_count='single', label='audio2secc model ckpt path or directory')
+ head_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=head_model_dir, file_count='single', label='head model ckpt path or directory (will be ignored if torso model is set)')
+ torso_model_dir = gr.FileExplorer(glob="checkpoints/**/*.ckpt", value=torso_model_dir, file_count='single', label='torso model ckpt path or directory')
+ # audio2secc_dir = gr.Textbox(audio2secc_dir, max_lines=1, label='audio2secc model ckpt path or directory (will be ignored if torso model is set)')
+ # head_model_dir = gr.Textbox(head_model_dir, max_lines=1, label='head model ckpt path or directory (will be ignored if torso model is set)')
+ # torso_model_dir = gr.Textbox(torso_model_dir, max_lines=1, label='torso model ckpt path or directory')
+
+
+ fn = infer_obj.infer_once_args
+ if warpfn:
+ fn = warpfn(fn)
+ submit.click(
+ fn=fn,
+ inputs=[
+ src_image_name,
+ drv_audio_name,
+ drv_pose_name,
+ bg_image_name,
+ blink_mode,
+ temperature,
+ mouth_amp,
+ out_mode,
+ map_to_init_pose,
+ low_memory_usage,
+ hold_eye_opened,
+ audio2secc_dir,
+ head_model_dir,
+ torso_model_dir,
+ min_face_area_percent,
+ ],
+ outputs=[
+ gen_video,
+ info_box,
+ ],
+ )
+
+ print(sep_line)
+ print("Gradio page is constructed.")
+ print(sep_line)
+
+ return real3dportrait_interface
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/audio2secc_vae/model_ckpt_steps_400000.ckpt')
+ parser.add_argument("--head_ckpt", type=str, default='')
+ parser.add_argument("--torso_ckpt", type=str, default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig/model_ckpt_steps_100000.ckpt')
+ parser.add_argument("--port", type=int, default=None)
+ parser.add_argument("--server", type=str, default='127.0.0.1')
+ parser.add_argument("--share", action='store_true', dest='share', help='share srever to Internet')
+
+ args = parser.parse_args()
+ demo = real3dportrait_demo(
+ audio2secc_dir=args.a2m_ckpt,
+ head_model_dir=args.head_ckpt,
+ torso_model_dir=args.torso_ckpt,
+ device='cuda:0',
+ warpfn=None,
+ )
+ demo.queue()
+ demo.launch(share=args.share, server_name=args.server, server_port=args.port)
diff --git a/inference/edit_secc.py b/inference/edit_secc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1e602b389665c2710eb76e5ab1244030096db2
--- /dev/null
+++ b/inference/edit_secc.py
@@ -0,0 +1,147 @@
+import cv2
+import torch
+from utils.commons.image_utils import dilate, erode
+from sklearn.neighbors import NearestNeighbors
+import copy
+import numpy as np
+from utils.commons.meters import Timer
+
+def hold_eye_opened_for_secc(img):
+ img = img.permute(1,2,0).cpu().numpy()
+ img = ((img +1)/2*255).astype(np.uint)
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
+ face_xys = np.stack(np.nonzero(face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ h,w = face_mask.shape
+ # get face and eye mask
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+
+ opened_eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
+ opened_eye_mask = torch.nn.functional.interpolate(torch.tensor(opened_eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[0], img.shape[1]), mode='nearest')[0].permute(1,2,0).sum(-1).bool().cpu() # [512,512,3]
+ coarse_opened_eye_xys = np.stack(np.nonzero(opened_eye_mask)) # [N_nonbg,2] coordinate of non-face pixels
+
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
+ dists, _ = nbrs.kneighbors(coarse_opened_eye_xys) # [512*512, 1] distance to nearest non-bg pixel
+ # print(dists.max())
+ non_opened_eye_pixs = dists > max(dists.max()*0.75, 4) # 大于这个距离的opened eye部分会被合上
+ non_opened_eye_pixs = non_opened_eye_pixs.reshape([-1])
+ opened_eye_xys_to_erode = coarse_opened_eye_xys[non_opened_eye_pixs]
+ opened_eye_mask[opened_eye_xys_to_erode[...,0], opened_eye_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
+
+ img[opened_eye_mask] = 0
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+
+
+# def hold_eye_opened_for_secc(img):
+# img = copy.copy(img)
+# eye_mask = cv2.imread('inference/os_avatar/opened_eye_mask.png')
+# eye_mask = torch.nn.functional.interpolate(torch.tensor(eye_mask).permute(2,0,1).unsqueeze(0), size=(img.shape[-2], img.shape[-1]), mode='nearest')[0].bool().to(img.device) # [3,512,512]
+# img[eye_mask] = -1
+# return img
+
+def blink_eye_for_secc(img, close_eye_percent=0.5):
+ """
+ secc_img: [3,h,w], tensor, -1~1
+ """
+ img = img.permute(1,2,0).cpu().numpy()
+ img = ((img +1)/2*255).astype(np.uint)
+ assert close_eye_percent <= 1.0 and close_eye_percent >= 0.
+ if close_eye_percent == 0: return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+ img = copy.deepcopy(img)
+ face_mask = (img[...,0] != 0) & (img[...,1] != 0) & (img[...,2] != 0)
+ h,w = face_mask.shape
+
+ # get face and eye mask
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ left_eye_prior_reigon[h//4:h//2, w//4:w//2] = True
+ right_eye_prior_reigon[h//4:h//2, w//2:w//4*3] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+ coarse_eye_mask = (~ face_mask) & eye_prior_reigon
+ coarse_left_eye_mask = (~ face_mask) & left_eye_prior_reigon
+ coarse_right_eye_mask = (~ face_mask) & right_eye_prior_reigon
+ coarse_eye_xys = np.stack(np.nonzero(coarse_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ min_h = coarse_eye_xys[:, 0].min()
+ max_h = coarse_eye_xys[:, 0].max()
+ coarse_left_eye_xys = np.stack(np.nonzero(coarse_left_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ left_min_w = coarse_left_eye_xys[:, 1].min()
+ left_max_w = coarse_left_eye_xys[:, 1].max()
+ coarse_right_eye_xys = np.stack(np.nonzero(coarse_right_eye_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ right_min_w = coarse_right_eye_xys[:, 1].min()
+ right_max_w = coarse_right_eye_xys[:, 1].max()
+
+ # 尽力较少需要考虑的face_xyz,以降低KNN的损耗
+ left_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ more_room = 4 # 过小会导致一些问题
+ left_eye_prior_reigon[min_h-more_room:max_h+more_room, left_min_w-more_room:left_max_w+more_room] = True
+ right_eye_prior_reigon = np.zeros([h,w], dtype=bool)
+ right_eye_prior_reigon[min_h-more_room:max_h+more_room, right_min_w-more_room:right_max_w+more_room] = True
+ eye_prior_reigon = left_eye_prior_reigon | right_eye_prior_reigon
+
+ around_eye_face_mask = face_mask & eye_prior_reigon
+ face_mask = around_eye_face_mask
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(coarse_eye_xys)
+ dists, _ = nbrs.kneighbors(face_xys) # [512*512, 1] distance to nearest non-bg pixel
+ face_pixs = dists > 5 # 只有距离最近的eye pixel大于5的才被认为是face,过小会导致一些问题
+ face_pixs = face_pixs.reshape([-1])
+ face_xys_to_erode = face_xys[~face_pixs]
+ face_mask[face_xys_to_erode[...,0], face_xys_to_erode[...,1]] = False # shrink 将mask在face-eye边界收缩3pixel,为了平滑
+ eye_mask = (~ face_mask) & eye_prior_reigon
+
+ h_grid = np.mgrid[0:h, 0:w][0]
+ eye_num_pixel_along_w_axis = eye_mask.sum(axis=0)
+ eye_mask_along_w_axis = eye_num_pixel_along_w_axis != 0
+
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 0
+ eye_mean_h_coord_along_w_axis = tmp_h_grid.sum(axis=0) / np.clip(eye_num_pixel_along_w_axis, a_min=1, a_max=h)
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 99999
+ eye_min_h_coord_along_w_axis = tmp_h_grid.min(axis=0)
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = -99999
+ eye_max_h_coord_along_w_axis = tmp_h_grid.max(axis=0)
+
+ eye_low_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_min_h_coord_along_w_axis # upper eye
+ eye_high_h_coord_along_w_axis = close_eye_percent * eye_mean_h_coord_along_w_axis + (1-close_eye_percent) * eye_max_h_coord_along_w_axis # lower eye
+
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = 99999
+ upper_eye_blink_mask = tmp_h_grid <= eye_low_h_coord_along_w_axis
+ tmp_h_grid = h_grid.copy()
+ tmp_h_grid[~eye_mask] = -99999
+ lower_eye_blink_mask = tmp_h_grid >= eye_high_h_coord_along_w_axis
+ eye_blink_mask = upper_eye_blink_mask | lower_eye_blink_mask
+
+ face_xys = np.stack(np.nonzero(around_eye_face_mask)).transpose(1, 0) # [N_nonbg,2] coordinate of non-face pixels
+ eye_blink_xys = np.stack(np.nonzero(eye_blink_mask)).transpose(1, 0) # [N_nonbg,hw] coordinate of non-face pixels
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(face_xys)
+ distances, indices = nbrs.kneighbors(eye_blink_xys)
+ bg_fg_xys = face_xys[indices[:, 0]]
+ img[eye_blink_xys[:, 0], eye_blink_xys[:, 1], :] = img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ return torch.tensor(img.astype(np.float32) / 127.5 - 1).permute(2,0,1)
+
+
+if __name__ == '__main__':
+ import imageio
+ import tqdm
+ img = cv2.imread("assets/cano_secc.png")
+ img = img / 127.5 - 1
+ img = torch.FloatTensor(img).permute(2, 0, 1)
+ fps = 25
+ writer = imageio.get_writer('demo_blink.mp4', fps=fps)
+
+ for i in tqdm.trange(33):
+ blink_percent = 0.03 * i
+ with Timer("Blink", True):
+ out_img = blink_eye_for_secc(img, blink_percent)
+ out_img = ((out_img.permute(1,2,0)+1)*127.5).int().numpy()
+ writer.append_data(out_img)
+ writer.close()
\ No newline at end of file
diff --git a/inference/infer_utils.py b/inference/infer_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb406fa8c734bd5295aae3f2a6e276f4b697da48
--- /dev/null
+++ b/inference/infer_utils.py
@@ -0,0 +1,154 @@
+import os
+import torch
+import torch.nn.functional as F
+import librosa
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+from scipy.spatial.transform import Rotation
+
+
+def load_img_to_512_hwc_array(img_name):
+ img = cv2.imread(img_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = cv2.resize(img, (512, 512))
+ return img
+
+def load_img_to_normalized_512_bchw_tensor(img_name):
+ img = load_img_to_512_hwc_array(img_name)
+ img = ((torch.tensor(img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2) # [b,c,h,w]
+ return img
+
+def mirror_index(index, len_seq):
+ """
+ get mirror index when indexing a sequence and the index is larger than len_pose
+ args:
+ index: int
+ len_pose: int
+ return:
+ mirror_index: int
+ """
+ turn = index // len_seq
+ res = index % len_seq
+ if turn % 2 == 0:
+ return res # forward indexing
+ else:
+ return len_seq - res - 1 # reverse indexing
+
+def smooth_camera_sequence(camera, kernel_size=7):
+ """
+ smooth the camera trajectory (i.e., rotation & translation)...
+ args:
+ camera: [N, 25] or [N, 16]. np.ndarray
+ kernel_size: int
+ return:
+ smoothed_camera: [N, 25] or [N, 16]. np.ndarray
+ """
+ # poses: [N, 25], numpy array
+ N = camera.shape[0]
+ K = kernel_size // 2
+ poses = camera[:, :16].reshape([-1, 4, 4]).copy()
+ trans = poses[:, :3, 3].copy() # [N, 3]
+ rots = poses[:, :3, :3].copy() # [N, 3, 3]
+
+ for i in range(N):
+ start = max(0, i - K)
+ end = min(N, i + K + 1)
+ poses[i, :3, 3] = trans[start:end].mean(0)
+ try:
+ poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix()
+ except:
+ if i == 0:
+ poses[i, :3, :3] = rots[i]
+ else:
+ poses[i, :3, :3] = poses[i-1, :3, :3]
+ poses = poses.reshape([-1, 16])
+ camera[:, :16] = poses
+ return camera
+
+def smooth_features_xd(in_tensor, kernel_size=7):
+ """
+ smooth the feature maps
+ args:
+ in_tensor: [T, c,h,w] or [T, c1,c2,h,w]
+ kernel_size: int
+ return:
+ out_tensor: [T, c,h,w] or [T, c1,c2,h,w]
+ """
+ t = in_tensor.shape[0]
+ ndim = in_tensor.ndim
+ pad = (kernel_size- 1)//2
+ in_tensor = torch.cat([torch.flip(in_tensor[0:pad], dims=[0]), in_tensor, torch.flip(in_tensor[t-pad:t], dims=[0])], dim=0)
+ if ndim == 2: # tc
+ _,c = in_tensor.shape
+ in_tensor = in_tensor.permute(1,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ elif ndim == 4: # tchw
+ _,c,h,w = in_tensor.shape
+ in_tensor = in_tensor.permute(1,2,3,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ elif ndim == 5: # tcchw, like deformation
+ _,c1,c2, h,w = in_tensor.shape
+ in_tensor = in_tensor.permute(1,2,3,4,0).reshape([-1,1,t+2*pad]) # [c, 1, t]
+ else: raise NotImplementedError()
+ avg_kernel = 1 / kernel_size * torch.Tensor([1.]*kernel_size).reshape([1,1,kernel_size]).float().to(in_tensor.device) # [1, 1, kw]
+ out_tensor = F.conv1d(in_tensor, avg_kernel)
+ if ndim == 2: # tc
+ return out_tensor.reshape([c,t]).permute(1,0)
+ elif ndim == 4: # tchw
+ return out_tensor.reshape([c,h,w,t]).permute(3,0,1,2)
+ elif ndim == 5: # tcchw, like deformation
+ return out_tensor.reshape([c1,c2,h,w,t]).permute(4,0,1,2,3)
+
+
+def extract_audio_motion_from_ref_video(video_name):
+ def save_wav16k(audio_name):
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
+ wav16k_name = audio_name[:-4] + '_16k.wav'
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
+ os.system(extract_wav_cmd)
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
+ return wav16k_name
+
+ def get_f0( wav16k_name):
+ from data_gen.process_lrs3.process_audio_mel_f0 import extract_mel_from_fname,extract_f0_from_wav_and_mel
+ wav, mel = extract_mel_from_fname(wav16k_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ f0 = f0.reshape([-1,1])
+ f0 = torch.tensor(f0)
+ return f0
+
+ def get_hubert(wav16k_name):
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
+ len_mel = hubert.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
+ hubert = torch.tensor(hubert)
+ return hubert
+
+ def get_exp(video_name):
+ from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(video_name, save=False)
+ exp = torch.tensor(drv_motion_coeff_dict['exp'])
+ return exp
+
+ wav16k_name = save_wav16k(video_name)
+ f0 = get_f0(wav16k_name)
+ hubert = get_hubert(wav16k_name)
+ os.system(f"rm {wav16k_name}")
+ exp = get_exp(video_name)
+ target_length = min(len(exp), len(hubert)//2, len(f0)//2)
+ exp = exp[:target_length]
+ f0 = f0[:target_length*2]
+ hubert = hubert[:target_length*2]
+ return exp.unsqueeze(0), hubert.unsqueeze(0), f0.unsqueeze(0)
+
+
+if __name__ == '__main__':
+ extract_audio_motion_from_ref_video('data/raw/videos/crop_0213.mp4')
\ No newline at end of file
diff --git a/inference/real3d_infer.py b/inference/real3d_infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8121654a46f52bd4055bd6eaac0f93637f9d893f
--- /dev/null
+++ b/inference/real3d_infer.py
@@ -0,0 +1,625 @@
+import os
+import sys
+sys.path.append('./')
+import torch
+import torch.nn.functional as F
+import torchshow as ts
+import librosa
+import random
+import time
+import numpy as np
+import importlib
+import tqdm
+import copy
+import cv2
+import math
+
+# common utils
+from utils.commons.hparams import hparams, set_hparams
+from utils.commons.tensor_utils import move_to_cuda, convert_to_tensor
+from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint
+# 3DMM-related utils
+from deep_3drecon.deep_3drecon_models.bfm import ParametricFaceModel
+from data_util.face3d_helper import Face3DHelper
+from data_gen.utils.process_image.fit_3dmm_landmark import fit_3dmm_for_a_image
+from data_gen.utils.process_video.fit_3dmm_landmark import fit_3dmm_for_a_video
+from deep_3drecon.secc_renderer import SECC_Renderer
+from data_gen.eg3d.convert_to_eg3d_convention import get_eg3d_convention_camera_pose_intrinsic
+from data_gen.utils.process_image.extract_lm2d import extract_lms_mediapipe_job
+
+# Face Parsing
+from data_gen.utils.mp_feature_extractors.mp_segmenter import MediapipeSegmenter
+from data_gen.utils.process_video.extract_segment_imgs import inpaint_torso_job, extract_background
+# other inference utils
+from inference.infer_utils import mirror_index, load_img_to_512_hwc_array, load_img_to_normalized_512_bchw_tensor
+from inference.infer_utils import smooth_camera_sequence, smooth_features_xd
+from inference.edit_secc import blink_eye_for_secc
+
+
+def read_first_frame_from_a_video(vid_name):
+ frames = []
+ cap = cv2.VideoCapture(vid_name)
+ ret, frame_bgr = cap.read()
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
+ return frame_rgb
+
+def analyze_weights_img(gen_output):
+ img_raw = gen_output['image_raw']
+ mask_005_to_03 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.3).repeat([1,3,1,1])
+ mask_005_to_05 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.5).repeat([1,3,1,1])
+ mask_005_to_07 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.7).repeat([1,3,1,1])
+ mask_005_to_09 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<0.9).repeat([1,3,1,1])
+ mask_005_to_10 = torch.bitwise_and(gen_output['weights_img']>0.05, gen_output['weights_img']<1.0).repeat([1,3,1,1])
+
+ img_raw_005_to_03 = img_raw.clone()
+ img_raw_005_to_03[~mask_005_to_03] = -1
+ img_raw_005_to_05 = img_raw.clone()
+ img_raw_005_to_05[~mask_005_to_05] = -1
+ img_raw_005_to_07 = img_raw.clone()
+ img_raw_005_to_07[~mask_005_to_07] = -1
+ img_raw_005_to_09 = img_raw.clone()
+ img_raw_005_to_09[~mask_005_to_09] = -1
+ img_raw_005_to_10 = img_raw.clone()
+ img_raw_005_to_10[~mask_005_to_10] = -1
+ ts.save([img_raw_005_to_03[0], img_raw_005_to_05[0], img_raw_005_to_07[0], img_raw_005_to_09[0], img_raw_005_to_10[0]])
+
+def cal_face_area_percent(img_name):
+ img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
+ lm478 = extract_lms_mediapipe_job(img) / 512
+ min_x = lm478[:,0].min()
+ max_x = lm478[:,0].max()
+ min_y = lm478[:,1].min()
+ max_y = lm478[:,1].max()
+ area = (max_x - min_x) * (max_y - min_y)
+ return area
+
+def crop_img_on_face_area_percent(img_name, out_name='temp/cropped_src_img.png', min_face_area_percent=0.2):
+ try:
+ os.makedirs(os.path.dirname(out_name), exist_ok=True)
+ except: pass
+ face_area_percent = cal_face_area_percent(img_name)
+ if face_area_percent >= min_face_area_percent:
+ print(f"face area percent {face_area_percent} larger than threshold {min_face_area_percent}, directly use the input image...")
+ cmd = f"cp {img_name} {out_name}"
+ os.system(cmd)
+ return out_name
+ else:
+ print(f"face area percent {face_area_percent} smaller than threshold {min_face_area_percent}, crop the input image...")
+ img = cv2.resize(cv2.imread(img_name)[:,:,::-1], (512,512))
+ lm478 = extract_lms_mediapipe_job(img).astype(int)
+ min_x = lm478[:,0].min()
+ max_x = lm478[:,0].max()
+ min_y = lm478[:,1].min()
+ max_y = lm478[:,1].max()
+ face_area = (max_x - min_x) * (max_y - min_y)
+ target_total_area = face_area / min_face_area_percent
+ target_hw = int(target_total_area**0.5)
+ center_x, center_y = (min_x+max_x)/2, (min_y+max_y)/2
+ shrink_pixels = 2 * max(-(center_x - target_hw/2), center_x + target_hw/2 - 512, -(center_y - target_hw/2), center_y + target_hw/2-512)
+ shrink_pixels = max(0, shrink_pixels)
+ hw = math.floor(target_hw - shrink_pixels)
+ new_min_x = int(center_x - hw/2)
+ new_max_x = int(center_x + hw/2)
+ new_min_y = int(center_y - hw/2)
+ new_max_y = int(center_y + hw/2)
+
+ img = img[new_min_y:new_max_y, new_min_x:new_max_x]
+ img = cv2.resize(img, (512, 512))
+ cv2.imwrite(out_name, img[:,:,::-1])
+ return out_name
+
+
+class GeneFace2Infer:
+ def __init__(self, audio2secc_dir, head_model_dir, torso_model_dir, device=None, inp=None):
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.device = device
+ self.audio2secc_model = self.load_audio2secc(audio2secc_dir)
+ self.secc2video_model = self.load_secc2video(head_model_dir, torso_model_dir, inp)
+ self.audio2secc_model.to(device).eval()
+ self.secc2video_model.to(device).eval()
+ self.seg_model = MediapipeSegmenter()
+ self.secc_renderer = SECC_Renderer(512)
+ self.face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='lm68')
+ self.mp_face3d_helper = Face3DHelper(use_gpu=True, keypoint_mode='mediapipe')
+
+ def load_audio2secc(self, audio2secc_dir):
+ config_name = f"{audio2secc_dir}/config.yaml" if not audio2secc_dir.endswith(".ckpt") else f"{os.path.dirname(audio2secc_dir)}/config.yaml"
+ set_hparams(f"{config_name}", print_hparams=False)
+ self.audio2secc_dir = audio2secc_dir
+ self.audio2secc_hparams = copy.deepcopy(hparams)
+ from modules.audio2motion.vae import VAEModel, PitchContourVAEModel
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ audio_in_dim = 1024
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ audio_in_dim = 13
+
+ if 'icl' in hparams['task_cls']:
+ self.use_icl_audio2motion = True
+ model = InContextAudio2MotionModel(hparams['icl_model_type'], hparams=self.audio2secc_hparams)
+ else:
+ self.use_icl_audio2motion = False
+ if hparams.get("use_pitch", False) is True:
+ model = PitchContourVAEModel(hparams, in_out_dim=64, audio_in_dim=audio_in_dim)
+ else:
+ model = VAEModel(in_out_dim=64, audio_in_dim=audio_in_dim)
+ load_ckpt(model, f"{audio2secc_dir}", model_name='model', strict=True)
+ return model
+
+ def load_secc2video(self, head_model_dir, torso_model_dir, inp):
+ if inp is None:
+ inp = {}
+ self.head_model_dir = head_model_dir
+ self.torso_model_dir = torso_model_dir
+ if torso_model_dir != '':
+ if torso_model_dir.endswith(".ckpt"):
+ set_hparams(f"{os.path.dirname(torso_model_dir)}/config.yaml", print_hparams=False)
+ else:
+ set_hparams(f"{torso_model_dir}/config.yaml", print_hparams=False)
+ if inp.get('head_torso_threshold', None) is not None:
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane_Torso
+ model = OSAvatarSECC_Img2plane_Torso()
+ load_ckpt(model, f"{torso_model_dir}", model_name='model', strict=True)
+ if head_model_dir != '':
+ print("| Warning: Assigned --torso_ckpt which also contains head, but --head_ckpt is also assigned, skipping the --head_ckpt.")
+ else:
+ from modules.real3d.secc_img2plane_torso import OSAvatarSECC_Img2plane
+ if head_model_dir.endswith(".ckpt"):
+ set_hparams(f"{os.path.dirname(head_model_dir)}/config.yaml", print_hparams=False)
+ else:
+ set_hparams(f"{head_model_dir}/config.yaml", print_hparams=False)
+ if inp.get('head_torso_threshold', None) is not None:
+ hparams['htbsr_head_threshold'] = inp['head_torso_threshold']
+ self.secc2video_hparams = copy.deepcopy(hparams)
+ model = OSAvatarSECC_Img2plane()
+ load_ckpt(model, f"{head_model_dir}", model_name='model', strict=True)
+ return model
+
+ def infer_once(self, inp):
+ self.inp = inp
+ samples = self.prepare_batch_from_inp(inp)
+ seed = inp['seed'] if inp['seed'] is not None else int(time.time())
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ out_name = self.forward_system(samples, inp)
+ return out_name
+
+ def prepare_batch_from_inp(self, inp):
+ """
+ :param inp: {'audio_source_name': (str)}
+ :return: a dict that contains the condition feature of NeRF
+ """
+ tmp_img_name = 'infer_out/tmp/cropped_src_img.png'
+ crop_img_on_face_area_percent(inp['src_image_name'], tmp_img_name, min_face_area_percent=inp['min_face_area_percent'])
+ inp['src_image_name'] = tmp_img_name
+
+ sample = {}
+ # Process Driving Motion
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ self.save_wav16k(inp['drv_audio_name'])
+ if self.audio2secc_hparams['audio_type'] == 'hubert':
+ hubert = self.get_hubert(self.wav16k_name)
+ elif self.audio2secc_hparams['audio_type'] == 'mfcc':
+ hubert = self.get_mfcc(self.wav16k_name) / 100
+
+ f0 = self.get_f0(self.wav16k_name)
+ if f0.shape[0] > len(hubert):
+ f0 = f0[:len(hubert)]
+ else:
+ num_to_pad = len(hubert) - len(f0)
+ f0 = np.pad(f0, pad_width=((0,num_to_pad), (0,0)))
+ t_x = hubert.shape[0]
+ x_mask = torch.ones([1, t_x]).float() # mask for audio frames
+ y_mask = torch.ones([1, t_x//2]).float() # mask for motion/image frames
+ sample.update({
+ 'hubert': torch.from_numpy(hubert).float().unsqueeze(0).cuda(),
+ 'f0': torch.from_numpy(f0).float().reshape([1,-1]).cuda(),
+ 'x_mask': x_mask.cuda(),
+ 'y_mask': y_mask.cuda(),
+ })
+ sample['blink'] = torch.zeros([1, t_x, 1]).long().cuda()
+ sample['audio'] = sample['hubert']
+ sample['eye_amp'] = torch.ones([1, 1]).cuda() * 1.0
+ sample['mouth_amp'] = torch.ones([1, 1]).cuda() * inp['mouth_amp']
+ elif inp['drv_audio_name'][-4:] in ['.mp4']:
+ drv_motion_coeff_dict = fit_3dmm_for_a_video(inp['drv_audio_name'], save=False)
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+ elif inp['drv_audio_name'][-4:] in ['.npy']:
+ drv_motion_coeff_dict = np.load(inp['drv_audio_name'], allow_pickle=True).tolist()
+ drv_motion_coeff_dict = convert_to_tensor(drv_motion_coeff_dict)
+ t_x = drv_motion_coeff_dict['exp'].shape[0] * 2
+ self.drv_motion_coeff_dict = drv_motion_coeff_dict
+
+ # Face Parsing
+ image_name = inp['src_image_name']
+ if image_name.endswith(".mp4"):
+ img = read_first_frame_from_a_video(image_name)
+ image_name = inp['src_image_name'] = image_name[:-4] + '.png'
+ cv2.imwrite(image_name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+ sample['ref_gt_img'] = load_img_to_normalized_512_bchw_tensor(image_name).cuda()
+ img = load_img_to_512_hwc_array(image_name)
+ segmap = self.seg_model._cal_seg_map(img)
+ sample['segmap'] = torch.tensor(segmap).float().unsqueeze(0).cuda()
+ head_img = self.seg_model._seg_out_img_with_segmap(img, segmap, mode='head')[0]
+ sample['ref_head_img'] = ((torch.tensor(head_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+ ts.save(sample['ref_head_img'])
+ inpaint_torso_img, _, _, _ = inpaint_torso_job(img, segmap)
+ sample['ref_torso_img'] = ((torch.tensor(inpaint_torso_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ if inp['bg_image_name'] == '':
+ bg_img = extract_background([img], [segmap], 'knn')
+ else:
+ bg_img = cv2.imread(inp['bg_image_name'])
+ bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)
+ bg_img = cv2.resize(bg_img, (512,512))
+ sample['bg_img'] = ((torch.tensor(bg_img) - 127.5)/127.5).float().unsqueeze(0).permute(0, 3, 1,2).cuda() # [b,c,h,w]
+
+ # 3DMM, get identity code and camera pose
+ coeff_dict = fit_3dmm_for_a_image(image_name, save=False)
+ assert coeff_dict is not None
+ src_id = torch.tensor(coeff_dict['id']).reshape([1,80]).cuda()
+ src_exp = torch.tensor(coeff_dict['exp']).reshape([1,64]).cuda()
+ src_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda()
+ src_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda()
+ sample['id'] = src_id.repeat([t_x//2,1])
+
+ # get the src_kp for torso model
+ src_kp = self.face3d_helper.reconstruct_lm2d(src_id, src_exp, src_euler, src_trans) # [1, 68, 2]
+ src_kp = (src_kp-0.5) / 0.5 # rescale to -1~1
+ sample['src_kp'] = torch.clamp(src_kp, -1, 1).repeat([t_x//2,1,1])
+
+ # get camera pose file
+ # random.seed(time.time())
+ inp['drv_pose_name'] = inp['drv_pose_name']
+ print(f"| To extract pose from {inp['drv_pose_name']}")
+
+ # extract camera pose
+ if inp['drv_pose_name'] == 'static':
+ sample['euler'] = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda().repeat([t_x//2,1]) # default static pose
+ sample['trans'] = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda().repeat([t_x//2,1])
+ else: # from file
+ if inp['drv_pose_name'].endswith('.mp4'):
+ # extract coeff from video
+ drv_pose_coeff_dict = fit_3dmm_for_a_video(inp['drv_pose_name'], save=False)
+ else:
+ # load from npy
+ drv_pose_coeff_dict = np.load(inp['drv_pose_name'], allow_pickle=True).tolist()
+ print(f"| Extracted pose from {inp['drv_pose_name']}")
+ eulers = convert_to_tensor(drv_pose_coeff_dict['euler']).reshape([-1,3]).cuda()
+ trans = convert_to_tensor(drv_pose_coeff_dict['trans']).reshape([-1,3]).cuda()
+ len_pose = len(eulers)
+ index_lst = [mirror_index(i, len_pose) for i in range(t_x//2)]
+ sample['euler'] = eulers[index_lst]
+ sample['trans'] = trans[index_lst]
+
+ # fix the z axis
+ sample['trans'][:, -1] = sample['trans'][0:1, -1].repeat([sample['trans'].shape[0]])
+
+ # mapping to the init pose
+ print(inp)
+ if inp.get("map_to_init_pose", 'True') in ['True', True]:
+ diff_euler = torch.tensor(coeff_dict['euler']).reshape([1,3]).cuda() - sample['euler'][0:1]
+ sample['euler'] = sample['euler'] + diff_euler
+ diff_trans = torch.tensor(coeff_dict['trans']).reshape([1,3]).cuda() - sample['trans'][0:1]
+ sample['trans'] = sample['trans'] + diff_trans
+
+ # prepare camera
+ camera_ret = get_eg3d_convention_camera_pose_intrinsic({'euler':sample['euler'].cpu(), 'trans':sample['trans'].cpu()})
+ c2w, intrinsics = camera_ret['c2w'], camera_ret['intrinsics']
+ # smooth camera
+ camera_smo_ksize = 7
+ camera = np.concatenate([c2w.reshape([-1,16]), intrinsics.reshape([-1,9])], axis=-1)
+ camera = smooth_camera_sequence(camera, kernel_size=camera_smo_ksize) # [T, 25]
+ camera = torch.tensor(camera).cuda().float()
+ sample['camera'] = camera
+
+ return sample
+
+ @torch.no_grad()
+ def get_hubert(self, wav16k_name):
+ from data_gen.utils.process_audio.extract_hubert import get_hubert_from_16k_wav
+ hubert = get_hubert_from_16k_wav(wav16k_name).detach().numpy()
+ len_mel = hubert.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ hubert = np.pad(hubert, pad_width=((0,num_to_pad), (0,0)))
+ return hubert
+
+ def get_mfcc(self, wav16k_name):
+ from utils.audio import librosa_wav2mfcc
+ hparams['fft_size'] = 1200
+ hparams['win_size'] = 1200
+ hparams['hop_size'] = 480
+ hparams['audio_num_mel_bins'] = 80
+ hparams['fmin'] = 80
+ hparams['fmax'] = 12000
+ hparams['audio_sample_rate'] = 24000
+ mfcc = librosa_wav2mfcc(wav16k_name,
+ fft_size=hparams['fft_size'],
+ hop_size=hparams['hop_size'],
+ win_length=hparams['win_size'],
+ num_mels=hparams['audio_num_mel_bins'],
+ fmin=hparams['fmin'],
+ fmax=hparams['fmax'],
+ sample_rate=hparams['audio_sample_rate'],
+ center=True)
+ mfcc = np.array(mfcc).reshape([-1, 13])
+ len_mel = mfcc.shape[0]
+ x_multiply = 8
+ if len_mel % x_multiply == 0:
+ num_to_pad = 0
+ else:
+ num_to_pad = x_multiply - len_mel % x_multiply
+ mfcc = np.pad(mfcc, pad_width=((0,num_to_pad), (0,0)))
+ return mfcc
+
+ @torch.no_grad()
+ def forward_audio2secc(self, batch, inp=None):
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ # audio-to-exp
+ ret = {}
+ pred = self.audio2secc_model.forward(batch, ret=ret,train=False, temperature=inp['temperature'],)
+ print("| audio-to-motion finished")
+ if pred.shape[-1] == 144:
+ id = ret['pred'][0][:,:80]
+ exp = ret['pred'][0][:,80:]
+ else:
+ id = batch['id']
+ exp = ret['pred'][0]
+ if len(id) < len(exp): # happens when use ICL
+ id = torch.cat([id, id[0].unsqueeze(0).repeat([len(exp)-len(id),1])])
+ batch['id'] = id
+ batch['exp'] = exp
+ else:
+ drv_motion_coeff_dict = self.drv_motion_coeff_dict
+ batch['exp'] = torch.FloatTensor(drv_motion_coeff_dict['exp']).cuda()
+
+ batch = self.get_driving_motion(batch['id'], batch['exp'], batch['euler'], batch['trans'], batch, inp)
+ if self.use_icl_audio2motion:
+ self.audio2secc_model.empty_context()
+ return batch
+
+ @torch.no_grad()
+ def get_driving_motion(self, id, exp, euler, trans, batch, inp):
+ zero_eulers = torch.zeros([id.shape[0], 3]).to(id.device)
+ zero_trans = torch.zeros([id.shape[0], 3]).to(exp.device)
+ # render the secc given the id,exp
+ with torch.no_grad():
+ chunk_size = 50
+ drv_secc_color_lst = []
+ num_iters = len(id)//chunk_size if len(id)%chunk_size == 0 else len(id)//chunk_size+1
+ for i in tqdm.trange(num_iters, desc="rendering drv secc"):
+ torch.cuda.empty_cache()
+ face_mask, drv_secc_color = self.secc_renderer(id[i*chunk_size:(i+1)*chunk_size], exp[i*chunk_size:(i+1)*chunk_size], zero_eulers[i*chunk_size:(i+1)*chunk_size], zero_trans[i*chunk_size:(i+1)*chunk_size])
+ drv_secc_color_lst.append(drv_secc_color.cpu())
+ drv_secc_colors = torch.cat(drv_secc_color_lst, dim=0)
+ _, src_secc_color = self.secc_renderer(id[0:1], exp[0:1], zero_eulers[0:1], zero_trans[0:1])
+ _, cano_secc_color = self.secc_renderer(id[0:1], exp[0:1]*0, zero_eulers[0:1], zero_trans[0:1])
+ batch['drv_secc'] = drv_secc_colors.cuda()
+ batch['src_secc'] = src_secc_color.cuda()
+ batch['cano_secc'] = cano_secc_color.cuda()
+
+ # blinking secc
+ if inp['blink_mode'] == 'period':
+ period = 5 # second
+
+ for i in tqdm.trange(len(drv_secc_colors),desc="blinking secc"):
+ if i % (25*period) == 0:
+ blink_dur_frames = random.randint(8, 12)
+ for offset in range(blink_dur_frames):
+ j = offset + i
+ if j >= len(drv_secc_colors)-1: break
+ def blink_percent_fn(t, T):
+ return -4/T**2 * t**2 + 4/T * t
+ blink_percent = blink_percent_fn(offset, blink_dur_frames)
+ secc = batch['drv_secc'][j]
+ out_secc = blink_eye_for_secc(secc, blink_percent)
+ out_secc = out_secc.cuda()
+ batch['drv_secc'][j] = out_secc
+
+ # get the drv_kp for torso model, using the transformed trajectory
+ drv_kp = self.face3d_helper.reconstruct_lm2d(id, exp, euler, trans) # [T, 68, 2]
+
+ drv_kp = (drv_kp-0.5) / 0.5 # rescale to -1~1
+ batch['drv_kp'] = torch.clamp(drv_kp, -1, 1)
+ return batch
+
+ @torch.no_grad()
+ def forward_secc2video(self, batch, inp=None):
+ num_frames = len(batch['drv_secc'])
+ camera = batch['camera']
+ src_kps = batch['src_kp']
+ drv_kps = batch['drv_kp']
+ cano_secc_color = batch['cano_secc']
+ src_secc_color = batch['src_secc']
+ drv_secc_colors = batch['drv_secc']
+ ref_img_gt = batch['ref_gt_img']
+ ref_img_head = batch['ref_head_img']
+ ref_torso_img = batch['ref_torso_img']
+ bg_img = batch['bg_img']
+ segmap = batch['segmap']
+
+ # smooth torso drv_kp
+ torso_smo_ksize = 7
+ drv_kps = smooth_features_xd(drv_kps.reshape([-1, 68*2]), kernel_size=torso_smo_ksize).reshape([-1, 68, 2])
+
+ # forward renderer
+ if inp['low_memory_usage']:
+ # save memory, when one image is rendered, write it into video
+ import imageio
+ debug_name = 'demo.mp4'
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
+
+ with torch.no_grad():
+ for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
+ 'kp_s': kp_src, 'kp_d': kp_drv}
+ if i == 0:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
+ else:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+ img = ((gen_output['image']+1)/2 * 255.).permute(0, 2, 3, 1)[0].int().cpu().numpy().astype(np.uint8)
+ writer.append_data(img)
+ writer.close()
+ else:
+ img_raw_lst = []
+ img_lst = []
+ depth_img_lst = []
+ with torch.no_grad():
+ for i in tqdm.trange(num_frames, desc="Real3D-Portrait is rendering frames"):
+ kp_src = torch.cat([src_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(src_kps.device)],dim=-1)
+ kp_drv = torch.cat([drv_kps[i:i+1].reshape([1, 68, 2]), torch.zeros([1, 68,1]).to(drv_kps.device)],dim=-1)
+ cond={'cond_cano': cano_secc_color,'cond_src': src_secc_color, 'cond_tgt': drv_secc_colors[i:i+1].cuda(),
+ 'ref_torso_img': ref_torso_img, 'bg_img': bg_img, 'segmap': segmap,
+ 'kp_s': kp_src, 'kp_d': kp_drv}
+ if i == 0:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=True, use_cached_backbone=False)
+ else:
+ gen_output = self.secc2video_model.forward(img=ref_img_head, camera=camera[i:i+1], cond=cond, ret={}, cache_backbone=False, use_cached_backbone=True)
+ img_lst.append(gen_output['image'])
+ img_raw_lst.append(gen_output['image_raw'])
+ depth_img_lst.append(gen_output['image_depth'])
+
+ # save demo video
+ depth_imgs = torch.cat(depth_img_lst)
+ imgs = torch.cat(img_lst)
+ imgs_raw = torch.cat(img_raw_lst)
+ secc_img = torch.cat([torch.nn.functional.interpolate(drv_secc_colors[i:i+1], (512,512)) for i in range(num_frames)])
+
+ if inp['out_mode'] == 'concat_debug':
+ secc_img = secc_img.cpu()
+ secc_img = ((secc_img + 1) * 127.5).permute(0, 2, 3, 1).int().numpy()
+
+ depth_img = F.interpolate(depth_imgs, (512,512)).cpu()
+ depth_img = depth_img.repeat([1,3,1,1])
+ depth_img = (depth_img - depth_img.min()) / (depth_img.max() - depth_img.min())
+ depth_img = depth_img * 2 - 1
+ depth_img = depth_img.clamp(-1,1)
+
+ secc_img = secc_img / 127.5 - 1
+ secc_img = torch.from_numpy(secc_img).permute(0, 3, 1, 2)
+ imgs = torch.cat([ref_img_gt.repeat([imgs.shape[0],1,1,1]).cpu(), secc_img, F.interpolate(imgs_raw, (512,512)).cpu(), depth_img, imgs.cpu()], dim=-1)
+ elif inp['out_mode'] == 'final':
+ imgs = imgs.cpu()
+ elif inp['out_mode'] == 'debug':
+ raise NotImplementedError("to do: save separate videos")
+ imgs = imgs.clamp(-1,1)
+
+ import imageio
+ debug_name = 'demo.mp4'
+ out_imgs = ((imgs.permute(0, 2, 3, 1) + 1)/2 * 255).int().cpu().numpy().astype(np.uint8)
+ writer = imageio.get_writer(debug_name, fps=25, format='FFMPEG', codec='h264')
+
+ for i in tqdm.trange(len(out_imgs), desc="Imageio is saving video"):
+ writer.append_data(out_imgs[i])
+ writer.close()
+
+ # add audio track
+ out_fname = 'infer_out/tmp/' + os.path.basename(inp['src_image_name'])[:-4] + '_' + os.path.basename(inp['drv_pose_name'])[:-4] + '.mp4' if inp['out_name'] == '' else inp['out_name']
+ try:
+ os.makedirs(os.path.dirname(out_fname), exist_ok=True)
+ except: pass
+ if inp['drv_audio_name'][-4:] in ['.wav', '.mp3']:
+ os.system(f"ffmpeg -i {debug_name} -i {self.wav16k_name} -y -v quiet -shortest {out_fname}")
+ os.system(f"rm {debug_name}")
+ os.system(f"rm {self.wav16k_name}")
+ else:
+ ret = os.system(f"ffmpeg -i {debug_name} -i {inp['drv_audio_name']} -map 0:v -map 1:a -y -v quiet -shortest {out_fname}")
+ if ret != 0: # 没有成功从drv_audio_name里面提取到音频, 则直接输出无音频轨道的纯视频
+ os.system(f"mv {debug_name} {out_fname}")
+ print(f"Saved at {out_fname}")
+ return out_fname
+
+ @torch.no_grad()
+ def forward_system(self, batch, inp):
+ self.forward_audio2secc(batch, inp)
+ out_fname = self.forward_secc2video(batch, inp)
+ return out_fname
+
+ @classmethod
+ def example_run(cls, inp=None):
+ inp_tmp = {
+ 'drv_audio_name': 'data/raw/val_wavs/zozo.wav',
+ 'src_image_name': 'data/raw/val_imgs/Macron.png'
+ }
+ if inp is not None:
+ inp_tmp.update(inp)
+ inp = inp_tmp
+
+ infer_instance = cls(inp['a2m_ckpt'], inp['head_ckpt'], inp['torso_ckpt'], inp=inp)
+ infer_instance.infer_once(inp)
+
+ ##############
+ # IO-related
+ ##############
+ def save_wav16k(self, audio_name):
+ supported_types = ('.wav', '.mp3', '.mp4', '.avi')
+ assert audio_name.endswith(supported_types), f"Now we only support {','.join(supported_types)} as audio source!"
+ wav16k_name = audio_name[:-4] + '_16k.wav'
+ self.wav16k_name = wav16k_name
+ extract_wav_cmd = f"ffmpeg -i {audio_name} -f wav -ar 16000 -v quiet -y {wav16k_name} -y"
+ os.system(extract_wav_cmd)
+ print(f"Extracted wav file (16khz) from {audio_name} to {wav16k_name}.")
+
+ def get_f0(self, wav16k_name):
+ from data_gen.utils.process_audio.extract_mel_f0 import extract_mel_from_fname, extract_f0_from_wav_and_mel
+ wav, mel = extract_mel_from_fname(self.wav16k_name)
+ f0, f0_coarse = extract_f0_from_wav_and_mel(wav, mel)
+ f0 = f0.reshape([-1,1])
+ return f0
+
+if __name__ == '__main__':
+ import argparse, glob, tqdm
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--a2m_ckpt", default='checkpoints/240210_real3dportrait_orig/audio2secc_vae', type=str)
+ parser.add_argument("--head_ckpt", default='', type=str)
+ parser.add_argument("--torso_ckpt", default='checkpoints/240210_real3dportrait_orig/secc2plane_torso_orig', type=str)
+ parser.add_argument("--src_img", default='data/raw/examples/Macron.png', type=str) # data/raw/examples/Macron.png
+ parser.add_argument("--bg_img", default='', type=str) # data/raw/examples/bg.png
+ parser.add_argument("--drv_aud", default='data/raw/examples/Obama_5s.wav', type=str) # data/raw/examples/Obama_5s.wav
+ parser.add_argument("--drv_pose", default='data/raw/examples/May_5s.mp4', type=str) # data/raw/examples/May_5s.mp4
+ parser.add_argument("--blink_mode", default='period', type=str) # none | period
+ parser.add_argument("--temperature", default=0.2, type=float) # sampling temperature in audio2motion, higher -> more diverse, less accurate
+ parser.add_argument("--mouth_amp", default=0.45, type=float) # scale of predicted mouth, enabled in audio-driven
+ parser.add_argument("--head_torso_threshold", default=None, type=float, help="0.1~1.0, turn up this value if the hair is translucent")
+ parser.add_argument("--out_name", default='') # output filename
+ parser.add_argument("--out_mode", default='concat_debug') # final: only output talking head video; concat_debug: talking head with internel features
+ parser.add_argument("--map_to_init_pose", default='True') # whether to map the pose of first frame to source image
+ parser.add_argument("--seed", default=None, type=int) # random seed, default None to use time.time()
+ parser.add_argument("--min_face_area_percent", default=0.2, type=float) # scale of predicted mouth, enabled in audio-driven
+ parser.add_argument("--low_memory_usage", action='store_true', help='write img to video upon generated, leads to slower fps, but use less memory')
+
+ args = parser.parse_args()
+
+ inp = {
+ 'a2m_ckpt': args.a2m_ckpt,
+ 'head_ckpt': args.head_ckpt,
+ 'torso_ckpt': args.torso_ckpt,
+ 'src_image_name': args.src_img,
+ 'bg_image_name': args.bg_img,
+ 'drv_audio_name': args.drv_aud,
+ 'drv_pose_name': args.drv_pose,
+ 'blink_mode': args.blink_mode,
+ 'temperature': args.temperature,
+ 'mouth_amp': args.mouth_amp,
+ 'out_name': args.out_name,
+ 'out_mode': args.out_mode,
+ 'map_to_init_pose': args.map_to_init_pose,
+ 'head_torso_threshold': args.head_torso_threshold,
+ 'seed': args.seed,
+ 'min_face_area_percent': args.min_face_area_percent,
+ 'low_memory_usage': args.low_memory_usage,
+ }
+
+ GeneFace2Infer.example_run(inp)
\ No newline at end of file
diff --git a/inference/real3dportrait_demo.ipynb b/inference/real3dportrait_demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..03a4b31c363b71c90efc16b1abde1cc13d59263e
--- /dev/null
+++ b/inference/real3dportrait_demo.ipynb
@@ -0,0 +1,287 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text",
+ "id": "view-in-github"
+ },
+ "source": [
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "QS04K9oO21AW"
+ },
+ "source": [
+ "Check GPU"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1ESQRDb-yVUG"
+ },
+ "outputs": [],
+ "source": [
+ "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "y-ctmIvu3Ei8"
+ },
+ "source": [
+ "Installation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "gXu76wdDgaxo"
+ },
+ "outputs": [],
+ "source": [
+ "# install pytorch3d, about 15s\n",
+ "import os\n",
+ "import sys\n",
+ "import torch\n",
+ "need_pytorch3d=False\n",
+ "try:\n",
+ " import pytorch3d\n",
+ "except ModuleNotFoundError:\n",
+ " need_pytorch3d=True\n",
+ "if need_pytorch3d:\n",
+ " if torch.__version__.startswith(\"2.1.\") and sys.platform.startswith(\"linux\"):\n",
+ " # We try to install PyTorch3D via a released wheel.\n",
+ " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
+ " version_str=\"\".join([\n",
+ " f\"py3{sys.version_info.minor}_cu\",\n",
+ " torch.version.cuda.replace(\".\",\"\"),\n",
+ " f\"_pyt{pyt_version_str}\"\n",
+ " ])\n",
+ " !pip install fvcore iopath\n",
+ " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
+ " else:\n",
+ " # We try to install PyTorch3D from source.\n",
+ " !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "DuUynxmotG_-"
+ },
+ "outputs": [],
+ "source": [
+ "# install dependencies, about 5~10 min\n",
+ "!pip install tensorboard==2.13.0 tensorboardX==2.6.1\n",
+ "!pip install pyspy==0.1.1\n",
+ "!pip install protobuf==3.20.3\n",
+ "!pip install scipy==1.9.1\n",
+ "!pip install kornia==0.5.0\n",
+ "!pip install trimesh==3.22.0\n",
+ "!pip install einops==0.6.1 torchshow==0.5.1\n",
+ "!pip install imageio==2.31.1 imageio-ffmpeg==0.4.8\n",
+ "!pip install scikit-learn==1.3.0 scikit-image==0.21.0\n",
+ "!pip install av==10.0.0 lpips==0.1.4\n",
+ "!pip install timm==0.9.2 librosa==0.9.2\n",
+ "!pip install openmim==0.3.9\n",
+ "!mim install mmcv==2.1.0 # use mim to speed up installation for mmcv\n",
+ "!pip install transformers==4.33.2\n",
+ "!pip install pretrainedmodels==0.7.4\n",
+ "!pip install ninja==1.11.1\n",
+ "!pip install faiss-cpu==1.7.4\n",
+ "!pip install praat-parselmouth==0.4.3 moviepy==1.0.3\n",
+ "!pip install mediapipe==0.10.7\n",
+ "!pip install --upgrade attr\n",
+ "!pip install beartype==0.16.4 gateloop_transformer==0.4.0\n",
+ "!pip install torchode==0.2.0 torchdiffeq==0.2.3\n",
+ "!pip install hydra-core==1.3.2 pandas==2.1.3\n",
+ "!pip install pytorch_lightning==2.1.2\n",
+ "!pip install httpx==0.23.3\n",
+ "!pip install gradio==4.16.0\n",
+ "!pip install gdown\n",
+ "!pip install pyloudnorm webrtcvad pyworld==0.2.1rc0 pypinyin==0.42.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "0GLEV0HVu8rj"
+ },
+ "outputs": [],
+ "source": [
+ "# RESTART kernel to make sure runtime is correct if you meet runtime errors\n",
+ "# import os\n",
+ "# os.kill(os.getpid(), 9)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "5UfKHKrH6kcq"
+ },
+ "source": [
+ "Clone code and download checkpoints"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "-gfRsd9DwIgl"
+ },
+ "outputs": [],
+ "source": [
+ "# clone Real3DPortrait repo from github\n",
+ "!git clone https://github.com/yerfor/Real3DPortrait\n",
+ "%cd Real3DPortrait"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "Yju8dQY7x5OS"
+ },
+ "outputs": [],
+ "source": [
+ "# download pretrained ckpts & third-party ckpts from google drive, about 1 min\n",
+ "!pip install --upgrade --no-cache-dir gdown\n",
+ "%cd deep_3drecon/BFM\n",
+ "!gdown https://drive.google.com/uc?id=1SPM3IHsyNAaVMwqZZGV6QVaV7I2Hly0v\n",
+ "!gdown https://drive.google.com/uc?id=1MSldX9UChKEb3AXLVTPzZQcsbGD4VmGF\n",
+ "!gdown https://drive.google.com/uc?id=180ciTvm16peWrcpl4DOekT9eUQ-lJfMU\n",
+ "!gdown https://drive.google.com/uc?id=1KX9MyGueFB3M-X0Ss152x_johyTXHTfU\n",
+ "!gdown https://drive.google.com/uc?id=19-NyZn_I0_mkF-F5GPyFMwQJ_-WecZIL\n",
+ "!gdown https://drive.google.com/uc?id=11ouQ7Wr2I-JKStp2Fd1afedmWeuifhof\n",
+ "!gdown https://drive.google.com/uc?id=18ICIvQoKX-7feYWP61RbpppzDuYTptCq\n",
+ "!gdown https://drive.google.com/uc?id=1VktuY46m0v_n_d4nvOupauJkK4LF6mHE\n",
+ "%cd ../..\n",
+ "\n",
+ "%cd checkpoints\n",
+ "!gdown https://drive.google.com/uc?id=1gz8A6xestHp__GbZT5qozb43YaybRJhZ\n",
+ "!gdown https://drive.google.com/uc?id=1gSUIw2AkkKnlLJnNfS2FCqtaVw9tw3QF\n",
+ "!unzip 240210_real3dportrait_orig.zip\n",
+ "!unzip pretrained_ckpts.zip\n",
+ "!ls\n",
+ "%cd ..\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "LHzLro206pnA"
+ },
+ "source": [
+ "Inference sample"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!python inference/real3d_infer.py -h"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "2aCDwxNivQoS"
+ },
+ "outputs": [],
+ "source": [
+ "# sample inference, about 3 min\n",
+ "!python inference/real3d_infer.py \\\n",
+ "--src_img data/raw/examples/Macron.png \\\n",
+ "--drv_aud data/raw/examples/Obama_5s.wav \\\n",
+ "--drv_pose data/raw/examples/May_5s.mp4 \\\n",
+ "--bg_img data/raw/examples/bg.png \\\n",
+ "--out_name output.mp4 \\\n",
+ "--out_mode concat_debug \\\n",
+ "--low_memory_usage"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "XL0c54l19mBG"
+ },
+ "source": [
+ "Display output video"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "6olmWwZP9Icj"
+ },
+ "outputs": [],
+ "source": [
+ "# borrow code from makeittalk\n",
+ "from IPython.display import HTML\n",
+ "from base64 import b64encode\n",
+ "import os, sys\n",
+ "import glob\n",
+ "\n",
+ "mp4_name = './output.mp4'\n",
+ "\n",
+ "mp4 = open('{}'.format(mp4_name),'rb').read()\n",
+ "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
+ "\n",
+ "print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
+ "display(HTML(\"\"\"\n",
+ " \n",
+ " \n",
+ " \n",
+ " \"\"\" % data_url))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "WebUI"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "!python inference/app_real3dportrait.py --share"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "authorship_tag": "ABX9TyPu++zOlOS4yKF4xn4FHGtZ",
+ "gpuType": "T4",
+ "include_colab_link": true,
+ "private_outputs": true,
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/install.sh b/install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7e59cf64b97ef2c923857de968b91af02478cce2
--- /dev/null
+++ b/install.sh
@@ -0,0 +1,21 @@
+conda install conda-forge::ffmpeg
+
+### We recommend torch2.0.1+cuda11.7.
+conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
+
+# Install from pytorch3d from conda (For fast installation, Linux only)
+conda install pytorch3d::pytorch3d
+## Alternatively, a choice of compatibility, build from Github's source code.
+## It may take a long time (maybe tens of minutes), Proxy is recommended if encountering the time-out problem
+pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
+
+# MMCV for some network structure
+pip install cython
+pip install openmim==0.3.9
+mim install mmcv==2.1.0 # use mim to speed up installation for mmcv
+
+# other dependencies
+pip install -r docs/prepare_env/requirements.txt -v
+
+# If you encounter the following error, please try to install the dependencies with the following command:
+pip install -r docs/prepare_env/requirements.txt -v --use-deprecated=legacy-resolver
\ No newline at end of file
diff --git a/modules/audio2motion/cnn_models.py b/modules/audio2motion/cnn_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58e8c472349f59ab1f733a384906644a0b796c2
--- /dev/null
+++ b/modules/audio2motion/cnn_models.py
@@ -0,0 +1,359 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def init_weights_func(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv1d") != -1:
+ torch.nn.init.xavier_uniform_(m.weight)
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+
+class ResidualBlock(nn.Module):
+ """Implements conv->PReLU->norm n-times"""
+
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
+ c_multiple=2, ln_eps=1e-12, bias=False):
+ super(ResidualBlock, self).__init__()
+
+ if norm_type == 'bn':
+ norm_builder = lambda: nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm_builder = lambda: nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
+ else:
+ norm_builder = lambda: nn.Identity()
+
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
+ padding=(dilation * (kernel_size - 1)) // 2, bias=bias),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ nn.GELU(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, bias=bias),
+ )
+ for _ in range(n)
+ ]
+
+ self.blocks = nn.ModuleList(self.blocks)
+ self.dropout = dropout
+
+ def forward(self, x):
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ for b in self.blocks:
+ x_ = b(x)
+ if self.dropout > 0 and self.training:
+ x_ = F.dropout(x_, self.dropout, training=self.training)
+ x = x + x_
+ x = x * nonpadding
+ return x
+
+
+class ConvBlocks(nn.Module):
+ """Decodes the expanded phoneme encoding into spectrograms"""
+
+ def __init__(self, channels, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, bias=False):
+ super(ConvBlocks, self).__init__()
+ self.is_BTC = is_BTC
+ self.res_blocks = nn.Sequential(
+ *[ResidualBlock(channels, kernel_size, d,
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
+ dropout=dropout, ln_eps=ln_eps, bias=bias)
+ for d in dilations],
+ )
+ if norm_type == 'bn':
+ norm = nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm = nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm = nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm = LayerNorm(channels, dim=1, eps=ln_eps)
+ self.last_norm = norm
+ self.post_net1 = nn.Conv1d(channels, out_dims, kernel_size=3, padding=1, bias=bias)
+ if init_weights:
+ self.apply(init_weights_func)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ if self.is_BTC:
+ x = x.transpose(1, 2) # [B, C, T]
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ x = self.res_blocks(x) * nonpadding
+ x = self.last_norm(x) * nonpadding
+ x = self.post_net1(x) * nonpadding
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x
+
+
+class SeqLevelConvolutionalModel(nn.Module):
+ def __init__(self, out_dim=64, dropout=0.5, audio_feat_type='ppg', backbone_type='unet', norm_type='bn'):
+ nn.Module.__init__(self)
+ self.audio_feat_type = audio_feat_type
+ if audio_feat_type == 'ppg':
+ self.audio_encoder = nn.Sequential(*[
+ nn.Conv1d(29, 48, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(48) if norm_type=='bn' else LayerNorm(48, dim=1),
+ nn.GELU(),
+ nn.Conv1d(48, 48, 3, 1, 1, bias=False)
+ ])
+ self.energy_encoder = nn.Sequential(*[
+ nn.Conv1d(1, 16, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(16) if norm_type=='bn' else LayerNorm(16, dim=1),
+ nn.GELU(),
+ nn.Conv1d(16, 16, 3, 1, 1, bias=False)
+ ])
+ elif audio_feat_type == 'mel':
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(80, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64) if norm_type=='bn' else LayerNorm(64, dim=1),
+ nn.GELU(),
+ nn.Conv1d(64, 64, 3, 1, 1, bias=False)
+ ])
+ else:
+ raise NotImplementedError("now only ppg or mel are supported!")
+
+ self.style_encoder = nn.Sequential(*[
+ nn.Linear(135, 256),
+ nn.GELU(),
+ nn.Linear(256, 256)
+ ])
+
+ if backbone_type == 'resnet':
+ self.backbone = ResNetBackbone()
+ elif backbone_type == 'unet':
+ self.backbone = UNetBackbone()
+ elif backbone_type == 'resblocks':
+ self.backbone = ResBlocksBackbone()
+ else:
+ raise NotImplementedError("Now only resnet and unet are supported!")
+
+ self.out_layer = nn.Sequential(
+ nn.BatchNorm1d(512) if norm_type=='bn' else LayerNorm(512, dim=1),
+ nn.Conv1d(512, 64, 3, 1, 1, bias=False),
+ nn.PReLU(),
+ nn.Conv1d(64, out_dim, 3, 1, 1, bias=False)
+ )
+ self.feat_dropout = nn.Dropout(p=dropout)
+
+ @property
+ def device(self):
+ return self.backbone.parameters().__next__().device
+
+ def forward(self, batch, ret, log_dict=None):
+ style, x_mask = batch['style'].to(self.device), batch['x_mask'].to(self.device)
+ style_feat = self.style_encoder(style) # [B,C=135] => [B,C=128]
+
+ if self.audio_feat_type == 'ppg':
+ audio, energy = batch['audio'].to(self.device), batch['energy'].to(self.device)
+ audio_feat = self.audio_encoder(audio.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=29] => [B,T,C=48]
+ energy_feat = self.energy_encoder(energy.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=1] => [B,T,C=16]
+ feat = torch.cat([audio_feat, energy_feat], dim=2) # [B,T,C=48+16]
+ elif self.audio_feat_type == 'mel':
+ mel = batch['mel'].to(self.device)
+ feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T,C=64]
+
+ feat, x_mask = self.backbone(x=feat, sty=style_feat, x_mask=x_mask)
+
+ out = self.out_layer(feat.transpose(1,2)).transpose(1,2) * x_mask.unsqueeze(2) # [B,T//2,C=256] => [B,T//2,C=64]
+
+ ret['pred'] = out
+ ret['mask'] = x_mask
+ return out
+
+
+class ResBlocksBackbone(nn.Module):
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
+ super(ResBlocksBackbone,self).__init__()
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
+ self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
+
+ self.dropout = nn.Dropout(p=p_dropout)
+
+ def forward(self, x, sty, x_mask=1.):
+ """
+ x: [B, T, C]
+ sty: [B, C=256]
+ x_mask: [B, T]
+ ret: [B, T/2, C]
+ """
+ x = x.transpose(1, 2) # [B, C, T]
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+
+ x = self.resblocks_0(x) * x_mask # [B, C, T]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
+ x = self.downsampler(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_1(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_2(x) * x_mask # [B, C, T/2]
+
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/2]
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/2]
+
+ x = self.resblocks_3(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_4(x) * x_mask # [B, C, T/2]
+
+ x = x.transpose(1,2)
+ x_mask = x_mask.squeeze(1)
+ return x, x_mask
+
+
+
+class ResNetBackbone(nn.Module):
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
+ super(ResNetBackbone,self).__init__()
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*14, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_4 = ConvBlocks(channels=512, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=0.5, mode='linear'))
+ self.upsampler = LambdaLayer(lambda x: F.interpolate(x, scale_factor=4, mode='linear'))
+
+ self.dropout = nn.Dropout(p=p_dropout)
+
+ def forward(self, x, sty, x_mask=1.):
+ """
+ x: [B, T, C]
+ sty: [B, C=256]
+ x_mask: [B, T]
+ ret: [B, T/2, C]
+ """
+ x = x.transpose(1, 2) # [B, C, T]
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+
+ x = self.resblocks_0(x) * x_mask # [B, C, T]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
+ x = self.downsampler(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_1(x) * x_mask # [B, C, T/2]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/4]
+ x = self.downsampler(x) * x_mask # [B, C, T/4]
+ x = self.resblocks_2(x) * x_mask # [B, C, T/4]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/8]
+ x = self.downsampler(x) * x_mask # [B, C, T/8]
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
+ x = self.resblocks_3(x) * x_mask # [B, C, T/8]
+
+ x_mask = self.upsampler(x_mask) # [B, 1, T/2]
+ x = self.upsampler(x) * x_mask # [B, C, T/2]
+ x = self.resblocks_4(x) * x_mask # [B, C, T/2]
+
+ x = x.transpose(1,2)
+ x_mask = x_mask.squeeze(1)
+ return x, x_mask
+
+
+class UNetBackbone(nn.Module):
+ def __init__(self, in_dim=64, out_dim=512, p_dropout=0.5, norm_type='bn'):
+ super(UNetBackbone, self).__init__()
+ self.resblocks_0 = ConvBlocks(channels=in_dim, out_dims=64, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_1 = ConvBlocks(channels=64, out_dims=128, dilations=[1]*4, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_2 = ConvBlocks(channels=128, out_dims=256, dilations=[1]*8, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_3 = ConvBlocks(channels=512, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False)
+ self.resblocks_4 = ConvBlocks(channels=768, out_dims=512, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [768 = c3(512) + c2(256)]
+ self.resblocks_5 = ConvBlocks(channels=640, out_dims=out_dim, dilations=[1]*3, kernel_size=3, norm_type=norm_type, is_BTC=False) # [640 = c4(512) + c1(128)]
+
+ self.downsampler = nn.Upsample(scale_factor=0.5, mode='linear')
+ self.upsampler = nn.Upsample(scale_factor=2, mode='linear')
+ self.dropout = nn.Dropout(p=p_dropout)
+
+ def forward(self, x, sty, x_mask=1.):
+ """
+ x: [B, T, C]
+ sty: [B, C=256]
+ x_mask: [B, T]
+ ret: [B, T/2, C]
+ """
+ x = x.transpose(1, 2) # [B, C, T]
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+
+ x0 = self.resblocks_0(x) * x_mask # [B, C, T]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/2]
+ x = self.downsampler(x0) * x_mask # [B, C, T/2]
+ x1 = self.resblocks_1(x) * x_mask # [B, C, T/2]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/4]
+ x = self.downsampler(x1) * x_mask # [B, C, T/4]
+ x2 = self.resblocks_2(x) * x_mask # [B, C, T/4]
+
+ x_mask = self.downsampler(x_mask) # [B, 1, T/8]
+ x = self.downsampler(x2) * x_mask # [B, C, T/8]
+ x = self.dropout(x.transpose(1,2)).transpose(1,2)
+ sty = sty[:, :, None].repeat([1,1,x_mask.shape[2]]) # [B,C=256,T/8]
+ x = torch.cat([x, sty], dim=1) # [B, C=256+256, T/8]
+ x3 = self.resblocks_3(x) * x_mask # [B, C, T/8]
+
+ x_mask = self.upsampler(x_mask) # [B, 1, T/4]
+ x = self.upsampler(x3) * x_mask # [B, C, T/4]
+ x = torch.cat([x, self.dropout(x2.transpose(1,2)).transpose(1,2)], dim=1) #
+ x4 = self.resblocks_4(x) * x_mask # [B, C, T/4]
+
+ x_mask = self.upsampler(x_mask) # [B, 1, T/2]
+ x = self.upsampler(x4) * x_mask # [B, C, T/2]
+ x = torch.cat([x, self.dropout(x1.transpose(1,2)).transpose(1,2)], dim=1)
+ x5 = self.resblocks_5(x) * x_mask # [B, C, T/2]
+
+ x = x5.transpose(1,2)
+ x_mask = x_mask.squeeze(1)
+ return x, x_mask
+
+
+if __name__ == '__main__':
+ pass
diff --git a/modules/audio2motion/flow_base.py b/modules/audio2motion/flow_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2ff1c626cc3e4aef72406e16971db7331aa5c85
--- /dev/null
+++ b/modules/audio2motion/flow_base.py
@@ -0,0 +1,838 @@
+import scipy
+from scipy import linalg
+from torch.nn import functional as F
+import torch
+from torch import nn
+import numpy as np
+
+import modules.audio2motion.utils as utils
+from modules.audio2motion.transformer_models import FFTBlocks
+from utils.commons.hparams import hparams
+
+
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0,
+ p_dropout=0, share_cond_layers=False):
+ super(WN, self).__init__()
+ assert (kernel_size % 2 == 1)
+ assert (hidden_channels % 2 == 0)
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.share_cond_layers = share_cond_layers
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+
+ self.drop = nn.Dropout(p_dropout)
+
+ self.use_adapters = hparams.get("use_adapters", False)
+ if self.use_adapters:
+ self.adapter_layers = torch.nn.ModuleList()
+
+ if gin_channels != 0 and not share_cond_layers:
+ cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_channels
+ else:
+ res_skip_channels = hidden_channels
+
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ if self.use_adapters:
+ adapter_layer = MlpAdapter(in_out_dim=res_skip_channels, hid_dim=res_skip_channels//4)
+ self.adapter_layers.append(adapter_layer)
+
+ def forward(self, x, x_mask=None, g=None, **kwargs):
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
+
+ if g is not None and not self.share_cond_layers:
+ g = self.cond_layer(g)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ x_in = self.drop(x_in)
+ if g is not None:
+ cond_offset = i * 2 * self.hidden_channels
+ g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
+ else:
+ g_l = torch.zeros_like(x_in)
+
+ acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if self.use_adapters:
+ res_skip_acts = self.adapter_layers[i](res_skip_acts.transpose(1,2)).transpose(1,2)
+ if i < self.n_layers - 1:
+ x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
+ output = output + res_skip_acts[:, self.hidden_channels:, :]
+ else:
+ output = output + res_skip_acts
+ return output * x_mask
+
+ def remove_weight_norm(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
+
+ def enable_adapters(self):
+ if not self.use_adapters:
+ return
+ for adapter_layer in self.adapter_layers:
+ adapter_layer.enable()
+
+ def disable_adapters(self):
+ if not self.use_adapters:
+ return
+ for adapter_layer in self.adapter_layers:
+ adapter_layer.disable()
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-4):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+
+class ActNorm(nn.Module):
+ def __init__(self, channels, ddi=False, **kwargs):
+ super().__init__()
+ self.channels = channels
+ self.initialized = not ddi
+
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ if x_mask is None:
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
+ x_len = torch.sum(x_mask, [1, 2])
+ if not self.initialized:
+ self.initialize(x, x_mask)
+ self.initialized = True
+
+ if reverse:
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
+ logdet = torch.sum(-self.logs) * x_len
+ else:
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
+ logdet = torch.sum(self.logs) * x_len # [b]
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+ def set_ddi(self, ddi):
+ self.initialized = not ddi
+
+ def initialize(self, x, x_mask):
+ with torch.no_grad():
+ denom = torch.sum(x_mask, [0, 2])
+ m = torch.sum(x * x_mask, [0, 2]) / denom
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
+ v = m_sq - (m ** 2)
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
+
+ bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
+
+ self.bias.data.copy_(bias_init)
+ self.logs.data.copy_(logs_init)
+
+
+class InvConvNear(nn.Module):
+ def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
+ super().__init__()
+ assert (n_split % 2 == 0)
+ self.channels = channels
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.no_jacobian = no_jacobian
+
+ w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
+ if torch.det(w_init) < 0:
+ w_init[:, 0] = -1 * w_init[:, 0]
+ self.lu = lu
+ if lu:
+ # LU decomposition can slightly speed up the inverse
+ np_p, np_l, np_u = linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
+ eye = np.eye(*w_init.shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
+ self.register_buffer('l_mask', torch.Tensor(l_mask))
+ self.register_buffer('eye', torch.Tensor(eye))
+ else:
+ self.weight = nn.Parameter(w_init)
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ b, c, t = x.size()
+ assert (c % self.n_split == 0)
+ if x_mask is None:
+ x_mask = 1
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+
+ x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
+ x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
+
+ if self.lu:
+ self.weight, log_s = self._get_weight()
+ logdet = log_s.sum()
+ logdet = logdet * (c / self.n_split) * x_len
+ else:
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
+
+ if reverse:
+ if hasattr(self, "weight_inv"):
+ weight = self.weight_inv
+ else:
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
+ logdet = -logdet
+ else:
+ weight = self.weight
+ if self.no_jacobian:
+ logdet = 0
+
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
+ z = F.conv2d(x, weight)
+
+ z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
+ return z, logdet
+
+ def _get_weight(self):
+ l, log_s, u = self.l, self.log_s, self.u
+ l = l * self.l_mask + self.eye
+ u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
+ weight = torch.matmul(self.p, torch.matmul(l, u))
+ return weight, log_s
+
+ def store_inverse(self):
+ weight, _ = self._get_weight()
+ self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
+
+
+class InvConv(nn.Module):
+ def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
+ super().__init__()
+ w_shape = [channels, channels]
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
+ LU_decomposed = lu
+ if not LU_decomposed:
+ # Sample a random orthogonal matrix:
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
+ else:
+ np_p, np_l, np_u = linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
+ eye = np.eye(*w_shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
+ self.l_mask = torch.Tensor(l_mask)
+ self.eye = torch.Tensor(eye)
+ self.w_shape = w_shape
+ self.LU = LU_decomposed
+ self.weight = None
+
+ def get_weight(self, device, reverse):
+ w_shape = self.w_shape
+ self.p = self.p.to(device)
+ self.sign_s = self.sign_s.to(device)
+ self.l_mask = self.l_mask.to(device)
+ self.eye = self.eye.to(device)
+ l = self.l * self.l_mask + self.eye
+ u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
+ dlogdet = self.log_s.sum()
+ if not reverse:
+ w = torch.matmul(self.p, torch.matmul(l, u))
+ else:
+ l = torch.inverse(l.double()).float()
+ u = torch.inverse(u.double()).float()
+ w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
+ return w.view(w_shape[0], w_shape[1], 1), dlogdet
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ """
+ log-det = log|abs(|W|)| * pixels
+ """
+ b, c, t = x.size()
+ if x_mask is None:
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+ logdet = 0
+ if not reverse:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet + dlogdet * x_len
+ return z, logdet
+ else:
+ if self.weight is None:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ else:
+ weight, dlogdet = self.weight, self.dlogdet
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet - dlogdet * x_len
+ return z, logdet
+
+ def store_inverse(self):
+ self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
+
+
+class Flip(nn.Module):
+ def forward(self, x, *args, reverse=False, **kwargs):
+ x = torch.flip(x, [1])
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
+ return x, logdet
+
+ def store_inverse(self):
+ pass
+
+
+class CouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False,
+ share_cond_layers=False, wn=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
+ start = torch.nn.utils.weight_norm(start)
+ self.start = start
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+ self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels,
+ p_dropout, share_cond_layers)
+ if wn is not None:
+ self.wn.in_layers = wn.in_layers
+ self.wn.res_skip_layers = wn.res_skip_layers
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+
+ x = self.start(x_0) * x_mask
+ x = self.wn(x, x_mask, g)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, :self.in_channels // 2, :]
+ logs = out[:, self.in_channels // 2:, :]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ self.wn.remove_weight_norm()
+
+
+class GlowFFTBlocks(FFTBlocks):
+ def __init__(self, hidden_size=128, gin_channels=256, num_layers=2, ffn_kernel_size=5,
+ dropout=None, num_heads=4, use_pos_embed=True, use_last_norm=True,
+ norm='ln', use_pos_embed_alpha=True):
+ super().__init__(hidden_size, num_layers, ffn_kernel_size, dropout, num_heads, use_pos_embed,
+ use_last_norm, norm, use_pos_embed_alpha)
+ self.inp_proj = nn.Conv1d(hidden_size + gin_channels, hidden_size, 1)
+
+ def forward(self, x, x_mask=None, g=None):
+ """
+ :param x: [B, C_x, T]
+ :param x_mask: [B, 1, T]
+ :param g: [B, C_g, T]
+ :return: [B, C_x, T]
+ """
+ if g is not None:
+ x = self.inp_proj(torch.cat([x, g], 1))
+ x = x.transpose(1, 2)
+ x = super(GlowFFTBlocks, self).forward(x, x_mask[:, 0] == 0)
+ x = x.transpose(1, 2)
+ return x
+
+
+class TransformerCouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
+ self.start = start
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+ self.fft_blocks = GlowFFTBlocks(
+ hidden_size=hidden_channels,
+ ffn_kernel_size=3,
+ gin_channels=gin_channels,
+ num_layers=n_layers)
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+
+ x = self.start(x_0) * x_mask
+ x = self.fft_blocks(x, x_mask, g)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, :self.in_channels // 2, :]
+ logs = out[:, self.in_channels // 2:, :]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+
+class FreqFFTCouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ hs = hidden_channels
+ stride = 8
+ self.start = torch.nn.Conv2d(3, hs, kernel_size=stride * 2,
+ stride=stride, padding=stride // 2)
+ end = nn.ConvTranspose2d(hs, 2, kernel_size=stride, stride=stride)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = nn.Sequential(
+ nn.Conv2d(hs * 3, hs, 3, 1, 1),
+ nn.ReLU(),
+ nn.GroupNorm(4, hs),
+ nn.Conv2d(hs, hs, 3, 1, 1),
+ end
+ )
+ self.fft_v = FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers)
+ self.fft_h = nn.Sequential(
+ nn.Conv1d(hs, hs, 3, 1, 1),
+ nn.ReLU(),
+ nn.Conv1d(hs, hs, 3, 1, 1),
+ )
+ self.fft_g = nn.Sequential(
+ nn.Conv1d(
+ gin_channels - 160, hs, kernel_size=stride * 2, stride=stride, padding=stride // 2),
+ Permute(0, 2, 1),
+ FFTBlocks(hidden_size=hs, ffn_kernel_size=1, num_layers=n_layers),
+ Permute(0, 2, 1),
+ )
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ g_, _ = utils.unsqueeze(g)
+ g_mel = g_[:, :80]
+ g_txt = g_[:, 80:]
+ g_mel, _ = utils.squeeze(g_mel)
+ g_txt, _ = utils.squeeze(g_txt) # [B, C, T]
+
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+ x = torch.stack([x_0, g_mel[:, :80], g_mel[:, 80:]], 1)
+ x = self.start(x) # [B, C, N_bins, T]
+ B, C, N_bins, T = x.shape
+
+ x_v = self.fft_v(x.permute(0, 3, 2, 1).reshape(B * T, N_bins, C))
+ x_v = x_v.reshape(B, T, N_bins, -1).permute(0, 3, 2, 1)
+ # x_v = x
+
+ x_h = self.fft_h(x.permute(0, 2, 1, 3).reshape(B * N_bins, C, T))
+ x_h = x_h.reshape(B, N_bins, -1, T).permute(0, 2, 1, 3)
+ # x_h = x
+
+ x_g = self.fft_g(g_txt)[:, :, None, :].repeat(1, 1, 10, 1)
+ x = torch.cat([x_v, x_h, x_g], 1)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, 0]
+ logs = out[:, 1]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+
+
+class ResidualCouplingLayer(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ p_dropout=0,
+ gin_channels=0,
+ mean_only=False,
+ nn_type='wn'):
+ assert channels % 2 == 0, "channels should be divisible by 2"
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.half_channels = channels // 2
+ self.mean_only = mean_only
+
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
+ if nn_type == 'wn':
+ self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout,
+ gin_channels=gin_channels)
+ # elif nn_type == 'conv':
+ # self.enc = ConditionalConvBlocks(
+ # hidden_channels, gin_channels, hidden_channels, [1] * n_layers, kernel_size,
+ # layers_in_block=1, is_BTC=False)
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
+ self.post.weight.data.zero_()
+ self.post.bias.data.zero_()
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
+ h = self.pre(x0) * x_mask
+ h = self.enc(h, x_mask=x_mask, g=g)
+ stats = self.post(h) * x_mask
+ if not self.mean_only:
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
+ else:
+ m = stats
+ logs = torch.zeros_like(m)
+
+ if not reverse:
+ x1 = m + x1 * torch.exp(logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = torch.sum(logs, [1, 2])
+ return x, logdet
+ else:
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
+ x = torch.cat([x0, x1], 1)
+ logdet = -torch.sum(logs, [1, 2])
+ return x, logdet
+
+
+class ResidualCouplingBlock(nn.Module):
+ def __init__(self,
+ channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_layers,
+ n_flows=4,
+ gin_channels=0,
+ nn_type='wn'):
+ super().__init__()
+ self.channels = channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.n_flows = n_flows
+ self.gin_channels = gin_channels
+
+ self.flows = nn.ModuleList()
+ for i in range(n_flows):
+ self.flows.append(
+ ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=gin_channels, mean_only=True, nn_type=nn_type))
+ self.flows.append(Flip())
+
+ def forward(self, x, x_mask, g=None, reverse=False):
+ if not reverse:
+ for flow in self.flows:
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ else:
+ for flow in reversed(self.flows):
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
+ return x
+
+
+class Glow(nn.Module):
+ def __init__(self,
+ in_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_blocks,
+ n_layers,
+ p_dropout=0.,
+ n_split=4,
+ n_sqz=2,
+ sigmoid_scale=False,
+ gin_channels=0,
+ inv_conv_type='near',
+ share_cond_layers=False,
+ share_wn_layers=0,
+ ):
+ super().__init__()
+ """
+ Note that regularization likes weight decay can leads to Nan error!
+ """
+
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_blocks = n_blocks
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.sigmoid_scale = sigmoid_scale
+ self.gin_channels = gin_channels
+ self.share_cond_layers = share_cond_layers
+ if gin_channels != 0 and share_cond_layers:
+ cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+ wn = None
+ self.flows = nn.ModuleList()
+ for b in range(n_blocks):
+ self.flows.append(ActNorm(channels=in_channels * n_sqz))
+ if inv_conv_type == 'near':
+ self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
+ if inv_conv_type == 'invconv':
+ self.flows.append(InvConv(channels=in_channels * n_sqz))
+ if share_wn_layers > 0:
+ if b % share_wn_layers == 0:
+ wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
+ p_dropout, share_cond_layers)
+ self.flows.append(
+ CouplingBlock(
+ in_channels * n_sqz,
+ hidden_channels,
+ kernel_size=kernel_size,
+ dilation_rate=dilation_rate,
+ n_layers=n_layers,
+ gin_channels=gin_channels * n_sqz,
+ p_dropout=p_dropout,
+ sigmoid_scale=sigmoid_scale,
+ share_cond_layers=share_cond_layers,
+ wn=wn
+ ))
+
+ def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
+ """
+ x: [B,T,C]
+ x_mask: [B,T]
+ g: [B,T,C]
+ """
+ x = x.transpose(1,2)
+ x_mask = x_mask.unsqueeze(1)
+ if g is not None:
+ g = g.transpose(1,2)
+
+ logdet_tot = 0
+ if not reverse:
+ flows = self.flows
+ else:
+ flows = reversed(self.flows)
+ if return_hiddens:
+ hs = []
+ if self.n_sqz > 1:
+ x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
+ if g is not None:
+ g, _ = utils.squeeze(g, x_mask, self.n_sqz)
+ x_mask = x_mask_
+ if self.share_cond_layers and g is not None:
+ g = self.cond_layer(g)
+ for f in flows:
+ x, logdet = f(x, x_mask, g=g, reverse=reverse)
+ if return_hiddens:
+ hs.append(x)
+ logdet_tot += logdet
+ if self.n_sqz > 1:
+ x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
+
+ x = x.transpose(1,2)
+ if return_hiddens:
+ return x, logdet_tot, hs
+ return x, logdet_tot
+
+ def store_inverse(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
+ for f in self.flows:
+ f.store_inverse()
+
+
+if __name__ == '__main__':
+ model = Glow(in_channels=64,
+ hidden_channels=128,
+ kernel_size=5,
+ dilation_rate=1,
+ n_blocks=12,
+ n_layers=4,
+ p_dropout=0.0,
+ n_split=4,
+ n_sqz=2,
+ sigmoid_scale=False,
+ gin_channels=80
+ )
+ exp = torch.rand([1,1440,64])
+ mel = torch.rand([1,1440,80])
+ x_mask = torch.ones([1,1440],dtype=torch.float32)
+ y, logdet = model(exp, x_mask,g=mel, reverse=False)
+ pred_exp, logdet = model(y, x_mask,g=mel, reverse=False)
+ # y: [b, t,c=64]
+ print(" ")
\ No newline at end of file
diff --git a/modules/audio2motion/multi_length_disc.py b/modules/audio2motion/multi_length_disc.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a57df2cef929691f2f1fa41981ed8316ff5dce6
--- /dev/null
+++ b/modules/audio2motion/multi_length_disc.py
@@ -0,0 +1,340 @@
+import numpy as np
+import random
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.audio2motion.cnn_models import LambdaLayer
+
+
+class Discriminator1DFactory(nn.Module):
+ def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
+ super(Discriminator1DFactory, self).__init__()
+ padding = kernel_size // 2
+
+ def discriminator_block(in_filters, out_filters, first=False):
+ """
+ Input: (B, c, T)
+ Output:(B, c, T//2)
+ """
+ conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
+ block = [
+ conv, # padding = kernel//2
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25)
+ ]
+ if norm_type == 'bn' and not first:
+ block.append(nn.BatchNorm1d(out_filters, 0.8))
+ if norm_type == 'in' and not first:
+ block.append(nn.InstanceNorm1d(out_filters, affine=True))
+ block = nn.Sequential(*block)
+ return block
+
+ if time_length >= 8:
+ self.model = nn.ModuleList([
+ discriminator_block(in_dim, hidden_size, first=True),
+ discriminator_block(hidden_size, hidden_size),
+ discriminator_block(hidden_size, hidden_size),
+ ])
+ ds_size = time_length // (2 ** 3)
+ elif time_length == 3:
+ self.model = nn.ModuleList([
+ nn.Sequential(*[
+ nn.Conv1d(in_dim, hidden_size, 3, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.BatchNorm1d(hidden_size, 0.8),
+ nn.Conv1d(hidden_size, hidden_size, 1, 1, 0),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.BatchNorm1d(hidden_size, 0.8)
+ ])
+ ])
+ ds_size = 1
+ elif time_length == 1:
+ self.model = nn.ModuleList([
+ nn.Sequential(*[
+ nn.Linear(in_dim, hidden_size),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ nn.Linear(hidden_size, hidden_size),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25),
+ ])
+ ])
+ ds_size = 1
+
+ self.adv_layer = nn.Linear(hidden_size * ds_size, 1)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, C, T]
+ :return: validity: [B, 1], h: List of hiddens
+ """
+ h = []
+ if x.shape[-1] == 1:
+ x = x.squeeze(-1)
+ for l in self.model:
+ x = l(x)
+ h.append(x)
+ if x.ndim == 2:
+ b, ct = x.shape
+ use_sigmoid = True
+ else:
+ b, c, t = x.shape
+ ct = c * t
+ use_sigmoid = False
+ x = x.view(b, ct)
+ validity = self.adv_layer(x) # [B, 1]
+ if use_sigmoid:
+ validity = torch.sigmoid(validity)
+ return validity, h
+
+
+class CosineDiscriminator1DFactory(nn.Module):
+ def __init__(self, time_length, kernel_size=3, in_dim=1, hidden_size=128, norm_type='bn'):
+ super().__init__()
+ padding = kernel_size // 2
+
+ def discriminator_block(in_filters, out_filters, first=False):
+ """
+ Input: (B, c, T)
+ Output:(B, c, T//2)
+ """
+ conv = nn.Conv1d(in_filters, out_filters, kernel_size, 2, padding)
+ block = [
+ conv, # padding = kernel//2
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Dropout2d(0.25)
+ ]
+ if norm_type == 'bn' and not first:
+ block.append(nn.BatchNorm1d(out_filters, 0.8))
+ if norm_type == 'in' and not first:
+ block.append(nn.InstanceNorm1d(out_filters, affine=True))
+ block = nn.Sequential(*block)
+ return block
+
+ self.model1 = nn.ModuleList([
+ discriminator_block(in_dim, hidden_size, first=True),
+ discriminator_block(hidden_size, hidden_size),
+ discriminator_block(hidden_size, hidden_size),
+ ])
+
+ self.model2 = nn.ModuleList([
+ discriminator_block(in_dim, hidden_size, first=True),
+ discriminator_block(hidden_size, hidden_size),
+ discriminator_block(hidden_size, hidden_size),
+ ])
+
+ self.relu = nn.ReLU()
+ def forward(self, x1, x2):
+ """
+
+ :param x1: [B, C, T]
+ :param x2: [B, C, T]
+ :return: validity: [B, 1], h: List of hiddens
+ """
+ h1, h2 = [], []
+ for l in self.model1:
+ x1 = l(x1)
+ h1.append(x1)
+ for l in self.model2:
+ x2 = l(x2)
+ h2.append(x1)
+ b,c,t = x1.shape
+ x1 = x1.view(b, c*t)
+ x2 = x2.view(b, c*t)
+ x1 = self.relu(x1)
+ x2 = self.relu(x2)
+ # x1 = F.normalize(x1, p=2, dim=1)
+ # x2 = F.normalize(x2, p=2, dim=1)
+ validity = F.cosine_similarity(x1, x2)
+ return validity, [h1,h2]
+
+
+class MultiWindowDiscriminator(nn.Module):
+ def __init__(self, time_lengths, cond_dim=80, in_dim=64, kernel_size=3, hidden_size=128, disc_type='standard', norm_type='bn', reduction='sum'):
+ super(MultiWindowDiscriminator, self).__init__()
+ self.win_lengths = time_lengths
+ self.reduction = reduction
+ self.disc_type = disc_type
+
+ if cond_dim > 0:
+ self.use_cond = True
+ self.cond_proj_layers = nn.ModuleList()
+ self.in_proj_layers = nn.ModuleList()
+ else:
+ self.use_cond = False
+
+ self.conv_layers = nn.ModuleList()
+ for time_length in time_lengths:
+ conv_layer = [
+ Discriminator1DFactory(
+ time_length, kernel_size, in_dim=64, hidden_size=hidden_size,
+ norm_type=norm_type) if self.disc_type == 'standard'
+ else CosineDiscriminator1DFactory(time_length, kernel_size, in_dim=64,
+ hidden_size=hidden_size,norm_type=norm_type)
+ ]
+ self.conv_layers += conv_layer
+ if self.use_cond:
+ self.cond_proj_layers.append(nn.Linear(cond_dim, 64))
+ self.in_proj_layers.append(nn.Linear(in_dim, 64))
+
+ def clip(self, x, cond, x_len, win_length, start_frames=None):
+ '''Ramdom clip x to win_length.
+ Args:
+ x (tensor) : (B, T, C).
+ cond (tensor) : (B, T, H).
+ x_len (tensor) : (B,).
+ win_length (int): target clip length
+
+ Returns:
+ (tensor) : (B, c_in, win_length, n_bins).
+
+ '''
+ clip_from_same_frame = start_frames is None
+ T_start = 0
+ # T_end = x_len.max() - win_length
+ T_end = x_len.min() - win_length
+ if T_end < 0:
+ return None, None, start_frames
+ T_end = T_end.item()
+ if start_frames is None:
+ start_frame = np.random.randint(low=T_start, high=T_end + 1)
+ start_frames = [start_frame] * x.size(0)
+ else:
+ start_frame = start_frames[0]
+
+
+ if clip_from_same_frame:
+ x_batch = x[:, start_frame: start_frame + win_length, :]
+ c_batch = cond[:, start_frame: start_frame + win_length, :] if cond is not None else None
+ else:
+ x_lst = []
+ c_lst = []
+ for i, start_frame in enumerate(start_frames):
+ x_lst.append(x[i, start_frame: start_frame + win_length, :])
+ if cond is not None:
+ c_lst.append(cond[i, start_frame: start_frame + win_length, :])
+ x_batch = torch.stack(x_lst, dim=0)
+ if cond is None:
+ c_batch = None
+ else:
+ c_batch = torch.stack(c_lst, dim=0)
+ return x_batch, c_batch, start_frames
+
+ def forward(self, x, x_len, cond=None, start_frames_wins=None):
+ '''
+ Args:
+ x (tensor): input mel, (B, T, C).
+ x_length (tensor): len of per mel. (B,).
+
+ Returns:
+ tensor : (B).
+ '''
+ validity = []
+ if start_frames_wins is None:
+ start_frames_wins = [None] * len(self.conv_layers)
+ h = []
+ for i, start_frames in zip(range(len(self.conv_layers)), start_frames_wins):
+ x_clip, c_clip, start_frames = self.clip(
+ x, cond, x_len, self.win_lengths[i], start_frames) # (B, win_length, C)
+ start_frames_wins[i] = start_frames
+ if x_clip is None:
+ continue
+ if self.disc_type == 'standard':
+ if self.use_cond:
+ x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
+ c_clip = self.cond_proj_layers[i](c_clip)
+ x_clip = x_clip + c_clip
+ validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2))
+ elif self.disc_type == 'cosine':
+ assert self.use_cond is True
+ x_clip = self.in_proj_layers[i](x_clip) # (B, T, C)
+ c_clip = self.cond_proj_layers[i](c_clip)
+ validity_pred, h_ = self.conv_layers[i](x_clip.transpose(1,2), c_clip.transpose(1,2))
+ else:
+ raise NotImplementedError
+
+ h += h_
+ validity.append(validity_pred)
+ if len(validity) != len(self.conv_layers):
+ return None, start_frames_wins, h
+ if self.reduction == 'sum':
+ validity = sum(validity) # [B]
+ elif self.reduction == 'stack':
+ validity = torch.stack(validity, -1) # [B, W_L]
+ return validity, start_frames_wins, h
+
+
+class Discriminator(nn.Module):
+ def __init__(self, x_dim=80, y_dim=64, disc_type='standard',
+ uncond_disc=False, kernel_size=3, hidden_size=128, norm_type='bn', reduction='sum', time_lengths=(8,16,32)):
+ """_summary_
+
+ Args:
+ time_lengths (list, optional): the list of window size. Defaults to [32, 64, 128].
+ x_dim (int, optional): the dim of audio features. Defaults to 80, corresponding to mel-spec.
+ y_dim (int, optional): the dim of facial coeff. Defaults to 64, correspond to exp; other options can be 7(pose) or 71(exp+pose).
+ kernel (tuple, optional): _description_. Defaults to (3, 3).
+ c_in (int, optional): _description_. Defaults to 1.
+ hidden_size (int, optional): _description_. Defaults to 128.
+ norm_type (str, optional): _description_. Defaults to 'bn'.
+ reduction (str, optional): _description_. Defaults to 'sum'.
+ uncond_disc (bool, optional): _description_. Defaults to False.
+ """
+ super(Discriminator, self).__init__()
+ self.time_lengths = time_lengths
+ self.x_dim, self.y_dim = x_dim, y_dim
+ self.disc_type = disc_type
+ self.reduction = reduction
+ self.uncond_disc = uncond_disc
+
+ if uncond_disc:
+ self.x_dim = 0
+ cond_dim = 0
+
+ else:
+ cond_dim = 64
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(self.x_dim, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64),
+ nn.GELU(),
+ nn.Conv1d(64, cond_dim, 3, 1, 1, bias=False)
+ ])
+
+ self.disc = MultiWindowDiscriminator(
+ time_lengths=self.time_lengths,
+ in_dim=self.y_dim,
+ cond_dim=cond_dim,
+ kernel_size=kernel_size,
+ hidden_size=hidden_size, norm_type=norm_type,
+ reduction=reduction,
+ disc_type=disc_type
+ )
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ @property
+ def device(self):
+ return self.disc.parameters().__next__().device
+
+ def forward(self,x, batch, start_frames_wins=None):
+ """
+
+ :param x: [B, T, C]
+ :param cond: [B, T, cond_size]
+ :return:
+ """
+ x = x.to(self.device)
+ if not self.uncond_disc:
+ mel = self.downsampler(batch['mel'].to(self.device))
+ mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+ else:
+ mel_feat = None
+ x_len = x.sum(-1).ne(0).int().sum([1])
+ disc_confidence, start_frames_wins, h = self.disc(x, x_len, mel_feat, start_frames_wins=start_frames_wins)
+ return disc_confidence
+
diff --git a/modules/audio2motion/transformer_base.py b/modules/audio2motion/transformer_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bbe0073907742b2921b28afed2b241b7caeb60
--- /dev/null
+++ b/modules/audio2motion/transformer_base.py
@@ -0,0 +1,988 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter
+import torch.onnx.operators
+import torch.nn.functional as F
+from collections import defaultdict
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
+ ).long() + padding_idx
+
+
+def softmax(x, dim):
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
+
+def _get_full_incremental_state_key(module_instance, key):
+ module_name = module_instance.__class__.__name__
+
+ # assign a unique ID to each module instance, so that incremental state is
+ # not shared across module instances
+ if not hasattr(module_instance, '_instance_id'):
+ INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
+ module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
+
+ return '{}.{}.{}'.format(module_name, module_instance._instance_id, key)
+
+
+
+def get_incremental_state(module, incremental_state, key):
+ """Helper for getting incremental state for an nn.Module."""
+ full_key = _get_full_incremental_state_key(module, key)
+ if incremental_state is None or full_key not in incremental_state:
+ return None
+ return incremental_state[full_key]
+
+
+def set_incremental_state(module, incremental_state, key, value):
+ """Helper for setting incremental state for an nn.Module."""
+ if incremental_state is not None:
+ full_key = _get_full_incremental_state_key(module, key)
+ incremental_state[full_key] = value
+
+
+
+class Reshape(nn.Module):
+ def __init__(self, *args):
+ super(Reshape, self).__init__()
+ self.shape = args
+
+ def forward(self, x):
+ return x.view(self.shape)
+
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+class GroupNorm1DTBC(nn.GroupNorm):
+ def forward(self, input):
+ return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1)
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+ if not export and torch.cuda.is_available():
+ try:
+ from apex.normalization import FusedLayerNorm
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ except ImportError:
+ pass
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+def Linear(in_features, out_features, bias=True):
+ m = nn.Linear(in_features, out_features, bias)
+ nn.init.xavier_uniform_(m.weight)
+ if bias:
+ nn.init.constant_(m.bias, 0.)
+ return m
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class ConvTBC(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
+ super(ConvTBC, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.padding = padding
+
+ self.weight = torch.nn.Parameter(torch.Tensor(
+ self.kernel_size, in_channels, out_channels))
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
+
+ def forward(self, input):
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ if hasattr(F, "multi_head_attention_forward"):
+ self.enable_torch_version = True
+ else:
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query, key, value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask)
+ else:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
+ if static_kv:
+ key_padding_mask = prev_key_padding_mask
+ else:
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
+
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_key_padding_mask'] = key_padding_mask
+
+ self._set_input_buffer(incremental_state, saved_state)
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class Swish(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_variables[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class CustomSwish(nn.Module):
+ def forward(self, input_tensor):
+ return Swish.apply(input_tensor)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = Linear(filter_size, hidden_size)
+ if self.act == 'swish':
+ self.swish_fn = CustomSwish()
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-1:]
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ if self.act == 'swish':
+ x = self.swish_fn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class BatchNorm1dTBC(nn.Module):
+ def __init__(self, c):
+ super(BatchNorm1dTBC, self).__init__()
+ self.bn = nn.BatchNorm1d(c)
+
+ def forward(self, x):
+ """
+
+ :param x: [T, B, C]
+ :return: [T, B, C]
+ """
+ x = x.permute(1, 2, 0) # [B, C, T]
+ x = self.bn(x) # [B, C, T]
+ x = x.permute(2, 0, 1) # [T, B, C]
+ return x
+
+
+class EncSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if num_heads > 0:
+ if norm == 'ln':
+ self.layer_norm1 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm1 = BatchNorm1dTBC(c)
+ elif norm == 'gn':
+ self.layer_norm1 = GroupNorm1DTBC(8, c)
+ self.self_attn = MultiheadAttention(
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
+ if norm == 'ln':
+ self.layer_norm2 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm2 = BatchNorm1dTBC(c)
+ elif norm == 'gn':
+ self.layer_norm2 = GroupNorm1DTBC(8, c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+ return x
+
+
+class DecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, act='gelu', norm='ln'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ if norm == 'ln':
+ self.layer_norm1 = LayerNorm(c)
+ elif norm == 'gn':
+ self.layer_norm1 = GroupNorm1DTBC(8, c)
+ self.self_attn = MultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ if norm == 'ln':
+ self.layer_norm2 = LayerNorm(c)
+ elif norm == 'gn':
+ self.layer_norm2 = GroupNorm1DTBC(8, c)
+ self.encoder_attn = MultiheadAttention(
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+ )
+ if norm == 'ln':
+ self.layer_norm3 = LayerNorm(c)
+ elif norm == 'gn':
+ self.layer_norm3 = GroupNorm1DTBC(8, c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ self.layer_norm3.training = layer_norm_training
+ residual = x
+ x = self.layer_norm1(x)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ attn_logits = None
+ if encoder_out is not None or attn_out is not None:
+ residual = x
+ x = self.layer_norm2(x)
+ if encoder_out is not None:
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ reset_attn_weight=reset_attn_weight
+ )
+ attn_logits = attn[1]
+ elif attn_out is not None:
+ x = self.encoder_attn.in_proj_v(attn_out)
+ if encoder_out is not None or attn_out is not None:
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ residual = x
+ x = self.layer_norm3(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
+ super().__init__()
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
+ self.norm = norm
+ if self.norm == 'bn':
+ self.norm = nn.BatchNorm1d(n_chans)
+ elif self.norm == 'in':
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
+ elif self.norm == 'gn':
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
+ elif self.norm == 'ln':
+ self.norm = LayerNorm(n_chans // 16, n_chans)
+ elif self.norm == 'wn':
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
+ self.dropout = nn.Dropout(dropout)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ """
+
+ :param x: [B, C, T]
+ :return: [B, C, T]
+ """
+ x = self.conv(x)
+ if not isinstance(self.norm, str):
+ if self.norm == 'none':
+ pass
+ elif self.norm == 'ln':
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
+ else:
+ x = self.norm(x)
+ x = self.relu(x)
+ x = self.dropout(x)
+ return x
+
+
+class ConvStacks(nn.Module):
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
+ dropout=0, strides=None, res=True):
+ super().__init__()
+ self.conv = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.res = res
+ self.in_proj = Linear(idim, n_chans)
+ if strides is None:
+ strides = [1] * n_layers
+ else:
+ assert len(strides) == n_layers
+ for idx in range(n_layers):
+ self.conv.append(ConvBlock(
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
+ self.out_proj = Linear(n_chans, odim)
+
+ def forward(self, x, return_hiddens=False):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ x = self.in_proj(x)
+ x = x.transpose(1, -1) # (B, idim, Tmax)
+ hiddens = []
+ for f in self.conv:
+ x_ = f(x)
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
+ hiddens.append(x)
+ x = x.transpose(1, -1)
+ x = self.out_proj(x) # (B, Tmax, H)
+ if return_hiddens:
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
+ return x, hiddens
+ return x
+
+
+class ConvGlobalStacks(nn.Module):
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0,
+ strides=[2, 2, 2, 2, 2]):
+ super().__init__()
+ self.conv = torch.nn.ModuleList()
+ self.pooling = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.in_proj = Linear(idim, n_chans)
+ for idx in range(n_layers):
+ self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx],
+ norm=norm, dropout=dropout))
+ self.pooling.append(nn.MaxPool1d(strides[idx]))
+ self.out_proj = Linear(n_chans, odim)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ x = self.in_proj(x)
+ x = x.transpose(1, -1) # (B, idim, Tmax)
+ for f, p in zip(self.conv, self.pooling):
+ x = f(x) # (B, C, T)
+ x = x.transpose(1, -1)
+ x = self.out_proj(x.mean(1)) # (B, H)
+ return x
+
+
+class ConvDecoder(nn.Module):
+ def __init__(self, c, dropout, kernel_size=9, act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+
+ self.pre_convs = nn.ModuleList()
+ self.pre_lns = nn.ModuleList()
+ for i in range(2):
+ self.pre_convs.append(TransformerFFNLayer(
+ c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
+ self.pre_lns.append(LayerNorm(c))
+
+ self.layer_norm_attn = LayerNorm(c)
+ self.encoder_attn = MultiheadAttention(c, 1, encoder_decoder_attention=True, bias=False)
+
+ self.post_convs = nn.ModuleList()
+ self.post_lns = nn.ModuleList()
+ for i in range(8):
+ self.post_convs.append(TransformerFFNLayer(
+ c, c * 2, padding='LEFT', kernel_size=kernel_size, dropout=dropout, act=act))
+ self.post_lns.append(LayerNorm(c))
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ **kwargs,
+ ):
+ attn_logits = None
+ for conv, ln in zip(self.pre_convs, self.pre_lns):
+ residual = x
+ x = ln(x)
+ x = conv(x) + residual
+ if encoder_out is not None:
+ residual = x
+ x = self.layer_norm_attn(x)
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ )
+ attn_logits = attn[1]
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ for conv, ln in zip(self.post_convs, self.post_lns):
+ residual = x
+ x = ln(x)
+ x = conv(x) + residual
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
diff --git a/modules/audio2motion/transformer_models.py b/modules/audio2motion/transformer_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..05cc5ea196fb0af0f7cce8b8a41a2dcd5562f631
--- /dev/null
+++ b/modules/audio2motion/transformer_models.py
@@ -0,0 +1,208 @@
+from numpy import isin
+import torch
+import torch.nn as nn
+from modules.audio2motion.transformer_base import *
+
+DEFAULT_MAX_SOURCE_POSITIONS = 2000
+DEFAULT_MAX_TARGET_POSITIONS = 2000
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size
+ if kernel_size is not None else 9,
+ padding='SAME',
+ norm=norm, act='gelu'
+ )
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+######################
+# fastspeech modules
+######################
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class FFTBlocks(nn.Module):
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None,
+ num_heads=2, use_pos_embed=True, use_last_norm=True, norm='ln',
+ use_pos_embed_alpha=True):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout if dropout is not None else 0.1
+ self.use_pos_embed = use_pos_embed
+ self.use_last_norm = use_last_norm
+ if use_pos_embed:
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+ self.padding_idx = 0
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend([
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
+ norm=norm)
+ for _ in range(self.num_layers)
+ ])
+ if self.use_last_norm:
+ if norm == 'ln':
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ elif norm == 'bn':
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
+ elif norm == 'gn':
+ self.layer_norm = GroupNorm1DTBC(8, embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+class SequentialSA(nn.Module):
+ def __init__(self,layers):
+ super(SequentialSA,self).__init__()
+ self.layers = nn.ModuleList(layers)
+
+ def forward(self,x,x_mask):
+ """
+ x: [batch, T, H]
+ x_mask: [batch, T]
+ """
+ pad_mask = 1. - x_mask
+ for layer in self.layers:
+ if isinstance(layer, EncSALayer):
+ x = x.permute(1,0,2)
+ x = layer(x,pad_mask)
+ x = x.permute(1,0,2)
+ elif isinstance(layer, nn.Linear):
+ x = layer(x) * x_mask.unsqueeze(2)
+ elif isinstance(layer, nn.AvgPool1d):
+ x = x.permute(0,2,1)
+ x = layer(x)
+ x = x.permute(0,2,1)
+ elif isinstance(layer, nn.PReLU):
+ bs, t, hid = x.shape
+ x = x.reshape([bs*t,hid])
+ x = layer(x)
+ x = x.reshape([bs, t, hid])
+ else: # Relu
+ x = layer(x)
+
+ return x
+
+class TransformerStyleFusionModel(nn.Module):
+ def __init__(self, num_heads=4, dropout = 0.1, out_dim = 64):
+ super(TransformerStyleFusionModel, self).__init__()
+ self.audio_layer = SequentialSA([
+ nn.Linear(29, 48),
+ nn.ReLU(48),
+ nn.Linear(48, 128),
+ ])
+
+ self.energy_layer = SequentialSA([
+ nn.Linear(1, 16),
+ nn.ReLU(16),
+ nn.Linear(16, 64),
+ ])
+
+ self.backbone1 = FFTBlocks(hidden_size=192,num_layers=3)
+
+ self.sty_encoder = nn.Sequential(*[
+ nn.Linear(135, 64),
+ nn.ReLU(),
+ nn.Linear(64, 128)
+ ])
+
+ self.backbone2 = FFTBlocks(hidden_size=320,num_layers=3)
+
+ self.out_layer = SequentialSA([
+ nn.AvgPool1d(kernel_size=2,stride=2,padding=0), #[b,hid,t_audio]=>[b,hid,t_audio//2]
+ nn.Linear(320,out_dim),
+ nn.PReLU(out_dim),
+ nn.Linear(out_dim,out_dim),
+ ])
+
+ self.dropout = nn.Dropout(p = dropout)
+
+ def forward(self, audio, energy, style, x_mask, y_mask):
+ pad_mask = 1. - x_mask
+ audio_feat = self.audio_layer(audio, x_mask)
+ energy_feat = self.energy_layer(energy, x_mask)
+ feat = torch.cat((audio_feat, energy_feat), dim=-1) # [batch, T, H=48+16]
+ feat = self.backbone1(feat, pad_mask)
+ feat = self.dropout(feat)
+
+ sty_feat = self.sty_encoder(style) # [batch,135]=>[batch, H=64]
+ sty_feat = sty_feat.unsqueeze(1).repeat(1, feat.shape[1], 1) # [batch, T, H=64]
+
+ feat = torch.cat([feat, sty_feat], dim=-1) # [batch, T, H=64+64]
+ feat = self.backbone2(feat, pad_mask) # [batch, T, H=128]
+ out = self.out_layer(feat, y_mask) # [batch, T//2, H=out_dim]
+
+ return out
+
+
+if __name__ == '__main__':
+ model = TransformerStyleFusionModel()
+ audio = torch.rand(4,200,29) # [B,T,H]
+ energy = torch.rand(4,200,1) # [B,T,H]
+ style = torch.ones(4,135) # [B,T]
+ x_mask = torch.ones(4,200) # [B,T]
+ x_mask[3,10:] = 0
+ ret = model(audio,energy,style, x_mask)
+ print(" ")
\ No newline at end of file
diff --git a/modules/audio2motion/utils.py b/modules/audio2motion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb56ec514bff822ba1a19a6474207ed82492410
--- /dev/null
+++ b/modules/audio2motion/utils.py
@@ -0,0 +1,29 @@
+import torch
+
+
+def squeeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ t = (t // n_sqz) * n_sqz
+ x = x[:, :, :t]
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
+ else:
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_sqz * x_mask, x_mask
+
+
+def unsqueeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
+ else:
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_unsqz * x_mask, x_mask
diff --git a/modules/audio2motion/vae.py b/modules/audio2motion/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..9801ed631a6142297ce96d33c93ee508f32304b9
--- /dev/null
+++ b/modules/audio2motion/vae.py
@@ -0,0 +1,468 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+import torch.distributions as dist
+import numpy as np
+import copy
+from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
+from modules.audio2motion.transformer_base import Embedding
+
+from utils.commons.pitch_utils import f0_to_coarse
+from utils.commons.hparams import hparams
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
+ ).long() + padding_idx
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e4) # an arbitrary large number
+
+class FVAEEncoder(nn.Module):
+ def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0, strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ if i == 0 else
+ nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
+
+ self.latent_channels = latent_channels
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ m, logs = torch.split(x, self.latent_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs))
+ return z, m, logs, x_mask
+
+
+class FVAEDecoder(nn.Module):
+ def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0,
+ strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
+ if i == 0 else
+ nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ return x
+
+class FVAE(nn.Module):
+ def __init__(self,
+ in_out_channels=64, hidden_channels=256, latent_size=16,
+ kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
+ use_prior_glow=True, glow_hidden=256, glow_kernel_size=3, glow_n_blocks=5,
+ sqz_prior=False, use_pos_emb=False):
+ super(FVAE, self).__init__()
+ self.in_out_channels = in_out_channels
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.latent_size = latent_size
+ self.use_prior_glow = use_prior_glow
+ self.sqz_prior = sqz_prior
+ self.g_pre_net = nn.Sequential(*[
+ nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.encoder = FVAEEncoder(in_out_channels, hidden_channels, latent_size, kernel_size,
+ enc_n_layers, gin_channels, strides=strides)
+ if use_prior_glow:
+ self.prior_flow = ResidualCouplingBlock(
+ latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
+ self.use_pos_embed = use_pos_emb
+ if sqz_prior:
+ self.query_proj = nn.Linear(latent_size, latent_size)
+ self.key_proj = nn.Linear(latent_size, latent_size)
+ self.value_proj = nn.Linear(latent_size, hidden_channels)
+ if self.in_out_channels in [7, 64]:
+ self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ elif self.in_out_channels == 71:
+ self.exp_decoder = FVAEDecoder(hidden_channels, hidden_channels, 64, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ self.pose_decoder = FVAEDecoder(hidden_channels, hidden_channels, 7, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ if self.use_pos_embed:
+ self.embed_positions = SinusoidalPositionalEmbedding(self.latent_size, 0,init_size=2000+1,)
+ else:
+ self.decoder = FVAEDecoder(latent_size, hidden_channels, in_out_channels, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+
+ self.prior_dist = dist.Normal(0, 1)
+
+ def forward(self, x=None, x_mask=None, g=None, infer=False, temperature=1. , **kwargs):
+ """
+
+ :param x: [B, T, C_in_out]
+ :param x_mask: [B, T]
+ :param g: [B, T, C_g]
+ :return:
+ """
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+ g = g.transpose(1,2) # [B, C_g, T]
+ g_for_sqz = g
+
+ g_sqz = self.g_pre_net(g_for_sqz)
+
+ if not infer:
+ x = x.transpose(1,2) # [B, C, T]
+ z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
+ if self.sqz_prior:
+ z = z_q
+ if self.use_pos_embed:
+ position = self.embed_positions(z.transpose(1,2).abs().sum(-1)).transpose(1,2)
+ z = z + position
+ q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16]
+ k = self.key_proj(z.transpose(1,2)) # [B, T, C=16]
+ v = self.value_proj(z.transpose(1,2)) # [B, T, C=256]
+ attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T]
+ attn = F.softmax(attn, dim=-1)
+ out = torch.bmm(attn, v) # [B, 1, C=256]
+ style_encoding = out.repeat([1,z_q.shape[-1],1]).transpose(1,2) # [B, C=256, T]
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(style_encoding, x_mask, g), self.pose_decoder(style_encoding, x_mask, g)], dim=1)
+ else:
+ x_recon = self.decoder(style_encoding, x_mask, g)
+ else:
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(z_q, x_mask, g), self.pose_decoder(z_q, x_mask, g)], dim=1)
+ else:
+ x_recon = self.decoder(z_q, x_mask, g)
+ q_dist = dist.Normal(m_q, logs_q.exp())
+ if self.use_prior_glow:
+ logqx = q_dist.log_prob(z_q)
+ z_p = self.prior_flow(z_q, x_mask_sqz, g_sqz)
+ logpx = self.prior_dist.log_prob(z_p)
+ loss_kl = ((logqx - logpx) * x_mask_sqz).sum() / x_mask_sqz.sum() / logqx.shape[1]
+ else:
+ loss_kl = torch.distributions.kl_divergence(q_dist, self.prior_dist)
+ loss_kl = (loss_kl * x_mask_sqz).sum() / x_mask_sqz.sum() / z_q.shape[1]
+ z_p = z_q
+ return x_recon.transpose(1,2), loss_kl, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
+ else:
+ latent_shape = [g_sqz.shape[0], self.latent_size, g_sqz.shape[2]]
+ z_p = self.prior_dist.sample(latent_shape).to(g.device) * temperature # [B, latent_size, T_sqz]
+ if self.use_prior_glow:
+ z_p = self.prior_flow(z_p, 1, g_sqz, reverse=True)
+ if self.sqz_prior:
+ z = z_p
+ if self.use_pos_embed:
+ position = self.embed_positions(z.abs().sum(-1))
+ z += position
+ q = self.query_proj(z.mean(dim=-1,keepdim=True).transpose(1,2)) # [B, 1, C=16]
+ k = self.key_proj(z.transpose(1,2)) # [B, T, C=16]
+ v = self.value_proj(z.transpose(1,2)) # [B, T, C=256]
+ attn = torch.bmm(q,k.transpose(1,2)) # [B, 1, T]
+ attn = F.softmax(attn, dim=-1)
+ out = torch.bmm(attn, v) # [B, 1, C=256]
+ style_encoding = out.repeat([1,z_p.shape[-1],1]).transpose(1,2) # [B, C=256, T]
+ x_recon = self.decoder(style_encoding, 1, g)
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(style_encoding, 1, g), self.pose_decoder(style_encoding, 1, g)], dim=1)
+ else:
+ x_recon = self.decoder(style_encoding, 1, g)
+ else:
+ if self.in_out_channels == 71:
+ x_recon = torch.cat([self.exp_decoder(z_p, 1, g), self.pose_decoder(z_p, 1, g)], dim=1)
+ else:
+ x_recon = self.decoder(z_p, 1, g)
+ return x_recon.transpose(1,2), z_p.transpose(1,2)
+
+
+class VAEModel(nn.Module):
+ def __init__(self, in_out_dim=64, audio_in_dim=1024, sqz_prior=False, cond_drop=False, use_prior_flow=True):
+ super().__init__()
+ feat_dim = 64
+ self.blink_embed = nn.Embedding(2, feat_dim)
+ self.audio_in_dim = audio_in_dim
+ cond_dim = feat_dim
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(audio_in_dim, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64),
+ nn.GELU(),
+ nn.Conv1d(64, feat_dim, 3, 1, 1, bias=False)
+ ])
+ self.cond_drop = cond_drop
+ if self.cond_drop:
+ self.dropout = nn.Dropout(0.5)
+
+ self.in_dim, self.out_dim = in_out_dim, in_out_dim
+ self.sqz_prior = sqz_prior
+ self.use_prior_flow = use_prior_flow
+ self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
+ enc_n_layers=8, dec_n_layers=4, gin_channels=cond_dim, strides=[4,],
+ use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior)
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='linear').transpose(1,2))
+ # self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ def num_params(self, model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
+
+ @property
+ def device(self):
+ return self.vae.parameters().__next__().device
+
+ def forward(self, batch, ret, train=True, return_latent=False, temperature=1.):
+ infer = not train
+ mask = batch['y_mask'].to(self.device)
+ mel = batch['audio'].to(self.device)
+ mel = self.downsampler(mel)
+ cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+
+ if self.cond_drop:
+ cond_feat = self.dropout(cond_feat)
+
+ if not infer:
+ exp = batch['y'].to(self.device)
+ x = exp
+ x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+ ret['loss_kl'] = loss_kl
+ if return_latent:
+ ret['m_q'] = m_q
+ ret['z_p'] = z_p
+ return x_recon, loss_kl, m_q, logs_q
+ else:
+ x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+
+ return x_recon
+
+
+class PitchContourVAEModel(nn.Module):
+ def __init__(self, hparams, in_out_dim=64, audio_in_dim=1024, sqz_prior=False, cond_drop=False, use_prior_flow=True):
+ super().__init__()
+ self.hparams = copy.deepcopy(hparams)
+ feat_dim = 128
+ self.audio_in_dim = audio_in_dim
+ self.blink_embed = nn.Embedding(2, feat_dim)
+
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(audio_in_dim, feat_dim, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(feat_dim ),
+ nn.GELU(),
+ nn.Conv1d(feat_dim , feat_dim, 3, 1, 1, bias=False)
+ ])
+
+ self.pitch_embed = Embedding(300, feat_dim, None)
+ self.pitch_encoder = nn.Sequential(*[
+ nn.Conv1d(feat_dim, feat_dim , 3, 1, 1, bias=False),
+ nn.BatchNorm1d(feat_dim),
+ nn.GELU(),
+ nn.Conv1d(feat_dim, feat_dim, 3, 1, 1, bias=False)
+ ])
+
+ cond_dim = feat_dim + feat_dim + feat_dim
+
+ if hparams.get('use_mouth_amp_embed', False):
+ self.mouth_amp_embed = nn.Parameter(torch.randn(feat_dim))
+ cond_dim += feat_dim
+
+ if hparams.get('use_eye_amp_embed', False):
+ self.eye_amp_embed = nn.Parameter(torch.randn(feat_dim))
+ cond_dim += feat_dim
+
+ self.cond_proj = nn.Linear(cond_dim, feat_dim, bias=True)
+
+ self.cond_drop = cond_drop
+ if self.cond_drop:
+ self.dropout = nn.Dropout(0.5)
+
+ self.in_dim, self.out_dim = in_out_dim, in_out_dim
+ self.sqz_prior = sqz_prior
+ self.use_prior_flow = use_prior_flow
+ self.vae = FVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
+ enc_n_layers=8, dec_n_layers=4, gin_channels=feat_dim, strides=[4,],
+ use_prior_glow=self.use_prior_flow, glow_hidden=64, glow_kernel_size=3, glow_n_blocks=4,sqz_prior=sqz_prior)
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ def num_params(self, model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
+
+ @property
+ def device(self):
+ return self.vae.parameters().__next__().device
+
+ def forward(self, batch, ret, train=True, return_latent=False, temperature=1.):
+ infer = not train
+ hparams = self.hparams
+ mask = batch['y_mask'].to(self.device)
+ mel = batch['audio'].to(self.device)
+ f0 = batch['f0'].to(self.device) # [b,t]
+ if 'blink' not in batch:
+ batch['blink'] = torch.zeros([f0.shape[0], f0.shape[1], 1], dtype=torch.long, device=f0.device)
+ blink = batch['blink'].to(self.device)
+ blink_feat = self.blink_embed(blink.squeeze(2))
+
+ blink_feat = self.downsampler(blink_feat)
+ mel = self.downsampler(mel)
+ f0 = self.downsampler(f0.unsqueeze(-1)).squeeze(-1)
+ f0_coarse = f0_to_coarse(f0)
+ pitch_emb = self.pitch_embed(f0_coarse)
+ cond_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+ pitch_feat = self.pitch_encoder(pitch_emb.transpose(1,2)).transpose(1,2)
+
+ cond_feats = [cond_feat, pitch_feat, blink_feat]
+ if hparams.get('use_mouth_amp_embed', False):
+ mouth_amp = batch.get('mouth_amp', torch.ones([f0.shape[0], 1], device=f0.device) * 0.4)
+ mouth_amp_feat = mouth_amp.unsqueeze(1) * self.mouth_amp_embed.unsqueeze(0)
+ mouth_amp_feat = mouth_amp_feat.repeat([1,cond_feat.shape[1],1])
+ cond_feats.append(mouth_amp_feat)
+
+ if hparams.get('use_eye_amp_embed', False):
+ eye_amp = batch.get('eye_amp', torch.ones([f0.shape[0], 1], device=f0.device) * 0.4)
+ eye_amp_feat = eye_amp.unsqueeze(1) * self.eye_amp_embed.unsqueeze(0)
+ eye_amp_feat = eye_amp_feat.repeat([1,cond_feat.shape[1],1])
+ cond_feats.append(eye_amp_feat)
+
+ cond_feat = torch.cat(cond_feats, dim=-1)
+ cond_feat = self.cond_proj(cond_feat)
+
+ if self.cond_drop:
+ cond_feat = self.dropout(cond_feat)
+
+ if not infer:
+ exp = batch['y'].to(self.device)
+ x = exp
+ x_recon, loss_kl, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=cond_feat, infer=False)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+ ret['loss_kl'] = loss_kl
+ if return_latent:
+ ret['m_q'] = m_q
+ ret['z_p'] = z_p
+ return x_recon, loss_kl, m_q, logs_q
+ else:
+ x_recon, z_p = self.vae(x=None, x_mask=mask, g=cond_feat, infer=True, temperature=temperature)
+ x_recon = x_recon * mask.unsqueeze(-1)
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+
+ return x_recon
+
+
+if __name__ == '__main__':
+ model = FVAE(in_out_channels=64, hidden_channels=128, latent_size=32,kernel_size=3, enc_n_layers=6, dec_n_layers=2,
+ gin_channels=80, strides=[4], use_prior_glow=False, glow_hidden=128, glow_kernel_size=3, glow_n_blocks=3)
+ x = torch.rand([8, 64, 1000])
+ x_mask = torch.ones([8,1,1000])
+ g = torch.rand([8, 80, 1000])
+ train_out = model(x,x_mask,g,infer=False)
+ x_recon, loss_kl, z_p, m_q, logs_q = train_out
+ print(" ")
+ infer_out = model(x,x_mask,g,infer=True)
+ x_recon, z_p = infer_out
+ print(" ")
diff --git a/modules/audio2motion/vqvae.py b/modules/audio2motion/vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..310ffc7bf4bf1c5a8c2901163439bba179a968fc
--- /dev/null
+++ b/modules/audio2motion/vqvae.py
@@ -0,0 +1,200 @@
+import scipy
+from scipy import linalg
+from torch.nn import functional as F
+import torch
+from torch import nn
+import numpy as np
+from modules.audio2motion.transformer_models import FFTBlocks
+import modules.audio2motion.utils as utils
+from modules.audio2motion.flow_base import Glow, WN, ResidualCouplingBlock
+import torch.distributions as dist
+from modules.audio2motion.cnn_models import LambdaLayer, LayerNorm
+
+from vector_quantize_pytorch import VectorQuantize
+
+
+class FVAEEncoder(nn.Module):
+ def __init__(self, in_channels, hidden_channels, latent_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0, strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.Conv1d(in_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ if i == 0 else
+ nn.Conv1d(hidden_channels, hidden_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, latent_channels * 2, 1)
+ self.latent_channels = latent_channels
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x_mask = x_mask[:, :, ::np.prod(self.strides)][:, :, :x.shape[-1]]
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ m, logs = torch.split(x, self.latent_channels, dim=1)
+ z = (m + torch.randn_like(m) * torch.exp(logs))
+ return z, m, logs, x_mask
+
+
+class FVAEDecoder(nn.Module):
+ def __init__(self, latent_channels, hidden_channels, out_channels, kernel_size,
+ n_layers, gin_channels=0, p_dropout=0,
+ strides=[4]):
+ super().__init__()
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.pre_net = nn.Sequential(*[
+ nn.ConvTranspose1d(latent_channels, hidden_channels, kernel_size=s, stride=s)
+ if i == 0 else
+ nn.ConvTranspose1d(hidden_channels, hidden_channels, kernel_size=s, stride=s)
+ for i, s in enumerate(strides)
+ ])
+ self.wn = WN(hidden_channels, kernel_size, 1, n_layers, gin_channels, p_dropout)
+ self.out_proj = nn.Conv1d(hidden_channels, out_channels, 1)
+
+ def forward(self, x, x_mask, g):
+ x = self.pre_net(x)
+ x = x * x_mask
+ x = self.wn(x, x_mask, g) * x_mask
+ x = self.out_proj(x)
+ return x
+
+
+class VQVAE(nn.Module):
+ def __init__(self,
+ in_out_channels=64, hidden_channels=256, latent_size=16,
+ kernel_size=3, enc_n_layers=5, dec_n_layers=5, gin_channels=80, strides=[4,],
+ sqz_prior=False):
+ super().__init__()
+ self.in_out_channels = in_out_channels
+ self.strides = strides
+ self.hidden_size = hidden_channels
+ self.latent_size = latent_size
+ self.g_pre_net = nn.Sequential(*[
+ nn.Conv1d(gin_channels, gin_channels, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+ self.encoder = FVAEEncoder(in_out_channels, hidden_channels, hidden_channels, kernel_size,
+ enc_n_layers, gin_channels, strides=strides)
+ # if use_prior_glow:
+ # self.prior_flow = ResidualCouplingBlock(
+ # latent_size, glow_hidden, glow_kernel_size, 1, glow_n_blocks, 4, gin_channels=gin_channels)
+ self.vq = VectorQuantize(dim=hidden_channels, codebook_size=256, codebook_dim=16)
+
+ self.decoder = FVAEDecoder(hidden_channels, hidden_channels, in_out_channels, kernel_size,
+ dec_n_layers, gin_channels, strides=strides)
+ self.prior_dist = dist.Normal(0, 1)
+ self.sqz_prior = sqz_prior
+
+ def forward(self, x=None, x_mask=None, g=None, infer=False, **kwargs):
+ """
+
+ :param x: [B, T, C_in_out]
+ :param x_mask: [B, T]
+ :param g: [B, T, C_g]
+ :return:
+ """
+ x_mask = x_mask[:, None, :] # [B, 1, T]
+ g = g.transpose(1,2) # [B, C_g, T]
+ g_for_sqz = g
+
+ g_sqz = self.g_pre_net(g_for_sqz)
+
+ if not infer:
+ x = x.transpose(1,2) # [B, C, T]
+ z_q, m_q, logs_q, x_mask_sqz = self.encoder(x, x_mask, g_sqz)
+ if self.sqz_prior:
+ z_q = F.interpolate(z_q, scale_factor=1/8)
+ z_p, idx, commit_loss = self.vq(z_q.transpose(1,2))
+ if self.sqz_prior:
+ z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)
+
+ x_recon = self.decoder(z_p.transpose(1,2), x_mask, g)
+ return x_recon.transpose(1,2), commit_loss, z_p.transpose(1,2), m_q.transpose(1,2), logs_q.transpose(1,2)
+ else:
+ bs, t = g_sqz.shape[0], g_sqz.shape[2]
+ if self.sqz_prior:
+ t = t // 8
+ latent_shape = [int(bs * t)]
+ latent_idx = torch.randint(0,256,latent_shape).to(self.vq.codebook.device)
+ # latent_idx = torch.ones_like(latent_idx, dtype=torch.long)
+ # z_p = torch.gather(self.vq.codebook, 0, latent_idx)# self.vq.codebook[latent_idx]
+ z_p = self.vq.codebook[latent_idx]
+ z_p = z_p.reshape([bs, t, -1])
+ z_p = self.vq.project_out(z_p)
+ if self.sqz_prior:
+ z_p = F.interpolate(z_p.transpose(1,2),scale_factor=8).transpose(1,2)
+
+ x_recon = self.decoder(z_p.transpose(1,2), 1, g)
+ return x_recon.transpose(1,2), z_p.transpose(1,2)
+
+
+class VQVAEModel(nn.Module):
+ def __init__(self, in_out_dim=71, sqz_prior=False, enc_no_cond=False):
+ super().__init__()
+ self.mel_encoder = nn.Sequential(*[
+ nn.Conv1d(80, 64, 3, 1, 1, bias=False),
+ nn.BatchNorm1d(64),
+ nn.GELU(),
+ nn.Conv1d(64, 64, 3, 1, 1, bias=False)
+ ])
+ self.in_dim, self.out_dim = in_out_dim, in_out_dim
+ self.sqz_prior = sqz_prior
+ self.enc_no_cond = enc_no_cond
+ self.vae = VQVAE(in_out_channels=in_out_dim, hidden_channels=256, latent_size=16, kernel_size=5,
+ enc_n_layers=8, dec_n_layers=4, gin_channels=64, strides=[4,], sqz_prior=sqz_prior)
+ self.downsampler = LambdaLayer(lambda x: F.interpolate(x.transpose(1,2), scale_factor=0.5, mode='nearest').transpose(1,2))
+
+ @property
+ def device(self):
+ return self.vae.parameters().__next__().device
+
+ def forward(self, batch, ret, log_dict=None, train=True):
+ infer = not train
+ mask = batch['y_mask'].to(self.device)
+ mel = batch['mel'].to(self.device)
+ mel = self.downsampler(mel)
+
+ mel_feat = self.mel_encoder(mel.transpose(1,2)).transpose(1,2)
+ if not infer:
+ exp = batch['exp'].to(self.device)
+ pose = batch['pose'].to(self.device)
+ if self.in_dim == 71:
+ x = torch.cat([exp, pose], dim=-1) # [B, T, C=64 + 7]
+ elif self.in_dim == 64:
+ x = exp
+ elif self.in_dim == 7:
+ x = pose
+ if self.enc_no_cond:
+ x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=torch.zeros_like(mel_feat), infer=False)
+ else:
+ x_recon, loss_commit, z_p, m_q, logs_q = self.vae(x=x, x_mask=mask, g=mel_feat, infer=False)
+ loss_commit = loss_commit.reshape([])
+ ret['pred'] = x_recon
+ ret['mask'] = mask
+ ret['loss_commit'] = loss_commit
+ return x_recon, loss_commit, m_q, logs_q
+ else:
+ x_recon, z_p = self.vae(x=None, x_mask=mask, g=mel_feat, infer=True)
+ return x_recon
+
+ # def __get_feat(self, exp, pose):
+ # diff_exp = exp[:-1, :] - exp[1:, :]
+ # exp_std = (np.std(exp, axis = 0) - self.exp_std_mean) / self.exp_std_std
+ # diff_exp_std = (np.std(diff_exp, axis = 0) - self.exp_diff_std_mean) / self.exp_diff_std_std
+
+ # diff_pose = pose[:-1, :] - pose[1:, :]
+ # diff_pose_std = (np.std(diff_pose, axis = 0) - self.pose_diff_std_mean) / self.pose_diff_std_std
+
+ # return np.concatenate((exp_std, diff_exp_std, diff_pose_std))
+
+ def num_params(self, model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
diff --git a/modules/commons/attention/attentions.py b/modules/commons/attention/attentions.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b2b5bc03732ff17a0cb135e977fbe526dff3341
--- /dev/null
+++ b/modules/commons/attention/attentions.py
@@ -0,0 +1,427 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+import numpy as np
+from typing import Optional, Tuple
+
+
+class ScaledDotProductAttention(nn.Module):
+ """
+ Scaled Dot-Product Attention proposed in "Attention Is All You Need"
+ Compute the dot products of the query with all keys, divide each by sqrt(dim),
+ and apply a softmax function to obtain the weights on the values
+ Args: dim, mask
+ dim (int): dimention of attention
+ mask (torch.Tensor): tensor containing indices to be masked
+ Inputs: query, key, value, mask
+ - **query** (batch, q_len, d_model): tensor containing projection vector for decoder.
+ - **key** (batch, k_len, d_model): tensor containing projection vector for encoder.
+ - **value** (batch, v_len, d_model): tensor containing features of the encoded input sequence.
+ - **mask** (-): tensor containing indices to be masked
+ Returns: context, attn
+ - **context**: tensor containing the context vector from attention mechanism.
+ - **attn**: tensor containing the attention (alignment) from the encoder outputs.
+ """
+ def __init__(self, dim: int):
+ super(ScaledDotProductAttention, self).__init__()
+ self.sqrt_dim = np.sqrt(dim)
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
+ score = torch.bmm(query, key.transpose(1, 2)) / self.sqrt_dim
+
+ if mask is not None:
+ score.masked_fill_(mask.view(score.size()), -float('Inf'))
+
+ attn = F.softmax(score, -1)
+ context = torch.bmm(attn, value)
+ return context, attn
+
+
+class DotProductAttention(nn.Module):
+ """
+ Compute the dot products of the query with all values and apply a softmax function to obtain the weights on the values
+ """
+ def __init__(self, hidden_dim):
+ super(DotProductAttention, self).__init__()
+
+ def forward(self, query: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, hidden_dim, input_size = query.size(0), query.size(2), value.size(1)
+
+ score = torch.bmm(query, value.transpose(1, 2))
+ attn = F.softmax(score.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
+ context = torch.bmm(attn, value)
+
+ return context, attn
+
+
+class AdditiveAttention(nn.Module):
+ """
+ Applies a additive attention (bahdanau) mechanism on the output features from the decoder.
+ Additive attention proposed in "Neural Machine Translation by Jointly Learning to Align and Translate" paper.
+ Args:
+ hidden_dim (int): dimesion of hidden state vector
+ Inputs: query, value
+ - **query** (batch_size, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch_size, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ Returns: context, attn
+ - **context**: tensor containing the context vector from attention mechanism.
+ - **attn**: tensor containing the alignment from the encoder outputs.
+ Reference:
+ - **Neural Machine Translation by Jointly Learning to Align and Translate**: https://arxiv.org/abs/1409.0473
+ """
+ def __init__(self, hidden_dim: int) -> None:
+ super(AdditiveAttention, self).__init__()
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.key_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
+ self.score_proj = nn.Linear(hidden_dim, 1)
+
+ def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tuple[Tensor, Tensor]:
+ score = self.score_proj(torch.tanh(self.key_proj(key) + self.query_proj(query) + self.bias)).squeeze(-1)
+ attn = F.softmax(score, dim=-1)
+ context = torch.bmm(attn.unsqueeze(1), value)
+ return context, attn
+
+
+class LocationAwareAttention(nn.Module):
+ """
+ Applies a location-aware attention mechanism on the output features from the decoder.
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
+ The location-aware attention mechanism is performing well in speech recognition tasks.
+ We refer to implementation of ClovaCall Attention style.
+ Args:
+ hidden_dim (int): dimesion of hidden state vector
+ smoothing (bool): flag indication whether to use smoothing or not.
+ Inputs: query, value, last_attn, smoothing
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
+ Reference:
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
+ - **ClovaCall**: https://github.com/clovaai/ClovaCall/blob/master/las.pytorch/models/attention.py
+ """
+ def __init__(self, hidden_dim: int, smoothing: bool = True) -> None:
+ super(LocationAwareAttention, self).__init__()
+ self.hidden_dim = hidden_dim
+ self.conv1d = nn.Conv1d(in_channels=1, out_channels=hidden_dim, kernel_size=3, padding=1)
+ self.query_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.value_proj = nn.Linear(hidden_dim, hidden_dim, bias=False)
+ self.score_proj = nn.Linear(hidden_dim, 1, bias=True)
+ self.bias = nn.Parameter(torch.rand(hidden_dim).uniform_(-0.1, 0.1))
+ self.smoothing = smoothing
+
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, hidden_dim, seq_len = query.size(0), query.size(2), value.size(1)
+
+ # Initialize previous attention (alignment) to zeros
+ if last_attn is None:
+ last_attn = value.new_zeros(batch_size, seq_len)
+
+ conv_attn = torch.transpose(self.conv1d(last_attn.unsqueeze(1)), 1, 2)
+ score = self.score_proj(torch.tanh(
+ self.query_proj(query.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
+ + self.value_proj(value.reshape(-1, hidden_dim)).view(batch_size, -1, hidden_dim)
+ + conv_attn
+ + self.bias
+ )).squeeze(dim=-1)
+
+ if self.smoothing:
+ score = torch.sigmoid(score)
+ attn = torch.div(score, score.sum(dim=-1).unsqueeze(dim=-1))
+ else:
+ attn = F.softmax(score, dim=-1)
+
+ context = torch.bmm(attn.unsqueeze(dim=1), value).squeeze(dim=1) # Bx1xT X BxTxD => Bx1xD => BxD
+
+ return context, attn
+
+
+class MultiHeadLocationAwareAttention(nn.Module):
+ """
+ Applies a multi-headed location-aware attention mechanism on the output features from the decoder.
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
+ The location-aware attention mechanism is performing well in speech recognition tasks.
+ In the above paper applied a signle head, but we applied multi head concept.
+ Args:
+ hidden_dim (int): The number of expected features in the output
+ num_heads (int): The number of heads. (default: )
+ conv_out_channel (int): The number of out channel in convolution
+ Inputs: query, value, prev_attn
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ - **prev_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s attention (alignment)
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the feature from encoder outputs
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
+ Reference:
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
+ """
+ def __init__(self, hidden_dim: int, num_heads: int = 8, conv_out_channel: int = 10) -> None:
+ super(MultiHeadLocationAwareAttention, self).__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.dim = int(hidden_dim / num_heads)
+ self.conv1d = nn.Conv1d(num_heads, conv_out_channel, kernel_size=3, padding=1)
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
+ self.score_proj = nn.Linear(self.dim, 1, bias=True)
+ self.bias = nn.Parameter(torch.rand(self.dim).uniform_(-0.1, 0.1))
+
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, seq_len = value.size(0), value.size(1)
+
+ if last_attn is None:
+ last_attn = value.new_zeros(batch_size, self.num_heads, seq_len)
+
+ loc_energy = torch.tanh(self.loc_proj(self.conv1d(last_attn).transpose(1, 2)))
+ loc_energy = loc_energy.unsqueeze(1).repeat(1, self.num_heads, 1, 1).view(-1, seq_len, self.dim)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.dim).permute(0, 2, 1, 3)
+ query = query.contiguous().view(-1, 1, self.dim)
+ value = value.contiguous().view(-1, seq_len, self.dim)
+
+ score = self.score_proj(torch.tanh(value + query + loc_energy + self.bias)).squeeze(2)
+ attn = F.softmax(score, dim=1)
+
+ value = value.view(batch_size, seq_len, self.num_heads, self.dim).permute(0, 2, 1, 3)
+ value = value.contiguous().view(-1, seq_len, self.dim)
+
+ context = torch.bmm(attn.unsqueeze(1), value).view(batch_size, -1, self.num_heads * self.dim)
+ attn = attn.view(batch_size, self.num_heads, -1)
+
+ return context, attn
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ Multi-Head Attention proposed in "Attention Is All You Need"
+ Instead of performing a single attention function with d_model-dimensional keys, values, and queries,
+ project the queries, keys and values h times with different, learned linear projections to d_head dimensions.
+ These are concatenated and once again projected, resulting in the final values.
+ Multi-head attention allows the model to jointly attend to information from different representation
+ subspaces at different positions.
+ MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_o
+ where head_i = Attention(Q · W_q, K · W_k, V · W_v)
+ Args:
+ d_model (int): The dimension of keys / values / quries (default: 512)
+ num_heads (int): The number of attention heads. (default: 8)
+ Inputs: query, key, value, mask
+ - **query** (batch, q_len, d_model): In transformer, three different ways:
+ Case 1: come from previoys decoder layer
+ Case 2: come from the input embedding
+ Case 3: come from the output embedding (masked)
+ - **key** (batch, k_len, d_model): In transformer, three different ways:
+ Case 1: come from the output of the encoder
+ Case 2: come from the input embeddings
+ Case 3: come from the output embedding (masked)
+ - **value** (batch, v_len, d_model): In transformer, three different ways:
+ Case 1: come from the output of the encoder
+ Case 2: come from the input embeddings
+ Case 3: come from the output embedding (masked)
+ - **mask** (-): tensor containing indices to be masked
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features.
+ - **attn** (batch * num_heads, v_len): tensor containing the attention (alignment) from the encoder outputs.
+ """
+ def __init__(self, d_model: int = 512, num_heads: int = 8):
+ super(MultiHeadAttention, self).__init__()
+
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
+
+ self.d_head = int(d_model / num_heads)
+ self.num_heads = num_heads
+ self.scaled_dot_attn = ScaledDotProductAttention(self.d_head)
+ self.query_proj = nn.Linear(d_model, self.d_head * num_heads)
+ self.key_proj = nn.Linear(d_model, self.d_head * num_heads)
+ self.value_proj = nn.Linear(d_model, self.d_head * num_heads)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ mask: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Tensor]:
+ batch_size = value.size(0)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) # BxQ_LENxNxD
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head) # BxK_LENxNxD
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head) # BxV_LENxNxD
+
+ query = query.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxQ_LENxD
+ key = key.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxK_LENxD
+ value = value.permute(2, 0, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.d_head) # BNxV_LENxD
+
+ if mask is not None:
+ mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) # BxNxQ_LENxK_LEN
+
+ context, attn = self.scaled_dot_attn(query, key, value, mask)
+
+ context = context.view(self.num_heads, batch_size, -1, self.d_head)
+ context = context.permute(1, 2, 0, 3).contiguous().view(batch_size, -1, self.num_heads * self.d_head) # BxTxND
+
+ return context, attn
+
+
+class RelativeMultiHeadAttention(nn.Module):
+ """
+ Multi-head attention with relative positional encoding.
+ This concept was proposed in the "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
+ Args:
+ d_model (int): The dimension of model
+ num_heads (int): The number of attention heads.
+ dropout_p (float): probability of dropout
+ Inputs: query, key, value, pos_embedding, mask
+ - **query** (batch, time, dim): Tensor containing query vector
+ - **key** (batch, time, dim): Tensor containing key vector
+ - **value** (batch, time, dim): Tensor containing value vector
+ - **pos_embedding** (batch, time, dim): Positional embedding tensor
+ - **mask** (batch, 1, time2) or (batch, time1, time2): Tensor containing indices to be masked
+ Returns:
+ - **outputs**: Tensor produces by relative multi head attention module.
+ """
+ def __init__(
+ self,
+ d_model: int = 512,
+ num_heads: int = 16,
+ dropout_p: float = 0.1,
+ ):
+ super(RelativeMultiHeadAttention, self).__init__()
+ assert d_model % num_heads == 0, "d_model % num_heads should be zero."
+ self.d_model = d_model
+ self.d_head = int(d_model / num_heads)
+ self.num_heads = num_heads
+ self.sqrt_dim = math.sqrt(d_model)
+
+ self.query_proj = nn.Linear(d_model, d_model)
+ self.key_proj = nn.Linear(d_model, d_model)
+ self.value_proj = nn.Linear(d_model, d_model)
+ self.pos_proj = nn.Linear(d_model, d_model, bias=False)
+
+ self.dropout = nn.Dropout(p=dropout_p)
+ self.u_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
+ self.v_bias = nn.Parameter(torch.Tensor(self.num_heads, self.d_head))
+ torch.nn.init.xavier_uniform_(self.u_bias)
+ torch.nn.init.xavier_uniform_(self.v_bias)
+
+ self.out_proj = nn.Linear(d_model, d_model)
+
+ def forward(
+ self,
+ query: Tensor,
+ key: Tensor,
+ value: Tensor,
+ pos_embedding: Tensor,
+ mask: Optional[Tensor] = None,
+ ) -> Tensor:
+ batch_size = value.size(0)
+
+ query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head)
+ key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
+ value = self.value_proj(value).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3)
+ pos_embedding = self.pos_proj(pos_embedding).view(batch_size, -1, self.num_heads, self.d_head)
+
+ content_score = torch.matmul((query + self.u_bias).transpose(1, 2), key.transpose(2, 3))
+ pos_score = torch.matmul((query + self.v_bias).transpose(1, 2), pos_embedding.permute(0, 2, 3, 1))
+ pos_score = self._compute_relative_positional_encoding(pos_score)
+
+ score = (content_score + pos_score) / self.sqrt_dim
+
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ score.masked_fill_(mask, -1e9)
+
+ attn = F.softmax(score, -1)
+ attn = self.dropout(attn)
+
+ context = torch.matmul(attn, value).transpose(1, 2)
+ context = context.contiguous().view(batch_size, -1, self.d_model)
+
+ return self.out_proj(context)
+
+ def _compute_relative_positional_encoding(self, pos_score: Tensor) -> Tensor:
+ batch_size, num_heads, seq_length1, seq_length2 = pos_score.size()
+ zeros = pos_score.new_zeros(batch_size, num_heads, seq_length1, 1)
+ padded_pos_score = torch.cat([zeros, pos_score], dim=-1)
+
+ padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1)
+ pos_score = padded_pos_score[:, :, 1:].view_as(pos_score)
+
+ return pos_score
+
+
+class CustomizingAttention(nn.Module):
+ r"""
+ Customizing Attention
+ Applies a multi-head + location-aware attention mechanism on the output features from the decoder.
+ Multi-head attention proposed in "Attention Is All You Need" paper.
+ Location-aware attention proposed in "Attention-Based Models for Speech Recognition" paper.
+ I combined these two attention mechanisms as custom.
+ Args:
+ hidden_dim (int): The number of expected features in the output
+ num_heads (int): The number of heads. (default: )
+ conv_out_channel (int): The dimension of convolution
+ Inputs: query, value, last_attn
+ - **query** (batch, q_len, hidden_dim): tensor containing the output features from the decoder.
+ - **value** (batch, v_len, hidden_dim): tensor containing features of the encoded input sequence.
+ - **last_attn** (batch_size * num_heads, v_len): tensor containing previous timestep`s alignment
+ Returns: output, attn
+ - **output** (batch, output_len, dimensions): tensor containing the attended output features from the decoder.
+ - **attn** (batch * num_heads, v_len): tensor containing the alignment from the encoder outputs.
+ Reference:
+ - **Attention Is All You Need**: https://arxiv.org/abs/1706.03762
+ - **Attention-Based Models for Speech Recognition**: https://arxiv.org/abs/1506.07503
+ """
+
+ def __init__(self, hidden_dim: int, num_heads: int = 4, conv_out_channel: int = 10) -> None:
+ super(CustomizingAttention, self).__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ self.dim = int(hidden_dim / num_heads)
+ self.scaled_dot_attn = ScaledDotProductAttention(self.dim)
+ self.conv1d = nn.Conv1d(1, conv_out_channel, kernel_size=3, padding=1)
+ self.query_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=True)
+ self.value_proj = nn.Linear(hidden_dim, self.dim * num_heads, bias=False)
+ self.loc_proj = nn.Linear(conv_out_channel, self.dim, bias=False)
+ self.bias = nn.Parameter(torch.rand(self.dim * num_heads).uniform_(-0.1, 0.1))
+
+ def forward(self, query: Tensor, value: Tensor, last_attn: Tensor) -> Tuple[Tensor, Tensor]:
+ batch_size, q_len, v_len = value.size(0), query.size(1), value.size(1)
+
+ if last_attn is None:
+ last_attn = value.new_zeros(batch_size * self.num_heads, v_len)
+
+ loc_energy = self.get_loc_energy(last_attn, batch_size, v_len) # get location energy
+
+ query = self.query_proj(query).view(batch_size, q_len, self.num_heads * self.dim)
+ value = self.value_proj(value).view(batch_size, v_len, self.num_heads * self.dim) + loc_energy + self.bias
+
+ query = query.view(batch_size, q_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
+ value = value.view(batch_size, v_len, self.num_heads, self.dim).permute(2, 0, 1, 3)
+ query = query.contiguous().view(-1, q_len, self.dim)
+ value = value.contiguous().view(-1, v_len, self.dim)
+
+ context, attn = self.scaled_dot_attn(query, value)
+ attn = attn.squeeze()
+
+ context = context.view(self.num_heads, batch_size, q_len, self.dim).permute(1, 2, 0, 3)
+ context = context.contiguous().view(batch_size, q_len, -1)
+
+ return context, attn
+
+ def get_loc_energy(self, last_attn: Tensor, batch_size: int, v_len: int) -> Tensor:
+ conv_feat = self.conv1d(last_attn.unsqueeze(1))
+ conv_feat = conv_feat.view(batch_size, self.num_heads, -1, v_len).permute(0, 1, 3, 2)
+
+ loc_energy = self.loc_proj(conv_feat).view(batch_size, self.num_heads, v_len, self.dim)
+ loc_energy = loc_energy.permute(0, 2, 1, 3).reshape(batch_size, v_len, self.num_heads * self.dim)
+
+ return loc_energy
\ No newline at end of file
diff --git a/modules/commons/attention/simple_attention.py b/modules/commons/attention/simple_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8c451ce9324491a5c9fa8546b0fe98dc146c6c1
--- /dev/null
+++ b/modules/commons/attention/simple_attention.py
@@ -0,0 +1,50 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def split_heads(x, num_heads):
+ """ Split heads
+ :param x: A tensor with shape [batch, length, channels]
+ :param num_heads: An integer
+ :returns: A tensor with shape [batch, heads, length, channels / heads]
+ """
+ assert x.shape[-1] % num_heads == 0, str(x.shape)
+ return x.reshape(x.shape[:-1] + (num_heads, x.shape[-1] // num_heads)).permute(0, 2, 1, 3)
+
+
+def combine_heads(x):
+ """ Combine heads
+ :param x: A tensor with shape [batch, heads, length, channels]
+ :returns: A tensor with shape [batch, length, heads * channels]
+ """
+ x = x.permute([0, 2, 1, 3])
+ return x.reshape(x.shape[:-2] + (x.shape[-1] * x.shape[-2],))
+
+
+class SimpleAttention(nn.Module):
+ def __init__(self, query_size=192, key_size=192, value_size=192, num_heads=1):
+ super(SimpleAttention, self).__init__()
+ self.q_transform = nn.Linear(query_size, query_size, bias=False)
+ self.k_transform = nn.Linear(key_size, query_size, bias=False)
+ self.v_transform = nn.Linear(value_size, query_size, bias=False)
+ self.output_transform = nn.Linear(query_size, query_size, bias=False)
+ self.query_size = query_size
+ self.key_size = key_size
+ self.value_size = value_size
+ self.num_heads = num_heads
+
+ def forward(self, query, key, value, attn_mask=None, bias=None):
+ q = self.q_transform(query)
+ k = self.k_transform(key)
+ v = self.v_transform(value)
+
+ logits = torch.bmm(q, k.transpose(1, 2)) # [batch, length_q, length_k]
+ if bias is not None:
+ logits += bias
+ if attn_mask is not None:
+ logits = logits + attn_mask * -1e9
+ weights = F.softmax(logits, dim=-1)
+ out = torch.bmm(weights, v)
+ out = self.output_transform(out)
+ return out, weights
diff --git a/modules/commons/conformer/conformer.py b/modules/commons/conformer/conformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b5d719a7b67ef745f178cf44e8c452191a3a2a
--- /dev/null
+++ b/modules/commons/conformer/conformer.py
@@ -0,0 +1,97 @@
+import torch
+from torch import nn
+from .espnet_positional_embedding import RelPositionalEncoding
+from .espnet_transformer_attn import RelPositionMultiHeadedAttention, MultiHeadedAttention
+from .layers import Swish, ConvolutionModule, EncoderLayer, MultiLayeredConv1d
+from ..layers import Embedding
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+class ConformerLayers(nn.Module):
+ def __init__(self, hidden_size, num_layers, kernel_size=9, dropout=0.0, num_heads=4, use_last_norm=True):
+ super().__init__()
+ self.use_last_norm = use_last_norm
+ self.layers = nn.ModuleList()
+ positionwise_layer = MultiLayeredConv1d
+ positionwise_layer_args = (hidden_size, hidden_size * 4, 1, dropout)
+ self.encoder_layers = nn.ModuleList([EncoderLayer(
+ hidden_size,
+ MultiHeadedAttention(num_heads, hidden_size, 0.0),
+ positionwise_layer(*positionwise_layer_args),
+ positionwise_layer(*positionwise_layer_args),
+ ConvolutionModule(hidden_size, kernel_size, Swish()),
+ dropout,
+ ) for _ in range(num_layers)])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(hidden_size)
+ else:
+ self.layer_norm = nn.Linear(hidden_size, hidden_size)
+
+ def forward(self, x, x_mask):
+ """
+
+ :param x: [B, T, H]
+ :param padding_mask: [B, T]
+ :return: [B, T, H]
+ """
+ for l in self.encoder_layers:
+ x, mask = l(x, x_mask)
+ x = self.layer_norm(x) * x_mask
+ return x
+
+
+class ConformerEncoder(ConformerLayers):
+ def __init__(self, hidden_size, dict_size=0, in_size=0, strides=[2,2], num_layers=None):
+ conformer_enc_kernel_size = 9
+ super().__init__(hidden_size, num_layers, conformer_enc_kernel_size)
+ self.dict_size = dict_size
+ if dict_size != 0:
+ self.embed = Embedding(dict_size, hidden_size, padding_idx=0)
+ else:
+ self.seq_proj_in = torch.nn.Linear(in_size, hidden_size)
+ self.seq_proj_out = torch.nn.Linear(hidden_size, in_size)
+ self.mel_in = torch.nn.Linear(160, hidden_size)
+ self.mel_pre_net = torch.nn.Sequential(*[
+ torch.nn.Conv1d(hidden_size, hidden_size, kernel_size=s * 2, stride=s, padding=s // 2)
+ for i, s in enumerate(strides)
+ ])
+
+ def forward(self, seq_out, mels_timbre, other_embeds=0):
+ """
+
+ :param src_tokens: [B, T]
+ :return: [B x T x C]
+ """
+ x_lengths = (seq_out > 0).long().sum(-1)
+ x = seq_out
+ if self.dict_size != 0:
+ x = self.embed(x) + other_embeds # [B, T, H]
+ else:
+ x = self.seq_proj_in(x) + other_embeds # [B, T, H]
+ mels_timbre = self.mel_in(mels_timbre).transpose(1, 2)
+ mels_timbre = self.mel_pre_net(mels_timbre).transpose(1, 2)
+
+ T_out = x.size(1)
+ if self.dict_size != 0:
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths + mels_timbre.size(1), x.size(1) + mels_timbre.size(1)), 2).to(x.dtype)
+ else:
+ x_mask = torch.cat((torch.ones(x.size(0), mels_timbre.size(1), 1).to(x.device), (x.abs().sum(2) > 0).float()[:, :, None]), dim=1)
+ x = torch.cat((mels_timbre, x), 1)
+ x = super(ConformerEncoder, self).forward(x, x_mask)
+ if self.dict_size != 0:
+ x = x[:, -T_out:, :]
+ else:
+ x = self.seq_proj_out(x[:, -T_out:, :])
+ return x
+
+
+class ConformerDecoder(ConformerLayers):
+ def __init__(self, hidden_size, num_layers):
+ conformer_dec_kernel_size = 9
+ super().__init__(hidden_size, num_layers, conformer_dec_kernel_size)
diff --git a/modules/commons/conformer/espnet_positional_embedding.py b/modules/commons/conformer/espnet_positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b9b5549cc779d1ea67f052b1c99cad92365503
--- /dev/null
+++ b/modules/commons/conformer/espnet_positional_embedding.py
@@ -0,0 +1,113 @@
+import math
+import torch
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ reverse (bool): Whether to reverse the input position.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+ """Construct an PositionalEncoding object."""
+ super(PositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.reverse = reverse
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class ScaledPositionalEncoding(PositionalEncoding):
+ """Scaled positional encoding module.
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ self.alpha.data = torch.tensor(1.0)
+
+ def forward(self, x):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x + self.alpha * self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self, x):
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[:, : x.size(1)]
+ return self.dropout(x), self.dropout(pos_emb)
\ No newline at end of file
diff --git a/modules/commons/conformer/espnet_transformer_attn.py b/modules/commons/conformer/espnet_transformer_attn.py
new file mode 100644
index 0000000000000000000000000000000000000000..48a52aacbaf07ef191c28baf12123036c2bc6b10
--- /dev/null
+++ b/modules/commons/conformer/espnet_transformer_attn.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadedAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+ if not self.flash:
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
+
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+
+ B, Nh, Nt, E = q.shape
+ q = q / math.sqrt(E)
+ mask = mask * mask[:, None, :, 0]
+ mask = mask[:, None]
+ if self.flash:
+ attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False, attn_mask=mask)
+ else:
+ attn = self.slow_attn(q, k, v, is_causal=False, attn_mask=mask)
+ attn = attn.transpose(1, 2)
+ attn = attn.reshape(B, -1, self.h * self.d_k)
+ attn = self.linear_out(attn)
+ return attn
+
+ def slow_attn(self, Q, K, V, is_causal, attn_mask):
+ attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype == torch.bool else attn_mask
+ attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1)
+ return attn_weight @ V
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate)
+ # linear transformation for positional ecoding
+ self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x, zero_triu=False):
+ """Compute relative positinal encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, size).
+ zero_triu (bool): If true, return the lower triangular part of the matrix.
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)
+
+ if zero_triu:
+ ones = torch.ones((x.size(2), x.size(3)))
+ x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+ return x
+
+ def forward(self, query, key, value, pos_emb, mask):
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ pos_emb (torch.Tensor): Positional embedding tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k
+ ) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask)
diff --git a/modules/commons/conformer/layers.py b/modules/commons/conformer/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd7f501667e0b8aa816373d843adc816748e73a8
--- /dev/null
+++ b/modules/commons/conformer/layers.py
@@ -0,0 +1,260 @@
+from torch import nn
+import torch
+
+from modules.commons.layers import LayerNorm
+
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernerl size of conv layers.
+ """
+
+ def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+ """Construct an ConvolutionModule object."""
+ super(ConvolutionModule, self).__init__()
+ # kernerl_size should be a odd number for 'SAME' padding
+ assert (kernel_size - 1) % 2 == 0
+
+ self.pointwise_conv1 = nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.depthwise_conv = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ groups=channels,
+ bias=bias,
+ )
+ self.norm = nn.BatchNorm1d(channels)
+ self.pointwise_conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ )
+ self.activation = activation
+
+ def forward(self, x):
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x)
+
+ return x.transpose(1, 2)
+
+
+class MultiLayeredConv1d(torch.nn.Module):
+ """Multi-layered conv1d for Transformer block.
+ This is a module of multi-leyered conv1d designed
+ to replace positionwise feed-forward network
+ in Transforner block, which is introduced in
+ `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
+ https://arxiv.org/pdf/1905.09263.pdf
+ """
+
+ def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
+ """Initialize MultiLayeredConv1d module.
+ Args:
+ in_chans (int): Number of input channels.
+ hidden_chans (int): Number of hidden channels.
+ kernel_size (int): Kernel size of conv1d.
+ dropout_rate (float): Dropout rate.
+ """
+ super(MultiLayeredConv1d, self).__init__()
+ self.w_1 = torch.nn.Conv1d(
+ in_chans,
+ hidden_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.w_2 = torch.nn.Conv1d(
+ hidden_chans,
+ in_chans,
+ kernel_size,
+ stride=1,
+ padding=(kernel_size - 1) // 2,
+ )
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ def forward(self, x):
+ """Calculate forward propagation.
+ Args:
+ x (torch.Tensor): Batch of input tensors (B, T, in_chans).
+ Returns:
+ torch.Tensor: Batch of output tensors (B, T, hidden_chans).
+ """
+ x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
+ return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
+
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x):
+ """Return Swich activation function."""
+ return x * torch.sigmoid(x)
+
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ feed_forward_macaron,
+ conv_module,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = LayerNorm(size) # for the FNN module
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = LayerNorm(size)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = LayerNorm(size) # for the CNN module
+ self.norm_final = LayerNorm(size) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+
+ def forward(self, x_input, mask, cache=None):
+ """Compute encoded features.
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+ """
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+
+ if self.concat_after:
+ x_concat = torch.cat((x, x_att), dim=-1)
+ x = residual + self.concat_linear(x_concat)
+ else:
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x = residual + self.dropout(self.conv_module(x))
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
diff --git a/modules/commons/conv.py b/modules/commons/conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..a601f06042c2db37ace11ce72149101a9b8aefe4
--- /dev/null
+++ b/modules/commons/conv.py
@@ -0,0 +1,198 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from modules.commons.layers import LayerNorm, Embedding
+
+
+class LambdaLayer(nn.Module):
+ def __init__(self, lambd):
+ super(LambdaLayer, self).__init__()
+ self.lambd = lambd
+
+ def forward(self, x):
+ return self.lambd(x)
+
+
+def init_weights_func(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv1d") != -1:
+ torch.nn.init.xavier_uniform_(m.weight)
+
+
+class ResidualBlock(nn.Module):
+ """Implements conv->PReLU->norm n-times"""
+
+ def __init__(self, channels, kernel_size, dilation, n=2, norm_type='bn', dropout=0.0,
+ c_multiple=2, ln_eps=1e-12, left_pad=False):
+ super(ResidualBlock, self).__init__()
+
+ if norm_type == 'bn':
+ norm_builder = lambda: nn.BatchNorm1d(channels)
+ elif norm_type == 'in':
+ norm_builder = lambda: nn.InstanceNorm1d(channels, affine=True)
+ elif norm_type == 'gn':
+ norm_builder = lambda: nn.GroupNorm(8, channels)
+ elif norm_type == 'ln':
+ norm_builder = lambda: LayerNorm(channels, dim=1, eps=ln_eps)
+ else:
+ norm_builder = lambda: nn.Identity()
+
+ if left_pad:
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.ConstantPad1d(((dilation * (kernel_size - 1)) // 2 * 2, 0), 0),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation, padding=0),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ nn.GELU(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, padding_mode='reflect'),
+ )
+ for i in range(n)
+ ]
+ else:
+ self.blocks = [
+ nn.Sequential(
+ norm_builder(),
+ nn.Conv1d(channels, c_multiple * channels, kernel_size, dilation=dilation,
+ padding=(dilation * (kernel_size - 1)) // 2, padding_mode='reflect'),
+ LambdaLayer(lambda x: x * kernel_size ** -0.5),
+ nn.GELU(),
+ nn.Conv1d(c_multiple * channels, channels, 1, dilation=dilation, padding_mode='reflect'),
+ )
+ for i in range(n)
+ ]
+
+ self.blocks = nn.ModuleList(self.blocks)
+ self.dropout = dropout
+
+ def forward(self, x):
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ for b in self.blocks:
+ x_ = b(x)
+ if self.dropout > 0 and self.training:
+ x_ = F.dropout(x_, self.dropout, training=self.training)
+ x = x + x_
+ x = x * nonpadding
+ return x
+
+
+class ConvBlocks(nn.Module):
+ """Decodes the expanded phoneme encoding into spectrograms"""
+
+ def __init__(self, hidden_size, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5,
+ init_weights=True, is_BTC=True, num_layers=None, post_net_kernel=3,
+ left_pad=False, c_in=None):
+ super(ConvBlocks, self).__init__()
+ self.is_BTC = is_BTC
+ if num_layers is not None:
+ dilations = [1] * num_layers
+ self.res_blocks = nn.Sequential(
+ *[ResidualBlock(hidden_size, kernel_size, d,
+ n=layers_in_block, norm_type=norm_type, c_multiple=c_multiple,
+ dropout=dropout, ln_eps=ln_eps, left_pad=left_pad)
+ for d in dilations],
+ )
+ if norm_type == 'bn':
+ norm = nn.BatchNorm1d(hidden_size)
+ elif norm_type == 'in':
+ norm = nn.InstanceNorm1d(hidden_size, affine=True)
+ elif norm_type == 'gn':
+ norm = nn.GroupNorm(8, hidden_size)
+ elif norm_type == 'ln':
+ norm = LayerNorm(hidden_size, dim=1, eps=ln_eps)
+ self.last_norm = norm
+ if left_pad:
+ self.post_net1 = nn.Sequential(
+ nn.ConstantPad1d((post_net_kernel // 2 * 2, 0), 0),
+ nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel, padding=0),
+ )
+ else:
+ self.post_net1 = nn.Conv1d(hidden_size, out_dims, kernel_size=post_net_kernel,
+ padding=post_net_kernel // 2, padding_mode='reflect')
+ self.c_in = c_in
+ if c_in is not None:
+ self.in_conv = nn.Conv1d(c_in, hidden_size, kernel_size=1, padding_mode='reflect')
+ if init_weights:
+ self.apply(init_weights_func)
+
+ def forward(self, x, nonpadding=None):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ if self.c_in is not None:
+ x = self.in_conv(x)
+ if nonpadding is None:
+ nonpadding = (x.abs().sum(1) > 0).float()[:, None, :]
+ elif self.is_BTC:
+ nonpadding = nonpadding.transpose(1, 2)
+ x = self.res_blocks(x) * nonpadding
+ x = self.last_norm(x) * nonpadding
+ x = self.post_net1(x) * nonpadding
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ return x
+
+
+class TextConvEncoder(ConvBlocks):
+ def __init__(self, dict_size, hidden_size, out_dims, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, num_layers=None, post_net_kernel=3):
+ super().__init__(hidden_size, out_dims, dilations, kernel_size,
+ norm_type, layers_in_block, c_multiple,
+ dropout, ln_eps, init_weights, num_layers=num_layers,
+ post_net_kernel=post_net_kernel)
+ self.dict_size = dict_size
+ if dict_size > 0:
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
+ self.embed_scale = math.sqrt(hidden_size)
+
+ def forward(self, txt_tokens, other_embeds=0):
+ """
+
+ :param txt_tokens: [B, T]
+ :return: {
+ 'encoder_out': [B x T x C]
+ }
+ """
+ if self.dict_size > 0:
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
+ else:
+ x = txt_tokens
+ x = x + other_embeds
+ return super().forward(x, nonpadding=(txt_tokens > 0).float()[..., None])
+
+
+class ConditionalConvBlocks(ConvBlocks):
+ def __init__(self, hidden_size, c_cond, c_out, dilations, kernel_size,
+ norm_type='ln', layers_in_block=2, c_multiple=2,
+ dropout=0.0, ln_eps=1e-5, init_weights=True, is_BTC=True, num_layers=None):
+ super().__init__(hidden_size, c_out, dilations, kernel_size,
+ norm_type, layers_in_block, c_multiple,
+ dropout, ln_eps, init_weights, is_BTC=False, num_layers=num_layers)
+ self.g_prenet = nn.Conv1d(c_cond, hidden_size, 3, padding=1, padding_mode='reflect')
+ self.is_BTC_ = is_BTC
+ if init_weights:
+ self.g_prenet.apply(init_weights_func)
+
+ def forward(self, x, cond, nonpadding=None):
+ if self.is_BTC_:
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2)
+ if nonpadding is not None:
+ nonpadding = nonpadding.transpose(1, 2)
+ if nonpadding is None:
+ nonpadding = x.abs().sum(1)[:, None]
+ x = x + self.g_prenet(cond)
+ x = x * nonpadding
+ x = super(ConditionalConvBlocks, self).forward(x) # input needs to be BTC
+ if self.is_BTC_:
+ x = x.transpose(1, 2)
+ return x
diff --git a/modules/commons/gpt.py b/modules/commons/gpt.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e40349d0fae65107206033583d2cdc55289d09
--- /dev/null
+++ b/modules/commons/gpt.py
@@ -0,0 +1,474 @@
+import math
+import torch
+from typing import Optional, Tuple
+from torch import nn
+from utils.nn.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
+import torch.nn.functional as F
+
+# from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
+
+DEFAULT_MAX_SOURCE_POSITIONS = 20000
+DEFAULT_MAX_TARGET_POSITIONS = 20000
+
+
+class RotaryEmbeddings(nn.Module):
+ cos: torch.Tensor
+ sin: torch.Tensor
+ theta: torch.Tensor
+
+ def __init__(
+ self,
+ width: int,
+ *,
+ seq_len: int = 4000,
+ base: int = 10000,
+ device: Optional[torch.device] = None,
+ ):
+ """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
+ will be precomputed for up to 'seq _len' positions. The embedding
+ will be recomputed when a longer sequence is found in the input.
+
+ :param width:
+ Rotary embedding dimensionality, must be even.
+ :param seq_len:
+ Number of positons to initially precompute.
+ :param base:
+ The base used for Θ_i, determines the cycle length of the
+ embeddings.
+ :param device: Device on which the module is to be initialized.
+ """
+ super().__init__()
+
+ if width % 2:
+ raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
+
+ # Ignore allocations on the meta device as we don't persist our buffer,
+ # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
+ if device is not None and device.type == "meta":
+ device = None
+ # Θ_i = 10000^(-2(i-1)/d)
+ theta = torch.pow(
+ base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
+ )
+ self.register_buffer("theta", theta, persistent=False)
+
+ self._create_rotary_embed(width=width, length=seq_len)
+
+ def _create_rotary_embed(self, *, width: int, length: int):
+ # mΘ
+ position = torch.arange(length, device=self.theta.device).unsqueeze(1)
+ m_theta = position * self.theta.unsqueeze(0)
+
+ # We apply both sin and cos twice (see Eq 15, 34), but the ordering
+ # is changed for compatibility with most common implementations.
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
+
+ re_cos = m_theta.cos().view([length, width]).half()
+ re_sin = m_theta.sin().view([length, width]).half()
+
+ self.register_buffer("cos", re_cos, persistent=False)
+ self.register_buffer("sin", re_sin, persistent=False)
+
+ def _rotate(self, input: torch.Tensor):
+ """Rotate the input tensor by half of its innermost width.
+
+ input (Tensor): array to rotate.
+ RETURNS (Tensor): rotated array.
+
+ Shapes:
+ input - (..., width)
+ output - (..., width)
+ """
+ half_idx = input.shape[-1] // 2
+ input_1 = -input[..., half_idx:]
+ input_2 = input[..., :half_idx]
+ return torch.cat([input_1, input_2], dim=-1)
+
+ def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
+ """
+ Apply rotary embeddings to an array.
+
+ :param input: Array to apply the rotary embeddings to.
+ :param positions: positions of the inputs. If no positions are
+ provided, they are assumed to be [0, seq_len).
+ :return: Array with the rotary embeddings applied.
+
+ Shapes:
+ input - (batch_size, num_heads, seq_len, width_per_head)
+ positions - (batch_size, seq_len)
+ output - (batch_size, num_heads, seq_len, width_per_head)
+ """
+ batch_size, _, seq_len, width = input.shape
+
+ if positions is None:
+ # Fastpath: positions from [0..seq_len), avoid indexing.
+ if self.cos.size(-2) < seq_len:
+ self._create_rotary_embed(width=width, length=seq_len)
+ rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
+ rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
+ else:
+ max_len = int(positions.max()) + 1
+ if self.cos.size(-2) < max_len:
+ self._create_rotary_embed(width=width, length=max_len)
+
+ # Flatten positions to index cos/sin arrays, then unflatten.
+ #
+ # Example shapes:
+ #
+ # positions_flat - (batch_size * seq_len)
+ # self.cos - (max_len, width)
+ # rot_cos - (batch_size, seq_len, width)
+ positions_flat = positions.view(-1)
+ rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
+ rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
+
+ # Eq 34 with ordering changed for compatibility.
+ return rot_cos * input + rot_sin * self._rotate(input)
+
+
+class LayerNorm(nn.Module):
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
+
+ def __init__(self, ndim, bias=False):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(ndim))
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
+
+ def forward(self, input):
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
+
+
+class CausalSelfAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, dropout=0.):
+ super().__init__()
+ # Typically, bias = True in Linears and LayerNorms, like GPT-2. But we set bias = False: a bit better and faster (following https://github.com/karpathy/nanoGPT)
+ assert embed_dim % num_heads == 0
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+ # key, query, value projections for all heads, but in a batch
+ self.c_attn = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
+ # output projection
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
+ # rotary embeddings
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
+ if not self.flash:
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
+
+ def forward(
+ self,
+ query, key, value,
+ spk_pos_ids_flat=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ else:
+ saved_state = None
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ q, k, v = self.c_attn(query).split(self.embed_dim, dim=2)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ # Apply rot embedding and store incremental_state
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
+ bsz, self.num_heads, -1, self.head_dim)
+ self._set_input_buffer(incremental_state, saved_state)
+ if incremental_state is not None:
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
+ else:
+ key_pos = spk_pos_ids_flat
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
+
+ src_len = k.size(1)
+
+ # Start Attention
+ if self.flash:
+ # efficient attention using Flash Attention CUDA kernels
+ attn = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=0,
+ is_causal=False)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+
+ # Flash Attn 2
+ # from flash_attn import flash_attn_func
+ # q, k, v = q.transpose(0, 1)[None, :], k.transpose(0, 1)[None, :], v.transpose(0, 1)[None, :]
+ # attn = flash_attn_func(q, k, v, dropout_p=0.0, causal=False)[0].contiguous().view(tgt_len, bsz, embed_dim)
+
+ attn = self.out_proj(attn)
+ attn_logits = None
+ else:
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2, bias=False)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size, bias=False)
+ )
+ self.ffn_2 = nn.Linear(filter_size, hidden_size, bias=False)
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ T_inp = x.shape[0]
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-T_inp:]
+ # if self.act == 'gelu':
+ # x = F.gelu(x)
+ # if self.act == 'relu':
+ # x = F.relu(x)
+ x = F.silu(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class GPTBlock(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False, norm_cls=LayerNorm):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = norm_cls(c)
+ self.self_attn = CausalSelfAttention(
+ c, num_heads, dropout=attention_dropout
+ )
+ self.layer_norm2 = norm_cls(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+ self.post_ln = post_ln
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ attn_out=None,
+ spk_pos_ids_flat=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm1(x)
+
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask,
+ spk_pos_ids_flat=spk_pos_ids_flat,
+ need_weights=False
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm1(x)
+
+ attn_logits = None
+
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm2(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm2(x)
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class GPTLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
+ lm_num_layers=10, norm_cls=LayerNorm):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = GPTBlock(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln, norm_cls=norm_cls)
+
+ # init all weights
+ self.apply(self._init_weights)
+ # apply special scaled init to the residual projections, per GPT-2 paper
+ for pn, p in self.named_parameters():
+ if pn.endswith('ffn_2.weight') or pn.endswith('out_proj.weight'):
+ torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * lm_num_layers))
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+ if module.bias is not None:
+ torch.nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Embedding):
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
+
+ @torch.autocast(device_type='cuda')
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
diff --git a/modules/commons/improved_diffusion/__init__.py b/modules/commons/improved_diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9665a0d63f695eab303318d824dad14041c7cde9
--- /dev/null
+++ b/modules/commons/improved_diffusion/__init__.py
@@ -0,0 +1,3 @@
+"""
+Codebase for "Improved Denoising Diffusion Probabilistic Models".
+"""
diff --git a/modules/commons/improved_diffusion/dist_util.py b/modules/commons/improved_diffusion/dist_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f665604d6baaf5df6008f131c86cf0779c8b208a
--- /dev/null
+++ b/modules/commons/improved_diffusion/dist_util.py
@@ -0,0 +1,82 @@
+"""
+Helpers for distributed training.
+"""
+
+import io
+import os
+import socket
+
+import blobfile as bf
+from mpi4py import MPI
+import torch as th
+import torch.distributed as dist
+
+# Change this to reflect your cluster layout.
+# The GPU for a given rank is (rank % GPUS_PER_NODE).
+GPUS_PER_NODE = 8
+
+SETUP_RETRY_COUNT = 3
+
+
+def setup_dist():
+ """
+ Setup a distributed process group.
+ """
+ if dist.is_initialized():
+ return
+
+ comm = MPI.COMM_WORLD
+ backend = "gloo" if not th.cuda.is_available() else "nccl"
+
+ if backend == "gloo":
+ hostname = "localhost"
+ else:
+ hostname = socket.gethostbyname(socket.getfqdn())
+ os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
+ os.environ["RANK"] = str(comm.rank)
+ os.environ["WORLD_SIZE"] = str(comm.size)
+
+ port = comm.bcast(_find_free_port(), root=0)
+ os.environ["MASTER_PORT"] = str(port)
+ dist.init_process_group(backend=backend, init_method="env://")
+
+
+def dev():
+ """
+ Get the device to use for torch.distributed.
+ """
+ if th.cuda.is_available():
+ return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}")
+ return th.device("cpu")
+
+
+def load_state_dict(path, **kwargs):
+ """
+ Load a PyTorch file without redundant fetches across MPI ranks.
+ """
+ if MPI.COMM_WORLD.Get_rank() == 0:
+ with bf.BlobFile(path, "rb") as f:
+ data = f.read()
+ else:
+ data = None
+ data = MPI.COMM_WORLD.bcast(data)
+ return th.load(io.BytesIO(data), **kwargs)
+
+
+def sync_params(params):
+ """
+ Synchronize a sequence of Tensors across ranks from rank 0.
+ """
+ for p in params:
+ with th.no_grad():
+ dist.broadcast(p, 0)
+
+
+def _find_free_port():
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ s.bind(("", 0))
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ return s.getsockname()[1]
+ finally:
+ s.close()
diff --git a/modules/commons/improved_diffusion/fp16_util.py b/modules/commons/improved_diffusion/fp16_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e0418153143200a718f56077b3360f30f4c663
--- /dev/null
+++ b/modules/commons/improved_diffusion/fp16_util.py
@@ -0,0 +1,76 @@
+"""
+Helpers to train with 16-bit precision.
+"""
+
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
+
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.half()
+ l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.float()
+ l.bias.data = l.bias.data.float()
+
+
+def make_master_params(model_params):
+ """
+ Copy model parameters into a (differently-shaped) list of full-precision
+ parameters.
+ """
+ master_params = _flatten_dense_tensors(
+ [param.detach().float() for param in model_params]
+ )
+ master_params = nn.Parameter(master_params)
+ master_params.requires_grad = True
+ return [master_params]
+
+
+def model_grads_to_master_grads(model_params, master_params):
+ """
+ Copy the gradients from the model parameters into the master parameters
+ from make_master_params().
+ """
+ master_params[0].grad = _flatten_dense_tensors(
+ [param.grad.data.detach().float() for param in model_params]
+ )
+
+
+def master_params_to_model_params(model_params, master_params):
+ """
+ Copy the master parameter data back into the model parameters.
+ """
+ # Without copying to a list, if a generator is passed, this will
+ # silently not copy any parameters.
+ model_params = list(model_params)
+
+ for param, master_param in zip(
+ model_params, unflatten_master_params(model_params, master_params)
+ ):
+ param.detach().copy_(master_param)
+
+
+def unflatten_master_params(model_params, master_params):
+ """
+ Unflatten the master parameters to look like model_params.
+ """
+ return _unflatten_dense_tensors(master_params[0].detach(), model_params)
+
+
+def zero_grad(model_params):
+ for param in model_params:
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
+ if param.grad is not None:
+ param.grad.detach_()
+ param.grad.zero_()
diff --git a/modules/commons/improved_diffusion/gaussian_diffusion.py b/modules/commons/improved_diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e76eafab7a071e14b92821dbe0d8fd4382bdccd
--- /dev/null
+++ b/modules/commons/improved_diffusion/gaussian_diffusion.py
@@ -0,0 +1,870 @@
+"""
+This code started out as a PyTorch port of Ho et al's diffusion models:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
+
+Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
+"""
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+from .nn import mean_flat
+from .losses import normal_kl, discretized_gaussian_log_likelihood
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif schedule_name == "cosine":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+
+ Ported directly from here, and then adapted over time to further experimentation.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
+ :param model_var_type: a ModelVarType determining how variance is output.
+ :param loss_type: a LossType determining the loss function to use.
+ :param rescale_timesteps: if True, pass floating point timesteps into the
+ model so that they are always scaled like in the
+ original paper (0 to 1000).
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ rescale_timesteps=False,
+ ):
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+ self.rescale_timesteps = rescale_timesteps
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # log calculation clipped because the posterior variance is 0 at the
+ # beginning of the diffusion chain.
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ )
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev)
+ * np.sqrt(alphas)
+ / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+
+ In other words, sample from q(x_t | x_0).
+
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+
+ q(x_{t-1} | x_t, x_0)
+
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ )
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
+ assert x_t.shape == xprev.shape
+ return ( # (xprev - coef2*x_t) / coef1
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
+ - _extract_into_tensor(
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
+ )
+ * x_t
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _scale_timesteps(self, t):
+ if self.rescale_timesteps:
+ return t.float() * (1000.0 / self.num_timesteps)
+ return t
+
+ def p_sample(
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ + th.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ if 'sample_merge' in final:
+ return final["sample_merge"]
+ else:
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ # mask = model_kwargs['mask']
+ # img = out["sample"] * mask
+ # if model_kwargs.get('replace_val') is not None:
+ # replace_idx = model_kwargs['replace_idx']
+ # replace_val = model_kwargs['replace_val']
+ # x_t = self.q_sample(replace_val, t - 1) if t > 0 else replace_val
+ # B, T = img.shape[:2]
+ # img = img.reshape(B, T, -1, 3)
+ # img[:, :, replace_idx] = x_t[:, :, replace_idx]
+ # out["sample"] = img = img.flatten(2)
+ # if 'frames_inp' in model_kwargs:
+ # x_t = self.q_sample(model_kwargs['frames_inp'], t - 1) \
+ # if t > 0 else model_kwargs['frames_inp']
+ # img = img * mask + x_t * (1 - mask)
+ # out['sample_merge'] = img
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, C = x_t.shape[:2]
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+
+ # mask = model_kwargs['mask']
+ # if mask.shape != x_start.shape:
+ # mask = mask.expand_as(x_start)
+ # mask = mask.flatten(2)
+ #
+ # terms["mse"] = (target - model_output) ** 2
+ # terms["mse"] = terms["mse"].flatten(2)
+ # terms["mse"] = (terms["mse"] * mask).sum(-1) / mask.sum(-1)
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+ # print(">>>", (target - model_output).abs().mean())
+
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+
+ This term can't be optimized, as it only depends on the encoder.
+
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res.expand(broadcast_shape)
diff --git a/modules/commons/improved_diffusion/image_datasets.py b/modules/commons/improved_diffusion/image_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e49d2394622e5b7ea988e4afe9fef117dedf6a9
--- /dev/null
+++ b/modules/commons/improved_diffusion/image_datasets.py
@@ -0,0 +1,106 @@
+from PIL import Image
+import blobfile as bf
+from mpi4py import MPI
+import numpy as np
+from torch.utils.data import DataLoader, Dataset
+
+
+def load_data(
+ *, data_dir, batch_size, image_size, class_cond=False, deterministic=False
+):
+ """
+ For a dataset, create a generator over (images, kwargs) pairs.
+
+ Each images is an NCHW float tensor, and the kwargs dict contains zero or
+ more keys, each of which map to a batched Tensor of their own.
+ The kwargs dict can be used for class labels, in which case the key is "y"
+ and the values are integer tensors of class labels.
+
+ :param data_dir: a dataset directory.
+ :param batch_size: the batch size of each returned pair.
+ :param image_size: the size to which images are resized.
+ :param class_cond: if True, include a "y" key in returned dicts for class
+ label. If classes are not available and this is true, an
+ exception will be raised.
+ :param deterministic: if True, yield results in a deterministic order.
+ """
+ if not data_dir:
+ raise ValueError("unspecified data directory")
+ all_files = _list_image_files_recursively(data_dir)
+ classes = None
+ if class_cond:
+ # Assume classes are the first part of the filename,
+ # before an underscore.
+ class_names = [bf.basename(path).split("_")[0] for path in all_files]
+ sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
+ classes = [sorted_classes[x] for x in class_names]
+ dataset = ImageDataset(
+ image_size,
+ all_files,
+ classes=classes,
+ shard=MPI.COMM_WORLD.Get_rank(),
+ num_shards=MPI.COMM_WORLD.Get_size(),
+ )
+ if deterministic:
+ loader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
+ )
+ else:
+ loader = DataLoader(
+ dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
+ )
+ while True:
+ yield from loader
+
+
+def _list_image_files_recursively(data_dir):
+ results = []
+ for entry in sorted(bf.listdir(data_dir)):
+ full_path = bf.join(data_dir, entry)
+ ext = entry.split(".")[-1]
+ if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
+ results.append(full_path)
+ elif bf.isdir(full_path):
+ results.extend(_list_image_files_recursively(full_path))
+ return results
+
+
+class ImageDataset(Dataset):
+ def __init__(self, resolution, image_paths, classes=None, shard=0, num_shards=1):
+ super().__init__()
+ self.resolution = resolution
+ self.local_images = image_paths[shard:][::num_shards]
+ self.local_classes = None if classes is None else classes[shard:][::num_shards]
+
+ def __len__(self):
+ return len(self.local_images)
+
+ def __getitem__(self, idx):
+ path = self.local_images[idx]
+ with bf.BlobFile(path, "rb") as f:
+ pil_image = Image.open(f)
+ pil_image.load()
+
+ # We are not on a new enough PIL to support the `reducing_gap`
+ # argument, which uses BOX downsampling at powers of two first.
+ # Thus, we do it by hand to improve downsample quality.
+ while min(*pil_image.size) >= 2 * self.resolution:
+ pil_image = pil_image.resize(
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
+ )
+
+ scale = self.resolution / min(*pil_image.size)
+ pil_image = pil_image.resize(
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
+ )
+
+ arr = np.array(pil_image.convert("RGB"))
+ crop_y = (arr.shape[0] - self.resolution) // 2
+ crop_x = (arr.shape[1] - self.resolution) // 2
+ arr = arr[crop_y : crop_y + self.resolution, crop_x : crop_x + self.resolution]
+ arr = arr.astype(np.float32) / 127.5 - 1
+
+ out_dict = {}
+ if self.local_classes is not None:
+ out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
+ return np.transpose(arr, [2, 0, 1]), out_dict
diff --git a/modules/commons/improved_diffusion/logger.py b/modules/commons/improved_diffusion/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d856dcfea6b56a2ee8d37b286887430dbfac30
--- /dev/null
+++ b/modules/commons/improved_diffusion/logger.py
@@ -0,0 +1,495 @@
+"""
+Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
+https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
+"""
+
+import os
+import sys
+import shutil
+import os.path as osp
+import json
+import time
+import datetime
+import tempfile
+import warnings
+from collections import defaultdict
+from contextlib import contextmanager
+
+DEBUG = 10
+INFO = 20
+WARN = 30
+ERROR = 40
+
+DISABLED = 50
+
+
+class KVWriter(object):
+ def writekvs(self, kvs):
+ raise NotImplementedError
+
+
+class SeqWriter(object):
+ def writeseq(self, seq):
+ raise NotImplementedError
+
+
+class HumanOutputFormat(KVWriter, SeqWriter):
+ def __init__(self, filename_or_file):
+ if isinstance(filename_or_file, str):
+ self.file = open(filename_or_file, "wt")
+ self.own_file = True
+ else:
+ assert hasattr(filename_or_file, "read"), (
+ "expected file or str, got %s" % filename_or_file
+ )
+ self.file = filename_or_file
+ self.own_file = False
+
+ def writekvs(self, kvs):
+ # Create strings for printing
+ key2str = {}
+ for (key, val) in sorted(kvs.items()):
+ if hasattr(val, "__float__"):
+ valstr = "%-8.3g" % val
+ else:
+ valstr = str(val)
+ key2str[self._truncate(key)] = self._truncate(valstr)
+
+ # Find max widths
+ if len(key2str) == 0:
+ print("WARNING: tried to write empty key-value dict")
+ return
+ else:
+ keywidth = max(map(len, key2str.keys()))
+ valwidth = max(map(len, key2str.values()))
+
+ # Write out the data
+ dashes = "-" * (keywidth + valwidth + 7)
+ lines = [dashes]
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
+ lines.append(
+ "| %s%s | %s%s |"
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
+ )
+ lines.append(dashes)
+ self.file.write("\n".join(lines) + "\n")
+
+ # Flush the output to the file
+ self.file.flush()
+
+ def _truncate(self, s):
+ maxlen = 30
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
+
+ def writeseq(self, seq):
+ seq = list(seq)
+ for (i, elem) in enumerate(seq):
+ self.file.write(elem)
+ if i < len(seq) - 1: # add space unless this is the last one
+ self.file.write(" ")
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ if self.own_file:
+ self.file.close()
+
+
+class JSONOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "wt")
+
+ def writekvs(self, kvs):
+ for k, v in sorted(kvs.items()):
+ if hasattr(v, "dtype"):
+ kvs[k] = float(v)
+ self.file.write(json.dumps(kvs) + "\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class CSVOutputFormat(KVWriter):
+ def __init__(self, filename):
+ self.file = open(filename, "w+t")
+ self.keys = []
+ self.sep = ","
+
+ def writekvs(self, kvs):
+ # Add our current row to the history
+ extra_keys = list(kvs.keys() - self.keys)
+ extra_keys.sort()
+ if extra_keys:
+ self.keys.extend(extra_keys)
+ self.file.seek(0)
+ lines = self.file.readlines()
+ self.file.seek(0)
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ self.file.write(k)
+ self.file.write("\n")
+ for line in lines[1:]:
+ self.file.write(line[:-1])
+ self.file.write(self.sep * len(extra_keys))
+ self.file.write("\n")
+ for (i, k) in enumerate(self.keys):
+ if i > 0:
+ self.file.write(",")
+ v = kvs.get(k)
+ if v is not None:
+ self.file.write(str(v))
+ self.file.write("\n")
+ self.file.flush()
+
+ def close(self):
+ self.file.close()
+
+
+class TensorBoardOutputFormat(KVWriter):
+ """
+ Dumps key/value pairs into TensorBoard's numeric format.
+ """
+
+ def __init__(self, dir):
+ os.makedirs(dir, exist_ok=True)
+ self.dir = dir
+ self.step = 1
+ prefix = "events"
+ path = osp.join(osp.abspath(dir), prefix)
+ import tensorflow as tf
+ from tensorflow.python import pywrap_tensorflow
+ from tensorflow.core.util import event_pb2
+ from tensorflow.python.util import compat
+
+ self.tf = tf
+ self.event_pb2 = event_pb2
+ self.pywrap_tensorflow = pywrap_tensorflow
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
+
+ def writekvs(self, kvs):
+ def summary_val(k, v):
+ kwargs = {"tag": k, "simple_value": float(v)}
+ return self.tf.Summary.Value(**kwargs)
+
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
+ event.step = (
+ self.step
+ ) # is there any reason why you'd want to specify the step?
+ self.writer.WriteEvent(event)
+ self.writer.Flush()
+ self.step += 1
+
+ def close(self):
+ if self.writer:
+ self.writer.Close()
+ self.writer = None
+
+
+def make_output_format(format, ev_dir, log_suffix=""):
+ os.makedirs(ev_dir, exist_ok=True)
+ if format == "stdout":
+ return HumanOutputFormat(sys.stdout)
+ elif format == "log":
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
+ elif format == "json":
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
+ elif format == "csv":
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
+ elif format == "tensorboard":
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
+ else:
+ raise ValueError("Unknown format specified: %s" % (format,))
+
+
+# ================================================================
+# API
+# ================================================================
+
+
+def logkv(key, val):
+ """
+ Log a value of some diagnostic
+ Call this once for each diagnostic quantity, each iteration
+ If called many times, last value will be used.
+ """
+ get_current().logkv(key, val)
+
+
+def logkv_mean(key, val):
+ """
+ The same as logkv(), but if called many times, values averaged.
+ """
+ get_current().logkv_mean(key, val)
+
+
+def logkvs(d):
+ """
+ Log a dictionary of key-value pairs
+ """
+ for (k, v) in d.items():
+ logkv(k, v)
+
+
+def dumpkvs():
+ """
+ Write all of the diagnostics from the current iteration
+ """
+ return get_current().dumpkvs()
+
+
+def getkvs():
+ return get_current().name2val
+
+
+def log(*args, level=INFO):
+ """
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
+ """
+ get_current().log(*args, level=level)
+
+
+def debug(*args):
+ log(*args, level=DEBUG)
+
+
+def info(*args):
+ log(*args, level=INFO)
+
+
+def warn(*args):
+ log(*args, level=WARN)
+
+
+def error(*args):
+ log(*args, level=ERROR)
+
+
+def set_level(level):
+ """
+ Set logging threshold on current logger.
+ """
+ get_current().set_level(level)
+
+
+def set_comm(comm):
+ get_current().set_comm(comm)
+
+
+def get_dir():
+ """
+ Get directory that log files are being written to.
+ will be None if there is no output directory (i.e., if you didn't call start)
+ """
+ return get_current().get_dir()
+
+
+record_tabular = logkv
+dump_tabular = dumpkvs
+
+
+@contextmanager
+def profile_kv(scopename):
+ logkey = "wait_" + scopename
+ tstart = time.time()
+ try:
+ yield
+ finally:
+ get_current().name2val[logkey] += time.time() - tstart
+
+
+def profile(n):
+ """
+ Usage:
+ @profile("my_func")
+ def my_func(): code
+ """
+
+ def decorator_with_name(func):
+ def func_wrapper(*args, **kwargs):
+ with profile_kv(n):
+ return func(*args, **kwargs)
+
+ return func_wrapper
+
+ return decorator_with_name
+
+
+# ================================================================
+# Backend
+# ================================================================
+
+
+def get_current():
+ if Logger.CURRENT is None:
+ _configure_default_logger()
+
+ return Logger.CURRENT
+
+
+class Logger(object):
+ DEFAULT = None # A logger with no output files. (See right below class definition)
+ # So that you can still log to the terminal without setting up any output files
+ CURRENT = None # Current logger being used by the free functions above
+
+ def __init__(self, dir, output_formats, comm=None):
+ self.name2val = defaultdict(float) # values this iteration
+ self.name2cnt = defaultdict(int)
+ self.level = INFO
+ self.dir = dir
+ self.output_formats = output_formats
+ self.comm = comm
+
+ # Logging API, forwarded
+ # ----------------------------------------
+ def logkv(self, key, val):
+ self.name2val[key] = val
+
+ def logkv_mean(self, key, val):
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
+ self.name2cnt[key] = cnt + 1
+
+ def dumpkvs(self):
+ if self.comm is None:
+ d = self.name2val
+ else:
+ d = mpi_weighted_mean(
+ self.comm,
+ {
+ name: (val, self.name2cnt.get(name, 1))
+ for (name, val) in self.name2val.items()
+ },
+ )
+ if self.comm.rank != 0:
+ d["dummy"] = 1 # so we don't get a warning about empty dict
+ out = d.copy() # Return the dict for unit testing purposes
+ for fmt in self.output_formats:
+ if isinstance(fmt, KVWriter):
+ fmt.writekvs(d)
+ self.name2val.clear()
+ self.name2cnt.clear()
+ return out
+
+ def log(self, *args, level=INFO):
+ if self.level <= level:
+ self._do_log(args)
+
+ # Configuration
+ # ----------------------------------------
+ def set_level(self, level):
+ self.level = level
+
+ def set_comm(self, comm):
+ self.comm = comm
+
+ def get_dir(self):
+ return self.dir
+
+ def close(self):
+ for fmt in self.output_formats:
+ fmt.close()
+
+ # Misc
+ # ----------------------------------------
+ def _do_log(self, args):
+ for fmt in self.output_formats:
+ if isinstance(fmt, SeqWriter):
+ fmt.writeseq(map(str, args))
+
+
+def get_rank_without_mpi_import():
+ # check environment variables here instead of importing mpi4py
+ # to avoid calling MPI_Init() when this module is imported
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
+ if varname in os.environ:
+ return int(os.environ[varname])
+ return 0
+
+
+def mpi_weighted_mean(comm, local_name2valcount):
+ """
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
+ Perform a weighted average over dicts that are each on a different node
+ Input: local_name2valcount: dict mapping key -> (value, count)
+ Returns: key -> mean
+ """
+ all_name2valcount = comm.gather(local_name2valcount)
+ if comm.rank == 0:
+ name2sum = defaultdict(float)
+ name2count = defaultdict(float)
+ for n2vc in all_name2valcount:
+ for (name, (val, count)) in n2vc.items():
+ try:
+ val = float(val)
+ except ValueError:
+ if comm.rank == 0:
+ warnings.warn(
+ "WARNING: tried to compute mean on non-float {}={}".format(
+ name, val
+ )
+ )
+ else:
+ name2sum[name] += val * count
+ name2count[name] += count
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
+ else:
+ return {}
+
+
+def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
+ """
+ If comm is provided, average all numerical stats across that comm
+ """
+ if dir is None:
+ dir = os.getenv("OPENAI_LOGDIR")
+ if dir is None:
+ dir = osp.join(
+ tempfile.gettempdir(),
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
+ )
+ assert isinstance(dir, str)
+ dir = os.path.expanduser(dir)
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
+
+ rank = get_rank_without_mpi_import()
+ if rank > 0:
+ log_suffix = log_suffix + "-rank%03i" % rank
+
+ if format_strs is None:
+ if rank == 0:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
+ else:
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
+ format_strs = filter(None, format_strs)
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
+
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
+ if output_formats:
+ log("Logging to %s" % dir)
+
+
+def _configure_default_logger():
+ configure()
+ Logger.DEFAULT = Logger.CURRENT
+
+
+def reset():
+ if Logger.CURRENT is not Logger.DEFAULT:
+ Logger.CURRENT.close()
+ Logger.CURRENT = Logger.DEFAULT
+ log("Reset logger")
+
+
+@contextmanager
+def scoped_configure(dir=None, format_strs=None, comm=None):
+ prevlogger = Logger.CURRENT
+ configure(dir=dir, format_strs=format_strs, comm=comm)
+ try:
+ yield
+ finally:
+ Logger.CURRENT.close()
+ Logger.CURRENT = prevlogger
+
diff --git a/modules/commons/improved_diffusion/losses.py b/modules/commons/improved_diffusion/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..251e42e4f36a31bb5e1aeda874b3a45d722000a2
--- /dev/null
+++ b/modules/commons/improved_diffusion/losses.py
@@ -0,0 +1,77 @@
+"""
+Helpers for various likelihood-based losses. These are ported from the original
+Ho et al. diffusion models codebase:
+https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
+"""
+
+import numpy as np
+
+import torch as th
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/modules/commons/improved_diffusion/nn.py b/modules/commons/improved_diffusion/nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4cd59c2324b003626b8cf4c7581effd334908d3
--- /dev/null
+++ b/modules/commons/improved_diffusion/nn.py
@@ -0,0 +1,170 @@
+"""
+Various utilities for neural networks.
+"""
+
+import math
+
+import torch as th
+import torch.nn as nn
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * th.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def update_ema(target_params, source_params, rate=0.99):
+ """
+ Update target parameters to be closer to those of source parameters using
+ an exponential moving average.
+
+ :param target_params: the target parameter sequence.
+ :param source_params: the source parameter sequence.
+ :param rate: the EMA rate (closer to 1 means slower).
+ """
+ for targ, src in zip(target_params, source_params):
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(th.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ with th.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with th.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = th.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
diff --git a/modules/commons/improved_diffusion/resample.py b/modules/commons/improved_diffusion/resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..c82eccdcd47c468d41e7cbe02de6a731f2c9bf81
--- /dev/null
+++ b/modules/commons/improved_diffusion/resample.py
@@ -0,0 +1,154 @@
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/modules/commons/improved_diffusion/respace.py b/modules/commons/improved_diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..045d58df956e6ddb04216e972bffff47c59bf488
--- /dev/null
+++ b/modules/commons/improved_diffusion/respace.py
@@ -0,0 +1,122 @@
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ if self.rescale_timesteps:
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/modules/commons/improved_diffusion/train_util.py b/modules/commons/improved_diffusion/train_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1867604145736352dc51ab05b6caae8b541a6ebb
--- /dev/null
+++ b/modules/commons/improved_diffusion/train_util.py
@@ -0,0 +1,356 @@
+import copy
+import functools
+import os
+
+import blobfile as bf
+import numpy as np
+import torch as th
+import torch.distributed as dist
+from torch.nn.parallel.distributed import DistributedDataParallel as DDP
+from torch.optim import AdamW
+
+from . import dist_util, logger
+from .fp16_util import (
+ make_master_params,
+ master_params_to_model_params,
+ model_grads_to_master_grads,
+ unflatten_master_params,
+ zero_grad,
+)
+from .nn import update_ema
+from .resample import LossAwareSampler, UniformSampler
+
+# For ImageNet experiments, this was a good default value.
+# We found that the lg_loss_scale quickly climbed to
+# 20-21 within the first ~1K steps of training.
+INITIAL_LOG_LOSS_SCALE = 20.0
+
+
+class TrainLoop:
+ def __init__(
+ self,
+ *,
+ model,
+ diffusion,
+ data,
+ batch_size,
+ microbatch,
+ lr,
+ ema_rate,
+ log_interval,
+ save_interval,
+ resume_checkpoint,
+ use_fp16=False,
+ fp16_scale_growth=1e-3,
+ schedule_sampler=None,
+ weight_decay=0.0,
+ lr_anneal_steps=0,
+ ):
+ self.model = model
+ self.diffusion = diffusion
+ self.data = data
+ self.batch_size = batch_size
+ self.microbatch = microbatch if microbatch > 0 else batch_size
+ self.lr = lr
+ self.ema_rate = (
+ [ema_rate]
+ if isinstance(ema_rate, float)
+ else [float(x) for x in ema_rate.split(",")]
+ )
+ self.log_interval = log_interval
+ self.save_interval = save_interval
+ self.resume_checkpoint = resume_checkpoint
+ self.use_fp16 = use_fp16
+ self.fp16_scale_growth = fp16_scale_growth
+ self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
+ self.weight_decay = weight_decay
+ self.lr_anneal_steps = lr_anneal_steps
+
+ self.step = 0
+ self.resume_step = 0
+ self.global_batch = self.batch_size * dist.get_world_size()
+
+ self.model_params = list(self.model.parameters())
+ self.master_params = self.model_params
+ self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
+ self.sync_cuda = th.cuda.is_available()
+
+ self._load_and_sync_parameters()
+ if self.use_fp16:
+ self._setup_fp16()
+
+ self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)
+ if self.resume_step:
+ self._load_optimizer_state()
+ # Model was resumed, either due to a restart or a checkpoint
+ # being specified at the command line.
+ self.ema_params = [
+ self._load_ema_parameters(rate) for rate in self.ema_rate
+ ]
+ else:
+ self.ema_params = [
+ copy.deepcopy(self.master_params) for _ in range(len(self.ema_rate))
+ ]
+
+ if th.cuda.is_available():
+ self.use_ddp = True
+ self.ddp_model = DDP(
+ self.model,
+ device_ids=[dist_util.dev()],
+ output_device=dist_util.dev(),
+ broadcast_buffers=False,
+ bucket_cap_mb=128,
+ find_unused_parameters=False,
+ )
+ else:
+ if dist.get_world_size() > 1:
+ logger.warn(
+ "Distributed training requires CUDA. "
+ "Gradients will not be synchronized properly!"
+ )
+ self.use_ddp = False
+ self.ddp_model = self.model
+
+ def _load_and_sync_parameters(self):
+ resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
+
+ if resume_checkpoint:
+ self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
+ if dist.get_rank() == 0:
+ logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
+ self.model.load_state_dict(
+ dist_util.load_state_dict(
+ resume_checkpoint, map_location=dist_util.dev()
+ )
+ )
+
+ dist_util.sync_params(self.model.parameters())
+
+ def _load_ema_parameters(self, rate):
+ ema_params = copy.deepcopy(self.master_params)
+
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
+ ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
+ if ema_checkpoint:
+ if dist.get_rank() == 0:
+ logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
+ state_dict = dist_util.load_state_dict(
+ ema_checkpoint, map_location=dist_util.dev()
+ )
+ ema_params = self._state_dict_to_master_params(state_dict)
+
+ dist_util.sync_params(ema_params)
+ return ema_params
+
+ def _load_optimizer_state(self):
+ main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
+ opt_checkpoint = bf.join(
+ bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
+ )
+ if bf.exists(opt_checkpoint):
+ logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
+ state_dict = dist_util.load_state_dict(
+ opt_checkpoint, map_location=dist_util.dev()
+ )
+ self.opt.load_state_dict(state_dict)
+
+ def _setup_fp16(self):
+ self.master_params = make_master_params(self.model_params)
+ self.model.convert_to_fp16()
+
+ def run_loop(self):
+ while (
+ not self.lr_anneal_steps
+ or self.step + self.resume_step < self.lr_anneal_steps
+ ):
+ batch, cond = next(self.data)
+ self.run_step(batch, cond)
+ if self.step % self.log_interval == 0:
+ logger.dumpkvs()
+ if self.step % self.save_interval == 0:
+ self.save()
+ # Run for a finite amount of time in integration tests.
+ if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
+ return
+ self.step += 1
+ # Save the last checkpoint if it wasn't already saved.
+ if (self.step - 1) % self.save_interval != 0:
+ self.save()
+
+ def run_step(self, batch, cond):
+ self.forward_backward(batch, cond)
+ if self.use_fp16:
+ self.optimize_fp16()
+ else:
+ self.optimize_normal()
+ self.log_step()
+
+ def forward_backward(self, batch, cond):
+ zero_grad(self.model_params)
+ for i in range(0, batch.shape[0], self.microbatch):
+ micro = batch[i : i + self.microbatch].to(dist_util.dev())
+ micro_cond = {
+ k: v[i : i + self.microbatch].to(dist_util.dev())
+ for k, v in cond.items()
+ }
+ last_batch = (i + self.microbatch) >= batch.shape[0]
+ t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
+
+ compute_losses = functools.partial(
+ self.diffusion.training_losses,
+ self.ddp_model,
+ micro,
+ t,
+ model_kwargs=micro_cond,
+ )
+
+ if last_batch or not self.use_ddp:
+ losses = compute_losses()
+ else:
+ with self.ddp_model.no_sync():
+ losses = compute_losses()
+
+ if isinstance(self.schedule_sampler, LossAwareSampler):
+ self.schedule_sampler.update_with_local_losses(
+ t, losses["loss"].detach()
+ )
+
+ loss = (losses["loss"] * weights).mean()
+ log_loss_dict(
+ self.diffusion, t, {k: v * weights for k, v in losses.items()}
+ )
+ if self.use_fp16:
+ loss_scale = 2 ** self.lg_loss_scale
+ (loss * loss_scale).backward()
+ else:
+ loss.backward()
+
+ def optimize_fp16(self):
+ if any(not th.isfinite(p.grad).all() for p in self.model_params):
+ self.lg_loss_scale -= 1
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
+ return
+
+ model_grads_to_master_grads(self.model_params, self.master_params)
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
+ self._log_grad_norm()
+ self._anneal_lr()
+ self.opt.step()
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ update_ema(params, self.master_params, rate=rate)
+ master_params_to_model_params(self.model_params, self.master_params)
+ self.lg_loss_scale += self.fp16_scale_growth
+
+ def optimize_normal(self):
+ self._log_grad_norm()
+ self._anneal_lr()
+ self.opt.step()
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ update_ema(params, self.master_params, rate=rate)
+
+ def _log_grad_norm(self):
+ sqsum = 0.0
+ for p in self.master_params:
+ sqsum += (p.grad ** 2).sum().item()
+ logger.logkv_mean("grad_norm", np.sqrt(sqsum))
+
+ def _anneal_lr(self):
+ if not self.lr_anneal_steps:
+ return
+ frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
+ lr = self.lr * (1 - frac_done)
+ for param_group in self.opt.param_groups:
+ param_group["lr"] = lr
+
+ def log_step(self):
+ logger.logkv("step", self.step + self.resume_step)
+ logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
+ if self.use_fp16:
+ logger.logkv("lg_loss_scale", self.lg_loss_scale)
+
+ def save(self):
+ def save_checkpoint(rate, params):
+ state_dict = self._master_params_to_state_dict(params)
+ if dist.get_rank() == 0:
+ logger.log(f"saving model {rate}...")
+ if not rate:
+ filename = f"model{(self.step+self.resume_step):06d}.pt"
+ else:
+ filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
+ with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
+ th.save(state_dict, f)
+
+ save_checkpoint(0, self.master_params)
+ for rate, params in zip(self.ema_rate, self.ema_params):
+ save_checkpoint(rate, params)
+
+ if dist.get_rank() == 0:
+ with bf.BlobFile(
+ bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
+ "wb",
+ ) as f:
+ th.save(self.opt.state_dict(), f)
+
+ dist.barrier()
+
+ def _master_params_to_state_dict(self, master_params):
+ if self.use_fp16:
+ master_params = unflatten_master_params(
+ self.model.parameters(), master_params
+ )
+ state_dict = self.model.state_dict()
+ for i, (name, _value) in enumerate(self.model.named_parameters()):
+ assert name in state_dict
+ state_dict[name] = master_params[i]
+ return state_dict
+
+ def _state_dict_to_master_params(self, state_dict):
+ params = [state_dict[name] for name, _ in self.model.named_parameters()]
+ if self.use_fp16:
+ return make_master_params(params)
+ else:
+ return params
+
+
+def parse_resume_step_from_filename(filename):
+ """
+ Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
+ checkpoint's number of steps.
+ """
+ split = filename.split("model")
+ if len(split) < 2:
+ return 0
+ split1 = split[-1].split(".")[0]
+ try:
+ return int(split1)
+ except ValueError:
+ return 0
+
+
+def get_blob_logdir():
+ return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())
+
+
+def find_resume_checkpoint():
+ # On your infrastructure, you may want to override this to automatically
+ # discover the latest checkpoint on your blob storage, etc.
+ return None
+
+
+def find_ema_checkpoint(main_checkpoint, step, rate):
+ if main_checkpoint is None:
+ return None
+ filename = f"ema_{rate}_{(step):06d}.pt"
+ path = bf.join(bf.dirname(main_checkpoint), filename)
+ if bf.exists(path):
+ return path
+ return None
+
+
+def log_loss_dict(diffusion, ts, losses):
+ for key, values in losses.items():
+ logger.logkv_mean(key, values.mean().item())
+ # Log the quantiles (four quartiles, in particular).
+ for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
+ quartile = int(4 * sub_t / diffusion.num_timesteps)
+ logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
diff --git a/modules/commons/layers.py b/modules/commons/layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..88e1c75876050fa05a768a5ae0467fdfc05bb006
--- /dev/null
+++ b/modules/commons/layers.py
@@ -0,0 +1,50 @@
+import torch
+from torch import nn
+
+
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1, eps=1e-5):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=eps)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class Reshape(nn.Module):
+ def __init__(self, *args):
+ super(Reshape, self).__init__()
+ self.shape = args
+
+ def forward(self, x):
+ return x.view(self.shape)
+
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
diff --git a/modules/commons/normalizing_flow/glow_modules.py b/modules/commons/normalizing_flow/glow_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c589af0f2eba2b154317912f9ad01a4163b3fd6a
--- /dev/null
+++ b/modules/commons/normalizing_flow/glow_modules.py
@@ -0,0 +1,362 @@
+import scipy
+from torch.nn import functional as F
+import torch
+from torch import nn
+import numpy as np
+from modules.commons.wavenet import WN
+from modules.tts.glow import utils
+
+
+class ActNorm(nn.Module):
+ def __init__(self, channels, ddi=False, **kwargs):
+ super().__init__()
+ self.channels = channels
+ self.initialized = not ddi
+
+ self.logs = nn.Parameter(torch.zeros(1, channels, 1))
+ self.bias = nn.Parameter(torch.zeros(1, channels, 1))
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ if x_mask is None:
+ x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
+ x_len = torch.sum(x_mask, [1, 2])
+ if not self.initialized:
+ self.initialize(x, x_mask)
+ self.initialized = True
+
+ if reverse:
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
+ logdet = torch.sum(-self.logs) * x_len
+ else:
+ z = (self.bias + torch.exp(self.logs) * x) * x_mask
+ logdet = torch.sum(self.logs) * x_len # [b]
+ return z, logdet
+
+ def store_inverse(self):
+ pass
+
+ def set_ddi(self, ddi):
+ self.initialized = not ddi
+
+ def initialize(self, x, x_mask):
+ with torch.no_grad():
+ denom = torch.sum(x_mask, [0, 2])
+ m = torch.sum(x * x_mask, [0, 2]) / denom
+ m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
+ v = m_sq - (m ** 2)
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
+
+ bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
+ logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
+
+ self.bias.data.copy_(bias_init)
+ self.logs.data.copy_(logs_init)
+
+
+class InvConvNear(nn.Module):
+ def __init__(self, channels, n_split=4, no_jacobian=False, lu=True, n_sqz=2, **kwargs):
+ super().__init__()
+ assert (n_split % 2 == 0)
+ self.channels = channels
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.no_jacobian = no_jacobian
+
+ w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
+ if torch.det(w_init) < 0:
+ w_init[:, 0] = -1 * w_init[:, 0]
+ self.lu = lu
+ if lu:
+ # LU decomposition can slightly speed up the inverse
+ np_p, np_l, np_u = scipy.linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_init.shape, dtype=float), -1)
+ eye = np.eye(*w_init.shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)), requires_grad=True)
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)), requires_grad=True)
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)), requires_grad=True)
+ self.register_buffer('l_mask', torch.Tensor(l_mask))
+ self.register_buffer('eye', torch.Tensor(eye))
+ else:
+ self.weight = nn.Parameter(w_init)
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ b, c, t = x.size()
+ assert (c % self.n_split == 0)
+ if x_mask is None:
+ x_mask = 1
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+
+ x = x.view(b, self.n_sqz, c // self.n_split, self.n_split // self.n_sqz, t)
+ x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
+
+ if self.lu:
+ self.weight, log_s = self._get_weight()
+ logdet = log_s.sum()
+ logdet = logdet * (c / self.n_split) * x_len
+ else:
+ logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
+
+ if reverse:
+ if hasattr(self, "weight_inv"):
+ weight = self.weight_inv
+ else:
+ weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
+ logdet = -logdet
+ else:
+ weight = self.weight
+ if self.no_jacobian:
+ logdet = 0
+
+ weight = weight.view(self.n_split, self.n_split, 1, 1)
+ z = F.conv2d(x, weight)
+
+ z = z.view(b, self.n_sqz, self.n_split // self.n_sqz, c // self.n_split, t)
+ z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
+ return z, logdet
+
+ def _get_weight(self):
+ l, log_s, u = self.l, self.log_s, self.u
+ l = l * self.l_mask + self.eye
+ u = u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(log_s))
+ weight = torch.matmul(self.p, torch.matmul(l, u))
+ return weight, log_s
+
+ def store_inverse(self):
+ weight, _ = self._get_weight()
+ self.weight_inv = torch.inverse(weight.float()).to(next(self.parameters()).device)
+
+
+class InvConv(nn.Module):
+ def __init__(self, channels, no_jacobian=False, lu=True, **kwargs):
+ super().__init__()
+ w_shape = [channels, channels]
+ w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(float)
+ LU_decomposed = lu
+ if not LU_decomposed:
+ # Sample a random orthogonal matrix:
+ self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
+ else:
+ np_p, np_l, np_u = scipy.linalg.lu(w_init)
+ np_s = np.diag(np_u)
+ np_sign_s = np.sign(np_s)
+ np_log_s = np.log(np.abs(np_s))
+ np_u = np.triu(np_u, k=1)
+ l_mask = np.tril(np.ones(w_shape, dtype=float), -1)
+ eye = np.eye(*w_shape, dtype=float)
+
+ self.register_buffer('p', torch.Tensor(np_p.astype(float)))
+ self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(float)))
+ self.l = nn.Parameter(torch.Tensor(np_l.astype(float)))
+ self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(float)))
+ self.u = nn.Parameter(torch.Tensor(np_u.astype(float)))
+ self.l_mask = torch.Tensor(l_mask)
+ self.eye = torch.Tensor(eye)
+ self.w_shape = w_shape
+ self.LU = LU_decomposed
+ self.weight = None
+
+ def get_weight(self, device, reverse):
+ w_shape = self.w_shape
+ self.p = self.p.to(device)
+ self.sign_s = self.sign_s.to(device)
+ self.l_mask = self.l_mask.to(device)
+ self.eye = self.eye.to(device)
+ l = self.l * self.l_mask + self.eye
+ u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s))
+ dlogdet = self.log_s.sum()
+ if not reverse:
+ w = torch.matmul(self.p, torch.matmul(l, u))
+ else:
+ l = torch.inverse(l.double()).float()
+ u = torch.inverse(u.double()).float()
+ w = torch.matmul(u, torch.matmul(l, self.p.inverse()))
+ return w.view(w_shape[0], w_shape[1], 1), dlogdet
+
+ def forward(self, x, x_mask=None, reverse=False, **kwargs):
+ """
+ log-det = log|abs(|W|)| * pixels
+ """
+ b, c, t = x.size()
+ if x_mask is None:
+ x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
+ else:
+ x_len = torch.sum(x_mask, [1, 2])
+ logdet = 0
+ if not reverse:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet + dlogdet * x_len
+ return z, logdet
+ else:
+ if self.weight is None:
+ weight, dlogdet = self.get_weight(x.device, reverse)
+ else:
+ weight, dlogdet = self.weight, self.dlogdet
+ z = F.conv1d(x, weight)
+ if logdet is not None:
+ logdet = logdet - dlogdet * x_len
+ return z, logdet
+
+ def store_inverse(self):
+ self.weight, self.dlogdet = self.get_weight('cuda', reverse=True)
+
+
+class CouplingBlock(nn.Module):
+ def __init__(self, in_channels, hidden_channels, kernel_size, dilation_rate, n_layers,
+ gin_channels=0, p_dropout=0, sigmoid_scale=False, wn=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = gin_channels
+ self.p_dropout = p_dropout
+ self.sigmoid_scale = sigmoid_scale
+
+ start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
+ start = torch.nn.utils.weight_norm(start)
+ self.start = start
+ # Initializing last layer to 0 makes the affine coupling layers
+ # do nothing at first. This helps with training stability
+ end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
+ end.weight.data.zero_()
+ end.bias.data.zero_()
+ self.end = end
+ self.wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels, p_dropout)
+ if wn is not None:
+ self.wn.in_layers = wn.in_layers
+ self.wn.res_skip_layers = wn.res_skip_layers
+
+ def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
+ if x_mask is None:
+ x_mask = 1
+ x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
+
+ x = self.start(x_0) * x_mask
+ x = self.wn(x, x_mask, g)
+ out = self.end(x)
+
+ z_0 = x_0
+ m = out[:, :self.in_channels // 2, :]
+ logs = out[:, self.in_channels // 2:, :]
+ if self.sigmoid_scale:
+ logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
+ if reverse:
+ z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
+ logdet = torch.sum(-logs * x_mask, [1, 2])
+ else:
+ z_1 = (m + torch.exp(logs) * x_1) * x_mask
+ logdet = torch.sum(logs * x_mask, [1, 2])
+ z = torch.cat([z_0, z_1], 1)
+ return z, logdet
+
+ def store_inverse(self):
+ self.wn.remove_weight_norm()
+
+
+class Glow(nn.Module):
+ def __init__(self,
+ in_channels,
+ hidden_channels,
+ kernel_size,
+ dilation_rate,
+ n_blocks,
+ n_layers,
+ p_dropout=0.,
+ n_split=4,
+ n_sqz=2,
+ sigmoid_scale=False,
+ gin_channels=0,
+ inv_conv_type='near',
+ share_cond_layers=False,
+ share_wn_layers=0,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_blocks = n_blocks
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ self.n_split = n_split
+ self.n_sqz = n_sqz
+ self.sigmoid_scale = sigmoid_scale
+ self.gin_channels = gin_channels
+ self.share_cond_layers = share_cond_layers
+ if gin_channels != 0 and share_cond_layers:
+ cond_layer = torch.nn.Conv1d(gin_channels * n_sqz, 2 * hidden_channels * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+ wn = None
+ self.flows = nn.ModuleList()
+ for b in range(n_blocks):
+ self.flows.append(ActNorm(channels=in_channels * n_sqz))
+ if inv_conv_type == 'near':
+ self.flows.append(InvConvNear(channels=in_channels * n_sqz, n_split=n_split, n_sqz=n_sqz))
+ if inv_conv_type == 'invconv':
+ self.flows.append(InvConv(channels=in_channels * n_sqz))
+ if share_wn_layers > 0:
+ if b % share_wn_layers == 0:
+ wn = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels * n_sqz,
+ p_dropout, share_cond_layers)
+ self.flows.append(
+ CouplingBlock(
+ in_channels * n_sqz,
+ hidden_channels,
+ kernel_size=kernel_size,
+ dilation_rate=dilation_rate,
+ n_layers=n_layers,
+ gin_channels=gin_channels * n_sqz,
+ p_dropout=p_dropout,
+ sigmoid_scale=sigmoid_scale,
+ wn=wn
+ ))
+
+ def forward(self, x, x_mask=None, g=None, reverse=False, return_hiddens=False):
+ logdet_tot = 0
+ if not reverse:
+ flows = self.flows
+ else:
+ flows = reversed(self.flows)
+ if return_hiddens:
+ hs = []
+ if self.n_sqz > 1:
+ x, x_mask_ = utils.squeeze(x, x_mask, self.n_sqz)
+ if g is not None:
+ g, _ = utils.squeeze(g, x_mask, self.n_sqz)
+ x_mask = x_mask_
+ if self.share_cond_layers and g is not None:
+ g = self.cond_layer(g)
+ for f in flows:
+ x, logdet = f(x, x_mask, g=g, reverse=reverse)
+ if return_hiddens:
+ hs.append(x)
+ logdet_tot += logdet
+ if self.n_sqz > 1:
+ x, x_mask = utils.unsqueeze(x, x_mask, self.n_sqz)
+ if return_hiddens:
+ return x, logdet_tot, hs
+ return x, logdet_tot
+
+ def store_inverse(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
+ for f in self.flows:
+ f.store_inverse()
diff --git a/modules/commons/normalizing_flow/res_flow.py b/modules/commons/normalizing_flow/res_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..428fb7da9e3becb0d11cdf239fff410c86028d95
--- /dev/null
+++ b/modules/commons/normalizing_flow/res_flow.py
@@ -0,0 +1,61 @@
+import torch
+from torch import nn
+from modules.commons.conv import ConditionalConvBlocks
+from modules.commons.wavenet import WN
+
+
+class FlipLayer(nn.Module):
+ def forward(self, x, *args, **kwargs):
+ x = torch.flip(x, [1])
+ return x
+
+
+class CouplingLayer(nn.Module):
+ def __init__(self, c_in, hidden_size, kernel_size, n_layers, p_dropout=0, c_in_g=0, nn_type='wn'):
+ super().__init__()
+ self.channels = c_in
+ self.hidden_size = hidden_size
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.c_half = c_in // 2
+
+ self.pre = nn.Conv1d(self.c_half, hidden_size, 1)
+ if nn_type == 'wn':
+ self.enc = WN(hidden_size, kernel_size, 1, n_layers, p_dropout=p_dropout,
+ c_cond=c_in_g)
+ elif nn_type == 'conv':
+ self.enc = ConditionalConvBlocks(
+ hidden_size, c_in_g, hidden_size, None, kernel_size,
+ layers_in_block=1, is_BTC=False, num_layers=n_layers)
+ self.post = nn.Conv1d(hidden_size, self.c_half, 1)
+
+ def forward(self, x, nonpadding, cond=None, reverse=False):
+ x0, x1 = x[:, :self.c_half], x[:, self.c_half:]
+ x_ = self.pre(x0) * nonpadding
+ x_ = self.enc(x_, nonpadding=nonpadding, cond=cond)
+ m = self.post(x_)
+ x1 = m + x1 if not reverse else x1 - m
+ x = torch.cat([x0, x1], 1)
+ return x * nonpadding
+
+
+class ResFlow(nn.Module):
+ def __init__(self,
+ c_in,
+ hidden_size,
+ kernel_size,
+ n_flow_layers,
+ n_flow_steps=4,
+ c_cond=0,
+ nn_type='wn'):
+ super().__init__()
+ self.flows = nn.ModuleList()
+ for i in range(n_flow_steps):
+ self.flows.append(
+ CouplingLayer(c_in, hidden_size, kernel_size, n_flow_layers, c_in_g=c_cond, nn_type=nn_type))
+ self.flows.append(FlipLayer())
+
+ def forward(self, x, nonpadding, cond=None, reverse=False):
+ for flow in (self.flows if not reverse else reversed(self.flows)):
+ x = flow(x, nonpadding, cond=cond, reverse=reverse)
+ return x
diff --git a/modules/commons/normalizing_flow/utils.py b/modules/commons/normalizing_flow/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eb56ec514bff822ba1a19a6474207ed82492410
--- /dev/null
+++ b/modules/commons/normalizing_flow/utils.py
@@ -0,0 +1,29 @@
+import torch
+
+
+def squeeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ t = (t // n_sqz) * n_sqz
+ x = x[:, :, :t]
+ x_sqz = x.view(b, c, t // n_sqz, n_sqz)
+ x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask[:, :, n_sqz - 1::n_sqz]
+ else:
+ x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_sqz * x_mask, x_mask
+
+
+def unsqueeze(x, x_mask=None, n_sqz=2):
+ b, c, t = x.size()
+
+ x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
+ x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
+
+ if x_mask is not None:
+ x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
+ else:
+ x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
+ return x_unsqz * x_mask, x_mask
diff --git a/modules/commons/rel_transformer.py b/modules/commons/rel_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd41b301a98609391d1a18b118d1f1b3e538af1d
--- /dev/null
+++ b/modules/commons/rel_transformer.py
@@ -0,0 +1,389 @@
+import math
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from modules.commons.layers import Embedding
+
+
+def convert_pad_shape(pad_shape):
+ l = pad_shape[::-1]
+ pad_shape = [item for sublist in l for item in sublist]
+ return pad_shape
+
+
+def shift_1d(x):
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
+ return x
+
+
+def sequence_mask(length, max_length=None):
+ if max_length is None:
+ max_length = length.max()
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
+ return x.unsqueeze(0) < length.unsqueeze(1)
+
+
+class Encoder(nn.Module):
+ def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.,
+ window_size=None, block_length=None, pre_ln=False, **kwargs):
+ super().__init__()
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.block_length = block_length
+ self.pre_ln = pre_ln
+
+ self.drop = nn.Dropout(p_dropout)
+ self.attn_layers = nn.ModuleList()
+ self.norm_layers_1 = nn.ModuleList()
+ self.ffn_layers = nn.ModuleList()
+ self.norm_layers_2 = nn.ModuleList()
+ for i in range(self.n_layers):
+ self.attn_layers.append(
+ MultiHeadAttention(hidden_channels, hidden_channels, n_heads, window_size=window_size,
+ p_dropout=p_dropout, block_length=block_length))
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
+ self.ffn_layers.append(
+ FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
+ if pre_ln:
+ self.last_ln = LayerNorm(hidden_channels)
+
+ def forward(self, x, x_mask, attn_mask=1):
+ if isinstance(attn_mask, torch.Tensor):
+ attn_mask = attn_mask[:, None]
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) * attn_mask
+ for i in range(self.n_layers):
+ x = x * x_mask
+ x_ = x
+ if self.pre_ln:
+ x = self.norm_layers_1[i](x)
+ y = self.attn_layers[i](x, x, attn_mask)
+ y = self.drop(y)
+ x = x_ + y
+ if not self.pre_ln:
+ x = self.norm_layers_1[i](x)
+
+ x_ = x
+ if self.pre_ln:
+ x = self.norm_layers_2[i](x)
+ y = self.ffn_layers[i](x, x_mask)
+ y = self.drop(y)
+ x = x_ + y
+ if not self.pre_ln:
+ x = self.norm_layers_2[i](x)
+ if self.pre_ln:
+ x = self.last_ln(x)
+ x = x * x_mask
+ return x
+
+
+class MultiHeadAttention(nn.Module):
+ def __init__(self, channels, out_channels, n_heads, window_size=None, heads_share=True, p_dropout=0.,
+ block_length=None, proximal_bias=False, proximal_init=False):
+ super().__init__()
+ assert channels % n_heads == 0
+
+ self.channels = channels
+ self.out_channels = out_channels
+ self.n_heads = n_heads
+ self.window_size = window_size
+ self.heads_share = heads_share
+ self.block_length = block_length
+ self.proximal_bias = proximal_bias
+ self.p_dropout = p_dropout
+ self.attn = None
+
+ self.k_channels = channels // n_heads
+ self.conv_q = nn.Conv1d(channels, channels, 1)
+ self.conv_k = nn.Conv1d(channels, channels, 1)
+ self.conv_v = nn.Conv1d(channels, channels, 1)
+ if window_size is not None:
+ n_heads_rel = 1 if heads_share else n_heads
+ rel_stddev = self.k_channels ** -0.5
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ nn.init.xavier_uniform_(self.conv_q.weight)
+ nn.init.xavier_uniform_(self.conv_k.weight)
+ if proximal_init:
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
+ nn.init.xavier_uniform_(self.conv_v.weight)
+
+ def forward(self, x, c, attn_mask=None):
+ q = self.conv_q(x)
+ k = self.conv_k(c)
+ v = self.conv_v(c)
+
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
+
+ x = self.conv_o(x)
+ return x
+
+ def attention(self, query, key, value, mask=None):
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
+ b, d, t_s, t_t = (*key.size(), query.size(2))
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
+
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
+ if self.window_size is not None:
+ assert t_s == t_t, "Relative attention is only available for self-attention."
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
+ scores_local = rel_logits / math.sqrt(self.k_channels)
+ scores = scores + scores_local
+ if self.proximal_bias:
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
+ scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
+ if mask is not None:
+ scores = scores.masked_fill(mask == 0, -1e4)
+ if self.block_length is not None:
+ block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
+ scores = scores * block_mask + -1e4 * (1 - block_mask)
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
+ p_attn = self.drop(p_attn)
+ output = torch.matmul(p_attn, value)
+ if self.window_size is not None:
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
+ value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
+ output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
+ return output, p_attn
+
+ def _matmul_with_relative_values(self, x, y):
+ """
+ x: [b, h, l, m]
+ y: [h or 1, m, d]
+ ret: [b, h, l, d]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0))
+ return ret
+
+ def _matmul_with_relative_keys(self, x, y):
+ """
+ x: [b, h, l, d]
+ y: [h or 1, m, d]
+ ret: [b, h, l, m]
+ """
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
+ return ret
+
+ def _get_relative_embeddings(self, relative_embeddings, length):
+ max_relative_position = 2 * self.window_size + 1
+ # Pad first before slice to avoid using cond ops.
+ pad_length = max(length - (self.window_size + 1), 0)
+ slice_start_position = max((self.window_size + 1) - length, 0)
+ slice_end_position = slice_start_position + 2 * length - 1
+ if pad_length > 0:
+ padded_relative_embeddings = F.pad(
+ relative_embeddings,
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
+ else:
+ padded_relative_embeddings = relative_embeddings
+ used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
+ return used_relative_embeddings
+
+ def _relative_position_to_absolute_position(self, x):
+ """
+ x: [b, h, l, 2*l-1]
+ ret: [b, h, l, l]
+ """
+ batch, heads, length, _ = x.size()
+ # Concat columns of pad to shift from relative to absolute indexing.
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
+
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
+ x_flat = x.view([batch, heads, length * 2 * length])
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
+
+ # Reshape and slice out the padded elements.
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1:]
+ return x_final
+
+ def _absolute_position_to_relative_position(self, x):
+ """
+ x: [b, h, l, l]
+ ret: [b, h, l, 2*l-1]
+ """
+ batch, heads, length, _ = x.size()
+ # padd along column
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
+ x_flat = x.view([batch, heads, -1])
+ # add 0's in the beginning that will skew the elements after reshape
+ x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
+ return x_final
+
+ def _attention_bias_proximal(self, length):
+ """Bias for self-attention to encourage attention to close positions.
+ Args:
+ length: an integer scalar.
+ Returns:
+ a Tensor with shape [1, 1, length, length]
+ """
+ r = torch.arange(length, dtype=torch.float32)
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
+
+
+class FFN(nn.Module):
+ def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.filter_channels = filter_channels
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.activation = activation
+
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, 1)
+ self.drop = nn.Dropout(p_dropout)
+
+ def forward(self, x, x_mask):
+ x = self.conv_1(x * x_mask)
+ if self.activation == "gelu":
+ x = x * torch.sigmoid(1.702 * x)
+ else:
+ x = torch.relu(x)
+ x = self.drop(x)
+ x = self.conv_2(x * x_mask)
+ return x * x_mask
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, channels, eps=1e-4):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+
+ def forward(self, x):
+ n_dims = len(x.shape)
+ mean = torch.mean(x, 1, keepdim=True)
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
+
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
+
+ shape = [1, -1] + [1] * (n_dims - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+class ConvReluNorm(nn.Module):
+ def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout):
+ super().__init__()
+ self.in_channels = in_channels
+ self.hidden_channels = hidden_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.n_layers = n_layers
+ self.p_dropout = p_dropout
+ assert n_layers > 1, "Number of layers should be larger than 0."
+
+ self.conv_layers = nn.ModuleList()
+ self.norm_layers = nn.ModuleList()
+ self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.relu_drop = nn.Sequential(
+ nn.ReLU(),
+ nn.Dropout(p_dropout))
+ for _ in range(n_layers - 1):
+ self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
+ self.norm_layers.append(LayerNorm(hidden_channels))
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
+ self.proj.weight.data.zero_()
+ self.proj.bias.data.zero_()
+
+ def forward(self, x, x_mask):
+ x_org = x
+ for i in range(self.n_layers):
+ x = self.conv_layers[i](x * x_mask)
+ x = self.norm_layers[i](x)
+ x = self.relu_drop(x)
+ x = x_org + self.proj(x)
+ return x * x_mask
+
+
+class RelTransformerEncoder(nn.Module):
+ def __init__(self,
+ n_vocab,
+ out_channels,
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout=0.0,
+ window_size=4,
+ block_length=None,
+ in_channels=None,
+ prenet=True,
+ pre_ln=True,
+ ):
+
+ super().__init__()
+
+ self.n_vocab = n_vocab
+ self.out_channels = out_channels
+ self.hidden_channels = hidden_channels
+ self.filter_channels = filter_channels
+ self.n_heads = n_heads
+ self.n_layers = n_layers
+ self.kernel_size = kernel_size
+ self.p_dropout = p_dropout
+ self.window_size = window_size
+ self.block_length = block_length
+ self.prenet = prenet
+ if n_vocab > 0:
+ self.emb = Embedding(n_vocab, hidden_channels, padding_idx=0)
+
+ if prenet:
+ if in_channels is None:
+ in_channels = hidden_channels
+ self.pre = ConvReluNorm(in_channels, in_channels, in_channels,
+ kernel_size=5, n_layers=3, p_dropout=0)
+ if in_channels is not None and in_channels != hidden_channels:
+ self.encoder_inp_proj = nn.Conv1d(in_channels, hidden_channels, 1)
+ self.encoder = Encoder(
+ hidden_channels,
+ filter_channels,
+ n_heads,
+ n_layers,
+ kernel_size,
+ p_dropout,
+ window_size=window_size,
+ block_length=block_length,
+ pre_ln=pre_ln,
+ )
+
+ def forward(self, x, x_mask=None, other_embeds=0, attn_mask=1):
+ if self.n_vocab > 0:
+ x_lengths = (x > 0).long().sum(-1)
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
+ else:
+ x_lengths = (x.abs().sum(-1) > 0).long().sum(-1)
+ x = x + other_embeds
+ x = torch.transpose(x, 1, -1) # [b, h, t]
+ x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
+
+ if self.prenet:
+ x = self.pre(x, x_mask)
+ self.prenet_out = x.transpose(1, 2)
+ if hasattr(self, 'encoder_inp_proj'):
+ x = self.encoder_inp_proj(x) * x_mask
+ x = self.encoder(x, x_mask, attn_mask)
+ return x.transpose(1, 2)
diff --git a/modules/commons/rnn.py b/modules/commons/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..205c2c76b8fda2de920bc59228a5eec0a20119a9
--- /dev/null
+++ b/modules/commons/rnn.py
@@ -0,0 +1,261 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class PreNet(nn.Module):
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
+ super().__init__()
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
+ self.p = dropout
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=self.training)
+ x = self.fc2(x)
+ x = F.relu(x)
+ x = F.dropout(x, self.p, training=self.training)
+ return x
+
+
+class HighwayNetwork(nn.Module):
+ def __init__(self, size):
+ super().__init__()
+ self.W1 = nn.Linear(size, size)
+ self.W2 = nn.Linear(size, size)
+ self.W1.bias.data.fill_(0.)
+
+ def forward(self, x):
+ x1 = self.W1(x)
+ x2 = self.W2(x)
+ g = torch.sigmoid(x2)
+ y = g * F.relu(x1) + (1. - g) * x
+ return y
+
+
+class BatchNormConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
+ super().__init__()
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
+ self.bnorm = nn.BatchNorm1d(out_channels)
+ self.relu = relu
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(x) if self.relu is True else x
+ return self.bnorm(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+class CBHG(nn.Module):
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
+ super().__init__()
+
+ # List of all rnns to call `flatten_parameters()` on
+ self._to_flatten = []
+
+ self.bank_kernels = [i for i in range(1, K + 1)]
+ self.conv1d_bank = nn.ModuleList()
+ for k in self.bank_kernels:
+ conv = BatchNormConv(in_channels, channels, k)
+ self.conv1d_bank.append(conv)
+
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
+
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
+
+ # Fix the highway input if necessary
+ if proj_channels[-1] != channels:
+ self.highway_mismatch = True
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
+ else:
+ self.highway_mismatch = False
+
+ self.highways = nn.ModuleList()
+ for i in range(num_highways):
+ hn = HighwayNetwork(channels)
+ self.highways.append(hn)
+
+ self.rnn = nn.GRU(channels, channels, batch_first=True, bidirectional=True)
+ self._to_flatten.append(self.rnn)
+
+ # Avoid fragmentation of RNN parameters and associated warning
+ self._flatten_parameters()
+
+ def forward(self, x):
+ # Although we `_flatten_parameters()` on init, when using DataParallel
+ # the model gets replicated, making it no longer guaranteed that the
+ # weights are contiguous in GPU memory. Hence, we must call it again
+ self._flatten_parameters()
+
+ # Save these for later
+ residual = x
+ seq_len = x.size(-1)
+ conv_bank = []
+
+ # Convolution Bank
+ for conv in self.conv1d_bank:
+ c = conv(x) # Convolution
+ conv_bank.append(c[:, :, :seq_len])
+
+ # Stack along the channel axis
+ conv_bank = torch.cat(conv_bank, dim=1)
+
+ # dump the last padding to fit residual
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
+
+ # Conv1d projections
+ x = self.conv_project1(x)
+ x = self.conv_project2(x)
+
+ # Residual Connect
+ x = x + residual
+
+ # Through the highways
+ x = x.transpose(1, 2)
+ if self.highway_mismatch is True:
+ x = self.pre_highway(x)
+ for h in self.highways:
+ x = h(x)
+
+ # And then the RNN
+ x, _ = self.rnn(x)
+ return x
+
+ def _flatten_parameters(self):
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
+ to improve efficiency and avoid PyTorch yelling at us."""
+ [m.flatten_parameters() for m in self._to_flatten]
+
+
+class TacotronEncoder(nn.Module):
+ def __init__(self, embed_dims, num_chars, cbhg_channels, K, num_highways, dropout):
+ super().__init__()
+ self.embedding = nn.Embedding(num_chars, embed_dims)
+ self.pre_net = PreNet(embed_dims, embed_dims, embed_dims, dropout=dropout)
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
+ proj_channels=[cbhg_channels, cbhg_channels],
+ num_highways=num_highways)
+ self.proj_out = nn.Linear(cbhg_channels * 2, cbhg_channels)
+
+ def forward(self, x):
+ x = self.embedding(x)
+ x = self.pre_net(x)
+ x.transpose_(1, 2)
+ x = self.cbhg(x)
+ x = self.proj_out(x)
+ return x
+
+
+class RNNEncoder(nn.Module):
+ def __init__(self, num_chars, embedding_dim, n_convolutions=3, kernel_size=5):
+ super(RNNEncoder, self).__init__()
+ self.embedding = nn.Embedding(num_chars, embedding_dim, padding_idx=0)
+ convolutions = []
+ for _ in range(n_convolutions):
+ conv_layer = nn.Sequential(
+ ConvNorm(embedding_dim,
+ embedding_dim,
+ kernel_size=kernel_size, stride=1,
+ padding=int((kernel_size - 1) / 2),
+ dilation=1, w_init_gain='relu'),
+ nn.BatchNorm1d(embedding_dim))
+ convolutions.append(conv_layer)
+ self.convolutions = nn.ModuleList(convolutions)
+
+ self.lstm = nn.LSTM(embedding_dim, int(embedding_dim / 2), 1,
+ batch_first=True, bidirectional=True)
+
+ def forward(self, x):
+ input_lengths = (x > 0).sum(-1)
+ input_lengths = input_lengths.cpu().numpy()
+
+ x = self.embedding(x)
+ x = x.transpose(1, 2) # [B, H, T]
+ for conv in self.convolutions:
+ x = F.dropout(F.relu(conv(x)), 0.5, self.training) + x
+ x = x.transpose(1, 2) # [B, T, H]
+
+ # pytorch tensor are not reversible, hence the conversion
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
+
+ self.lstm.flatten_parameters()
+ outputs, _ = self.lstm(x)
+ outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True)
+
+ return outputs
+
+
+class DecoderRNN(torch.nn.Module):
+ def __init__(self, hidden_size, decoder_rnn_dim, dropout):
+ super(DecoderRNN, self).__init__()
+ self.in_conv1d = nn.Sequential(
+ torch.nn.Conv1d(
+ in_channels=hidden_size,
+ out_channels=hidden_size,
+ kernel_size=9, padding=4,
+ ),
+ torch.nn.ReLU(),
+ torch.nn.Conv1d(
+ in_channels=hidden_size,
+ out_channels=hidden_size,
+ kernel_size=9, padding=4,
+ ),
+ )
+ self.ln = nn.LayerNorm(hidden_size)
+ if decoder_rnn_dim == 0:
+ decoder_rnn_dim = hidden_size * 2
+ self.rnn = torch.nn.LSTM(
+ input_size=hidden_size,
+ hidden_size=decoder_rnn_dim,
+ num_layers=1,
+ batch_first=True,
+ bidirectional=True,
+ dropout=dropout
+ )
+ self.rnn.flatten_parameters()
+ self.conv1d = torch.nn.Conv1d(
+ in_channels=decoder_rnn_dim * 2,
+ out_channels=hidden_size,
+ kernel_size=3,
+ padding=1,
+ )
+
+ def forward(self, x):
+ input_masks = x.abs().sum(-1).ne(0).data[:, :, None]
+ input_lengths = input_masks.sum([-1, -2])
+ input_lengths = input_lengths.cpu().numpy()
+
+ x = self.in_conv1d(x.transpose(1, 2)).transpose(1, 2)
+ x = self.ln(x)
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
+ self.rnn.flatten_parameters()
+ x, _ = self.rnn(x) # [B, T, C]
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
+ x = x * input_masks
+ pre_mel = self.conv1d(x.transpose(1, 2)).transpose(1, 2) # [B, T, C]
+ pre_mel = pre_mel * input_masks
+ return pre_mel
diff --git a/modules/commons/rot_transformer.py b/modules/commons/rot_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d17c488042b54a70f0b897f4efc488dfbce3b3b3
--- /dev/null
+++ b/modules/commons/rot_transformer.py
@@ -0,0 +1,635 @@
+import math
+import torch
+from typing import Optional, Tuple
+from torch import nn
+from torch.nn import Parameter, Linear
+from torch.cuda.amp import autocast
+from modules.commons.layers import LayerNorm, Embedding
+from modules.commons.transformer import TransformerFFNLayer, MultiheadAttention
+from utils.nn.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
+import torch.nn.functional as F
+
+DEFAULT_MAX_SOURCE_POSITIONS = 3000
+DEFAULT_MAX_TARGET_POSITIONS = 3000
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class RotaryEmbeddings(nn.Module):
+ cos: torch.Tensor
+ sin: torch.Tensor
+ theta: torch.Tensor
+
+ def __init__(
+ self,
+ width: int,
+ *,
+ seq_len: int = 4000,
+ base: int = 10000,
+ device: Optional[torch.device] = None,
+ ):
+ """Rotary embeddings (Su et al., 2021) layer. The rotary embedding
+ will be precomputed for up to 'seq _len' positions. The embedding
+ will be recomputed when a longer sequence is found in the input.
+
+ :param width:
+ Rotary embedding dimensionality, must be even.
+ :param seq_len:
+ Number of positons to initially precompute.
+ :param base:
+ The base used for Θ_i, determines the cycle length of the
+ embeddings.
+ :param device: Device on which the module is to be initialized.
+ """
+ super().__init__()
+
+ if width % 2:
+ raise ValueError(f"Width of rotary embeddings must be even, was: {width}")
+
+ # Ignore allocations on the meta device as we don't persist our buffer,
+ # i.e., we don't expect the backing tensor to be replaced with pretrained weights.
+ if device is not None and device.type == "meta":
+ device = None
+ # Θ_i = 10000^(-2(i-1)/d)
+ theta = torch.pow(
+ base, -torch.arange(0, width, 2, dtype=torch.float, device=device) / width
+ )
+ self.register_buffer("theta", theta, persistent=False)
+
+ self._create_rotary_embed(width=width, length=seq_len)
+
+ def _create_rotary_embed(self, *, width: int, length: int):
+ # mΘ
+ position = torch.arange(length, device=self.theta.device).unsqueeze(1)
+ m_theta = position * self.theta.unsqueeze(0)
+
+ # We apply both sin and cos twice (see Eq 15, 34), but the ordering
+ # is changed for compatibility with most common implementations.
+ m_theta = torch.cat([m_theta, m_theta], dim=-1)
+
+ re_cos = m_theta.cos().view([length, width])
+ re_sin = m_theta.sin().view([length, width])
+
+ self.register_buffer("cos", re_cos, persistent=False)
+ self.register_buffer("sin", re_sin, persistent=False)
+
+ def _rotate(self, input: torch.Tensor):
+ """Rotate the input tensor by half of its innermost width.
+
+ input (Tensor): array to rotate.
+ RETURNS (Tensor): rotated array.
+
+ Shapes:
+ input - (..., width)
+ output - (..., width)
+ """
+ half_idx = input.shape[-1] // 2
+ input_1 = -input[..., half_idx:]
+ input_2 = input[..., :half_idx]
+ return torch.cat([input_1, input_2], dim=-1)
+
+ def forward(self, input: torch.Tensor, *, positions: Optional[torch.Tensor] = None):
+ """
+ Apply rotary embeddings to an array.
+
+ :param input: Array to apply the rotary embeddings to.
+ :param positions: positions of the inputs. If no positions are
+ provided, they are assumed to be [0, seq_len).
+ :return: Array with the rotary embeddings applied.
+
+ Shapes:
+ input - (batch_size, num_heads, seq_len, width_per_head)
+ positions - (batch_size, seq_len)
+ output - (batch_size, num_heads, seq_len, width_per_head)
+ """
+ batch_size, _, seq_len, width = input.shape
+
+ if positions is None:
+ # Fastpath: positions from [0..seq_len), avoid indexing.
+ if self.cos.size(-2) < seq_len:
+ self._create_rotary_embed(width=width, length=seq_len)
+ rot_cos = self.cos[:seq_len, :].view(1, 1, seq_len, width)
+ rot_sin = self.sin[:seq_len, :].view(1, 1, seq_len, width)
+ else:
+ max_len = int(positions.max()) + 1
+ if self.cos.size(-2) < max_len:
+ self._create_rotary_embed(width=width, length=max_len)
+
+ # Flatten positions to index cos/sin arrays, then unflatten.
+ #
+ # Example shapes:
+ #
+ # positions_flat - (batch_size * seq_len)
+ # self.cos - (max_len, width)
+ # rot_cos - (batch_size, seq_len, width)
+ positions_flat = positions.view(-1)
+ rot_cos = self.cos[positions_flat].view(batch_size, 1, seq_len, width)
+ rot_sin = self.sin[positions_flat].view(batch_size, 1, seq_len, width)
+
+ # Eq 34 with ordering changed for compatibility.
+ return rot_cos * input + rot_sin * self._rotate(input)
+
+
+class RotMultiheadAttention(MultiheadAttention):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
+ encoder_decoder_attention=encoder_decoder_attention)
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
+
+ def forward(
+ self,
+ query, key, value,
+ spk_pos_ids_flat=None,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q = q * self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ # Apply rot embedding and store incremental_state
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
+ bsz, self.num_heads, -1, self.head_dim)
+ self._set_input_buffer(incremental_state, saved_state)
+ if incremental_state is not None:
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
+ else:
+ key_pos = spk_pos_ids_flat
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+
+class RotMultiheadAttention2(MultiheadAttention):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__(embed_dim, num_heads, kdim=kdim, vdim=vdim, dropout=dropout, bias=bias,
+ add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn, self_attention=self_attention,
+ encoder_decoder_attention=encoder_decoder_attention)
+ self.rotary_embeds = RotaryEmbeddings(width=embed_dim // num_heads)
+
+ def forward(
+ self,
+ query, key, value,
+ spk_pos_ids_flat=None,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ # Apply rot embedding and store incremental_state
+ q = self.rotary_embeds(q[None, :], positions=spk_pos_ids_flat)[0]
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ saved_state['prev_key'], saved_state['prev_value'] = k.view(bsz, self.num_heads, -1, self.head_dim), v.view(
+ bsz, self.num_heads, -1, self.head_dim)
+ self._set_input_buffer(incremental_state, saved_state)
+ key_pos = torch.arange(k.shape[-2], device=q.device).unsqueeze(0)
+ k = self.rotary_embeds(k[None, :], positions=key_pos)[0]
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=attn_mask, dropout_p=0, is_causal=False)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn_logits = None
+ attn_weights = None
+ return attn, (attn_weights, attn_logits)
+
+
+class RotDecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = RotMultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ self.layer_norm2 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+ self.post_ln = post_ln
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ spk_pos_ids_flat=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm1(x)
+
+ x, (attn_weights, _) = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask,
+ spk_pos_ids_flat=spk_pos_ids_flat
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm1(x)
+
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm2(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm2(x)
+ return x, attn_weights
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class RotDecSALayer2(RotDecSALayer):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9,
+ ffn_hidden_size=1024, act='gelu', post_ln=False):
+ super().__init__(c, num_heads, dropout, attention_dropout, relu_dropout, kernel_size, ffn_hidden_size, act,
+ post_ln)
+ self.self_attn = RotMultiheadAttention2(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+
+
+class RotTransformerDecoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=8, ffn_hidden_size=1024, post_ln=False,
+ op_version=1):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if op_version == 1:
+ self.op = RotDecSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln)
+ else:
+ self.op = RotDecSALayer2(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
diff --git a/modules/commons/taming_tfm_modules.py b/modules/commons/taming_tfm_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..79418633fbf06fac1afaa2d794a9ef2af9bdb7b3
--- /dev/null
+++ b/modules/commons/taming_tfm_modules.py
@@ -0,0 +1,366 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+class Normalize(nn.Module):
+ def __init__(self, channels, eps=1e-5):
+ super().__init__()
+ self.channels = channels
+ self.eps = eps
+
+ self.gamma = nn.Parameter(torch.ones(channels))
+ self.beta = nn.Parameter(torch.zeros(channels))
+ self.proj = nn.Linear(channels, channels)
+
+ def forward(self, x):
+ x = x.transpose(1, 2)
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
+ x = self.proj(x)
+ return x.transpose(1, 2)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x):
+ if self.with_conv:
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.conv2 = torch.nn.Conv1d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv1d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, _, x_mask):
+ x = x * x_mask
+ h = x
+ h = self.norm1(h) * x_mask
+ h = nonlinearity(h) * x_mask
+ h = self.conv1(h) * x_mask
+
+ h = self.norm2(h) * x_mask
+ h = nonlinearity(h) * x_mask
+ h = self.conv2(h) * x_mask
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x) * x_mask
+ else:
+ x = self.nin_shortcut(x) * x_mask
+
+ return (x + h) * x_mask
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv1d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, x_mask):
+ h_ = x * x_mask
+ h_ = self.norm(h_) * x_mask
+ q = self.q(h_) * x_mask
+ k = self.k(h_) * x_mask
+ v = self.v(h_) * x_mask
+
+ # compute attention
+ b, c, h = q.shape
+ w = 1
+ q = q.reshape(b, c, h * w)
+ q = q.permute(0, 2, 1) # b,hw,c
+ k = k.reshape(b, c, h * w) # b,c,hw
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c) ** (-0.5))
+ w_ = w_ + ((1 - x_mask) * -1e8) + ((1 - x_mask) * -1e8).transpose(1, 2)
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h)
+
+ h_ = self.proj_out(h_) * x_mask
+
+ return (x + h_) * x_mask
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
+ resamp_with_conv=False, in_channels):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv1d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ in_ch_mult = (1,) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch))
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv1d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, x_mask):
+ if x_mask is None:
+ x_mask = torch.ones_like(x_mask[:, :, :1])
+ x = x.permute(0, 2, 1)
+ x_mask = x_mask.permute(0, 2, 1)
+
+ temb = None
+ # downsampling
+ hs = [self.conv_in(x) * x_mask]
+ for i_level in range(self.num_resolutions):
+ x_mask_ = x_mask[:, :, ::2 ** i_level]
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb, x_mask_) * x_mask_
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h, x_mask_) * x_mask_
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]) * x_mask_[:, :, ::2])
+
+ x_mask_ = x_mask[:, :, ::2 ** (self.num_resolutions - 1)]
+ # middle
+ h = hs[-1] * x_mask_
+ h = self.mid.block_1(h, temb, x_mask_) * x_mask_
+ h = self.mid.attn_1(h, x_mask_) * x_mask_
+ h = self.mid.block_2(h, temb, x_mask_) * x_mask_
+
+ # end
+ h = self.norm_out(h) * x_mask_
+ h = nonlinearity(h) * x_mask_
+ h = self.conv_out(h) * x_mask_
+ h = h.permute(0, 2, 1)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
+ resamp_with_conv=True, in_channels, give_pre_end=False):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv1d(in_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch))
+ block_in = block_out
+ if i_level == self.num_resolutions - 1:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv1d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z, x_mask):
+ if x_mask is None:
+ x_mask = torch.ones_like(z[:, :, :1]).repeat(1, 8, 1)
+ z = z.permute(0, 2, 1)
+ x_mask = x_mask.permute(0, 2, 1)
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ i_level = self.num_resolutions - 1
+ x_mask_ = x_mask[:, :, ::2 ** i_level]
+ h = self.mid.block_1(h, temb, x_mask_)
+ h = self.mid.attn_1(h, x_mask_)
+ h = self.mid.block_2(h, temb, x_mask_)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ x_mask_ = x_mask[:, :, ::2 ** i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb, x_mask_)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h, x_mask_)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h) * x_mask
+ h = h.permute(0, 2, 1)
+ return h
diff --git a/modules/commons/transformer.py b/modules/commons/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..36e09edfb2a124f7cc8913254b167fefec4f5b96
--- /dev/null
+++ b/modules/commons/transformer.py
@@ -0,0 +1,752 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter, Linear
+from modules.commons.layers import LayerNorm, Embedding
+from utils.nn.seq_utils import get_incremental_state, set_incremental_state, softmax, make_positions
+import torch.nn.functional as F
+
+DEFAULT_MAX_SOURCE_POSITIONS = 3000
+DEFAULT_MAX_TARGET_POSITIONS = 3000
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = Linear(filter_size, hidden_size)
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ prev_input = saved_state['prev_input']
+ x = torch.cat((prev_input, x), dim=0)
+ x = x[-self.kernel_size:]
+ saved_state['prev_input'] = x
+ self._set_input_buffer(incremental_state, saved_state)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-1:]
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'f',
+ buffer,
+ )
+
+ def clear_buffer(self, incremental_state):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_input' in saved_state:
+ del saved_state['prev_input']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query, key, value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask)
+ else:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ # previous time steps are cached - no need to recompute
+ # key and value if they are static
+ if static_kv:
+ assert self.encoder_decoder_attention and not self.self_attention
+ key = value = None
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q = q * self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
+ if 'prev_key' in saved_state:
+ prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ k = prev_key
+ else:
+ k = torch.cat((prev_key, k), dim=1)
+ if 'prev_value' in saved_state:
+ prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim)
+ if static_kv:
+ v = prev_value
+ else:
+ v = torch.cat((prev_value, v), dim=1)
+ if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None:
+ prev_key_padding_mask = saved_state['prev_key_padding_mask']
+ if static_kv:
+ key_padding_mask = prev_key_padding_mask
+ else:
+ key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1)
+
+ saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
+ saved_state['prev_key_padding_mask'] = key_padding_mask
+
+ self._set_input_buffer(incremental_state, saved_state)
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e8,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+ def _get_input_buffer(self, incremental_state):
+ return get_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ ) or {}
+
+ def _set_input_buffer(self, incremental_state, buffer):
+ set_incremental_state(
+ self,
+ incremental_state,
+ 'attn_state',
+ buffer,
+ )
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+ def clear_buffer(self, incremental_state=None):
+ if incremental_state is not None:
+ saved_state = self._get_input_buffer(incremental_state)
+ if 'prev_key' in saved_state:
+ del saved_state['prev_key']
+ if 'prev_value' in saved_state:
+ del saved_state['prev_value']
+ self._set_input_buffer(incremental_state, saved_state)
+
+
+class EncSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+ relu_dropout=0.1, kernel_size=9, padding='SAME', act='gelu',
+ ffn_hidden_size=1024):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if num_heads > 0:
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False)
+ self.layer_norm2 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+ return x
+
+
+class DecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1,
+ kernel_size=9, ffn_hidden_size=1024, act='gelu', post_ln=False):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ self.layer_norm2 = LayerNorm(c)
+ self.encoder_attn = MultiheadAttention(
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+ )
+ self.layer_norm3 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, ffn_hidden_size, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+ self.post_ln = post_ln
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ self.layer_norm3.training = layer_norm_training
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm1(x)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm1(x)
+
+ attn_logits = None
+ if encoder_out is not None or attn_out is not None:
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm2(x)
+ if encoder_out is not None:
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state,
+ 'enc_dec_attn_constraint_mask'),
+ reset_attn_weight=reset_attn_weight
+ )
+ attn_logits = attn[1]
+ elif attn_out is not None:
+ x = self.encoder_attn.in_proj_v(attn_out)
+ if encoder_out is not None or attn_out is not None:
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm2(x)
+
+ residual = x
+ if not self.post_ln:
+ x = self.layer_norm3(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ if self.post_ln:
+ x = self.layer_norm3(x)
+ return x, attn_logits
+
+ def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None):
+ self.encoder_attn.clear_buffer(incremental_state)
+ self.ffn.clear_buffer(incremental_state)
+
+ def set_buffer(self, name, tensor, incremental_state):
+ return set_incremental_state(self, incremental_state, name, tensor)
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+class TransformerDecoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=9, num_heads=2, ffn_hidden_size=1024, post_ln=False):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = DecSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size, ffn_hidden_size=ffn_hidden_size,
+ post_ln=post_ln)
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+ def clear_buffer(self, *args):
+ return self.op.clear_buffer(*args)
+
+ def set_buffer(self, *args):
+ return self.op.set_buffer(*args)
+
+
+class FFTBlocks(nn.Module):
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=0.0,
+ num_heads=2, use_pos_embed=True, use_last_norm=True,
+ use_pos_embed_alpha=True, ffn_hidden_size=1024):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.use_pos_embed = use_pos_embed
+ self.use_last_norm = use_last_norm
+ if use_pos_embed:
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+ self.padding_idx = 0
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend([
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
+ kernel_size=ffn_kernel_size, num_heads=num_heads,
+ ffn_hidden_size=ffn_hidden_size)
+ for _ in range(self.num_layers)
+ ])
+ if self.use_last_norm:
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+
+class FastSpeechEncoder(FFTBlocks):
+ def __init__(self, dict_size, hidden_size=256, num_layers=4, kernel_size=9,
+ dropout=0.0, num_heads=2, ffn_hidden_size=1024):
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
+ use_pos_embed=False, dropout=dropout, ffn_hidden_size=ffn_hidden_size)
+ self.embed_tokens = Embedding(dict_size, hidden_size, 0)
+ self.embed_scale = math.sqrt(hidden_size)
+ self.padding_idx = 0
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ def forward(self, txt_tokens, attn_mask=None, other_embeds=0):
+ """
+
+ :param txt_tokens: [B, T]
+ :return: {
+ 'encoder_out': [B x T x C]
+ }
+ """
+ encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
+ x = self.forward_embedding(txt_tokens) + other_embeds # [B, T, H]
+ if self.num_layers > 0:
+ x = super(FastSpeechEncoder, self).forward(x, encoder_padding_mask, attn_mask=attn_mask)
+ return x
+
+ def forward_embedding(self, txt_tokens):
+ # embed tokens and positions
+ x = self.embed_scale * self.embed_tokens(txt_tokens)
+ if self.use_pos_embed:
+ positions = self.embed_positions(txt_tokens)
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ return x
diff --git a/modules/commons/unet1d.py b/modules/commons/unet1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed8a9bc82c22058bcc6d9c2ea59868b35c7fc2d5
--- /dev/null
+++ b/modules/commons/unet1d.py
@@ -0,0 +1,202 @@
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+
+
+class UNet1d(nn.Module):
+
+ def __init__(self, in_channels=3, out_channels=1, init_features=128, multi=None):
+ super(UNet1d, self).__init__()
+ if multi is None:
+ multi = [1, 2, 2, 4]
+ features = init_features
+ self.encoder1 = UNet1d._block(in_channels, features * multi[0], name="enc1")
+ self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.encoder2 = UNet1d._block(features * multi[0], features * multi[1], name="enc2")
+ self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.encoder3 = UNet1d._block(features * multi[1], features * multi[2], name="enc3")
+ self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
+ self.encoder4 = UNet1d._block(features * multi[2], features * multi[3], name="enc4")
+ self.pool4 = nn.MaxPool1d(kernel_size=2, stride=2)
+
+ self.bottleneck = UNet1d._block(features * multi[3], features * multi[3], name="bottleneck")
+
+ self.upconv4 = nn.ConvTranspose1d(
+ features * multi[3], features * multi[3], kernel_size=2, stride=2
+ )
+ self.decoder4 = UNet1d._block((features * multi[3]) * 2, features * multi[3], name="dec4")
+ self.upconv3 = nn.ConvTranspose1d(
+ features * multi[3], features * multi[2], kernel_size=2, stride=2
+ )
+ self.decoder3 = UNet1d._block((features * multi[2]) * 2, features * multi[2], name="dec3")
+ self.upconv2 = nn.ConvTranspose1d(
+ features * multi[2], features * multi[1], kernel_size=2, stride=2
+ )
+ self.decoder2 = UNet1d._block((features * multi[1]) * 2, features * multi[1], name="dec2")
+ self.upconv1 = nn.ConvTranspose1d(
+ features * multi[1], features * multi[0], kernel_size=2, stride=2
+ )
+ self.decoder1 = UNet1d._block(features * multi[0] * 2, features * multi[0], name="dec1")
+
+ self.conv = nn.Conv1d(
+ in_channels=features * multi[0], out_channels=out_channels, kernel_size=1
+ )
+
+ def forward(self, x, nonpadding=None):
+ if nonpadding is None:
+ nonpadding = torch.ones_like(x)[:, :, :1]
+ enc1 = self.encoder1(x.transpose(1, 2)) * nonpadding.transpose(1, 2)
+ enc2 = self.encoder2(self.pool1(enc1))
+ enc3 = self.encoder3(self.pool2(enc2))
+ enc4 = self.encoder4(self.pool3(enc3))
+
+ bottleneck = self.bottleneck(self.pool4(enc4))
+
+ dec4 = self.upconv4(bottleneck)
+ dec4 = torch.cat((dec4, enc4), dim=1)
+ dec4 = self.decoder4(dec4)
+ dec3 = self.upconv3(dec4)
+ dec3 = torch.cat((dec3, enc3), dim=1)
+ dec3 = self.decoder3(dec3)
+ dec2 = self.upconv2(dec3)
+ dec2 = torch.cat((dec2, enc2), dim=1)
+ dec2 = self.decoder2(dec2)
+ dec1 = self.upconv1(dec2)
+ dec1 = torch.cat((dec1, enc1), dim=1)
+ dec1 = self.decoder1(dec1)
+ return self.conv(dec1).transpose(1, 2) * nonpadding
+
+ @staticmethod
+ def _block(in_channels, features, name):
+ return nn.Sequential(
+ OrderedDict(
+ [
+ (
+ name + "conv1",
+ nn.Conv1d(
+ in_channels=in_channels,
+ out_channels=features,
+ kernel_size=5,
+ padding=2,
+ bias=False,
+ ),
+ ),
+ (name + "norm1", nn.GroupNorm(4, features)),
+ (name + "tanh1", nn.Tanh()),
+ (
+ name + "conv2",
+ nn.Conv1d(
+ in_channels=features,
+ out_channels=features,
+ kernel_size=5,
+ padding=2,
+ bias=False,
+ ),
+ ),
+ (name + "norm2", nn.GroupNorm(4, features)),
+ (name + "tanh2", nn.Tanh()),
+ ]
+ )
+ )
+
+
+class UNet2d(nn.Module):
+ def __init__(self, in_channels=3, out_channels=1, init_features=32, multi=None):
+ super(UNet2d, self).__init__()
+
+ features = init_features
+ self.encoder1 = UNet2d._block(in_channels, features * multi[0], name="enc1")
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.encoder2 = UNet2d._block(features * multi[0], features * multi[1], name="enc2")
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.encoder3 = UNet2d._block(features * multi[1], features * multi[2], name="enc3")
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.encoder4 = UNet2d._block(features * multi[2], features * multi[3], name="enc4")
+ self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
+
+ self.bottleneck = UNet2d._block(features * multi[3], features * multi[3], name="bottleneck")
+
+ self.upconv4 = nn.ConvTranspose2d(
+ features * multi[3], features * multi[3], kernel_size=2, stride=2
+ )
+ self.decoder4 = UNet2d._block((features * multi[3]) * 2, features * multi[3], name="dec4")
+ self.upconv3 = nn.ConvTranspose2d(
+ features * multi[3], features * multi[2], kernel_size=2, stride=2
+ )
+ self.decoder3 = UNet2d._block((features * multi[2]) * 2, features * multi[2], name="dec3")
+ self.upconv2 = nn.ConvTranspose2d(
+ features * multi[2], features * multi[1], kernel_size=2, stride=2
+ )
+ self.decoder2 = UNet2d._block((features * multi[1]) * 2, features * multi[1], name="dec2")
+ self.upconv1 = nn.ConvTranspose2d(
+ features * multi[1], features * multi[0], kernel_size=2, stride=2
+ )
+ self.decoder1 = UNet2d._block(features * multi[0] * 2, features * multi[0], name="dec1")
+
+ self.conv = nn.Conv2d(
+ in_channels=features * multi[0], out_channels=out_channels, kernel_size=1
+ )
+
+ def forward(self, x):
+ enc1 = self.encoder1(x)
+ enc2 = self.encoder2(self.pool1(enc1))
+ enc3 = self.encoder3(self.pool2(enc2))
+ enc4 = self.encoder4(self.pool3(enc3))
+
+ bottleneck = self.bottleneck(self.pool4(enc4))
+
+ dec4 = self.upconv4(bottleneck)
+ dec4 = torch.cat((dec4, enc4), dim=1)
+ dec4 = self.decoder4(dec4)
+ dec3 = self.upconv3(dec4)
+ dec3 = torch.cat((dec3, enc3), dim=1)
+ dec3 = self.decoder3(dec3)
+ dec2 = self.upconv2(dec3)
+ dec2 = torch.cat((dec2, enc2), dim=1)
+ dec2 = self.decoder2(dec2)
+ dec1 = self.upconv1(dec2)
+ dec1 = torch.cat((dec1, enc1), dim=1)
+ dec1 = self.decoder1(dec1)
+ x = self.conv(dec1)
+ return x
+
+ @staticmethod
+ def _block(in_channels, features, name):
+ return nn.Sequential(
+ OrderedDict(
+ [
+ (
+ name + "conv1",
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=features,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ),
+ (name + "norm1", nn.GroupNorm(4, features)),
+ (name + "tanh1", nn.Tanh()),
+ (
+ name + "conv2",
+ nn.Conv2d(
+ in_channels=features,
+ out_channels=features,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ),
+ (name + "norm2", nn.GroupNorm(4, features)),
+ (name + "tanh2", nn.Tanh()),
+ (name + "conv3", nn.Conv2d(
+ in_channels=features,
+ out_channels=features,
+ kernel_size=1,
+ padding=0,
+ bias=True,
+ )),
+ ]
+ )
+ )
diff --git a/modules/commons/vqvae.py b/modules/commons/vqvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc259ad1ecbc4aca7397f25476f407c43a032a0
--- /dev/null
+++ b/modules/commons/vqvae.py
@@ -0,0 +1,148 @@
+import torch
+import torch.nn as nn
+from scipy.cluster.vq import kmeans2
+from torch.nn import functional as F
+
+
+class VQEmbeddingEMA(nn.Module):
+ def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, decay=0.999, epsilon=1e-5,
+ print_vq_prob=False):
+ super(VQEmbeddingEMA, self).__init__()
+ self.commitment_cost = commitment_cost
+ self.n_embeddings = n_embeddings
+ self.decay = decay
+ self.epsilon = epsilon
+ self.print_vq_prob = print_vq_prob
+ self.register_buffer('data_initialized', torch.zeros(1))
+
+ init_bound = 1 / 512
+ embedding = torch.Tensor(n_embeddings, embedding_dim)
+ embedding.uniform_(-init_bound, init_bound)
+ self.register_buffer("embedding", embedding)
+ self.register_buffer("ema_count", torch.zeros(n_embeddings))
+ self.register_buffer("ema_weight", self.embedding.clone())
+
+ def encode(self, x):
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+ x_flat = x.detach().reshape(-1, D)
+
+ distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
+ torch.sum(x_flat ** 2, dim=1, keepdim=True),
+ x_flat, self.embedding.t(),
+ alpha=-2.0, beta=1.0) # [B*T_mel, N_vq]
+ indices = torch.argmin(distances.float(), dim=-1) # [B*T_mel]
+ quantized = F.embedding(indices, self.embedding)
+ quantized = quantized.view_as(x)
+ return x_flat, quantized, indices
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, D]
+ :return: [B, T, D]
+ """
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+ # if self.training and self.data_initialized.item() == 0:
+ # print('| running kmeans in VQVAE') # data driven initialization for the embeddings
+ # x_flat = x.detach().reshape(-1, D)
+ # rp = torch.randperm(x_flat.size(0))
+ # kd = kmeans2(x_flat[rp].data.cpu().numpy(), self.n_embeddings, minit='points')
+ # self.embedding.copy_(torch.from_numpy(kd[0]))
+ # x_flat, quantized, indices = self.encode(x)
+ # encodings = F.one_hot(indices, M).float()
+ # self.ema_weight.copy_(torch.matmul(encodings.t(), x_flat))
+ # self.ema_count.copy_(torch.sum(encodings, dim=0))
+
+ x_flat, quantized, indices = self.encode(x)
+ encodings = F.one_hot(indices, M).float()
+ indices = indices.reshape(B, T)
+
+ if self.training and self.data_initialized.item() != 0:
+ self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)
+
+ n = torch.sum(self.ema_count)
+ self.ema_count = (self.ema_count + self.epsilon) / (n + M * self.epsilon) * n
+
+ dw = torch.matmul(encodings.t(), x_flat)
+ self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * dw
+
+ self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1)
+
+ if self.training and self.data_initialized.item() == 0:
+ self.data_initialized.fill_(1)
+
+ e_latent_loss = F.mse_loss(x, quantized.detach(), reduction='none')
+ nonpadding = (x.abs().sum(-1) > 0).float()
+ e_latent_loss = (e_latent_loss.mean(-1) * nonpadding).sum() / nonpadding.sum()
+ loss = self.commitment_cost * e_latent_loss
+
+ quantized = x + (quantized - x).detach()
+
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ if self.print_vq_prob:
+ print("| VQ code avg_probs: ", avg_probs)
+ return quantized, loss, indices, perplexity
+
+
+class VQEmbedding(nn.Module):
+ def __init__(self, n_embeddings, embedding_dim, commitment_cost=0.25, lambda_kl=1.0):
+ super(VQEmbedding, self).__init__()
+ self.commitment_cost = commitment_cost
+ self.lambda_kl = lambda_kl
+ self.n_embeddings = n_embeddings
+ embedding = torch.Tensor(n_embeddings, embedding_dim)
+ self.register_buffer("embedding", embedding)
+ self.register_buffer('data_initialized', torch.zeros(1))
+
+ def encode(self, x):
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+ x_flat = x.detach().reshape(-1, D)
+
+ distances = torch.addmm(torch.sum(self.embedding ** 2, dim=1) +
+ torch.sum(x_flat ** 2, dim=1, keepdim=True),
+ x_flat, self.embedding.t(),
+ alpha=-2.0, beta=1.0) # [B*T_mel, N_vq]
+ indices = torch.argmin(distances.float(), dim=-1) # [B*T_mel]
+ quantized = F.embedding(indices, self.embedding)
+ quantized = quantized.view_as(x)
+ return x_flat, quantized, indices
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, D]
+ :return: [B, T, D]
+ """
+ B, T, _ = x.shape
+ M, D = self.embedding.size()
+
+ x_flat, quantized, indices = self.encode(x)
+ encodings = F.one_hot(indices, M).float()
+ indices = indices.reshape(B, T)
+
+ # DeepMind def does not do this but I find I have to... ;\
+ if self.training and self.data_initialized.item() == 0:
+ print('| running kmeans in VQVAE') # data driven initialization for the embeddings
+ rp = torch.randperm(x_flat.size(0))
+ kd = kmeans2(x_flat[rp].data.cpu().numpy(), self.n_embeddings, minit='points')
+ self.embedding.copy_(torch.from_numpy(kd[0]))
+ self.data_initialized.fill_(1)
+ # TODO: this won't work in multi-GPU setups
+ x_flat, quantized, indices = self.encode(x)
+ encodings = F.one_hot(indices, M).float()
+ indices = indices.reshape(B, T)
+
+ # vector quantization cost that trains the embedding vectors
+ loss = self.commitment_cost * (x.detach() - quantized).pow(2).mean() + \
+ (quantized - x.detach()).pow(2).mean()
+ loss *= self.lambda_kl
+
+ quantized = x + (quantized - x).detach()
+
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ return quantized, loss, indices, perplexity
diff --git a/modules/commons/vqvae_cvq.py b/modules/commons/vqvae_cvq.py
new file mode 100644
index 0000000000000000000000000000000000000000..082039d3566b1b618d9bb54878122ab48de6cdbc
--- /dev/null
+++ b/modules/commons/vqvae_cvq.py
@@ -0,0 +1,190 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+import torch.distributed as dist
+
+from utils.commons.hparams import hparams
+
+
+class ClusteringVectorQuantiser(nn.Module):
+ """
+ Improved version over vector quantiser, with the dynamic initialisation
+ for these unoptimised "dead" points.
+ num_embed: number of codebook entry
+ embed_dim: dimensionality of codebook entry
+ beta: weight for the commitment loss
+ distance: distance for looking up the closest code
+ anchor: anchor sampled methods
+ first_batch: if true, the offline version of our model
+ contras_loss: if true, use the contras_loss to further improve the performance
+ """
+ def __init__(self, num_embed=1024, embed_dim=512, beta=0.25, distance='l2',
+ anchor='closest', first_batch=False, contras_loss=True):
+ super().__init__()
+
+ self.num_embed = num_embed
+ self.embed_dim = embed_dim
+ self.beta = beta
+ self.distance = distance
+ self.anchor = anchor
+ self.first_batch = first_batch
+ self.contras_loss = contras_loss
+ self.decay = 0.99
+ self.init = False
+
+ self.pool = FeaturePool(self.num_embed, self.embed_dim)
+ self.embedding = nn.Embedding(self.num_embed, self.embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.num_embed, 1.0 / self.num_embed)
+ self.register_buffer("embed_prob", torch.zeros(self.num_embed))
+
+
+ def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
+ if mask is not None:
+ assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
+ assert mask.shape[-1] == 1, (mask.shape,)
+ z = z * mask
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ assert z.shape[-1] == self.embed_dim
+ z_flattened = z.view(-1, self.embed_dim)
+
+ # clculate the distance
+ if self.distance == 'l2':
+ # l2 distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = - torch.sum(z_flattened.detach() ** 2, dim=1, keepdim=True) - \
+ torch.sum(self.embedding.weight ** 2, dim=1) + \
+ 2 * torch.einsum('bd, dn-> bn', z_flattened.detach(), rearrange(self.embedding.weight, 'n d-> d n'))
+ elif self.distance == 'cos':
+ # cosine distances from z to embeddings e_j
+ normed_z_flattened = F.normalize(z_flattened, dim=1).detach()
+ normed_codebook = F.normalize(self.embedding.weight, dim=1)
+ d = torch.einsum('bd,dn->bn', normed_z_flattened, rearrange(normed_codebook, 'n d -> d n'))
+
+ # encoding
+ sort_distance, indices = d.sort(dim=1)
+ # look up the closest point for the indices
+ encoding_indices = indices[:,-1]
+ encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=z.device)
+ encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+
+ # quantise and unflatten
+ z_q = torch.matmul(encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = self.beta * (z_q.detach() - z) ** 2 + (z_q - z.detach()) ** 2
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum() / self.embed_dim
+ else:
+ loss = loss.mean()
+ # loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+ # reshape back to match original input shape
+ # z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+ # count
+ # import pdb
+ # pdb.set_trace()
+ avg_probs = torch.mean(encodings, dim=0)
+ # perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+ # min_encodings = encodings
+
+ # online clustered reinitialisation for unoptimized points
+ if self.training:
+ # calculate the average usage of code entries
+ self.embed_prob.mul_(self.decay).add_(avg_probs, alpha= 1 - self.decay)
+ # running average updates
+ if self.anchor in ['closest', 'random', 'probrandom'] and (not self.init):
+ # closest sampling
+ if self.anchor == 'closest':
+ sort_distance, indices = d.sort(dim=0)
+ random_feat = z_flattened.detach()[indices[-1,:]]
+ # feature pool based random sampling
+ elif self.anchor == 'random':
+ random_feat = self.pool.query(z_flattened.detach())
+ # probabilitical based random sampling
+ elif self.anchor == 'probrandom':
+ norm_distance = F.softmax(d.t(), dim=1)
+ prob = torch.multinomial(norm_distance, num_samples=1).view(-1)
+ random_feat = z_flattened.detach()[prob]
+ # decay parameter based on the average usage
+ decay = torch.exp(-(self.embed_prob*self.num_embed*10)/(1-self.decay)-1e-3).unsqueeze(1).repeat(1, self.embed_dim)
+ if hparams.get('reduce_cvq_embed') and dist.is_initialized():
+ # 确保在所有GPU上同步embedding的权重
+ dist.all_reduce(random_feat.data, op=dist.ReduceOp.SUM)
+ random_feat.data /= dist.get_world_size()
+ self.embedding.weight.data = self.embedding.weight.data * (1 - decay) + random_feat * decay
+ if self.first_batch:
+ self.init = True
+ # contrastive loss
+ if self.contras_loss:
+ sort_distance, indices = d.sort(dim=0)
+ dis_pos = sort_distance[-max(1, int(sort_distance.size(0)/self.num_embed)):,:].mean(dim=0, keepdim=True)
+ dis_neg = sort_distance[:int(sort_distance.size(0)*1/2),:]
+ dis = torch.cat([dis_pos, dis_neg], dim=0).t() / 0.07
+ contra_loss = F.cross_entropy(dis, torch.zeros((dis.size(0),), dtype=torch.long, device=dis.device))
+ loss += contra_loss
+
+ encoding_indices = encoding_indices.reshape(z.shape[:-1])
+ return z_q, loss, encoding_indices
+
+ def get_codebook_entry(self, encoding_indices):
+ # # get quantized latent vectors
+ # print(encoding_indices.shape)
+ # encoding_indices = encoding_indices.view(-1)
+ # encodings = torch.zeros(encoding_indices.unsqueeze(1).shape[0], self.num_embed, device=encoding_indices.device)
+ # print(encodings.shape)
+ # encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
+ # print(encodings.shape)
+ # # quantise and unflatten
+ # z_q = torch.matmul(encodings, self.embedding.weight).view(encoding_indices.shape[0], -1)
+ z_q = self.embedding(encoding_indices)
+ return z_q
+
+class FeaturePool():
+ """
+ This class implements a feature buffer that stores previously encoded features
+
+ This buffer enables us to initialize the codebook using a history of generated features
+ rather than the ones produced by the latest encoders
+ """
+ def __init__(self, pool_size, dim=64):
+ """
+ Initialize the FeaturePool class
+
+ Parameters:
+ pool_size(int) -- the size of featue buffer
+ """
+ self.pool_size = pool_size
+ if self.pool_size > 0:
+ self.nums_features = 0
+ self.features = (torch.rand((pool_size, dim)) * 2 - 1)/ pool_size
+
+ def query(self, features):
+ """
+ return features from the pool
+ """
+ self.features = self.features.to(features.device)
+ if self.nums_features < self.pool_size:
+ if features.size(0) > self.pool_size: # if the batch size is large enough, directly update the whole codebook
+ random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
+ self.features = features[random_feat_id]
+ self.nums_features = self.pool_size
+ else:
+ # if the mini-batch is not large nuough, just store it for the next update
+ num = self.nums_features + features.size(0)
+ self.features[self.nums_features:num] = features
+ self.nums_features = num
+ else:
+ if features.size(0) > int(self.pool_size):
+ random_feat_id = torch.randint(0, features.size(0), (int(self.pool_size),))
+ self.features = features[random_feat_id]
+ else:
+ random_id = torch.randperm(self.pool_size)
+ self.features[random_id[:features.size(0)]] = features
+
+ return self.features
\ No newline at end of file
diff --git a/modules/commons/vqvae_fsq.py b/modules/commons/vqvae_fsq.py
new file mode 100644
index 0000000000000000000000000000000000000000..12ade280e20a2f1cb9701e465e7335d45dee286a
--- /dev/null
+++ b/modules/commons/vqvae_fsq.py
@@ -0,0 +1,72 @@
+"""
+Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
+Code adapted from Jax version in Appendix A.1
+"""
+
+from typing import List
+
+import torch
+import torch.nn as nn
+from torch import Tensor, int32
+
+
+def round_ste(z: Tensor) -> Tensor:
+ """Round with straight through gradients."""
+ zhat = z.round()
+ return z + (zhat - z).detach()
+
+
+class FSQ(nn.Module):
+ def __init__(self, levels: List[int]):
+ super().__init__()
+ _levels = torch.tensor(levels, dtype=int32)
+ self.register_buffer("_levels", _levels)
+
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32)
+ self.register_buffer("_basis", _basis)
+
+ self.dim = len(levels)
+ self.n_codes = self._levels.prod().item()
+ implicit_codebook = self.indices_to_codes(torch.arange(self.n_codes))
+ self.register_buffer("implicit_codebook", implicit_codebook)
+
+ def forward(self, z: Tensor) -> Tensor:
+ zhat = self.quantize(z)
+ indices = self.codes_to_indices(zhat)
+ return zhat, indices
+
+ def bound(self, z: Tensor, eps: float = 1e-3) -> Tensor:
+ """Bound `z`, an array of shape (..., d)."""
+ half_l = (self._levels - 1) * (1 - eps) / 2
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
+ shift = (offset / half_l).tan()
+ return (z + shift).tanh() * half_l - offset
+
+ def quantize(self, z: Tensor) -> Tensor:
+ """Quantizes z, returns quantized zhat, same shape as z."""
+ quantized = round_ste(self.bound(z))
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
+ return quantized / half_width
+
+ def _scale_and_shift(self, zhat_normalized: Tensor) -> Tensor:
+ half_width = self._levels // 2
+ return (zhat_normalized * half_width) + half_width
+
+ def _scale_and_shift_inverse(self, zhat: Tensor) -> Tensor:
+ half_width = self._levels // 2
+ return (zhat - half_width) / half_width
+
+ def codes_to_indices(self, zhat: Tensor) -> Tensor:
+ """Converts a `code` to an index in the codebook."""
+ assert zhat.shape[-1] == self.dim
+ zhat = self._scale_and_shift(zhat)
+ return (zhat * self._basis).sum(dim=-1).to(int32)
+
+ def indices_to_codes(self, indices: Tensor) -> Tensor:
+ """Inverse of `codes_to_indices`."""
+ indices = indices.unsqueeze(-1)
+ codes_non_centered = (indices // self._basis) % self._levels
+ return self._scale_and_shift_inverse(codes_non_centered)
+
+ def get_codebook_entry(self, encoding_indices):
+ return self.indices_to_codes(encoding_indices)
diff --git a/modules/commons/vqvae_lfq.py b/modules/commons/vqvae_lfq.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff9b0bf4837caa7bc6944952e02d4ce8e495f0bc
--- /dev/null
+++ b/modules/commons/vqvae_lfq.py
@@ -0,0 +1,276 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
+https://arxiv.org/abs/2309.15505
+"""
+
+from math import log2, ceil
+from collections import namedtuple
+
+import torch
+from torch import nn, Tensor, einsum
+import torch.nn.functional as F
+from torch.nn import Module
+
+from einops import rearrange, reduce, pack, unpack
+
+# constants
+
+# Return = namedtuple('Return', ['quantized', 'indices', 'entropy_aux_loss'])
+
+LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'batch_entropy', 'commitment'])
+
+# helper functions
+
+def exists(v):
+ return v is not None
+
+def default(*args):
+ for arg in args:
+ if exists(arg):
+ return arg() if callable(arg) else arg
+ return None
+
+def pack_one(t, pattern):
+ return pack([t], pattern)
+
+def unpack_one(t, ps, pattern):
+ return unpack(t, ps, pattern)[0]
+
+# distance
+
+def euclidean_distance_squared(x, y):
+ x2 = reduce(x ** 2, '... n d -> ... n', 'sum')
+ y2 = reduce(y ** 2, 'n d -> n', 'sum')
+ xy = einsum('... i d, j d -> ... i j', x, y) * -2
+ return rearrange(x2, '... i -> ... i 1') + y2 + xy
+
+# entropy
+
+def log(t, eps = 1e-20):
+ return t.clamp(min = eps).log()
+
+def entropy(prob):
+ return -prob * log(prob)
+
+# class
+
+class LFQ(Module):
+ def __init__(
+ self,
+ *,
+ dim = None,
+ codebook_size = None,
+ entropy_loss_weight = 0.1,
+ commitment_loss_weight = 1.,
+ diversity_gamma = 2.5,
+ straight_through_activation = nn.Identity(),
+ num_codebooks = 1,
+ keep_num_codebooks_dim = None,
+ codebook_scale = 1. # for residual LFQ, codebook scaled down by 2x at each layer
+ ):
+ super().__init__()
+
+ # some assert validations
+
+ assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ'
+ assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})'
+
+ codebook_size = default(codebook_size, lambda: 2 ** dim)
+ codebook_dim = int(log2(codebook_size))
+
+ codebook_dims = codebook_dim * num_codebooks
+ dim = default(dim, codebook_dims)
+
+ self.project_in = nn.Linear(dim, codebook_dims) if dim != codebook_dims else nn.Identity()
+ self.project_out = nn.Linear(codebook_dims, dim) if dim != codebook_dims else nn.Identity()
+
+ self.dim = dim
+ self.codebook_dim = codebook_dim
+ self.num_codebooks = num_codebooks
+
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
+
+ # straight through activation
+
+ self.activation = straight_through_activation
+
+ # entropy aux loss related weights
+
+ self.diversity_gamma = diversity_gamma
+ self.entropy_loss_weight = entropy_loss_weight
+
+ # codebook scale
+
+ self.codebook_scale = codebook_scale
+
+ # commitment loss
+
+ self.commitment_loss_weight = commitment_loss_weight
+
+ # for no auxiliary loss, during inference
+
+ self.register_buffer('mask', 2 ** torch.arange(codebook_dim - 1, -1, -1))
+ self.register_buffer('zero', torch.tensor(0.), persistent = False)
+
+ # codes
+
+ all_codes = torch.arange(codebook_size)
+ bits = ((all_codes[..., None].int() & self.mask) != 0).float()
+ codebook = self.bits_to_codes(bits)
+
+ self.register_buffer('codebook', codebook, persistent = False)
+
+ def bits_to_codes(self, bits):
+ return bits * self.codebook_scale * 2 - self.codebook_scale
+
+ @property
+ def dtype(self):
+ return self.codebook.dtype
+
+ def indices_to_codes(
+ self,
+ indices,
+ project_out = True
+ ):
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, '... -> ... 1')
+
+ # indices to codes, which are bits of either -1 or 1
+
+ bits = ((indices[..., None].int() & self.mask) != 0).to(self.dtype)
+
+ codes = self.bits_to_codes(bits)
+
+ codes = rearrange(codes, '... c d -> ... (c d)')
+
+ # whether to project codes out to original dimensions
+ # if the input feature dimensions were not log2(codebook size)
+
+ if project_out:
+ codes = self.project_out(codes)
+
+ # rearrange codes back to original shape
+
+ if is_img_or_video:
+ codes = rearrange(codes, 'b ... d -> b d ...')
+
+ return codes
+
+ def forward(
+ self,
+ x,
+ mask=None,
+ inv_temperature = 1.,
+ return_loss_breakdown = False
+ ):
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension, which is also log2(codebook size)
+ c - number of codebook dim
+ """
+
+ is_img_or_video = x.ndim >= 4
+
+ # standardize image or video into (batch, seq, dimension)
+
+ if is_img_or_video:
+ x = rearrange(x, 'b d ... -> b ... d')
+ x, ps = pack_one(x, 'b * d')
+
+ assert x.shape[-1] == self.dim, f'expected dimension of {self.dim} but received {x.shape[-1]}'
+
+ x = self.project_in(x)
+
+ # split out number of codebooks
+
+ x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks)
+
+ # quantize by eq 3.
+
+ original_input = x
+
+ codebook_value = torch.ones_like(x) * self.codebook_scale
+ quantized = torch.where(x > 0, codebook_value, -codebook_value)
+
+ # use straight-through gradients with tanh (or custom activation fn) if training
+
+ if self.training:
+ x = self.activation(x)
+ x = x - x.detach() + quantized
+ else:
+ x = quantized
+
+ # calculate indices
+
+ indices = reduce((x > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum')
+
+ # entropy aux loss
+
+ if self.training:
+ distance = euclidean_distance_squared(original_input, self.codebook)
+
+ prob = (-distance * inv_temperature).softmax(dim = -1)
+
+ per_sample_entropy = entropy(prob).mean()
+
+ avg_prob = reduce(prob, 'b n c d -> b c d', 'mean')
+ codebook_entropy = entropy(avg_prob).mean()
+
+ # 1. entropy will be nudged to be low for each code, to encourage the network to output confident predictions
+ # 2. codebook entropy will be nudged to be high, to encourage all codes to be uniformly used within the batch
+
+ entropy_aux_loss = per_sample_entropy - self.diversity_gamma * codebook_entropy
+ else:
+ # if not training, just return dummy 0
+ entropy_aux_loss = per_sample_entropy = codebook_entropy = self.zero
+
+ # commit loss
+
+ if self.training:
+ commit_loss = F.mse_loss(original_input, quantized.detach())
+ else:
+ commit_loss = self.zero
+
+ # merge back codebook dim
+
+ x = rearrange(x, 'b n c d -> b n (c d)')
+
+ # project out to feature dimension if needed
+
+ x = self.project_out(x)
+
+ # reconstitute image or video dimensions
+
+ if is_img_or_video:
+ x = unpack_one(x, ps, 'b * d')
+ x = rearrange(x, 'b ... d -> b d ...')
+
+ indices = unpack_one(indices, ps, 'b * c')
+
+ # whether to remove single codebook dim
+
+ if not self.keep_num_codebooks_dim:
+ indices = rearrange(indices, '... 1 -> ...')
+
+ # complete aux loss
+
+ aux_loss = entropy_aux_loss * self.entropy_loss_weight + commit_loss * self.commitment_loss_weight
+
+ ret = x, aux_loss, indices
+
+ if not return_loss_breakdown:
+ return ret
+
+ return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss)
+
+ def get_codebook_entry(self, encoding_indices):
+ return self.indices_to_codes(encoding_indices)
diff --git a/modules/commons/vqvae_lfq_y.py b/modules/commons/vqvae_lfq_y.py
new file mode 100644
index 0000000000000000000000000000000000000000..b34ead5d2481801a6a966b7d560b326e8083e310
--- /dev/null
+++ b/modules/commons/vqvae_lfq_y.py
@@ -0,0 +1,109 @@
+"""
+Lookup Free Quantization
+Proposed in https://arxiv.org/abs/2310.05737
+
+basically a 2-level FSQ (Finite Scalar Quantization) with entropy loss
+https://arxiv.org/abs/2309.15505
+"""
+
+import torch
+from einops import rearrange
+from torch.nn import Module
+
+
+# entropy
+
+def binary_entropy(prob):
+ return -prob * log(prob) - (1 - prob) * log(1 - prob)
+
+
+# tensor helpers
+
+def log(t, eps=1e-20):
+ return t.clamp(min=eps).log()
+
+
+# convert to bit representations and back
+
+def decimal_to_bits(x: torch.LongTensor, bits: int) -> torch.FloatTensor:
+ # [b, ...] {0, 1, ..., max - 1} -> [b, ..., d] {-1, 1}
+ mask = 2 ** torch.arange(bits).to(x) # [d]
+ bits = ((x.unsqueeze(-1) & mask) != 0).float() # [b, n, d] {0, 1}
+ return bits * 2 - 1 # {0, 1} -> {-1, 1}
+
+
+def bits_to_decimal(x: torch.FloatTensor) -> torch.LongTensor:
+ # [b, ..., d] {-1, 1} -> [b, ...] {0, 1, ..., max - 1}
+ x = (x > 0).long() # {-1, 1} -> {0, 1}, [b, ..., d]
+ mask = 2 ** torch.arange(x.size(-1)).to(x) # [d]
+ dec = (x * mask).sum(-1) # [b, ...]
+ return dec
+
+
+# class
+
+class LFQY(Module):
+ def __init__(self, dim, entropy_loss_weight=0.1, diversity_gamma=1.0):
+ super().__init__()
+ self.dim = dim
+ self.diversity_gamma = diversity_gamma
+ self.entropy_loss_weight = entropy_loss_weight
+
+ def indices_to_codes(self, indices):
+ codes = decimal_to_bits(indices, self.dim)
+ # codes = rearrange(codes, 'b ... d -> b d ...')
+ return codes
+
+ def forward(self, x, mask=None, inv_temperature=1.):
+ """
+ einstein notation
+ b - batch
+ n - sequence (or flattened spatial dimensions)
+ d - feature dimension, which is also log2(codebook size)
+ """
+ # x = rearrange(x, 'b d ... -> b ... d')
+
+ assert x.shape[-1] == self.dim
+ z = torch.tanh(x / inv_temperature) # (-1, 1)
+
+ # quantize by eq 3.
+ quantized = torch.sign(x) # {-1, 1}
+ z = z + (quantized - z).detach()
+
+ # calculate indices
+ indices = bits_to_decimal(z)
+
+ # entropy aux loss
+ if self.training:
+ prob = torch.sigmoid(x / inv_temperature) # [b, ..., d]
+
+ bit_entropy = binary_entropy(prob).sum(-1).mean()
+ # E[H(q)] = avg(sum(H(q_i)))
+
+ avg_prob = prob.flatten(0, -2).mean(0) # [b, ..., d] -> [n, d] -> [d]
+ codebook_entropy = binary_entropy(avg_prob).sum()
+ # H(E[q]) = sum(H(avg(q_i)))
+
+ """
+ 1. entropy will be nudged to be low for each bit,
+ so each scalar commits to one latent binary bit or the other.
+ 2. codebook entropy will be nudged to be high,
+ to encourage all codes to be uniformly used.
+ """
+
+ entropy_aux_loss = bit_entropy - self.diversity_gamma * codebook_entropy
+ else:
+ # if not training, just return dummy 0
+ entropy_aux_loss = torch.zeros(1).to(z)
+
+ entropy_aux_loss = entropy_aux_loss * self.entropy_loss_weight
+
+ # reconstitute image or video dimensions
+
+ # z = rearrange(z, 'b ... d -> b d ...')
+
+ # bits to decimal for the codebook indices
+ return z, entropy_aux_loss, indices
+
+ def get_codebook_entry(self, encoding_indices):
+ return self.indices_to_codes(encoding_indices)
diff --git a/modules/commons/vqvae_taming.py b/modules/commons/vqvae_taming.py
new file mode 100644
index 0000000000000000000000000000000000000000..59b7abff0050186aacdd5899f142c5dcbcf49295
--- /dev/null
+++ b/modules/commons/vqvae_taming.py
@@ -0,0 +1,428 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from scipy.cluster.vq import kmeans2
+from torch import einsum
+from einops import rearrange
+import torch.distributed as dist
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # .........\end
+
+ # with:
+ # .........\start
+ # min_encoding_indices = torch.argmin(d, dim=1)
+ # z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:, None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:, self.used, ...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:, self.used, ...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b * h * w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, legacy=False):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.re_embed = n_e
+
+ def encode(self, z):
+ B, T, _ = z.shape
+ z_flattened = z.reshape(-1, self.e_dim)
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+
+ z_q = z_q.view_as(z)
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ return z_flattened, z_q, min_encoding_indices
+
+ def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
+ if mask is not None:
+ assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
+ assert mask.shape[-1] == 1, (mask.shape,)
+ z = z * mask
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ assert z.shape[-1] == self.e_dim
+ z_flattened = z.reshape(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.matmul(z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+ #torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ perplexity = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * (z_q.detach() - z) ** 2 + \
+ (z_q - z.detach()) ** 2
+ else:
+ loss = (z_q.detach() - z) ** 2 + self.beta * \
+ (z_q - z.detach()) ** 2
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum() / self.e_dim
+ else:
+ loss = loss.mean()
+ return z_q, loss, min_encoding_indices, perplexity
+
+ def get_codebook_entry(self, indices, shape=None):
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class VectorQuantizer4(nn.Module):
+ def __init__(self, n_e, e_dim, beta, legacy=False, kmeans_reset_every=1000):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.re_embed = n_e
+ self.reset_every = kmeans_reset_every
+ self.reset_thres = 20
+ self.z_buffer = []
+ self.register_buffer('use_flag', torch.zeros(n_e))
+ self.register_buffer('steps', torch.zeros(1))
+
+ def encode(self, z):
+ B, T, _ = z.shape
+ z_flattened = z.reshape(-1, self.e_dim)
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+
+ z_q = z_q.view_as(z)
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ return z_flattened, z_q, min_encoding_indices
+
+ def forward(self, z, mask=None, temp=None, rescale_logits=False, return_logits=False):
+ if mask is not None:
+ assert mask.shape[:2] == z.shape[:2], (mask.shape, z.shape)
+ assert mask.shape[-1] == 1, (mask.shape,)
+ z = z * mask
+ assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits == False, "Only for interface compatible with Gumbel"
+ assert return_logits == False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ # z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ assert z.shape[-1] == self.e_dim
+ z_flattened = z.reshape(-1, self.e_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ perplexity = None
+
+ if self.training:
+ self.steps += 1
+ self.use_flag += torch.bincount(min_encoding_indices, minlength=self.n_e)
+ is_master = not dist.is_initialized() or dist.get_rank() == 0
+ if self.reset_every - 100 <= self.steps <= self.reset_every:
+ if dist.is_initialized():
+ z_buffer_ = [None for _ in range(dist.get_world_size())]
+ dist.all_gather_object(z_buffer_, z_flattened.detach().cpu())
+ else:
+ z_buffer_ = [z_flattened.detach().cpu()]
+ self.z_buffer += z_buffer_
+
+ if self.steps % self.reset_every == 0:
+ if dist.is_initialized():
+ dist.all_reduce(self.use_flag)
+ vq_usage = (self.use_flag > self.reset_thres).sum().item() / self.use_flag.shape[0]
+ print("| VQ usage: ", vq_usage)
+ if vq_usage != 1:
+ if is_master:
+ if self.steps.item() == self.reset_every:
+ print('| running kmeans in VQVAE') # data driven initialization for the embeddings
+ z_buffer = torch.cat(self.z_buffer, 0)
+ rp = torch.randperm(z_buffer.shape[0])
+ kd = kmeans2(z_buffer[rp].numpy(), self.n_e, minit='points')[0]
+ self.embedding.weight.data = torch.from_numpy(kd).to(z.device)
+ else:
+ reset_ids = self.use_flag < self.reset_thres
+ keep_ids = self.use_flag >= self.reset_thres
+ t = torch.randint(0, keep_ids.sum(), [reset_ids.sum()], device=self.use_flag.device)
+ keep_ids = torch.where(keep_ids)[0][t]
+ self.embedding.weight.data[reset_ids] = self.embedding.weight.data[keep_ids].clone()
+ if dist.is_initialized():
+ dist.broadcast(self.embedding.weight.data, 0)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).reshape(z.shape)
+ self.use_flag.fill_(0)
+ self.z_buffer = []
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * (z_q.detach() - z) ** 2 + \
+ (z_q - z.detach()) ** 2
+ else:
+ loss = (z_q.detach() - z) ** 2 + self.beta * \
+ (z_q - z.detach()) ** 2
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[:-1])
+ if mask is not None:
+ loss = (loss * mask).sum() / mask.sum() / self.e_dim
+ else:
+ loss = loss.mean()
+ return z_q, loss, min_encoding_indices, perplexity
+
+ def get_codebook_entry(self, indices, shape=None):
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
diff --git a/modules/commons/wavenet.py b/modules/commons/wavenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7809c9b9d3331ba4fd2ffd4caae14e721e4b0732
--- /dev/null
+++ b/modules/commons/wavenet.py
@@ -0,0 +1,97 @@
+import torch
+from torch import nn
+
+
+def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
+ n_channels_int = n_channels[0]
+ in_act = input_a + input_b
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
+ acts = t_act * s_act
+ return acts
+
+
+class WN(torch.nn.Module):
+ def __init__(self, hidden_size, kernel_size, dilation_rate, n_layers, c_cond=0,
+ p_dropout=0, share_cond_layers=False, is_BTC=False):
+ super(WN, self).__init__()
+ assert (kernel_size % 2 == 1)
+ assert (hidden_size % 2 == 0)
+ self.is_BTC = is_BTC
+ self.hidden_size = hidden_size
+ self.kernel_size = kernel_size
+ self.dilation_rate = dilation_rate
+ self.n_layers = n_layers
+ self.gin_channels = c_cond
+ self.p_dropout = p_dropout
+ self.share_cond_layers = share_cond_layers
+
+ self.in_layers = torch.nn.ModuleList()
+ self.res_skip_layers = torch.nn.ModuleList()
+ self.drop = nn.Dropout(p_dropout)
+
+ if c_cond != 0 and not share_cond_layers:
+ cond_layer = torch.nn.Conv1d(c_cond, 2 * hidden_size * n_layers, 1)
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight')
+
+ for i in range(n_layers):
+ dilation = dilation_rate ** i
+ padding = int((kernel_size * dilation - dilation) / 2)
+ in_layer = torch.nn.Conv1d(hidden_size, 2 * hidden_size, kernel_size,
+ dilation=dilation, padding=padding)
+ in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
+ self.in_layers.append(in_layer)
+
+ # last one is not necessary
+ if i < n_layers - 1:
+ res_skip_channels = 2 * hidden_size
+ else:
+ res_skip_channels = hidden_size
+
+ res_skip_layer = torch.nn.Conv1d(hidden_size, res_skip_channels, 1)
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight')
+ self.res_skip_layers.append(res_skip_layer)
+
+ def forward(self, x, nonpadding=None, cond=None):
+ if self.is_BTC:
+ x = x.transpose(1, 2)
+ cond = cond.transpose(1, 2) if cond is not None else None
+ nonpadding = nonpadding.transpose(1, 2) if nonpadding is not None else None
+ if nonpadding is None:
+ nonpadding = 1
+ output = torch.zeros_like(x)
+ n_channels_tensor = torch.IntTensor([self.hidden_size])
+
+ if cond is not None and not self.share_cond_layers:
+ cond = self.cond_layer(cond)
+
+ for i in range(self.n_layers):
+ x_in = self.in_layers[i](x)
+ x_in = self.drop(x_in)
+ if cond is not None:
+ cond_offset = i * 2 * self.hidden_size
+ cond_l = cond[:, cond_offset:cond_offset + 2 * self.hidden_size, :]
+ else:
+ cond_l = torch.zeros_like(x_in)
+
+ acts = fused_add_tanh_sigmoid_multiply(x_in, cond_l, n_channels_tensor)
+
+ res_skip_acts = self.res_skip_layers[i](acts)
+ if i < self.n_layers - 1:
+ x = (x + res_skip_acts[:, :self.hidden_size, :]) * nonpadding
+ output = output + res_skip_acts[:, self.hidden_size:, :]
+ else:
+ output = output + res_skip_acts
+ output = output * nonpadding
+ if self.is_BTC:
+ output = output.transpose(1, 2)
+ return output
+
+ def remove_weight_norm(self):
+ def remove_weight_norm(m):
+ try:
+ nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(remove_weight_norm)
diff --git a/modules/eg3ds/camera_utils/pose_sampler.py b/modules/eg3ds/camera_utils/pose_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e36f3bcac364ab993ad59f25ce7f90726f32ceb
--- /dev/null
+++ b/modules/eg3ds/camera_utils/pose_sampler.py
@@ -0,0 +1,216 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""
+Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
+"""
+
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+
+from modules.eg3ds.volumetric_rendering import math_utils
+
+
+class UnifiedCameraPoseSampler():
+ """
+ A unified class for obtain camera pose, a 25 dimension vector that consists of camera2world matrix (4x4) and camera intrinsic (3,3)
+ it utilize the samplers constructed below.
+ """
+ def get_camera_pose(self, pitch, yaw, lookat_location=None, distance_to_orig=2.7, batch_size=1, device='cpu', roll=None):
+ if lookat_location is None:
+ lookat_location = torch.tensor([0., 0., -0.2], device=device)
+
+ c2w = LookAtPoseSampler.sample(yaw, pitch, lookat_location, 0, 0, distance_to_orig, batch_size, device, roll=roll).reshape([batch_size, 16])
+ intrinsics = torch.tensor([[4.2647, 0, 0.5], [0, 4.2647, 0.5], [0, 0, 1]], device=device).reshape([9,]).unsqueeze(0).repeat([batch_size, 1])
+ # intrinsics = FOV_to_intrinsics(fov_degrees, device=device).reshape([9,]).unsqueeze(0).repeat([batch_size, 1])
+ camera = torch.cat([c2w, intrinsics], dim=1) # [batch, 25]
+ return camera
+
+
+class GaussianCameraPoseSampler:
+ """
+ Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
+ Camera is specified as looking at the origin.
+ If horizontal and vertical stddev (specified in radians) are zero, gives a
+ deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
+ The coordinate system is specified with y-up, z-forward, x-left.
+ Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
+ vertical mean is the polar angle (angle from the y axis) in radians.
+ A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ """
+ horizontal_mean: 偏转角, 也叫方位角, -pi/2 denotes camera at left, 0 denotes forward, pi/2 denotes right,
+ vertical_mean: 俯仰角, 0 denotes up, -pi/2 denotes camera at up, 0 means horizontal, pi/2 denotes down. however, 0.2 is a good choice for front face.
+ """
+ assert horizontal_mean < np.pi/2 + 1e-5 and horizontal_mean > - np.pi/2 - 1e-5
+ assert vertical_mean < np.pi/2 + 1e-5 and vertical_mean > - np.pi/2 - 1e-5
+ horizontal_mean += np.pi/2
+ vertical_mean += np.pi/2
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins) # the direction the camera is pointing, pointing to origin in this func
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+class LookAtPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ camera is specified as looking at 'lookat_position', a 3-vector.
+
+ Example:
+ For a camera pose looking at the origin with the camera at position [0, 0, 1]:
+ cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, lookat_position, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu', roll=None):
+ """
+ horizontal_mean: 偏转角, 也叫方位角, -pi/2 denotes camera at left, 0 denotes forward, pi/2 denotes right,
+ vertical_mean: 俯仰角, 0 denotes up, -pi/2 denotes camera at up, 0 means horizontal, pi/2 denotes down. however, 0.2 is a good choice for front face.
+ """
+ # assert horizontal_mean < np.pi + 1e-5 and horizontal_mean > - np.pi - 1e-5
+ # assert vertical_mean < np.pi + 1e-5 and vertical_mean > - np.pi - 1e-5
+ horizontal_mean += np.pi/2
+ vertical_mean += np.pi/2
+
+ # if horizontal_mean < -np.pi:
+ # horizontal_mean += 2*np.pi
+ # if vertical_mean < -np.pi:
+ # vertical_mean += 2*np.pi
+ # if horizontal_mean > np.pi:
+ # horizontal_mean -= 2*np.pi
+ # if vertical_mean > np.pi:
+ # vertical_mean -= 2*np.pi
+
+ h = torch.randn((batch_size, 1), device=device) * horizontal_stddev + horizontal_mean
+ v = torch.randn((batch_size, 1), device=device) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h # 球坐标系里的滚转角
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device)
+
+ # radius*torch.sin(phi) 是球半径在水平平面上的投影,随后再根据yaw角来分别计算x和y
+ # radius*torch.cos(phi)则是纵轴的分量
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ # forward_vectors = math_utils.normalize_vecs(-camera_origins)
+ forward_vectors = math_utils.normalize_vecs(lookat_position.to(device) - camera_origins) # the direction the camera is pointing, pointing to the lookat_position
+ return create_cam2world_matrix(forward_vectors, camera_origins, roll)
+
+
+class UniformCameraPoseSampler:
+ """
+ Same as GaussianCameraPoseSampler, except the
+ pose is sampled from a UNIFORM distribution with range +-[horizontal/vertical]_stddev, instead of a GAUSSIAN distribution.
+
+ Example:
+ For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
+
+ cam2worlds = UniformCameraPoseSampler.sample(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
+ """
+
+ @staticmethod
+ def sample(horizontal_mean, vertical_mean, horizontal_stddev=0, vertical_stddev=0, radius=1, batch_size=1, device='cpu'):
+ """
+ horizontal_mean: 偏转角, 也叫方位角, -pi/2 denotes camera at left, 0 denotes forward, pi/2 denotes right,
+ vertical_mean: 俯仰角, 0 denotes up, -pi/2 denotes camera at up, 0 means horizontal, pi/2 denotes down. however, 0.2 is a good choice for front face.
+ """
+ assert horizontal_mean < np.pi/2 + 1e-5 and horizontal_mean > - np.pi/2 - 1e-5
+ assert vertical_mean < np.pi/2 + 1e-5 and vertical_mean > - np.pi/2 - 1e-5
+ horizontal_mean += np.pi/2
+ vertical_mean += np.pi/2
+
+ h = (torch.rand((batch_size, 1), device=device) * 2 - 1) * horizontal_stddev + horizontal_mean
+ v = (torch.rand((batch_size, 1), device=device) * 2 - 1) * vertical_stddev + vertical_mean
+ v = torch.clamp(v, 1e-5, math.pi - 1e-5)
+
+ theta = h
+ v = v / math.pi
+ phi = torch.arccos(1 - 2*v)
+
+ camera_origins = torch.zeros((batch_size, 3), device=device) # the location of camera
+
+ camera_origins[:, 0:1] = radius*torch.sin(phi) * torch.cos(math.pi-theta)
+ camera_origins[:, 2:3] = radius*torch.sin(phi) * torch.sin(math.pi-theta)
+ camera_origins[:, 1:2] = radius*torch.cos(phi)
+
+ forward_vectors = math_utils.normalize_vecs(-camera_origins) # the direction the camera is pointing, pointing to origin in this func
+ return create_cam2world_matrix(forward_vectors, camera_origins)
+
+
+def create_cam2world_matrix(forward_vector, origin, roll=None):
+ """
+ Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
+ Works on batches of forward_vectors, origins. Assumes y-axis is up.
+ Modified by yerfor to support roll controll
+ roll: Default None, leads to 0 roll; or Tensor([Batch_size, 1]), with radian in [-pi, pi]
+ """
+
+ batch_size = len(forward_vector)
+ forward_vector = math_utils.normalize_vecs(forward_vector)
+ # up_vector 代表相机的正上方方向向量,所以可以通过旋转它来控制roll
+ up_vector = torch.zeros([batch_size, 3], dtype=forward_vector.dtype, device=forward_vector.device)
+ if roll is None:
+ roll = torch.zeros([batch_size, 1], dtype=forward_vector.dtype, device=forward_vector.device)
+ else:
+ roll = roll.reshape([batch_size, 1])
+
+ up_vector[:, 0] = torch.sin(roll)
+ up_vector[:, 1] = torch.cos(roll)
+
+ right_vector = -math_utils.normalize_vecs(torch.cross(up_vector, forward_vector, dim=-1))
+ up_vector = math_utils.normalize_vecs(torch.cross(forward_vector, right_vector, dim=-1))
+
+ rotation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
+ rotation_matrix[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), axis=-1)
+
+ translation_matrix = torch.eye(4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], 1, 1)
+ translation_matrix[:, :3, 3] = origin
+ cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
+ assert(cam2world.shape[1:] == (4, 4))
+ return cam2world
+
+
+def FOV_to_intrinsics(fov_degrees=18.837, device='cpu'):
+ """
+ Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
+ Note the intrinsics are returned as normalized by image size, rather than in pixel units.
+ Assumes principal point is at image center.
+ """
+
+ focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
+ intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device)
+ return intrinsics
\ No newline at end of file
diff --git a/modules/eg3ds/dnnlib/__init__.py b/modules/eg3ds/dnnlib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd91ed142e955581e83948455fb71cd837215f61
--- /dev/null
+++ b/modules/eg3ds/dnnlib/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+from .util import EasyDict, make_cache_dir_path
diff --git a/modules/eg3ds/dnnlib/util.py b/modules/eg3ds/dnnlib/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b67c4e312cd1b847ca21fd3b929802a57e6f6d
--- /dev/null
+++ b/modules/eg3ds/dnnlib/util.py
@@ -0,0 +1,493 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utility classes and functions."""
+
+import ctypes
+import fnmatch
+import importlib
+import inspect
+import numpy as np
+import os
+import shutil
+import sys
+import types
+import io
+import pickle
+import re
+import requests
+import html
+import hashlib
+import glob
+import tempfile
+import urllib
+import urllib.request
+import uuid
+
+from distutils.util import strtobool
+from typing import Any, List, Tuple, Union
+
+
+# Util classes
+# ------------------------------------------------------------------------------------------
+
+
+class EasyDict(dict):
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
+
+ def __getattr__(self, name: str) -> Any:
+ try:
+ return self[name]
+ except KeyError:
+ raise AttributeError(name)
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ self[name] = value
+
+ def __delattr__(self, name: str) -> None:
+ del self[name]
+
+
+class Logger(object):
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
+
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
+ self.file = None
+
+ if file_name is not None:
+ self.file = open(file_name, file_mode)
+
+ self.should_flush = should_flush
+ self.stdout = sys.stdout
+ self.stderr = sys.stderr
+
+ sys.stdout = self
+ sys.stderr = self
+
+ def __enter__(self) -> "Logger":
+ return self
+
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
+ self.close()
+
+ def write(self, text: Union[str, bytes]) -> None:
+ """Write text to stdout (and a file) and optionally flush."""
+ if isinstance(text, bytes):
+ text = text.decode()
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
+ return
+
+ if self.file is not None:
+ self.file.write(text)
+
+ self.stdout.write(text)
+
+ if self.should_flush:
+ self.flush()
+
+ def flush(self) -> None:
+ """Flush written text to both stdout and a file, if open."""
+ if self.file is not None:
+ self.file.flush()
+
+ self.stdout.flush()
+
+ def close(self) -> None:
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
+ self.flush()
+
+ # if using multiple loggers, prevent closing in wrong order
+ if sys.stdout is self:
+ sys.stdout = self.stdout
+ if sys.stderr is self:
+ sys.stderr = self.stderr
+
+ if self.file is not None:
+ self.file.close()
+ self.file = None
+
+
+# Cache directories
+# ------------------------------------------------------------------------------------------
+
+_dnnlib_cache_dir = None
+
+def set_cache_dir(path: str) -> None:
+ global _dnnlib_cache_dir
+ _dnnlib_cache_dir = path
+
+def make_cache_dir_path(*paths: str) -> str:
+ if _dnnlib_cache_dir is not None:
+ return os.path.join(_dnnlib_cache_dir, *paths)
+ if 'DNNLIB_CACHE_DIR' in os.environ:
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
+ if 'HOME' in os.environ:
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
+ if 'USERPROFILE' in os.environ:
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
+
+# Small util functions
+# ------------------------------------------------------------------------------------------
+
+
+def format_time(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
+ else:
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
+
+
+def format_time_brief(seconds: Union[int, float]) -> str:
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
+ s = int(np.rint(seconds))
+
+ if s < 60:
+ return "{0}s".format(s)
+ elif s < 60 * 60:
+ return "{0}m {1:02}s".format(s // 60, s % 60)
+ elif s < 24 * 60 * 60:
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
+ else:
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
+
+
+def ask_yes_no(question: str) -> bool:
+ """Ask the user the question until the user inputs a valid answer."""
+ while True:
+ try:
+ print("{0} [y/n]".format(question))
+ return strtobool(input().lower())
+ except ValueError:
+ pass
+
+
+def tuple_product(t: Tuple) -> Any:
+ """Calculate the product of the tuple elements."""
+ result = 1
+
+ for v in t:
+ result *= v
+
+ return result
+
+
+_str_to_ctype = {
+ "uint8": ctypes.c_ubyte,
+ "uint16": ctypes.c_uint16,
+ "uint32": ctypes.c_uint32,
+ "uint64": ctypes.c_uint64,
+ "int8": ctypes.c_byte,
+ "int16": ctypes.c_int16,
+ "int32": ctypes.c_int32,
+ "int64": ctypes.c_int64,
+ "float32": ctypes.c_float,
+ "float64": ctypes.c_double
+}
+
+
+def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
+ type_str = None
+
+ if isinstance(type_obj, str):
+ type_str = type_obj
+ elif hasattr(type_obj, "__name__"):
+ type_str = type_obj.__name__
+ elif hasattr(type_obj, "name"):
+ type_str = type_obj.name
+ else:
+ raise RuntimeError("Cannot infer type name from input")
+
+ assert type_str in _str_to_ctype.keys()
+
+ my_dtype = np.dtype(type_str)
+ my_ctype = _str_to_ctype[type_str]
+
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
+
+ return my_dtype, my_ctype
+
+
+def is_pickleable(obj: Any) -> bool:
+ try:
+ with io.BytesIO() as stream:
+ pickle.dump(obj, stream)
+ return True
+ except:
+ return False
+
+
+# Functionality to import modules/objects by name, and call functions by name
+# ------------------------------------------------------------------------------------------
+
+def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
+ """Searches for the underlying module behind the name to some python object.
+ Returns the module and the object name (original name with module part removed)."""
+
+ # allow convenience shorthands, substitute them by full names
+ obj_name = re.sub("^np.", "numpy.", obj_name)
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
+
+ # list alternatives for (module_name, local_obj_name)
+ parts = obj_name.split(".")
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
+
+ # try each alternative in turn
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ return module, local_obj_name
+ except:
+ pass
+
+ # maybe some of the modules themselves contain errors?
+ for module_name, _local_obj_name in name_pairs:
+ try:
+ importlib.import_module(module_name) # may raise ImportError
+ except ImportError:
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
+ raise
+
+ # maybe the requested attribute is missing?
+ for module_name, local_obj_name in name_pairs:
+ try:
+ module = importlib.import_module(module_name) # may raise ImportError
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
+ except ImportError:
+ pass
+
+ # we are out of luck, but we have no idea why
+ raise ImportError(obj_name)
+
+
+def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
+ """Traverses the object name and returns the last (rightmost) python object."""
+ if obj_name == '':
+ return module
+ obj = module
+ for part in obj_name.split("."):
+ obj = getattr(obj, part)
+ return obj
+
+
+def get_obj_by_name(name: str) -> Any:
+ """Finds the python object with the given name."""
+ module, obj_name = get_module_from_obj_name(name)
+ return get_obj_from_module(module, obj_name)
+
+
+def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
+ """Finds the python object with the given name and calls it as a function."""
+ assert func_name is not None
+ func_obj = get_obj_by_name(func_name)
+ assert callable(func_obj)
+ return func_obj(*args, **kwargs)
+
+
+def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
+ """Finds the python class with the given name and constructs it with the given arguments."""
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
+
+
+def get_module_dir_by_obj_name(obj_name: str) -> str:
+ """Get the directory path of the module containing the given object name."""
+ module, _ = get_module_from_obj_name(obj_name)
+ return os.path.dirname(inspect.getfile(module))
+
+
+def is_top_level_function(obj: Any) -> bool:
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
+
+
+def get_top_level_function_name(obj: Any) -> str:
+ """Return the fully-qualified name of a top-level function."""
+ assert is_top_level_function(obj)
+ module = obj.__module__
+ if module == '__main__':
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
+ return module + "." + obj.__name__
+
+
+# File system helpers
+# ------------------------------------------------------------------------------------------
+
+def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
+ """List all files recursively in a given directory while ignoring given file and directory names.
+ Returns list of tuples containing both absolute and relative paths."""
+ assert os.path.isdir(dir_path)
+ base_name = os.path.basename(os.path.normpath(dir_path))
+
+ if ignores is None:
+ ignores = []
+
+ result = []
+
+ for root, dirs, files in os.walk(dir_path, topdown=True):
+ for ignore_ in ignores:
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
+
+ # dirs need to be edited in-place
+ for d in dirs_to_remove:
+ dirs.remove(d)
+
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
+
+ absolute_paths = [os.path.join(root, f) for f in files]
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
+
+ if add_base_to_relative:
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
+
+ assert len(absolute_paths) == len(relative_paths)
+ result += zip(absolute_paths, relative_paths)
+
+ return result
+
+
+def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
+ """Takes in a list of tuples of (src, dst) paths and copies files.
+ Will create all necessary directories."""
+ for file in files:
+ target_dir_name = os.path.dirname(file[1])
+
+ # will create all intermediate-level directories
+ if not os.path.exists(target_dir_name):
+ os.makedirs(target_dir_name)
+
+ shutil.copyfile(file[0], file[1])
+
+
+# URL helpers
+# ------------------------------------------------------------------------------------------
+
+def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
+ """Determine whether the given object is a valid URL string."""
+ if not isinstance(obj, str) or not "://" in obj:
+ return False
+ if allow_file_urls and obj.startswith('file://'):
+ return True
+ try:
+ res = requests.compat.urlparse(obj)
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
+ if not res.scheme or not res.netloc or not "." in res.netloc:
+ return False
+ except:
+ return False
+ return True
+
+
+def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
+ """Download the given URL and return a binary-mode file object to access the data."""
+ assert num_attempts >= 1
+ assert not (return_filename and (not cache))
+
+ # Doesn't look like an URL scheme so interpret it as a local filename.
+ if not re.match('^[a-z]+://', url):
+ return url if return_filename else open(url, "rb")
+
+ # Handle file URLs. This code handles unusual file:// patterns that
+ # arise on Windows:
+ #
+ # file:///c:/foo.txt
+ #
+ # which would translate to a local '/c:/foo.txt' filename that's
+ # invalid. Drop the forward slash for such pathnames.
+ #
+ # If you touch this code path, you should test it on both Linux and
+ # Windows.
+ #
+ # Some internet resources suggest using urllib.request.url2pathname() but
+ # but that converts forward slashes to backslashes and this causes
+ # its own set of problems.
+ if url.startswith('file://'):
+ filename = urllib.parse.urlparse(url).path
+ if re.match(r'^/[a-zA-Z]:', filename):
+ filename = filename[1:]
+ return filename if return_filename else open(filename, "rb")
+
+ assert is_url(url)
+
+ # Lookup from cache.
+ if cache_dir is None:
+ cache_dir = make_cache_dir_path('downloads')
+
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
+ if cache:
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
+ if len(cache_files) == 1:
+ filename = cache_files[0]
+ return filename if return_filename else open(filename, "rb")
+
+ # Download.
+ url_name = None
+ url_data = None
+ with requests.Session() as session:
+ if verbose:
+ print("Downloading %s ..." % url, end="", flush=True)
+ for attempts_left in reversed(range(num_attempts)):
+ try:
+ with session.get(url) as res:
+ res.raise_for_status()
+ if len(res.content) == 0:
+ raise IOError("No data received")
+
+ if len(res.content) < 8192:
+ content_str = res.content.decode("utf-8")
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
+ if len(links) == 1:
+ url = requests.compat.urljoin(url, links[0])
+ raise IOError("Google Drive virus checker nag")
+ if "Google Drive - Quota exceeded" in content_str:
+ raise IOError("Google Drive download quota exceeded -- please try again later")
+
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
+ url_name = match[1] if match else url
+ url_data = res.content
+ if verbose:
+ print(" done")
+ break
+ except KeyboardInterrupt:
+ raise
+ except:
+ if not attempts_left:
+ if verbose:
+ print(" failed")
+ raise
+ if verbose:
+ print(".", end="", flush=True)
+
+ # Save to cache.
+ if cache:
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
+ os.makedirs(cache_dir, exist_ok=True)
+ with open(temp_file, "wb") as f:
+ f.write(url_data)
+ os.replace(temp_file, cache_file) # atomic
+ if return_filename:
+ return cache_file
+
+ # Return data as file object.
+ assert not return_filename
+ return io.BytesIO(url_data)
diff --git a/modules/eg3ds/metrics/__init__.py b/modules/eg3ds/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/modules/eg3ds/metrics/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/modules/eg3ds/metrics/equivariance.py b/modules/eg3ds/metrics/equivariance.py
new file mode 100644
index 0000000000000000000000000000000000000000..d105cb93031d5a9638d7a9c12c65db1d8c4a0860
--- /dev/null
+++ b/modules/eg3ds/metrics/equivariance.py
@@ -0,0 +1,270 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Equivariance metrics (EQ-T, EQ-T_frac, and EQ-R) from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import copy
+import numpy as np
+import torch
+import torch.fft
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+# Utilities.
+
+def sinc(x):
+ y = (x * np.pi).abs()
+ z = torch.sin(y) / y.clamp(1e-30, float('inf'))
+ return torch.where(y < 1e-30, torch.ones_like(x), z)
+
+def lanczos_window(x, a):
+ x = x.abs() / a
+ return torch.where(x < 1, sinc(x), torch.zeros_like(x))
+
+def rotation_matrix(angle):
+ angle = torch.as_tensor(angle).to(torch.float32)
+ mat = torch.eye(3, device=angle.device)
+ mat[0, 0] = angle.cos()
+ mat[0, 1] = angle.sin()
+ mat[1, 0] = -angle.sin()
+ mat[1, 1] = angle.cos()
+ return mat
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.1.
+
+def apply_integer_translation(x, tx, ty):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.round().to(torch.int64)
+ iy = ty.round().to(torch.int64)
+
+ z = torch.zeros_like(x)
+ m = torch.zeros_like(x)
+ if abs(ix) < W and abs(iy) < H:
+ y = x[:, :, max(-iy,0) : H+min(-iy,0), max(-ix,0) : W+min(-ix,0)]
+ z[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = y
+ m[:, :, max(iy,0) : H+min(iy,0), max(ix,0) : W+min(ix,0)] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply integer translation to a batch of 2D images. Corresponds to the
+# operator T_x in Appendix E.2.
+
+def apply_fractional_translation(x, tx, ty, a=3):
+ _N, _C, H, W = x.shape
+ tx = torch.as_tensor(tx * W).to(dtype=torch.float32, device=x.device)
+ ty = torch.as_tensor(ty * H).to(dtype=torch.float32, device=x.device)
+ ix = tx.floor().to(torch.int64)
+ iy = ty.floor().to(torch.int64)
+ fx = tx - ix
+ fy = ty - iy
+ b = a - 1
+
+ z = torch.zeros_like(x)
+ zx0 = max(ix - b, 0)
+ zy0 = max(iy - b, 0)
+ zx1 = min(ix + a, 0) + W
+ zy1 = min(iy + a, 0) + H
+ if zx0 < zx1 and zy0 < zy1:
+ taps = torch.arange(a * 2, device=x.device) - b
+ filter_x = (sinc(taps - fx) * sinc((taps - fx) / a)).unsqueeze(0)
+ filter_y = (sinc(taps - fy) * sinc((taps - fy) / a)).unsqueeze(1)
+ y = x
+ y = upfirdn2d.filter2d(y, filter_x / filter_x.sum(), padding=[b,a,0,0])
+ y = upfirdn2d.filter2d(y, filter_y / filter_y.sum(), padding=[0,0,b,a])
+ y = y[:, :, max(b-iy,0) : H+b+a+min(-iy-a,0), max(b-ix,0) : W+b+a+min(-ix-a,0)]
+ z[:, :, zy0:zy1, zx0:zx1] = y
+
+ m = torch.zeros_like(x)
+ mx0 = max(ix + a, 0)
+ my0 = max(iy + a, 0)
+ mx1 = min(ix - b, 0) + W
+ my1 = min(iy - b, 0) + H
+ if mx0 < mx1 and my0 < my1:
+ m[:, :, my0:my1, mx0:mx1] = 1
+ return z, m
+
+#----------------------------------------------------------------------------
+# Construct an oriented low-pass filter that applies the appropriate
+# bandlimit with respect to the input and output of the given affine 2D
+# image transformation.
+
+def construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1):
+ assert a <= amax < aflt
+ mat = torch.as_tensor(mat).to(torch.float32)
+
+ # Construct 2D filter taps in input & output coordinate spaces.
+ taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up)
+ yi, xi = torch.meshgrid(taps, taps)
+ xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2)
+
+ # Convolution of two oriented 2D sinc filters.
+ fi = sinc(xi * cutoff_in) * sinc(yi * cutoff_in)
+ fo = sinc(xo * cutoff_out) * sinc(yo * cutoff_out)
+ f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real
+
+ # Convolution of two oriented 2D Lanczos windows.
+ wi = lanczos_window(xi, a) * lanczos_window(yi, a)
+ wo = lanczos_window(xo, a) * lanczos_window(yo, a)
+ w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real
+
+ # Construct windowed FIR filter.
+ f = f * w
+
+ # Finalize.
+ c = (aflt - amax) * up
+ f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c]
+ f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up)
+ f = f / f.sum([0,2], keepdim=True) / (up ** 2)
+ f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1]
+ return f
+
+#----------------------------------------------------------------------------
+# Apply the given affine transformation to a batch of 2D images.
+
+def apply_affine_transformation(x, mat, up=4, **filter_kwargs):
+ _N, _C, H, W = x.shape
+ mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device)
+
+ # Construct filter.
+ f = construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs)
+ assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1
+ p = f.shape[0] // 2
+
+ # Construct sampling grid.
+ theta = mat.inverse()
+ theta[:2, 2] *= 2
+ theta[0, 2] += 1 / up / W
+ theta[1, 2] += 1 / up / H
+ theta[0, :] *= W / (W + p / up * 2)
+ theta[1, :] *= H / (H + p / up * 2)
+ theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1])
+ g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False)
+
+ # Resample image.
+ y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p)
+ z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False)
+
+ # Form mask.
+ m = torch.zeros_like(y)
+ c = p * 2 + 1
+ m[:, :, c:-c, c:-c] = 1
+ m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False)
+ return z, m
+
+#----------------------------------------------------------------------------
+# Apply fractional rotation to a batch of 2D images. Corresponds to the
+# operator R_\alpha in Appendix E.3.
+
+def apply_fractional_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(angle)
+ return apply_affine_transformation(x, mat, a=a, amax=a*2, **filter_kwargs)
+
+#----------------------------------------------------------------------------
+# Modify the frequency content of a batch of 2D images as if they had undergo
+# fractional rotation -- but without actually rotating them. Corresponds to
+# the operator R^*_\alpha in Appendix E.3.
+
+def apply_fractional_pseudo_rotation(x, angle, a=3, **filter_kwargs):
+ angle = torch.as_tensor(angle).to(dtype=torch.float32, device=x.device)
+ mat = rotation_matrix(-angle)
+ f = construct_affine_bandlimit_filter(mat, a=a, amax=a*2, up=1, **filter_kwargs)
+ y = upfirdn2d.filter2d(x=x, f=f)
+ m = torch.zeros_like(y)
+ c = f.shape[0] // 2
+ m[:, :, c:-c, c:-c] = 1
+ return y, m
+
+#----------------------------------------------------------------------------
+# Compute the selected equivariance metrics for the given generator.
+
+def compute_equivariance_metrics(opts, num_samples, batch_size, translate_max=0.125, rotate_max=1, compute_eqt_int=False, compute_eqt_frac=False, compute_eqr=False):
+ assert compute_eqt_int or compute_eqt_frac or compute_eqr
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ I = torch.eye(3, device=opts.device)
+ M = getattr(getattr(getattr(G, 'synthesis', None), 'input', None), 'transform', None)
+ if M is None:
+ raise ValueError('Cannot compute equivariance metrics; the given generator does not support user-specified image transformations')
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ sums = None
+ progress = opts.progress.sub(tag='eq sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ s = []
+
+ # Randomize noise buffers, if any.
+ for name, buf in G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Run mapping network.
+ z = torch.randn([batch_size, G.z_dim], device=opts.device)
+ c = next(c_iter)
+ ws = G.mapping(z=z, c=c)
+
+ # Generate reference image.
+ M[:] = I
+ orig = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+
+ # Integer translation (EQ-T).
+ if compute_eqt_int:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ t = (t * G.img_resolution).round() / G.img_resolution
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_integer_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Fractional translation (EQ-T_frac).
+ if compute_eqt_frac:
+ t = (torch.rand(2, device=opts.device) * 2 - 1) * translate_max
+ M[:] = I
+ M[:2, 2] = -t
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, mask = apply_fractional_translation(orig, t[0], t[1])
+ s += [(ref - img).square() * mask, mask]
+
+ # Rotation (EQ-R).
+ if compute_eqr:
+ angle = (torch.rand([], device=opts.device) * 2 - 1) * (rotate_max * np.pi)
+ M[:] = rotation_matrix(-angle)
+ img = G.synthesis(ws=ws, noise_mode='const', **opts.G_kwargs)
+ ref, ref_mask = apply_fractional_rotation(orig, angle)
+ pseudo, pseudo_mask = apply_fractional_pseudo_rotation(img, angle)
+ mask = ref_mask * pseudo_mask
+ s += [(ref - pseudo).square() * mask, mask]
+
+ # Accumulate results.
+ s = torch.stack([x.to(torch.float64).sum() for x in s])
+ sums = sums + s if sums is not None else s
+ progress.update(num_samples)
+
+ # Compute PSNRs.
+ if opts.num_gpus > 1:
+ torch.distributed.all_reduce(sums)
+ sums = sums.cpu()
+ mses = sums[0::2] / sums[1::2]
+ psnrs = np.log10(2) * 20 - mses.log10() * 10
+ psnrs = tuple(psnrs.numpy())
+ return psnrs[0] if len(psnrs) == 1 else psnrs
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/frechet_inception_distance.py b/modules/eg3ds/metrics/frechet_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..e682de6162066e255b04c0db2f1cc8860c96de7c
--- /dev/null
+++ b/modules/eg3ds/metrics/frechet_inception_distance.py
@@ -0,0 +1,45 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Frechet Inception Distance (FID) from the paper
+"GANs trained by a two time-scale update rule converge to a local Nash
+equilibrium". Matches the original implementation by Heusel et al. at
+https://github.com/bioinf-jku/TTUR/blob/master/fid.py"""
+
+import numpy as np
+import scipy.linalg
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_fid(opts, max_real, num_gen):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov()
+
+ mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ m = np.square(mu_gen - mu_real).sum()
+ s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
+ return float(fid)
+
+#----------------------------------------------------------------------------
+
diff --git a/modules/eg3ds/metrics/inception_score.py b/modules/eg3ds/metrics/inception_score.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8887595d5d563d391a9f95f193081e70d11caba
--- /dev/null
+++ b/modules/eg3ds/metrics/inception_score.py
@@ -0,0 +1,41 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Inception Score (IS) from the paper "Improved techniques for training
+GANs". Matches the original implementation by Salimans et al. at
+https://github.com/openai/improved-gan/blob/master/inception_score/model.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_is(opts, num_gen, num_splits):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/inception-2015-12-05.pkl'
+ detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer.
+
+ gen_probs = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan'), float('nan')
+
+ scores = []
+ for i in range(num_splits):
+ part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits]
+ kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True)))
+ kl = np.mean(np.sum(kl, axis=1))
+ scores.append(np.exp(kl))
+ return float(np.mean(scores)), float(np.std(scores))
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/kernel_inception_distance.py b/modules/eg3ds/metrics/kernel_inception_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a7735f387fb639135a0dd9a63be6b24c9bb3ade
--- /dev/null
+++ b/modules/eg3ds/metrics/kernel_inception_distance.py
@@ -0,0 +1,49 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Kernel Inception Distance (KID) from the paper "Demystifying MMD
+GANs". Matches the original implementation by Binkowski et al. at
+https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py"""
+
+import numpy as np
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size):
+ # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/inception-2015-12-05.pkl'
+ detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all()
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all()
+
+ if opts.rank != 0:
+ return float('nan')
+
+ n = real_features.shape[1]
+ m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size)
+ t = 0
+ for _subset_idx in range(num_subsets):
+ x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)]
+ y = real_features[np.random.choice(real_features.shape[0], m, replace=False)]
+ a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3
+ b = (x @ y.T / n + 1) ** 3
+ t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m
+ kid = t / num_subsets / m
+ return float(kid)
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/metric_main.py b/modules/eg3ds/metrics/metric_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..77eadbef168888cd740abb2e638ee111ef15c559
--- /dev/null
+++ b/modules/eg3ds/metrics/metric_main.py
@@ -0,0 +1,155 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Main API for computing and reporting quality metrics."""
+
+import os
+import time
+import json
+import torch
+import modules.eg3ds.dnnlib as dnnlib
+
+from . import metric_utils
+from . import frechet_inception_distance
+from . import kernel_inception_distance
+from . import precision_recall
+from . import perceptual_path_length
+from . import inception_score
+from . import equivariance
+
+#----------------------------------------------------------------------------
+
+_metric_dict = dict() # name => fn
+
+def register_metric(fn):
+ assert callable(fn)
+ _metric_dict[fn.__name__] = fn
+ return fn
+
+def is_valid_metric(metric):
+ return metric in _metric_dict
+
+def list_valid_metrics():
+ return list(_metric_dict.keys())
+
+#----------------------------------------------------------------------------
+
+def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments.
+ assert is_valid_metric(metric)
+ opts = metric_utils.MetricOptions(**kwargs)
+
+ # Calculate.
+ start_time = time.time()
+ results = _metric_dict[metric](opts)
+ total_time = time.time() - start_time
+
+ # Broadcast results.
+ for key, value in list(results.items()):
+ if opts.num_gpus > 1:
+ value = torch.as_tensor(value, dtype=torch.float64, device=opts.device)
+ torch.distributed.broadcast(tensor=value, src=0)
+ value = float(value.cpu())
+ results[key] = value
+
+ # Decorate with metadata.
+ return dnnlib.EasyDict(
+ results = dnnlib.EasyDict(results),
+ metric = metric,
+ total_time = total_time,
+ total_time_str = dnnlib.util.format_time(total_time),
+ num_gpus = opts.num_gpus,
+ )
+
+#----------------------------------------------------------------------------
+
+def report_metric(result_dict, run_dir=None, snapshot_pkl=None):
+ metric = result_dict['metric']
+ assert is_valid_metric(metric)
+ if run_dir is not None and snapshot_pkl is not None:
+ snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir)
+
+ jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time()))
+ print(jsonl_line)
+ if run_dir is not None and os.path.isdir(run_dir):
+ with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f:
+ f.write(jsonl_line + '\n')
+
+#----------------------------------------------------------------------------
+# Recommended metrics.
+
+@register_metric
+def fid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000)
+ return dict(fid50k_full=fid)
+
+@register_metric
+def kid50k_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k_full=kid)
+
+@register_metric
+def pr50k3_full(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall)
+
+@register_metric
+def ppl2_wend(opts):
+ ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2)
+ return dict(ppl2_wend=ppl)
+
+@register_metric
+def eqt50k_int(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True)
+ return dict(eqt50k_int=psnr)
+
+@register_metric
+def eqt50k_frac(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True)
+ return dict(eqt50k_frac=psnr)
+
+@register_metric
+def eqr50k(opts):
+ opts.G_kwargs.update(force_fp32=True)
+ psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True)
+ return dict(eqr50k=psnr)
+
+#----------------------------------------------------------------------------
+# Legacy metrics.
+
+@register_metric
+def fid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000)
+ return dict(fid50k=fid)
+
+@register_metric
+def kid50k(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000)
+ return dict(kid50k=kid)
+
+@register_metric
+def pr50k3(opts):
+ opts.dataset_kwargs.update(max_size=None)
+ precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000)
+ return dict(pr50k3_precision=precision, pr50k3_recall=recall)
+
+@register_metric
+def is50k(opts):
+ opts.dataset_kwargs.update(max_size=None, xflip=False)
+ mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10)
+ return dict(is50k_mean=mean, is50k_std=std)
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/metric_utils.py b/modules/eg3ds/metrics/metric_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..756169b281ff0cf72bbacb879bafccc2721b5d42
--- /dev/null
+++ b/modules/eg3ds/metrics/metric_utils.py
@@ -0,0 +1,324 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Miscellaneous utilities used internally by the quality metrics."""
+
+import os
+import sys
+sys.path.append("/home/tiger/projects/GeneFace_private/modules/eg3ds")
+
+import time
+import hashlib
+import pickle
+import copy
+import uuid
+import numpy as np
+import torch
+import modules.eg3ds.dnnlib as dnnlib
+
+from tasks.eg3ds.dataset_utils.kv_eg3d_ffhq_dataset import KV_FFHQ_EG3D_Dataset
+from utils.commons.hparams import hparams
+#----------------------------------------------------------------------------
+
+def chunk(iterable, chunk_size):
+ final_ret = []
+ cnt = 0
+ ret = []
+ for record in iterable:
+ if cnt == 0:
+ ret = []
+ ret.append(record)
+ cnt += 1
+ if len(ret) == chunk_size:
+ final_ret.append(ret)
+ ret = []
+ if len(final_ret[-1]) != chunk_size:
+ final_ret.append(ret)
+ return final_ret
+
+class MetricOptions:
+ def __init__(self, G=None, G_kwargs={}, dataset_kwargs={}, num_gpus=1, rank=0, device=None, progress=None, cache=True):
+ assert 0 <= rank < num_gpus
+ self.G = G
+ self.G_kwargs = dnnlib.EasyDict(G_kwargs)
+ self.dataset_kwargs = dnnlib.EasyDict(dataset_kwargs)
+ self.num_gpus = num_gpus
+ self.rank = rank
+ self.device = device if device is not None else torch.device('cuda', rank)
+ self.progress = progress.sub() if progress is not None and rank == 0 else ProgressMonitor()
+ self.cache = cache
+
+#----------------------------------------------------------------------------
+
+_feature_detector_cache = dict()
+
+def get_feature_detector_name(url):
+ return os.path.splitext(url.split('/')[-1])[0]
+
+def get_feature_detector(url, device=torch.device('cpu'), num_gpus=1, rank=0, verbose=False):
+ assert 0 <= rank < num_gpus
+ key = (url, device)
+ if key not in _feature_detector_cache:
+ is_leader = (rank == 0)
+ if not is_leader and num_gpus > 1:
+ torch.distributed.barrier() # leader goes first
+ with dnnlib.util.open_url(url, verbose=(verbose and is_leader)) as f:
+
+ _feature_detector_cache[key] = pickle.load(f).to(device)
+ if is_leader and num_gpus > 1:
+ torch.distributed.barrier() # others follow
+ return _feature_detector_cache[key]
+
+#----------------------------------------------------------------------------
+
+def iterate_random_labels(opts, batch_size):
+ if opts.G.c_dim == 0:
+ c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device)
+ while True:
+ yield c
+ else:
+ # dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if hparams['ds_name'] in ['FFHQ']:
+ dataset = KV_FFHQ_EG3D_Dataset('train', shuffle=False)
+ else:
+ raise NotImplementedError()
+ while True:
+ # c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)]
+ # c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device)
+ index = np.random.randint(len(dataset), size=(batch_size))
+ samples = dataset[index]
+ cameras = [s['real_camera'] for s in samples]
+ c = torch.stack(cameras).pin_memory().to(opts.device)
+ yield c
+
+#----------------------------------------------------------------------------
+
+class FeatureStats:
+ def __init__(self, capture_all=False, capture_mean_cov=False, max_items=None):
+ self.capture_all = capture_all
+ self.capture_mean_cov = capture_mean_cov
+ self.max_items = max_items
+ self.num_items = 0
+ self.num_features = None
+ self.all_features = None
+ self.raw_mean = None
+ self.raw_cov = None
+
+ def set_num_features(self, num_features):
+ if self.num_features is not None:
+ assert num_features == self.num_features
+ else:
+ self.num_features = num_features
+ self.all_features = []
+ self.raw_mean = np.zeros([num_features], dtype=np.float64)
+ self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
+
+ def is_full(self):
+ return (self.max_items is not None) and (self.num_items >= self.max_items)
+
+ def append(self, x):
+ x = np.asarray(x, dtype=np.float32)
+ assert x.ndim == 2
+ if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
+ if self.num_items >= self.max_items:
+ return
+ x = x[:self.max_items - self.num_items]
+
+ self.set_num_features(x.shape[1])
+ self.num_items += x.shape[0]
+ if self.capture_all:
+ self.all_features.append(x)
+ if self.capture_mean_cov:
+ x64 = x.astype(np.float64)
+ self.raw_mean += x64.sum(axis=0)
+ self.raw_cov += x64.T @ x64
+
+ def append_torch(self, x, num_gpus=1, rank=0):
+ assert isinstance(x, torch.Tensor) and x.ndim == 2
+ assert 0 <= rank < num_gpus
+ if num_gpus > 1:
+ ys = []
+ for src in range(num_gpus):
+ y = x.clone()
+ torch.distributed.broadcast(y, src=src)
+ ys.append(y)
+ x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
+ self.append(x.cpu().numpy())
+
+ def get_all(self):
+ assert self.capture_all
+ return np.concatenate(self.all_features, axis=0)
+
+ def get_all_torch(self):
+ return torch.from_numpy(self.get_all())
+
+ def get_mean_cov(self):
+ assert self.capture_mean_cov
+ mean = self.raw_mean / self.num_items
+ cov = self.raw_cov / self.num_items
+ cov = cov - np.outer(mean, mean)
+ return mean, cov
+
+ def save(self, pkl_file):
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(self.__dict__, f)
+
+ @staticmethod
+ def load(pkl_file):
+ with open(pkl_file, 'rb') as f:
+ s = dnnlib.EasyDict(pickle.load(f))
+ obj = FeatureStats(capture_all=s.capture_all, max_items=s.max_items)
+ obj.__dict__.update(s)
+ return obj
+
+#----------------------------------------------------------------------------
+
+class ProgressMonitor:
+ def __init__(self, tag=None, num_items=None, flush_interval=1000, verbose=False, progress_fn=None, pfn_lo=0, pfn_hi=1000, pfn_total=1000):
+ self.tag = tag
+ self.num_items = num_items
+ self.verbose = verbose
+ self.flush_interval = flush_interval
+ self.progress_fn = progress_fn
+ self.pfn_lo = pfn_lo
+ self.pfn_hi = pfn_hi
+ self.pfn_total = pfn_total
+ self.start_time = time.time()
+ self.batch_time = self.start_time
+ self.batch_items = 0
+ if self.progress_fn is not None:
+ self.progress_fn(self.pfn_lo, self.pfn_total)
+
+ def update(self, cur_items):
+ assert (self.num_items is None) or (cur_items <= self.num_items)
+ if (cur_items < self.batch_items + self.flush_interval) and (self.num_items is None or cur_items < self.num_items):
+ return
+ cur_time = time.time()
+ total_time = cur_time - self.start_time
+ time_per_item = (cur_time - self.batch_time) / max(cur_items - self.batch_items, 1)
+ if (self.verbose) and (self.tag is not None):
+ print(f'{self.tag:<19s} items {cur_items:<7d} time {dnnlib.util.format_time(total_time):<12s} ms/item {time_per_item*1e3:.2f}')
+ self.batch_time = cur_time
+ self.batch_items = cur_items
+
+ if (self.progress_fn is not None) and (self.num_items is not None):
+ self.progress_fn(self.pfn_lo + (self.pfn_hi - self.pfn_lo) * (cur_items / self.num_items), self.pfn_total)
+
+ def sub(self, tag=None, num_items=None, flush_interval=1000, rel_lo=0, rel_hi=1):
+ return ProgressMonitor(
+ tag = tag,
+ num_items = num_items,
+ flush_interval = flush_interval,
+ verbose = self.verbose,
+ progress_fn = self.progress_fn,
+ pfn_lo = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_lo,
+ pfn_hi = self.pfn_lo + (self.pfn_hi - self.pfn_lo) * rel_hi,
+ pfn_total = self.pfn_total,
+ )
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_dataset(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, data_loader_kwargs=None, max_items=None, **stats_kwargs):
+ # dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs)
+ if hparams['ds_name'] in ['FFHQ']:
+ dataset = KV_FFHQ_EG3D_Dataset('train', shuffle=False)
+ else:
+ raise NotImplementedError()
+
+ if data_loader_kwargs is None:
+ data_loader_kwargs = dict(pin_memory=True, num_workers=3, prefetch_factor=2)
+
+ # Try to lookup from cache.
+ cache_file = None
+ if opts.cache:
+ # Choose cache file name.
+ args = dict(dataset_kwargs=opts.dataset_kwargs, detector_url=detector_url, detector_kwargs=detector_kwargs, stats_kwargs=stats_kwargs)
+ md5 = hashlib.md5(repr(sorted(args.items())).encode('utf-8'))
+ ds_name = hparams['ds_name'] + dataset.prefix
+ cache_tag = f'{ds_name}-{get_feature_detector_name(detector_url)}-{md5.hexdigest()}'
+ cache_file = dnnlib.make_cache_dir_path('gan-metrics', cache_tag + '.pkl')
+
+ # Check if the file exists (all processes must agree).
+ flag = os.path.isfile(cache_file) if opts.rank == 0 else False
+ if opts.num_gpus > 1:
+ flag = torch.as_tensor(flag, dtype=torch.float32, device=opts.device)
+ torch.distributed.broadcast(tensor=flag, src=0)
+ flag = (float(flag.cpu()) != 0)
+
+ # Load.
+ if flag:
+ return FeatureStats.load(cache_file)
+
+ # Initialize.
+ num_items = len(dataset)
+ if max_items is not None:
+ num_items = min(num_items, max_items)
+ stats = FeatureStats(max_items=num_items, **stats_kwargs)
+ progress = opts.progress.sub(tag='dataset features', num_items=num_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ item_subset = [(i * opts.num_gpus + opts.rank) % num_items for i in range((num_items - 1) // opts.num_gpus + 1)]
+ item_subset = chunk(item_subset, chunk_size=batch_size)
+ for batch in torch.utils.data.DataLoader(dataset=dataset, sampler=item_subset, batch_size=1, collate_fn=dataset.collater, **data_loader_kwargs):
+ images = batch['real_imgs']
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+
+ if images.dtype != torch.uint8:
+ images = (images * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+
+ features = detector(images.to(opts.device), **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+
+ # Save to cache.
+ if cache_file is not None and opts.rank == 0:
+ os.makedirs(os.path.dirname(cache_file), exist_ok=True)
+ temp_file = cache_file + '.' + uuid.uuid4().hex
+ stats.save(temp_file)
+ os.replace(temp_file, cache_file) # atomic
+ return stats
+
+#----------------------------------------------------------------------------
+
+def compute_feature_stats_for_generator(opts, detector_url, detector_kwargs, rel_lo=0, rel_hi=1, batch_size=64, batch_gen=None, **stats_kwargs):
+ if batch_gen is None:
+ batch_gen = min(batch_size, 4)
+ assert batch_size % batch_gen == 0
+
+ # Setup generator and labels.
+ G = copy.deepcopy(opts.G).eval().requires_grad_(False).to(opts.device)
+ c_iter = iterate_random_labels(opts=opts, batch_size=batch_gen)
+
+ # Initialize.
+ stats = FeatureStats(**stats_kwargs)
+ assert stats.max_items is not None
+ progress = opts.progress.sub(tag='generator features', num_items=stats.max_items, rel_lo=rel_lo, rel_hi=rel_hi)
+ detector = get_feature_detector(url=detector_url, device=opts.device, num_gpus=opts.num_gpus, rank=opts.rank, verbose=progress.verbose)
+
+ # Main loop.
+ while not stats.is_full():
+ images = []
+ for _i in range(batch_size // batch_gen):
+ z = torch.randn([batch_gen, G.z_dim], device=opts.device)
+ img = G(z=z, camera=next(c_iter))['image']
+ # img = G(z=z, c=next(c_iter), **opts.G_kwargs)['image']
+ img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)
+ images.append(img)
+ images = torch.cat(images)
+ if images.shape[1] == 1:
+ images = images.repeat([1, 3, 1, 1])
+ features = detector(images, **detector_kwargs)
+ stats.append_torch(features, num_gpus=opts.num_gpus, rank=opts.rank)
+ progress.update(stats.num_items)
+ return stats
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/perceptual_path_length.py b/modules/eg3ds/metrics/perceptual_path_length.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e58dac3317733e2ace6d64ee1f97cafa0a38225
--- /dev/null
+++ b/modules/eg3ds/metrics/perceptual_path_length.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Perceptual Path Length (PPL) from the paper "A Style-Based Generator
+Architecture for Generative Adversarial Networks". Matches the original
+implementation by Karras et al. at
+https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py"""
+
+import copy
+import numpy as np
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+# Spherical interpolation of a batch of vectors.
+def slerp(a, b, t):
+ a = a / a.norm(dim=-1, keepdim=True)
+ b = b / b.norm(dim=-1, keepdim=True)
+ d = (a * b).sum(dim=-1, keepdim=True)
+ p = t * torch.acos(d)
+ c = b - d * a
+ c = c / c.norm(dim=-1, keepdim=True)
+ d = a * torch.cos(p) + c * torch.sin(p)
+ d = d / d.norm(dim=-1, keepdim=True)
+ return d
+
+#----------------------------------------------------------------------------
+
+class PPLSampler(torch.nn.Module):
+ def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16):
+ assert space in ['z', 'w']
+ assert sampling in ['full', 'end']
+ super().__init__()
+ self.G = copy.deepcopy(G)
+ self.G_kwargs = G_kwargs
+ self.epsilon = epsilon
+ self.space = space
+ self.sampling = sampling
+ self.crop = crop
+ self.vgg16 = copy.deepcopy(vgg16)
+
+ def forward(self, c):
+ # Generate random latents and interpolation t-values.
+ t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0)
+ z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2)
+
+ # Interpolate in W or Z.
+ if self.space == 'w':
+ w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2)
+ wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2))
+ wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon)
+ else: # space == 'z'
+ zt0 = slerp(z0, z1, t.unsqueeze(1))
+ zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon)
+ wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2)
+
+ # Randomize noise buffers.
+ for name, buf in self.G.named_buffers():
+ if name.endswith('.noise_const'):
+ buf.copy_(torch.randn_like(buf))
+
+ # Generate images.
+ img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs)
+
+ # Center crop.
+ if self.crop:
+ assert img.shape[2] == img.shape[3]
+ c = img.shape[2] // 8
+ img = img[:, :, c*3 : c*7, c*2 : c*6]
+
+ # Downsample to 256x256.
+ factor = self.G.img_resolution // 256
+ if factor > 1:
+ img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5])
+
+ # Scale dynamic range from [-1,1] to [0,255].
+ img = (img + 1) * (255 / 2)
+ if self.G.img_channels == 1:
+ img = img.repeat([1, 3, 1, 1])
+
+ # Evaluate differential LPIPS.
+ lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2)
+ dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2
+ return dist
+
+#----------------------------------------------------------------------------
+
+def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size):
+ vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose)
+
+ # Setup sampler and labels.
+ sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16)
+ sampler.eval().requires_grad_(False).to(opts.device)
+ c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size)
+
+ # Sampling loop.
+ dist = []
+ progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples)
+ for batch_start in range(0, num_samples, batch_size * opts.num_gpus):
+ progress.update(batch_start)
+ x = sampler(next(c_iter))
+ for src in range(opts.num_gpus):
+ y = x.clone()
+ if opts.num_gpus > 1:
+ torch.distributed.broadcast(y, src=src)
+ dist.append(y)
+ progress.update(num_samples)
+
+ # Compute PPL.
+ if opts.rank != 0:
+ return float('nan')
+ dist = torch.cat(dist)[:num_samples].cpu().numpy()
+ lo = np.percentile(dist, 1, interpolation='lower')
+ hi = np.percentile(dist, 99, interpolation='higher')
+ ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean()
+ return float(ppl)
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/metrics/precision_recall.py b/modules/eg3ds/metrics/precision_recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..6043717d59c53c34d76e35600a58f91e77659e0c
--- /dev/null
+++ b/modules/eg3ds/metrics/precision_recall.py
@@ -0,0 +1,65 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Precision/Recall (PR) from the paper "Improved Precision and Recall
+Metric for Assessing Generative Models". Matches the original implementation
+by Kynkaanniemi et al. at
+https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py"""
+
+import torch
+from . import metric_utils
+
+#----------------------------------------------------------------------------
+
+def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size):
+ assert 0 <= rank < num_gpus
+ num_cols = col_features.shape[0]
+ num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus
+ col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches)
+ dist_batches = []
+ for col_batch in col_batches[rank :: num_gpus]:
+ dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0]
+ for src in range(num_gpus):
+ dist_broadcast = dist_batch.clone()
+ if num_gpus > 1:
+ torch.distributed.broadcast(dist_broadcast, src=src)
+ dist_batches.append(dist_broadcast.cpu() if rank == 0 else None)
+ return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None
+
+#----------------------------------------------------------------------------
+
+def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size):
+ # detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl'
+ detector_url = 'file:///home/tiger/nfs/myenv/cache/useful_ckpts/vgg16.pkl'
+ detector_kwargs = dict(return_features=True)
+
+ real_features = metric_utils.compute_feature_stats_for_dataset(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device)
+
+ gen_features = metric_utils.compute_feature_stats_for_generator(
+ opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
+ rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device)
+
+ results = dict()
+ for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]:
+ kth = []
+ for manifold_batch in manifold.split(row_batch_size):
+ dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None)
+ kth = torch.cat(kth) if opts.rank == 0 else None
+ pred = []
+ for probes_batch in probes.split(row_batch_size):
+ dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size)
+ pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None)
+ results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan')
+ return results['precision'], results['recall']
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/models/dual_discriminator.py b/modules/eg3ds/models/dual_discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d82d4148253a341cf3bccf7bd056a39be00e22
--- /dev/null
+++ b/modules/eg3ds/models/dual_discriminator.py
@@ -0,0 +1,374 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Discriminator architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+import torch.nn as nn
+#
+
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.models.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue
+from einops import rearrange
+from utils.commons.hparams import hparams
+
+
+class SingleDiscriminator(torch.nn.Module):
+ def __init__(self,
+ img_resolution, # Input resolution.
+ img_channels =3, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ sr_upsample_factor = 1, # Ignored for SingleDiscriminator
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.camera_dim = 25
+ if hparams['disc_cond_mode'] == 'idexp_lm3d_normalized':
+ self.cond_dim = 204
+ else:
+ self.cond_dim = 0
+ c_dim = self.camera_dim
+ self.c_dim = c_dim
+
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs):
+ img = img['image']
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ c = camera
+ if self.cond_dim > 0:
+ cond_feat = self.cond_encoder(cond)
+ c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8]
+
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
+ is_bcthw_flag = True if image_orig_tensor.ndim == 5 else False
+ if is_bcthw_flag: # [B, c, T, H, W]
+ n,c,t,h,w = image_orig_tensor.shape
+ image_orig_tensor = rearrange(image_orig_tensor, "n c t h w -> (n t) c h w")
+
+ if filter_mode == 'antialiased':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ elif filter_mode == 'classic':
+ ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
+ ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
+ ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
+ elif filter_mode == 'none':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
+ elif type(filter_mode) == float:
+ assert 0 < filter_mode < 1
+
+ filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False)
+ ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
+ if is_bcthw_flag: # [B, c, T, H, W]
+ ada_filtered_64 = rearrange(ada_filtered_64, "(n t) c h w -> n c t h w", n=n,t=t)
+
+ return ada_filtered_64
+
+#----------------------------------------------------------------------------
+
+class DualDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ channel_base = hparams['base_channel']
+ channel_max = hparams['max_channel']
+ conv_clamp = 256
+ cmap_dim = None
+ block_kwargs = {'freeze_layers': 0}
+ mapping_kwargs = {}
+ epilogue_kwargs = {'mbstd_group_size': hparams['group_size_for_mini_batch_std']}
+ architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'.
+
+ img_channels = 3
+ img_channels *= 2
+
+ self.camera_dim = 25
+ c_dim = self.camera_dim
+
+ self.img_resolution = hparams['final_resolution']
+ self.img_resolution_log2 = int(np.log2(self.img_resolution))
+ self.img_channels = 3
+
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ self.num_fp16_res = hparams['num_fp16_layers_in_discriminator']
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - self.num_fp16_res), 8)
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < self.img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ # use_fp16 = True
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ if hparams.get("disc_cond_mode", 'none') != 'none':
+ """
+ For discriminator, embed cond with mapping network works well.
+ """
+ self.cond_dim = 204
+ self.mapping = MappingNetwork(z_dim=self.cond_dim, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, img, camera, cond=None, update_emas=False, feature_maps=None, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ img = torch.cat([img['image'], image_raw], 1)
+
+ # add by yerfor
+ img = torch.clamp(img, min=-1, max=1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+ if feature_maps is not None:
+ feature_maps.append(x)
+ cmap = None
+
+ c = camera.clone() # prevent inplace modification in sample!
+ if hparams['disc_c_noise'] > 0:
+ if len(c) > 1:
+ c_std = c.std(0)
+ else:
+ # c_std = 1
+ c_std = torch.tensor([0.0664, 0.0295, 0.2720, 0.6971, 0.0279, 0.0178, 0.1280, 0.3284, 0.2721,
+ 0.1274, 0.0679, 0.1642, 0.0000, 0.0000, 0.0000, 0.0000, 0.0079, 0.0000,
+ 0.0000, 0.0000, 0.0079, 0.0000, 0.0000, 0.0000, 0.0000]).to(c.device)
+ c += torch.randn_like(c) * c_std * hparams['disc_c_noise']
+
+ # x: [B, 512, 4, 4], img: None, cmap: [B, 512]
+ if hparams.get("disc_cond_mode", 'none') != 'none':
+ cmap = self.mapping(cond, c)
+ else:
+ cmap = self.mapping(None, c)
+
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+class DummyDualDiscriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels *= 2
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ self.raw_fade = 1
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
+
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+
+# Tri-discriminator: upsampled image, super-resolved image, and segmentation mask
+# V2: first concatenate imgs and seg mask, using only one conv block
+class MaskDualDiscriminatorV2(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ seg_resolution = 128, # Input resolution.
+ seg_channels = 1, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels = img_channels * 2 + seg_channels
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.seg_resolution = seg_resolution
+ self.seg_channels = seg_channels
+
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+ self.disc_c_noise = disc_c_noise
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter)
+ seg = 2 * seg - 1 # normalize to [-1,1]
+ img = torch.cat([img['image'], image_raw, seg], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'c_dim={self.c_dim:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'seg_resolution={self.seg_resolution:d}, seg_channels={self.seg_channels:d}'])
\ No newline at end of file
diff --git a/modules/eg3ds/models/dual_discriminator_cond.py b/modules/eg3ds/models/dual_discriminator_cond.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d6b37470054d002607f05fb764988d160272c80
--- /dev/null
+++ b/modules/eg3ds/models/dual_discriminator_cond.py
@@ -0,0 +1,279 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Discriminator architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+import torch.nn as nn
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.models.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue
+from modules.eg3ds.models.cond_encoder import LM3D_Win_Encoder
+
+from utils.commons.hparams import hparams
+
+
+class SingleDiscriminator(torch.nn.Module):
+ def __init__(self,
+ img_resolution, # Input resolution.
+ img_channels =3, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ sr_upsample_factor = 1, # Ignored for SingleDiscriminator
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.camera_dim = 25
+ if hparams['cond_type'] == 'idexp_lm3d_normalized':
+ self.cond_dim = 204
+ else:
+ self.cond_dim = 0
+ c_dim = self.camera_dim
+ if self.cond_dim > 0:
+ cond_out_dim = hparams['cond_out_dim']
+ c_dim += cond_out_dim
+ self.cond_encoder = LM3D_Win_Encoder(self.cond_dim, hid_dim=hparams['cond_hid_dim'], out_dim=cond_out_dim, smo_size=hparams['smo_win_size'])
+ self.c_dim = c_dim
+
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs):
+ img = img['image']
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ c = camera
+ if self.cond_dim > 0:
+ cond_feat = self.cond_encoder(cond)
+ c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8]
+
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
+ if filter_mode == 'antialiased':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ elif filter_mode == 'classic':
+ ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
+ ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
+ ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
+ elif filter_mode == 'none':
+ ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
+ elif type(filter_mode) == float:
+ assert 0 < filter_mode < 1
+
+ filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True)
+ aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False)
+ ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
+
+ return ada_filtered_64
+
+#----------------------------------------------------------------------------
+
+class DualDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ channel_base = hparams['base_channel']
+ channel_max = hparams['max_channel']
+ conv_clamp = 256
+ cmap_dim = None
+ disc_c_noise = 0.
+ block_kwargs = {'freeze_layers': 0}
+ mapping_kwargs = {}
+ epilogue_kwargs = {'mbstd_group_size': 4}
+ architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'.
+
+ img_channels = 3
+ img_channels *= 2
+
+ self.camera_dim = 25
+ if hparams['cond_type'] == 'idexp_lm3d_normalized':
+ self.cond_dim = 204
+ else:
+ self.cond_dim = 0
+ c_dim = self.camera_dim
+
+ if self.cond_dim > 0:
+ cond_out_dim = hparams['cond_out_dim']
+ c_dim += cond_out_dim
+ self.cond_encoder = LM3D_Win_Encoder(self.cond_dim, hid_dim=hparams['cond_hid_dim'], out_dim=cond_out_dim, smo_size=hparams['smo_win_size'])
+
+ self.img_resolution = hparams['final_resolution']
+ self.img_resolution_log2 = int(np.log2(self.img_resolution))
+ self.img_channels = 3
+
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ self.num_fp16_res = hparams['num_fp16_layers_in_discriminator']
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - self.num_fp16_res), 8)
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < self.img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+ self.disc_c_noise = disc_c_noise
+
+ def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs):
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+
+ c = camera
+ if self.cond_dim > 0:
+ cond_feat = self.cond_encoder(cond)
+ c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8]
+ if self.disc_c_noise > 0:
+ c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
+
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
+
+class DummyDualDiscriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ img_channels *= 2
+
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ self.raw_fade = 1
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
+
+ image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
+ img = torch.cat([img['image'], image_raw], 1)
+
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/models/networks_stylegan2.py b/modules/eg3ds/models/networks_stylegan2.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d638e3f6898cec32c92c475ad7a73df12e8f9c
--- /dev/null
+++ b/modules/eg3ds/models/networks_stylegan2.py
@@ -0,0 +1,814 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Network architectures from the paper
+"Analyzing and Improving the Image Quality of StyleGAN".
+Matches the original implementation of configs E-F by Karras et al. at
+https://github.com/NVlabs/stylegan2/blob/master/training/networks_stylegan2.py"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+import math
+from modules.eg3ds.torch_utils import misc
+from modules.eg3ds.torch_utils.ops import conv2d_resample
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.torch_utils.ops import bias_act
+from modules.eg3ds.torch_utils.ops import fma
+
+from utils.commons.hparams import hparams
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def normalize_2nd_moment(x, dim=1, eps=1e-8):
+ return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor of shape [batch_size, in_channels, in_height, in_width].
+ weight, # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
+ styles, # Modulation coefficients of shape [batch_size, in_channels].
+ noise = None, # Optional noise tensor to add to the output activations.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ padding = 0, # Padding with respect to the upsampled image.
+ resample_filter = None, # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
+ demodulate = True, # Apply weight demodulation?
+ flip_weight = True, # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
+ fused_modconv = True, # Perform modulation, convolution, and demodulation as a single fused operation?
+):
+ batch_size = x.shape[0]
+ out_channels, in_channels, kh, kw = weight.shape
+ misc.assert_shape(weight, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(styles, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs to avoid FP16 overflow.
+ if x.dtype == torch.float16 and demodulate:
+ weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float('inf'), dim=[1,2,3], keepdim=True)) # max_Ikk
+ styles = styles / styles.norm(float('inf'), dim=1, keepdim=True) # max_I
+
+ # Calculate per-sample weights and demodulation coefficients.
+ w = None
+ dcoefs = None
+ if demodulate or fused_modconv:
+ w = weight.unsqueeze(0) # [NOIkk]
+ w = w * styles.reshape(batch_size, 1, -1, 1, 1) # [NOIkk], 将weight乘以style
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO] # [2, 512,512,3,3]==>[2, 512] 归一化
+ if demodulate and fused_modconv:
+ w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1) # [NOIkk]
+
+ # Execute by scaling the activations before and after the convolution.
+ if not fused_modconv:
+ x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1) # 将x乘以style
+ x = conv2d_resample.conv2d_resample(x=x, w=weight.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, flip_weight=flip_weight) # conv2d forward
+ if demodulate and noise is not None:
+ x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype)) # FusedMultiplyAdd
+ elif demodulate:
+ x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
+ elif noise is not None:
+ x = x.add_(noise.to(x.dtype))
+ return x
+
+ # Execute as one fused op using grouped convolution.
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(batch_size)
+ misc.assert_shape(x, [batch_size, in_channels, None, None])
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=resample_filter, up=up, down=down, padding=padding, groups=batch_size, flip_weight=flip_weight)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ if noise is not None:
+ x = x.add_(noise)
+ return x
+
+#----------------------------------------------------------------------------
+
+
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ lr_multiplier = 1, # Learning rate multiplier.
+ bias_init = 0, # Initial value for the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.bias_init = bias_init
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
+ self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+#----------------------------------------------------------------------------
+
+
+class Conv2dLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ kernel_size, # Width and height of the convolution kernel.
+ bias = True, # Apply additive bias before the activation function?
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ up = 1, # Integer upsampling factor.
+ down = 1, # Integer downsampling factor.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output to +-X, None = disable clamping.
+ channels_last = False, # Expect the input to have memory_format=channels_last?
+ trainable = True, # Update the weights of this layer during training?
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.activation = activation
+ self.up = up
+ self.down = down
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+ self.trainable = trainable
+
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
+ bias = torch.zeros([out_channels]) if bias else None
+ if trainable:
+ self.weight = torch.nn.Parameter(weight)
+ self.bias = torch.nn.Parameter(bias) if bias is not None else None
+ else:
+ self.register_buffer('weight', weight)
+ if bias is not None:
+ self.register_buffer('bias', bias)
+ else:
+ self.bias = None
+
+ def forward(self, x, gain=1):
+ w = self.weight * self.weight_gain
+
+ b = self.bias.to(x.dtype) if self.bias is not None else None
+ flip_weight = (self.up == 1) # slightly faster
+ x = conv2d_resample.conv2d_resample(x=x, w=w.to(x.dtype), f=self.resample_filter, up=self.up, down=self.down, padding=self.padding, flip_weight=flip_weight)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, activation={self.activation:s},',
+ f'up={self.up}, down={self.down}'])
+
+#----------------------------------------------------------------------------
+
+
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality, 0 = no latent.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no label.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output, None = do not broadcast.
+ num_layers = 8, # Number of mapping layers.
+ embed_features = None, # Label embedding dimensionality, None = same as w_dim.
+ layer_features = None, # Number of intermediate features in the mapping layers, None = same as w_dim.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ last_activation = None, # add by panohead, define the last activation
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training, None = do not track.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ if embed_features is None:
+ embed_features = w_dim
+ if c_dim == 0:
+ embed_features = 0
+ if layer_features is None:
+ layer_features = w_dim
+ features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
+
+ if c_dim > 0:
+ self.embed = FullyConnectedLayer(c_dim, embed_features)
+ for idx in range(num_layers):
+ in_features = features_list[idx]
+ out_features = features_list[idx + 1]
+ if idx == num_layers - 1 and last_activation:
+ layer = FullyConnectedLayer(in_features, out_features, activation=last_activation, lr_multiplier=lr_multiplier)
+ else:
+ layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+
+ if num_ws is not None and w_avg_beta is not None:
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+ # Embed, normalize, and concat inputs.
+ x = None
+ with torch.autograd.profiler.record_function('input'):
+ if self.z_dim > 0:
+ misc.assert_shape(z, [None, self.z_dim])
+ x = normalize_2nd_moment(z.to(torch.float32))
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Main layers.
+ for idx in range(self.num_layers):
+ layer = getattr(self, f'fc{idx}')
+ x = layer(x)
+
+ # Update moving average of W.
+ if update_emas and self.w_avg_beta is not None:
+ with torch.autograd.profiler.record_function('update_w_avg'):
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast.
+ if self.num_ws is not None:
+ with torch.autograd.profiler.record_function('broadcast'):
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+
+ # Apply truncation trick.
+ if truncation_psi != 1:
+ with torch.autograd.profiler.record_function('truncate'):
+ assert self.w_avg_beta is not None
+ if self.num_ws is None or truncation_cutoff is None:
+ x = self.w_avg.lerp(x, truncation_psi)
+ else:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) # 从w_avg出发向x前进,前进步数[0~1.]为truncation_psi
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this layer.
+ kernel_size = 3, # Convolution kernel size.
+ up = 1, # Integer upsampling factor.
+ use_noise = True, # Enable noise input?
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ channels_last = False, # Use channels_last format for the weights?
+ **other_args
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.up = up
+ self.use_noise = use_noise
+ self.activation = activation
+ self.conv_clamp = conv_clamp
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.padding = kernel_size // 2
+ self.act_gain = bias_act.activation_funcs[activation].def_gain
+
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ if use_noise:
+ self.register_buffer('noise_const', torch.randn([resolution, resolution]))
+ self.noise_strength = torch.nn.Parameter(torch.zeros([]))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+
+ def forward(self, x, w, noise_mode='random', fused_modconv=True, gain=1, **kwargs):
+ assert noise_mode in ['random', 'const', 'none']
+ in_resolution = self.resolution // self.up
+ misc.assert_shape(x, [None, self.in_channels, in_resolution, in_resolution])
+ styles = self.affine(w)
+
+ noise = None
+ if self.use_noise and noise_mode == 'random':
+ noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
+ if self.use_noise and noise_mode == 'const':
+ noise = self.noise_const * self.noise_strength
+
+ flip_weight = (self.up == 1) # slightly faster
+ weight = self.weight
+ x = modulated_conv2d(x=x, weight=weight, styles=styles, noise=noise, up=self.up,
+ padding=self.padding, resample_filter=self.resample_filter, flip_weight=flip_weight, fused_modconv=fused_modconv)
+
+ act_gain = self.act_gain * gain
+ act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
+ return x
+
+ def extra_repr(self):
+ return ' '.join([
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d},',
+ f'resolution={self.resolution:d}, up={self.up}, activation={self.activation:s}'])
+
+#----------------------------------------------------------------------------
+
+
+class ToRGBLayer(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.w_dim = w_dim
+ self.conv_clamp = conv_clamp
+ self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
+ memory_format = torch.channels_last if channels_last else torch.contiguous_format
+ self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
+ self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
+ self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
+
+ def forward(self, x, w, fused_modconv=True):
+ styles = self.affine(w) * self.weight_gain
+ weight = self.weight
+ x = modulated_conv2d(x=x, weight=weight, styles=styles, demodulate=False, fused_modconv=fused_modconv) # demodulate为False
+ x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
+ return x
+
+ def extra_repr(self):
+ return f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}, w_dim={self.w_dim:d}'
+
+#----------------------------------------------------------------------------
+
+class SynthesisBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution, up=2,
+ resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+
+ # ToRGB.
+ if img is not None:
+ misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ **block_kwargs, # Arguments for SynthesisBlock.
+ ):
+ assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
+ super().__init__()
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.num_fp16_res = num_fp16_res
+ self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ self.num_ws = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res // 2] if res > 4 else 0
+ out_channels = channels_dict[res]
+ use_fp16 = (res >= fp16_resolution)
+ is_last = (res == self.img_resolution)
+ block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
+ img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
+ self.num_ws += block.num_conv
+ if is_last:
+ self.num_ws += block.num_torgb
+ setattr(self, f'b{res}', block)
+
+ def forward(self, ws, **block_kwargs):
+ block_ws = []
+ with torch.autograd.profiler.record_function('split_ws'):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32)
+ w_idx = 0
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) # [B, num_conv_and_rgb, w_dim]
+ w_idx += block.num_conv
+
+ x = img = None
+ for res, cur_ws in zip(self.block_resolutions, block_ws):
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, cur_ws, **block_kwargs)
+ return img
+
+ def extra_repr(self):
+ return ' '.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_fp16_res={self.num_fp16_res:d}'])
+
+#----------------------------------------------------------------------------
+
+
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+ if hparams.get("gen_cond_mode", 'none') == 'mapping': # comes from a attemp to inject landmark condition
+ self.cond_dim = 204
+ self.cond_mapping = MappingNetwork(z_dim=self.cond_dim, c_dim=0, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, z, c, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ if hparams.get("gen_cond_mode", 'none') == 'mapping':
+ d_ws = self.cond_mapping(cond, 0, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ ws = ws * 0.5 + d_ws * 0.5
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
+
+#----------------------------------------------------------------------------
+
+
+class DiscriminatorBlock(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ tmp_channels, # Number of intermediate channels.
+ out_channels, # Number of output channels.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ first_layer_idx, # Index of the first layer.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ freeze_layers = 0, # Freeze-D: Number of layers to freeze.
+ ):
+ assert in_channels in [0, tmp_channels]
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.first_layer_idx = first_layer_idx
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+
+ self.num_layers = 0
+ def trainable_gen():
+ while True:
+ layer_idx = self.first_layer_idx + self.num_layers
+ trainable = (layer_idx >= freeze_layers)
+ self.num_layers += 1
+ yield trainable
+ trainable_iter = trainable_gen()
+
+ if in_channels == 0 or architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, tmp_channels, kernel_size=1, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv0 = Conv2dLayer(tmp_channels, tmp_channels, kernel_size=3, activation=activation,
+ trainable=next(trainable_iter), conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ self.conv1 = Conv2dLayer(tmp_channels, out_channels, kernel_size=3, activation=activation, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, conv_clamp=conv_clamp, channels_last=self.channels_last)
+
+ if architecture == 'resnet':
+ self.skip = Conv2dLayer(tmp_channels, out_channels, kernel_size=1, bias=False, down=2,
+ trainable=next(trainable_iter), resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, force_fp32=False):
+ if (x if x is not None else img).device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+
+ # Input.
+ if x is not None:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # FromRGB.
+ if self.in_channels == 0 or self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ y = self.fromrgb(img)
+ x = x + y if x is not None else y
+ img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == 'skip' else None
+
+ # Main layers.
+ if self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x)
+ x = self.conv1(x, gain=np.sqrt(0.5))
+ x = y.add_(x)
+ else:
+ x = self.conv0(x)
+ x = self.conv1(x)
+
+ assert x.dtype == dtype
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+#----------------------------------------------------------------------------
+
+
+class MinibatchStdLayer(torch.nn.Module):
+ def __init__(self, group_size, num_channels=1):
+ super().__init__()
+ self.group_size = group_size
+ self.num_channels = num_channels
+
+ def forward(self, x):
+ N, C, H, W = x.shape
+ with misc.suppress_tracer_warnings(): # as_tensor results are registered as constants
+ G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
+ F = self.num_channels
+ c = C // F
+
+ y = x.reshape(G, -1, F, c, H, W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
+ y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
+ y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
+ y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
+ y = y.mean(dim=[2,3,4]) # [nF] Take average over channels and pixels.
+ y = y.reshape(-1, F, 1, 1) # [nF11] Add missing dimensions.
+ y = y.repeat(G, 1, H, W) # [NFHW] Replicate over group and pixels.
+ x = torch.cat([x, y], dim=1) # [NCHW] Append to input as new channels.
+ return x
+
+ def extra_repr(self):
+ return f'group_size={self.group_size}, num_channels={self.num_channels:d}'
+
+#----------------------------------------------------------------------------
+
+
+class DiscriminatorEpilogue(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels.
+ cmap_dim, # Dimensionality of mapped conditioning label, 0 = no label.
+ resolution, # Resolution of this block.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ mbstd_group_size = 2, # Group size for the minibatch standard deviation layer, None = entire minibatch.
+ mbstd_num_channels = 1, # Number of features for the minibatch standard deviation layer, 0 = disable.
+ activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
+ conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.cmap_dim = cmap_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.architecture = architecture
+
+ if architecture == 'skip':
+ self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
+ self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
+ self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
+ self.fc = FullyConnectedLayer(in_channels * (resolution ** 2), in_channels, activation=activation)
+ self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)
+
+ def forward(self, x, img, cmap, force_fp32=False):
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution]) # [NCHW]
+ _ = force_fp32 # unused
+ dtype = torch.float32
+ memory_format = torch.contiguous_format
+
+ # FromRGB.
+ x = x.to(dtype=dtype, memory_format=memory_format)
+ if self.architecture == 'skip':
+ misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
+ img = img.to(dtype=dtype, memory_format=memory_format)
+ x = x + self.fromrgb(img)
+
+ # Main layers.
+ if self.mbstd is not None:
+ x = self.mbstd(x)
+ x = self.conv(x)
+ x = self.fc(x.flatten(1))
+ x = self.out(x)
+
+ # Conditioning.
+ if self.cmap_dim > 0:
+ misc.assert_shape(cmap, [None, self.cmap_dim])
+ x = (x * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
+
+ assert x.dtype == dtype
+ return x
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+#----------------------------------------------------------------------------
+
+
+class Discriminator(torch.nn.Module):
+ def __init__(self,
+ c_dim, # Conditioning label (C) dimensionality.
+ img_resolution, # Input resolution.
+ img_channels, # Number of input color channels.
+ architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
+ block_kwargs = {}, # Arguments for DiscriminatorBlock.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
+ ):
+ super().__init__()
+ self.c_dim = c_dim
+ self.img_resolution = img_resolution
+ self.img_resolution_log2 = int(np.log2(img_resolution))
+ self.img_channels = img_channels
+ self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
+ channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
+ fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
+
+ if cmap_dim is None:
+ cmap_dim = channels_dict[4]
+ if c_dim == 0:
+ cmap_dim = 0
+
+ common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
+ cur_layer_idx = 0
+ for res in self.block_resolutions:
+ in_channels = channels_dict[res] if res < img_resolution else 0
+ tmp_channels = channels_dict[res]
+ out_channels = channels_dict[res // 2]
+ use_fp16 = (res >= fp16_resolution)
+ block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
+ first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
+ setattr(self, f'b{res}', block)
+ cur_layer_idx += block.num_layers
+ if c_dim > 0:
+ self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
+ self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
+
+ def forward(self, img, c, update_emas=False, **block_kwargs):
+ _ = update_emas # unused
+ x = None
+ for res in self.block_resolutions:
+ block = getattr(self, f'b{res}')
+ x, img = block(x, img, **block_kwargs)
+
+ cmap = None
+ if self.c_dim > 0:
+ cmap = self.mapping(None, c)
+ x = self.b4(x, img, cmap)
+ return x
+
+ def extra_repr(self):
+ return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
+
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/modules/eg3ds/models/networks_stylegan3.py b/modules/eg3ds/models/networks_stylegan3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c38853db600f4006c3f6e0045a8df1e707ee85
--- /dev/null
+++ b/modules/eg3ds/models/networks_stylegan3.py
@@ -0,0 +1,516 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Generator architecture from the paper
+"Alias-Free Generative Adversarial Networks"."""
+
+import numpy as np
+import scipy.signal
+import scipy.optimize
+import torch
+from modules.eg3ds.torch_utils import misc
+from modules.eg3ds.torch_utils.ops import conv2d_gradfix
+from modules.eg3ds.torch_utils.ops import filtered_lrelu
+from modules.eg3ds.torch_utils.ops import bias_act
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def modulated_conv2d(
+ x, # Input tensor: [batch_size, in_channels, in_height, in_width]
+ w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
+ s, # Style tensor: [batch_size, in_channels]
+ demodulate = True, # Apply weight demodulation?
+ padding = 0, # Padding: int or [padH, padW]
+ input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
+):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ batch_size = int(x.shape[0])
+ out_channels, in_channels, kh, kw = w.shape
+ misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
+ misc.assert_shape(s, [batch_size, in_channels]) # [NI]
+
+ # Pre-normalize inputs.
+ if demodulate:
+ w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
+ s = s * s.square().mean().rsqrt()
+
+ # Modulate weights.
+ w = w.unsqueeze(0) # [NOIkk]
+ w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Demodulate weights.
+ if demodulate:
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
+ w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Apply input scaling.
+ if input_gain is not None:
+ input_gain = input_gain.expand(batch_size, in_channels) # [NI]
+ w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
+
+ # Execute as one fused op using grouped convolution.
+ x = x.reshape(1, -1, *x.shape[2:])
+ w = w.reshape(-1, in_channels, kh, kw)
+ x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
+ x = x.reshape(batch_size, -1, *x.shape[2:])
+ return x
+
+#----------------------------------------------------------------------------
+
+
+class FullyConnectedLayer(torch.nn.Module):
+ def __init__(self,
+ in_features, # Number of input features.
+ out_features, # Number of output features.
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
+ bias = True, # Apply additive bias before the activation function?
+ lr_multiplier = 1, # Learning rate multiplier.
+ weight_init = 1, # Initial standard deviation of the weight tensor.
+ bias_init = 0, # Initial value of the additive bias.
+ ):
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.activation = activation
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
+ bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
+ self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
+ self.bias_gain = lr_multiplier
+
+ def forward(self, x):
+ w = self.weight.to(x.dtype) * self.weight_gain
+ b = self.bias
+ if b is not None:
+ b = b.to(x.dtype)
+ if self.bias_gain != 1:
+ b = b * self.bias_gain
+ if self.activation == 'linear' and b is not None:
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
+ else:
+ x = x.matmul(w.t())
+ x = bias_act.bias_act(x, b, act=self.activation)
+ return x
+
+ def extra_repr(self):
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
+
+#----------------------------------------------------------------------------
+
+
+class MappingNetwork(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ num_ws, # Number of intermediate latents to output.
+ num_layers = 2, # Number of mapping layers.
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.num_ws = num_ws
+ self.num_layers = num_layers
+ self.w_avg_beta = w_avg_beta
+
+ # Construct layers.
+ self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
+ features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
+ for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
+ layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
+ setattr(self, f'fc{idx}', layer)
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
+ misc.assert_shape(z, [None, self.z_dim])
+ if truncation_cutoff is None:
+ truncation_cutoff = self.num_ws
+
+ # Embed, normalize, and concatenate inputs.
+ x = z.to(torch.float32)
+ x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ if self.c_dim > 0:
+ misc.assert_shape(c, [None, self.c_dim])
+ y = self.embed(c.to(torch.float32))
+ y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
+ x = torch.cat([x, y], dim=1) if x is not None else y
+
+ # Execute layers.
+ for idx in range(self.num_layers):
+ x = getattr(self, f'fc{idx}')(x)
+
+ # Update moving average of W.
+ if update_emas:
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
+
+ # Broadcast and apply truncation.
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
+ if truncation_psi != 1:
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
+ return x
+
+ def extra_repr(self):
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisInput(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ channels, # Number of output channels.
+ size, # Output spatial size: int or [width, height].
+ sampling_rate, # Output sampling rate.
+ bandwidth, # Output bandwidth.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.channels = channels
+ self.size = np.broadcast_to(np.asarray(size), [2])
+ self.sampling_rate = sampling_rate
+ self.bandwidth = bandwidth
+
+ # Draw random frequencies from uniform 2D disc.
+ freqs = torch.randn([self.channels, 2])
+ radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
+ freqs /= radii * radii.square().exp().pow(0.25)
+ freqs *= bandwidth
+ phases = torch.rand([self.channels]) - 0.5
+
+ # Setup parameters and buffers.
+ self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
+ self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
+ self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
+ self.register_buffer('freqs', freqs)
+ self.register_buffer('phases', phases)
+
+ def forward(self, w):
+ # Introduce batch dimension.
+ transforms = self.transform.unsqueeze(0) # [batch, row, col]
+ freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
+ phases = self.phases.unsqueeze(0) # [batch, channel]
+
+ # Apply learned transformation.
+ t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
+ t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
+ m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
+ m_r[:, 0, 0] = t[:, 0] # r'_c
+ m_r[:, 0, 1] = -t[:, 1] # r'_s
+ m_r[:, 1, 0] = t[:, 1] # r'_s
+ m_r[:, 1, 1] = t[:, 0] # r'_c
+ m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
+ m_t[:, 0, 2] = -t[:, 2] # t'_x
+ m_t[:, 1, 2] = -t[:, 3] # t'_y
+ transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
+
+ # Transform frequencies.
+ phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
+ freqs = freqs @ transforms[:, :2, :2]
+
+ # Dampen out-of-band frequencies that may occur due to the user-specified transform.
+ amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
+
+ # Construct sampling grid.
+ theta = torch.eye(2, 3, device=w.device)
+ theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
+ theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
+ grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
+
+ # Compute Fourier features.
+ x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
+ x = x + phases.unsqueeze(1).unsqueeze(2)
+ x = torch.sin(x * (np.pi * 2))
+ x = x * amplitudes.unsqueeze(1).unsqueeze(2)
+
+ # Apply trainable mapping.
+ weight = self.weight / np.sqrt(self.channels)
+ x = x @ weight.t()
+
+ # Ensure correct shape.
+ x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
+ misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])
+ return x
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
+ f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisLayer(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ is_torgb, # Is this the final ToRGB layer?
+ is_critically_sampled, # Does this layer use critical sampling?
+ use_fp16, # Does this layer use FP16?
+
+ # Input & output specifications.
+ in_channels, # Number of input channels.
+ out_channels, # Number of output channels.
+ in_size, # Input spatial size: int or [width, height].
+ out_size, # Output spatial size: int or [width, height].
+ in_sampling_rate, # Input sampling rate (s).
+ out_sampling_rate, # Output sampling rate (s).
+ in_cutoff, # Input cutoff frequency (f_c).
+ out_cutoff, # Output cutoff frequency (f_c).
+ in_half_width, # Input transition band half-width (f_h).
+ out_half_width, # Output Transition band half-width (f_h).
+
+ # Hyperparameters.
+ conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
+ filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
+ lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
+ use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
+ conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
+ magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.is_torgb = is_torgb
+ self.is_critically_sampled = is_critically_sampled
+ self.use_fp16 = use_fp16
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.in_size = np.broadcast_to(np.asarray(in_size), [2])
+ self.out_size = np.broadcast_to(np.asarray(out_size), [2])
+ self.in_sampling_rate = in_sampling_rate
+ self.out_sampling_rate = out_sampling_rate
+ self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
+ self.in_cutoff = in_cutoff
+ self.out_cutoff = out_cutoff
+ self.in_half_width = in_half_width
+ self.out_half_width = out_half_width
+ self.conv_kernel = 1 if is_torgb else conv_kernel
+ self.conv_clamp = conv_clamp
+ self.magnitude_ema_beta = magnitude_ema_beta
+
+ # Setup parameters and buffers.
+ self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
+ self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
+ self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
+ self.register_buffer('magnitude_ema', torch.ones([]))
+
+ # Design upsampling filter.
+ self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
+ assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
+ self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
+ self.register_buffer('up_filter', self.design_lowpass_filter(
+ numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))
+
+ # Design downsampling filter.
+ self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
+ assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
+ self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
+ self.down_radial = use_radial_filters and not self.is_critically_sampled
+ self.register_buffer('down_filter', self.design_lowpass_filter(
+ numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))
+
+ # Compute padding.
+ pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
+ pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
+ pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
+ pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
+ pad_hi = pad_total - pad_lo
+ self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]
+
+ def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
+ assert noise_mode in ['random', 'const', 'none'] # unused
+ misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])
+ misc.assert_shape(w, [x.shape[0], self.w_dim])
+
+ # Track input magnitude.
+ if update_emas:
+ with torch.autograd.profiler.record_function('update_magnitude_ema'):
+ magnitude_cur = x.detach().to(torch.float32).square().mean()
+ self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
+ input_gain = self.magnitude_ema.rsqrt()
+
+ # Execute affine layer.
+ styles = self.affine(w)
+ if self.is_torgb:
+ weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
+ styles = styles * weight_gain
+
+ # Execute modulated conv2d.
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
+ x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
+ padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)
+
+ # Execute bias, filtered leaky ReLU, and clamping.
+ gain = 1 if self.is_torgb else np.sqrt(2)
+ slope = 1 if self.is_torgb else 0.2
+ x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
+ up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)
+
+ # Ensure correct shape and dtype.
+ misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])
+ assert x.dtype == dtype
+ return x
+
+ @staticmethod
+ def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
+ assert numtaps >= 1
+
+ # Identity filter.
+ if numtaps == 1:
+ return None
+
+ # Separable Kaiser low-pass filter.
+ if not radial:
+ f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
+ return torch.as_tensor(f, dtype=torch.float32)
+
+ # Radially symmetric jinc-based filter.
+ x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
+ r = np.hypot(*np.meshgrid(x, x))
+ f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
+ beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
+ w = np.kaiser(numtaps, beta)
+ f *= np.outer(w, w)
+ f /= np.sum(f)
+ return torch.as_tensor(f, dtype=torch.float32)
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
+ f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
+ f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
+ f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
+ f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
+ f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisNetwork(torch.nn.Module):
+ def __init__(self,
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output image resolution.
+ img_channels, # Number of color channels.
+ channel_base = 32768, # Overall multiplier for the number of channels.
+ channel_max = 512, # Maximum number of channels in any layer.
+ num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
+ num_critical = 2, # Number of critically sampled layers at the end.
+ first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
+ first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
+ last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
+ margin_size = 10, # Number of additional pixels outside the image.
+ output_scale = 0.25, # Scale factor for the output image.
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ super().__init__()
+ self.w_dim = w_dim
+ self.num_ws = num_layers + 2
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.num_layers = num_layers
+ self.num_critical = num_critical
+ self.margin_size = margin_size
+ self.output_scale = output_scale
+ self.num_fp16_res = num_fp16_res
+
+ # Geometric progression of layer cutoffs and min. stopbands.
+ last_cutoff = self.img_resolution / 2 # f_{c,N}
+ last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
+ exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
+ cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]
+ stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]
+
+ # Compute remaining layer parameters.
+ sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
+ half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
+ sizes = sampling_rates + self.margin_size * 2
+ sizes[-2:] = self.img_resolution
+ channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
+ channels[-1] = self.img_channels
+
+ # Construct layers.
+ self.input = SynthesisInput(
+ w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),
+ sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])
+ self.layer_names = []
+ for idx in range(self.num_layers + 1):
+ prev = max(idx - 1, 0)
+ is_torgb = (idx == self.num_layers)
+ is_critically_sampled = (idx >= self.num_layers - self.num_critical)
+ use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
+ layer = SynthesisLayer(
+ w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
+ in_channels=int(channels[prev]), out_channels= int(channels[idx]),
+ in_size=int(sizes[prev]), out_size=int(sizes[idx]),
+ in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
+ in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
+ in_half_width=half_widths[prev], out_half_width=half_widths[idx],
+ **layer_kwargs)
+ name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
+ setattr(self, name, layer)
+ self.layer_names.append(name)
+
+ def forward(self, ws, **layer_kwargs):
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
+ ws = ws.to(torch.float32).unbind(dim=1)
+
+ # Execute layers.
+ x = self.input(ws[0])
+ for name, w in zip(self.layer_names, ws[1:]):
+ x = getattr(self, name)(x, w, **layer_kwargs)
+ if self.output_scale != 1:
+ x = x * self.output_scale
+
+ # Ensure correct shape and dtype.
+ misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])
+ x = x.to(torch.float32)
+ return x
+
+ def extra_repr(self):
+ return '\n'.join([
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
+ f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
+ f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])
+
+#----------------------------------------------------------------------------
+
+
+class Generator(torch.nn.Module):
+ def __init__(self,
+ z_dim, # Input latent (Z) dimensionality.
+ c_dim, # Conditioning label (C) dimensionality.
+ w_dim, # Intermediate latent (W) dimensionality.
+ img_resolution, # Output resolution.
+ img_channels, # Number of output color channels.
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
+ ):
+ super().__init__()
+ self.z_dim = z_dim
+ self.c_dim = c_dim
+ self.w_dim = w_dim
+ self.img_resolution = img_resolution
+ self.img_channels = img_channels
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
+ self.num_ws = self.synthesis.num_ws
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
+
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ return img
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/models/superresolution.py b/modules/eg3ds/models/superresolution.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb1bf50ae0a3153600c297090a053b3d5f5111e1
--- /dev/null
+++ b/modules/eg3ds/models/superresolution.py
@@ -0,0 +1,360 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Superresolution network architectures from the paper
+"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from modules.eg3ds.models.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer
+from modules.eg3ds.torch_utils.ops import upfirdn2d
+from modules.eg3ds.torch_utils import misc
+
+from modules.eg3ds.models.networks_stylegan2 import SynthesisBlock
+from modules.eg3ds.models.networks_stylegan3 import SynthesisLayer as AFSynthesisLayer
+from utils.commons.hparams import hparams
+
+
+#----------------------------------------------------------------------------
+
+# for 512x512 generation
+class SuperresolutionHybrid8X(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 512
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 128
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlock(channels, 128, w_dim=512, resolution=256,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=512,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+# for 256x256 generation
+
+class SuperresolutionHybrid4X(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 256
+ use_fp16 = sr_num_fp16_res > 0
+ self.sr_antialias = sr_antialias
+ self.input_resolution = 128
+ self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] < self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+# for 128 x 128 generation
+
+class SuperresolutionHybrid2X(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 128
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 64
+ self.sr_antialias = sr_antialias
+ self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=64,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=128,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+# TODO: Delete (here for backwards compatibility with old 256x256 models)
+
+class SuperresolutionHybridDeepfp32(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res,
+ num_fp16_res=4, conv_clamp=None, channel_base=None, channel_max=None,# IGNORE
+ **block_kwargs):
+ super().__init__()
+ assert img_resolution == 256
+ use_fp16 = sr_num_fp16_res > 0
+
+ self.input_resolution = 128
+ self.block0 = SynthesisBlockNoUp(channels, 128, w_dim=512, resolution=128,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(128, 64, w_dim=512, resolution=256,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] < self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+
+#----------------------------------------------------------------------------
+
+
+class SynthesisBlockNoUp(torch.nn.Module):
+ def __init__(self,
+ in_channels, # Number of input channels, 0 = first block.
+ out_channels, # Number of output channels.
+ w_dim, # Intermediate latent (W) dimensionality.
+ resolution, # Resolution of this block.
+ img_channels, # Number of output color channels.
+ is_last, # Is this the last block?
+ architecture = 'skip', # Architecture: 'orig', 'skip', 'resnet'.
+ resample_filter = [1,3,3,1], # Low-pass filter to apply when resampling activations.
+ conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
+ use_fp16 = False, # Use FP16 for this block?
+ fp16_channels_last = False, # Use channels-last memory format with FP16?
+ fused_modconv_default = True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
+ **layer_kwargs, # Arguments for SynthesisLayer.
+ ):
+ assert architecture in ['orig', 'skip', 'resnet']
+ super().__init__()
+ self.in_channels = in_channels
+ self.w_dim = w_dim
+ self.resolution = resolution
+ self.img_channels = img_channels
+ self.is_last = is_last
+ self.architecture = architecture
+ self.use_fp16 = use_fp16
+ self.channels_last = (use_fp16 and fp16_channels_last)
+ self.fused_modconv_default = fused_modconv_default
+ self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
+ self.num_conv = 0
+ self.num_torgb = 0
+
+ if in_channels == 0:
+ self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))
+
+ if in_channels != 0:
+ self.conv0 = SynthesisLayer(in_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ self.conv1 = SynthesisLayer(out_channels, out_channels, w_dim=w_dim, resolution=resolution,
+ conv_clamp=conv_clamp, channels_last=self.channels_last, **layer_kwargs)
+ self.num_conv += 1
+
+ if is_last or architecture == 'skip':
+ self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim,
+ conv_clamp=conv_clamp, channels_last=self.channels_last)
+ self.num_torgb += 1
+
+ if in_channels != 0 and architecture == 'resnet':
+ self.skip = Conv2dLayer(in_channels, out_channels, kernel_size=1, bias=False, up=2,
+ resample_filter=resample_filter, channels_last=self.channels_last)
+
+ def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
+ _ = update_emas # unused
+ misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
+ w_iter = iter(ws.unbind(dim=1))
+ if ws.device.type != 'cuda':
+ force_fp32 = True
+ dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
+ memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
+ if fused_modconv is None:
+ fused_modconv = self.fused_modconv_default
+ if fused_modconv == 'inference_only':
+ fused_modconv = (not self.training)
+
+ # Input.
+ if self.in_channels == 0:
+ x = self.const.to(dtype=dtype, memory_format=memory_format)
+ x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
+ else:
+ misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
+ x = x.to(dtype=dtype, memory_format=memory_format)
+
+ # Main layers.
+ if self.in_channels == 0:
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ elif self.architecture == 'resnet':
+ y = self.skip(x, gain=np.sqrt(0.5))
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
+ x = y.add_(x)
+ else:
+ x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+ x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
+
+ # ToRGB.
+ # if img is not None:
+ # misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
+ # img = upfirdn2d.upsample2d(img, self.resample_filter)
+ if self.is_last or self.architecture == 'skip':
+ y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
+ y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
+ img = img.add_(y) if img is not None else y
+
+ assert x.dtype == dtype
+ assert img is None or img.dtype == torch.float32
+ return x, img
+
+ def extra_repr(self):
+ return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
+
+
+#----------------------------------------------------------------------------
+# for 512x512 generation
+class ResBlock2d(nn.Module):
+ """
+ Res block, preserve spatial resolution.
+ """
+
+ def __init__(self, in_features, kernel_size, padding):
+ super(ResBlock2d, self).__init__()
+ self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+ padding=padding)
+ self.act = nn.ReLU(inplace=False)
+ # self.act = nn.LeakyReLU(inplace=False) # run3
+ # self.norm1 = nn.BatchNorm2d(in_features, affine=True)
+ # self.norm2 = nn.BatchNorm2d(in_features, affine=True)
+
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.act(out)
+ out = self.conv2(out)
+ out = self.act(out)
+ out = out + x
+ return out
+
+ # def forward(self, x):
+ # out = self.norm1(x)
+ # out = F.relu(out)
+ # out = self.conv1(out)
+ # out = self.norm2(out)
+ # out = F.relu(out)
+ # out = self.conv2(out)
+ # out = x + out
+ # return out
+
+
+class LargeSynthesisBlock0(nn.Module):
+ def __init__(self, channels, use_fp16, **block_kwargs):
+ super().__init__()
+ self.block = SynthesisBlock(channels, 256, w_dim=512, resolution=256,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.resblocks = nn.Sequential(*[
+ ResBlock2d(256, kernel_size=3, padding=1) for _ in range(hparams['resblocks_in_large_sr'])
+ ])
+ self.to_rgb = nn.Conv2d(256, 3, kernel_size=1)
+
+ def forward(self, x, rgb, ws, **block_kwargs):
+ x, rgb = self.block(x, rgb, ws, **block_kwargs)
+ x = self.resblocks(x)
+ rgb = rgb + self.to_rgb(x)
+ return x, rgb
+
+class LargeSynthesisBlock1(nn.Module):
+ def __init__(self, use_fp16, **block_kwargs):
+ super().__init__()
+ self.block = SynthesisBlock(256, 128, w_dim=512, resolution=512,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.resblocks = nn.Sequential(*[
+ ResBlock2d(128, kernel_size=3, padding=1) for _ in range(hparams['resblocks_in_large_sr'])
+ ])
+ self.to_rgb = nn.Conv2d(128, 3, kernel_size=1)
+
+ def forward(self, x, rgb, ws, **block_kwargs):
+ x, rgb = self.block(x, rgb, ws, **block_kwargs)
+ x = self.resblocks(x)
+ rgb = rgb + self.to_rgb(x)
+ return x, rgb
+
+class SuperresolutionHybrid8XDC(torch.nn.Module):
+ def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, large_sr=False, **block_kwargs):
+ super().__init__()
+ assert img_resolution == 512
+
+ use_fp16 = sr_num_fp16_res > 0
+ self.input_resolution = 128
+ self.sr_antialias = sr_antialias
+ if large_sr is True:
+ self.block0 = LargeSynthesisBlock0(channels, use_fp16=sr_num_fp16_res > 0, **block_kwargs)
+ self.block1 = LargeSynthesisBlock1(use_fp16=sr_num_fp16_res > 0, **block_kwargs)
+ else:
+ self.block0 = SynthesisBlock(channels, 256, w_dim=512, resolution=256,
+ img_channels=3, is_last=False, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+ self.block1 = SynthesisBlock(256, 128, w_dim=512, resolution=512,
+ img_channels=3, is_last=True, use_fp16=use_fp16, conv_clamp=(256 if use_fp16 else None), **block_kwargs)
+
+ def forward(self, rgb, x, ws, **block_kwargs):
+ ws = ws[:, -1:, :].repeat(1, 3, 1)
+
+ if x.shape[-1] != self.input_resolution:
+ x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+ rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution),
+ mode='bilinear', align_corners=False, antialias=self.sr_antialias)
+
+ x, rgb = self.block0(x, rgb, ws, **block_kwargs)
+ x, rgb = self.block1(x, rgb, ws, **block_kwargs)
+ return rgb
+#----------------------------------------------------------------------------
\ No newline at end of file
diff --git a/modules/eg3ds/models/triplane.py b/modules/eg3ds/models/triplane.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c582f66693b05a8830fe5f1a46ea3cf21f21634
--- /dev/null
+++ b/modules/eg3ds/models/triplane.py
@@ -0,0 +1,189 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import torch
+import torch.nn as nn
+from modules.eg3ds.models.networks_stylegan2 import Generator as StyleGAN2Backbone
+from modules.eg3ds.models.networks_stylegan2 import FullyConnectedLayer
+from modules.eg3ds.volumetric_rendering.renderer import ImportanceRenderer
+from modules.eg3ds.volumetric_rendering.ray_sampler import RaySampler
+from modules.eg3ds.models.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X, SuperresolutionHybrid8X, SuperresolutionHybrid8XDC
+
+import copy
+from utils.commons.hparams import hparams
+
+
+class TriPlaneGenerator(torch.nn.Module):
+ def __init__(self, hp=None):
+ super().__init__()
+ global hparams
+ self.hparams = copy.copy(hparams) if hp is None else copy.copy(hp)
+ hparams = self.hparams
+
+ self.z_dim = hparams['z_dim']
+ self.camera_dim = 25
+ self.w_dim=hparams['w_dim']
+
+ self.img_resolution = hparams['final_resolution']
+ self.img_channels = 3
+ self.renderer = ImportanceRenderer(hp=hparams)
+ self.renderer.triplane_feature_type = 'triplane'
+ self.ray_sampler = RaySampler()
+
+ self.neural_rendering_resolution = hparams['neural_rendering_resolution']
+
+ mapping_kwargs = {'num_layers': hparams['mapping_network_depth']}
+ synthesis_kwargs = {'channel_base': hparams['base_channel'], 'channel_max': hparams['max_channel'], 'fused_modconv_default': 'inference_only', 'num_fp16_res': hparams['num_fp16_layers_in_generator'], 'conv_clamp': None}
+
+ triplane_c_dim = self.camera_dim
+
+ # if gen_cond_mode == 'mapping', add a cond_mapping in backbone
+ self.backbone = StyleGAN2Backbone(self.z_dim, triplane_c_dim, self.w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
+ self.decoder = OSGDecoder(32, {'decoder_lr_mul': 1, 'decoder_output_dim': 32})
+
+ self.rendering_kwargs = {'image_resolution': hparams['final_resolution'],
+ 'disparity_space_sampling': False,
+ 'clamp_mode': 'softplus',
+ 'gpc_reg_prob': hparams['gpc_reg_prob'],
+ 'c_scale': 1.0,
+ 'superresolution_noise_mode': 'none',
+ 'density_reg': hparams['lambda_density_reg'], 'density_reg_p_dist': hparams['density_reg_p_dist'],
+ 'reg_type': 'l1', 'decoder_lr_mul': 1.0,
+ 'sr_antialias': True,
+ 'depth_resolution': hparams['num_samples_coarse'],
+ 'depth_resolution_importance': hparams['num_samples_fine'],
+ 'ray_start': hparams['ray_near'], 'ray_end': hparams['ray_far'],
+ 'box_warp': hparams['box_warp'],
+ 'avg_camera_radius': 2.7, # 仅仅用在infer的pose sampler里面,在那里相机围绕一个半径恒定的球移动,这个半径代表着camera距离世界坐标系中心的距离。
+ 'avg_camera_pivot': [0, 0, 0.2], # 仅仅用在infer的pose sampler里面,代表着camera看向的位置,这决定了view direction。这里的[0.,0.,0.2]应该是3dmm人脸的“人中”
+ 'white_back': False, # 如果背景是纯白色可以考虑启用,因为默认无density的世界是黑色的,这个设置让默认世界变成白色,这让网络不需要建模一层薄薄的voxel来生成白色背景。
+ }
+
+ sr_num_fp16_res = hparams['num_fp16_layers_in_super_resolution']
+ sr_kwargs = {'channel_base': hparams['base_channel'], 'channel_max': hparams['max_channel'], 'fused_modconv_default': 'inference_only'}
+ self.superresolution = SuperresolutionHybrid8XDC(channels=32, img_resolution=self.img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=True, **sr_kwargs)
+
+ def mapping(self, z, camera, cond=None, truncation_psi=0.7, truncation_cutoff=None, update_emas=False):
+ """
+ Generate weights by forward the Mapping network.
+
+ z: latent sampled from N(0,1): [B, z_dim=512]
+ camera: falttened extrinsic 4x4 matrix and intrinsic 3x3 matrix [B, c=16+9]
+ cond: auxiliary condition, such as idexp_lm3d: [B, c=68*3]
+ truncation_psi: the threshold of truncation trick in BigGAN, 1.0 means no effect, 0.0 means the ws is the mean_ws, and 0~1 value means linear interpolation in these two.
+ truncation_cutoff: number of ws to adopt truncation. default None means adopt to all ws. other int mean the first number of layers to adopt this trick.
+ """
+ c = camera
+ ws = self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ if hparams.get("gen_cond_mode", 'none') == 'mapping':
+ d_ws = self.backbone.cond_mapping(cond, None, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ ws = ws * 0.5 + d_ws * 0.5
+ return ws
+
+ def synthesis(self, ws, camera, cond=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
+ """
+ Run the Backbone to synthesize images given the ws generated by self.mapping
+ """
+ ret = {}
+
+ cam2world_matrix = camera[:, :16].view(-1, 4, 4)
+ intrinsics = camera[:, 16:25].view(-1, 3, 3)
+
+ neural_rendering_resolution = self.neural_rendering_resolution
+
+ # Create a batch of rays for volume rendering
+ ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
+
+ # Create triplanes by running StyleGAN backbone
+ N, M, _ = ray_origins.shape
+ if use_cached_backbone and self._last_planes is not None:
+ planes = self._last_planes
+ else:
+ planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ if cache_backbone:
+ self._last_planes = planes
+
+ # Reshape output into three 32-channel planes
+ planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1])
+
+ # Perform volume rendering
+ feature_samples, depth_samples, weights_samples, is_ray_valid = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
+
+ # Reshape into 'raw' neural-rendered image
+ H = W = self.neural_rendering_resolution
+ feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
+ if hparams.get("mask_invalid_rays", False):
+ is_ray_valid_mask = is_ray_valid.reshape([feature_samples.shape[0], 1,self.neural_rendering_resolution,self.neural_rendering_resolution]) # [B, 1, H, W]
+ feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] = -1
+ depth_image[~is_ray_valid_mask] = depth_image[is_ray_valid_mask].min().item()
+
+ # Run superresolution to get final image
+ rgb_image = feature_image[:, :3]
+ ws_to_sr = ws
+ if hparams['ones_ws_for_sr']:
+ ws_to_sr = torch.ones_like(ws)
+ sr_image = self.superresolution(rgb_image, feature_image, ws_to_sr, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
+
+ rgb_image = rgb_image.clamp(-1,1)
+ sr_image = sr_image.clamp(-1,1)
+ ret.update({'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image, 'image_feature': feature_image[:, 3:], 'plane': planes})
+ return ret
+
+ def sample(self, coordinates, directions, z, camera, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ """
+ Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
+ Not aggregated into pixels, but in the world coordinate.
+ """
+ ws = self.mapping(z, camera, cond=cond, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
+ return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
+
+ def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
+ """
+ Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
+ """
+ planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
+ return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
+
+ def forward(self, z, camera, cond=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
+ """
+ Render a batch of generated images.
+ """
+ ws = self.mapping(z, camera, cond=cond, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
+ return self.synthesis(ws, camera, cond=cond, update_emas=update_emas, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
+
+
+class OSGDecoder(torch.nn.Module):
+ def __init__(self, n_features, options):
+ super().__init__()
+ self.hidden_dim = 64
+
+ self.net = torch.nn.Sequential(
+ FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']),
+ torch.nn.Softplus(),
+ FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])
+ )
+
+ def forward(self, sampled_features, ray_directions):
+ # Aggregate features
+ sampled_features = sampled_features.mean(1)
+ x = sampled_features
+
+ N, M, C = x.shape
+ x = x.view(N*M, C)
+
+ x = self.net(x)
+ x = x.view(N, M, -1)
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
+ sigma = x[..., 0:1]
+ return {'rgb': rgb, 'sigma': sigma}
diff --git a/modules/eg3ds/torch_utils/__init__.py b/modules/eg3ds/torch_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/modules/eg3ds/torch_utils/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/modules/eg3ds/torch_utils/custom_ops.py b/modules/eg3ds/torch_utils/custom_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2524f47ab3d5b8750cfb868cc14012f424acc8
--- /dev/null
+++ b/modules/eg3ds/torch_utils/custom_ops.py
@@ -0,0 +1,159 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import glob
+import hashlib
+import importlib
+import os
+import re
+import shutil
+import uuid
+
+import torch
+import torch.utils.cpp_extension
+from torch.utils.file_baton import FileBaton
+
+#----------------------------------------------------------------------------
+# Global options.
+
+verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
+
+#----------------------------------------------------------------------------
+# Internal helper funcs.
+
+def _find_compiler_bindir():
+ patterns = [
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
+ 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
+ ]
+ for pattern in patterns:
+ matches = sorted(glob.glob(pattern))
+ if len(matches):
+ return matches[-1]
+ return None
+
+#----------------------------------------------------------------------------
+
+def _get_mangled_gpu_name():
+ name = torch.cuda.get_device_name().lower()
+ out = []
+ for c in name:
+ if re.match('[a-z0-9_-]+', c):
+ out.append(c)
+ else:
+ out.append('-')
+ return ''.join(out)
+
+#----------------------------------------------------------------------------
+# Main entry point for compiling and loading C++/CUDA plugins.
+
+_cached_plugins = dict()
+
+def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
+ assert verbosity in ['none', 'brief', 'full']
+ if headers is None:
+ headers = []
+ if source_dir is not None:
+ sources = [os.path.join(source_dir, fname) for fname in sources]
+ headers = [os.path.join(source_dir, fname) for fname in headers]
+
+ # Already cached?
+ if module_name in _cached_plugins:
+ return _cached_plugins[module_name]
+
+ # Print status.
+ if verbosity == 'full':
+ print(f'Setting up PyTorch plugin "{module_name}"...')
+ elif verbosity == 'brief':
+ print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
+ verbose_build = (verbosity == 'full')
+
+ # Compile and load.
+ try: # pylint: disable=too-many-nested-blocks
+ # Make sure we can find the necessary compiler binaries.
+ if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
+ compiler_bindir = _find_compiler_bindir()
+ if compiler_bindir is None:
+ raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
+ os.environ['PATH'] += ';' + compiler_bindir
+
+ # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
+ # break the build or unnecessarily restrict what's available to nvcc.
+ # Unset it to let nvcc decide based on what's available on the
+ # machine.
+ os.environ['TORCH_CUDA_ARCH_LIST'] = ''
+
+ # Incremental build md5sum trickery. Copies all the input source files
+ # into a cached build directory under a combined md5 digest of the input
+ # source files. Copying is done only if the combined digest has changed.
+ # This keeps input file timestamps and filenames the same as in previous
+ # extension builds, allowing for fast incremental rebuilds.
+ #
+ # This optimization is done only in case all the source files reside in
+ # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
+ # environment variable is set (we take this as a signal that the user
+ # actually cares about this.)
+ #
+ # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
+ # around the *.cu dependency bug in ninja config.
+ #
+ all_source_files = sorted(sources + headers)
+ all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
+ if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
+
+ # Compute combined hash digest for all source files.
+ hash_md5 = hashlib.md5()
+ for src in all_source_files:
+ with open(src, 'rb') as f:
+ hash_md5.update(f.read())
+
+ # Select cached build directory name.
+ source_digest = hash_md5.hexdigest()
+ build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
+ cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
+
+ if not os.path.isdir(cached_build_dir):
+ tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
+ os.makedirs(tmpdir)
+ for src in all_source_files:
+ shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
+ try:
+ os.replace(tmpdir, cached_build_dir) # atomic
+ except OSError:
+ # source directory already exists, delete tmpdir and its contents.
+ shutil.rmtree(tmpdir)
+ if not os.path.isdir(cached_build_dir): raise
+
+ # Compile.
+ cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
+ torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
+ verbose=verbose_build, sources=cached_sources, **build_kwargs)
+ else:
+ torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
+
+ # Load.
+ module = importlib.import_module(module_name)
+
+ except:
+ if verbosity == 'brief':
+ print('Failed!')
+ raise
+
+ # Print status and add to cache dict.
+ if verbosity == 'full':
+ print(f'Done setting up PyTorch plugin "{module_name}".')
+ elif verbosity == 'brief':
+ print('Done.')
+ _cached_plugins[module_name] = module
+ return module
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/misc.py b/modules/eg3ds/torch_utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8d56d3d55fda85709ed63716485c7d55514bd1c
--- /dev/null
+++ b/modules/eg3ds/torch_utils/misc.py
@@ -0,0 +1,268 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import re
+import contextlib
+import numpy as np
+import torch
+import warnings
+from modules.eg3ds import dnnlib
+
+#----------------------------------------------------------------------------
+# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
+# same constant is used multiple times.
+
+_constant_cache = dict()
+
+def constant(value, shape=None, dtype=None, device=None, memory_format=None):
+ value = np.asarray(value)
+ if shape is not None:
+ shape = tuple(shape)
+ if dtype is None:
+ dtype = torch.get_default_dtype()
+ if device is None:
+ device = torch.device('cpu')
+ if memory_format is None:
+ memory_format = torch.contiguous_format
+
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
+ tensor = _constant_cache.get(key, None)
+ if tensor is None:
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
+ if shape is not None:
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
+ tensor = tensor.contiguous(memory_format=memory_format)
+ _constant_cache[key] = tensor
+ return tensor
+
+#----------------------------------------------------------------------------
+# Replace NaN/Inf with specified numerical values.
+
+try:
+ nan_to_num = torch.nan_to_num # 1.8.0a0
+except AttributeError:
+ def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
+ assert isinstance(input, torch.Tensor)
+ if posinf is None:
+ posinf = torch.finfo(input.dtype).max
+ if neginf is None:
+ neginf = torch.finfo(input.dtype).min
+ assert nan == 0
+ return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
+
+#----------------------------------------------------------------------------
+# Symbolic assert.
+
+try:
+ symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
+except AttributeError:
+ symbolic_assert = torch.Assert # 1.7.0
+
+#----------------------------------------------------------------------------
+# Context manager to temporarily suppress known warnings in torch.jit.trace().
+# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
+
+@contextlib.contextmanager
+def suppress_tracer_warnings():
+ flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
+ warnings.filters.insert(0, flt)
+ yield
+ warnings.filters.remove(flt)
+
+#----------------------------------------------------------------------------
+# Assert that the shape of a tensor matches the given list of integers.
+# None indicates that the size of a dimension is allowed to vary.
+# Performs symbolic assertion when used in torch.jit.trace().
+
+def assert_shape(tensor, ref_shape):
+ if tensor.ndim != len(ref_shape):
+ raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
+ for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
+ if ref_size is None:
+ pass
+ elif isinstance(ref_size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
+ elif isinstance(size, torch.Tensor):
+ with suppress_tracer_warnings(): # as_tensor results are registered as constants
+ symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
+ elif size != ref_size:
+ raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
+
+#----------------------------------------------------------------------------
+# Function decorator that calls torch.autograd.profiler.record_function().
+
+def profiled_function(fn):
+ def decorator(*args, **kwargs):
+ with torch.autograd.profiler.record_function(fn.__name__):
+ return fn(*args, **kwargs)
+ decorator.__name__ = fn.__name__
+ return decorator
+
+#----------------------------------------------------------------------------
+# Sampler for torch.utils.data.DataLoader that loops over the dataset
+# indefinitely, shuffling items as it goes.
+
+class InfiniteSampler(torch.utils.data.Sampler):
+ def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
+ assert len(dataset) > 0
+ assert num_replicas > 0
+ assert 0 <= rank < num_replicas
+ assert 0 <= window_size <= 1
+ super().__init__(dataset)
+ self.dataset = dataset
+ self.rank = rank
+ self.num_replicas = num_replicas
+ self.shuffle = shuffle
+ self.seed = seed
+ self.window_size = window_size
+
+ def __iter__(self):
+ order = np.arange(len(self.dataset))
+ rnd = None
+ window = 0
+ if self.shuffle:
+ rnd = np.random.RandomState(self.seed)
+ rnd.shuffle(order)
+ window = int(np.rint(order.size * self.window_size))
+
+ idx = 0
+ while True:
+ i = idx % order.size
+ if idx % self.num_replicas == self.rank:
+ yield order[i]
+ if window >= 2:
+ j = (i - rnd.randint(window)) % order.size
+ order[i], order[j] = order[j], order[i]
+ idx += 1
+
+#----------------------------------------------------------------------------
+# Utilities for operating with torch.nn.Module parameters and buffers.
+
+def params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.parameters()) + list(module.buffers())
+
+def named_params_and_buffers(module):
+ assert isinstance(module, torch.nn.Module)
+ return list(module.named_parameters()) + list(module.named_buffers())
+
+def copy_params_and_buffers(src_module, dst_module, require_all=False):
+ assert isinstance(src_module, torch.nn.Module)
+ assert isinstance(dst_module, torch.nn.Module)
+ src_tensors = dict(named_params_and_buffers(src_module))
+ for name, tensor in named_params_and_buffers(dst_module):
+ assert (name in src_tensors) or (not require_all)
+ if name in src_tensors:
+ tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
+
+#----------------------------------------------------------------------------
+# Context manager for easily enabling/disabling DistributedDataParallel
+# synchronization.
+
+@contextlib.contextmanager
+def ddp_sync(module, sync):
+ assert isinstance(module, torch.nn.Module)
+ if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
+ yield
+ else:
+ with module.no_sync():
+ yield
+
+#----------------------------------------------------------------------------
+# Check DistributedDataParallel consistency across processes.
+
+def check_ddp_consistency(module, ignore_regex=None):
+ assert isinstance(module, torch.nn.Module)
+ for name, tensor in named_params_and_buffers(module):
+ fullname = type(module).__name__ + '.' + name
+ if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
+ continue
+ tensor = tensor.detach()
+ if tensor.is_floating_point():
+ tensor = nan_to_num(tensor)
+ other = tensor.clone()
+ torch.distributed.broadcast(tensor=other, src=0)
+ assert (tensor == other).all(), fullname
+
+#----------------------------------------------------------------------------
+# Print summary table of module hierarchy.
+
+def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
+ assert isinstance(module, torch.nn.Module)
+ assert not isinstance(module, torch.jit.ScriptModule)
+ assert isinstance(inputs, (tuple, list))
+
+ # Register hooks.
+ entries = []
+ nesting = [0]
+ def pre_hook(_mod, _inputs):
+ nesting[0] += 1
+ def post_hook(mod, _inputs, outputs):
+ nesting[0] -= 1
+ if nesting[0] <= max_nesting:
+ outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
+ outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
+ entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
+ hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
+ hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
+
+ # Run module.
+ outputs = module(*inputs)
+ for hook in hooks:
+ hook.remove()
+
+ # Identify unique outputs, parameters, and buffers.
+ tensors_seen = set()
+ for e in entries:
+ e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
+ e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
+ e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
+ tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
+
+ # Filter out redundant entries.
+ if skip_redundant:
+ entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
+
+ # Construct table.
+ rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
+ rows += [['---'] * len(rows[0])]
+ param_total = 0
+ buffer_total = 0
+ submodule_names = {mod: name for name, mod in module.named_modules()}
+ for e in entries:
+ name = '' if e.mod is module else submodule_names[e.mod]
+ param_size = sum(t.numel() for t in e.unique_params)
+ buffer_size = sum(t.numel() for t in e.unique_buffers)
+ output_shapes = [str(list(t.shape)) for t in e.outputs]
+ output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
+ rows += [[
+ name + (':0' if len(e.outputs) >= 2 else ''),
+ str(param_size) if param_size else '-',
+ str(buffer_size) if buffer_size else '-',
+ (output_shapes + ['-'])[0],
+ (output_dtypes + ['-'])[0],
+ ]]
+ for idx in range(1, len(e.outputs)):
+ rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
+ param_total += param_size
+ buffer_total += buffer_size
+ rows += [['---'] * len(rows[0])]
+ rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
+
+ # Print table.
+ widths = [max(len(cell) for cell in column) for column in zip(*rows)]
+ print()
+ for row in rows:
+ print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
+ print()
+ return outputs
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/__init__.py b/modules/eg3ds/torch_utils/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfebd04f47e6f6b1b44984c14c23b57d56f72240
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+# empty
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.cpp b/modules/eg3ds/torch_utils/ops/bias_act.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ee6f6d0caaf4f84b94851d223e384344e1109cdc
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.cpp
@@ -0,0 +1,103 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+
+static bool has_same_layout(torch::Tensor x, torch::Tensor y)
+{
+ if (x.dim() != y.dim())
+ return false;
+ for (int64_t i = 0; i < x.dim(); i++)
+ {
+ if (x.size(i) != y.size(i))
+ return false;
+ if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
+ return false;
+ }
+ return true;
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
+{
+ // Validate arguments.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
+ TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
+ TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
+ TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
+ TORCH_CHECK(b.dim() == 1, "b must have rank 1");
+ TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
+ TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
+ TORCH_CHECK(grad >= 0, "grad must be non-negative");
+
+ // Validate layout.
+ TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
+ TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
+ TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
+ TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
+ TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
+
+ // Create output tensor.
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+ torch::Tensor y = torch::empty_like(x);
+ TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
+
+ // Initialize CUDA kernel parameters.
+ bias_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.b = (b.numel()) ? b.data_ptr() : NULL;
+ p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
+ p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
+ p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
+ p.y = y.data_ptr();
+ p.grad = grad;
+ p.act = act;
+ p.alpha = alpha;
+ p.gain = gain;
+ p.clamp = clamp;
+ p.sizeX = (int)x.numel();
+ p.sizeB = (int)b.numel();
+ p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
+
+ // Choose CUDA kernel.
+ void* kernel;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
+ {
+ kernel = choose_bias_act_kernel(p);
+ });
+ TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
+
+ // Launch CUDA kernel.
+ p.loopX = 4;
+ int blockSize = 4 * 32;
+ int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
+ void* args[] = {&p};
+ AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
+ return y;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("bias_act", &bias_act);
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.cu b/modules/eg3ds/torch_utils/ops/bias_act.cu
new file mode 100644
index 0000000000000000000000000000000000000000..71ca3900deda41e62d80044f0e409875f4c794b5
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.cu
@@ -0,0 +1,177 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "bias_act.h"
+
+//------------------------------------------------------------------------
+// Helpers.
+
+template struct InternalType;
+template <> struct InternalType { typedef double scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+template <> struct InternalType { typedef float scalar_t; };
+
+//------------------------------------------------------------------------
+// CUDA kernel.
+
+template
+__global__ void bias_act_kernel(bias_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+ int G = p.grad;
+ scalar_t alpha = (scalar_t)p.alpha;
+ scalar_t gain = (scalar_t)p.gain;
+ scalar_t clamp = (scalar_t)p.clamp;
+ scalar_t one = (scalar_t)1;
+ scalar_t two = (scalar_t)2;
+ scalar_t expRange = (scalar_t)80;
+ scalar_t halfExpRange = (scalar_t)40;
+ scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
+ scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
+
+ // Loop over elements.
+ int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
+ for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
+ {
+ // Load.
+ scalar_t x = (scalar_t)((const T*)p.x)[xi];
+ scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
+ scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
+ scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
+ scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
+ scalar_t yy = (gain != 0) ? yref / gain : 0;
+ scalar_t y = 0;
+
+ // Apply bias.
+ ((G == 0) ? x : xref) += b;
+
+ // linear
+ if (A == 1)
+ {
+ if (G == 0) y = x;
+ if (G == 1) y = x;
+ }
+
+ // relu
+ if (A == 2)
+ {
+ if (G == 0) y = (x > 0) ? x : 0;
+ if (G == 1) y = (yy > 0) ? x : 0;
+ }
+
+ // lrelu
+ if (A == 3)
+ {
+ if (G == 0) y = (x > 0) ? x : x * alpha;
+ if (G == 1) y = (yy > 0) ? x : x * alpha;
+ }
+
+ // tanh
+ if (A == 4)
+ {
+ if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
+ if (G == 1) y = x * (one - yy * yy);
+ if (G == 2) y = x * (one - yy * yy) * (-two * yy);
+ }
+
+ // sigmoid
+ if (A == 5)
+ {
+ if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
+ if (G == 1) y = x * yy * (one - yy);
+ if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
+ }
+
+ // elu
+ if (A == 6)
+ {
+ if (G == 0) y = (x >= 0) ? x : exp(x) - one;
+ if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
+ }
+
+ // selu
+ if (A == 7)
+ {
+ if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
+ if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
+ if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
+ }
+
+ // softplus
+ if (A == 8)
+ {
+ if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
+ if (G == 1) y = x * (one - exp(-yy));
+ if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
+ }
+
+ // swish
+ if (A == 9)
+ {
+ if (G == 0)
+ y = (x < -expRange) ? 0 : x / (exp(-x) + one);
+ else
+ {
+ scalar_t c = exp(xref);
+ scalar_t d = c + one;
+ if (G == 1)
+ y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
+ else
+ y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
+ yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
+ }
+ }
+
+ // Apply gain.
+ y *= gain * dy;
+
+ // Clamp.
+ if (clamp >= 0)
+ {
+ if (G == 0)
+ y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
+ else
+ y = (yref > -clamp & yref < clamp) ? y : 0;
+ }
+
+ // Store.
+ ((T*)p.y)[xi] = (T)y;
+ }
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p)
+{
+ if (p.act == 1) return (void*)bias_act_kernel;
+ if (p.act == 2) return (void*)bias_act_kernel;
+ if (p.act == 3) return (void*)bias_act_kernel;
+ if (p.act == 4) return (void*)bias_act_kernel;
+ if (p.act == 5) return (void*)bias_act_kernel;
+ if (p.act == 6) return (void*)bias_act_kernel;
+ if (p.act == 7) return (void*)bias_act_kernel;
+ if (p.act == 8) return (void*)bias_act_kernel;
+ if (p.act == 9) return (void*)bias_act_kernel;
+ return NULL;
+}
+
+//------------------------------------------------------------------------
+// Template specializations.
+
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+template void* choose_bias_act_kernel (const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.h b/modules/eg3ds/torch_utils/ops/bias_act.h
new file mode 100644
index 0000000000000000000000000000000000000000..8994bfb4e9cae790865348e08de5f685152d3344
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.h
@@ -0,0 +1,42 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct bias_act_kernel_params
+{
+ const void* x; // [sizeX]
+ const void* b; // [sizeB] or NULL
+ const void* xref; // [sizeX] or NULL
+ const void* yref; // [sizeX] or NULL
+ const void* dy; // [sizeX] or NULL
+ void* y; // [sizeX]
+
+ int grad;
+ int act;
+ float alpha;
+ float gain;
+ float clamp;
+
+ int sizeX;
+ int sizeB;
+ int stepB;
+ int loopX;
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template void* choose_bias_act_kernel(const bias_act_kernel_params& p);
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/bias_act.py b/modules/eg3ds/torch_utils/ops/bias_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..3984639c54faae2233837175ccb210a63016426c
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/bias_act.py
@@ -0,0 +1,211 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom PyTorch ops for efficient bias and activation."""
+
+import os
+import numpy as np
+import torch
+from modules.eg3ds import dnnlib
+
+from .. import custom_ops
+from .. import misc
+
+#----------------------------------------------------------------------------
+
+activation_funcs = {
+ 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
+ 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
+ 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
+ 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
+ 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
+ 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
+ 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
+ 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
+ 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
+}
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+_null_tensor = torch.empty([0])
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='bias_act_plugin',
+ sources=['bias_act.cpp', 'bias_act.cu'],
+ headers=['bias_act.h'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+#----------------------------------------------------------------------------
+
+def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
+ r"""Fused bias and activation function.
+
+ Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
+ and scales the result by `gain`. Each of the steps is optional. In most cases,
+ the fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports first and second order gradients,
+ but not third order gradients.
+
+ Args:
+ x: Input activation tensor. Can be of any shape.
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The shape must be known, and it must match the dimension of `x`
+ corresponding to `dim`.
+ dim: The dimension in `x` corresponding to the elements of `b`.
+ The value of `dim` is ignored if `b` is not specified.
+ act: Name of the activation function to evaluate, or `"linear"` to disable.
+ Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
+ See `activation_funcs` for a full list. `None` is not allowed.
+ alpha: Shape parameter for the activation function, or `None` to use the default.
+ gain: Scaling factor for the output tensor, or `None` to use default.
+ See `activation_funcs` for the default scaling of each activation function.
+ If unsure, consider specifying 1.
+ clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
+ the clamping (default).
+ impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
+
+ Returns:
+ Tensor of the same shape and datatype as `x`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
+ return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Slow reference implementation of `bias_act()` using standard TensorFlow ops.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Add bias.
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.ndim == 1
+ assert 0 <= dim < x.ndim
+ assert b.shape[0] == x.shape[dim]
+ x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
+
+ # Evaluate activation function.
+ alpha = float(alpha)
+ x = spec.func(x, alpha=alpha)
+
+ # Scale by gain.
+ gain = float(gain)
+ if gain != 1:
+ x = x * gain
+
+ # Clamp.
+ if clamp >= 0:
+ x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
+ return x
+
+#----------------------------------------------------------------------------
+
+_bias_act_cuda_cache = dict()
+
+def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
+ """Fast CUDA implementation of `bias_act()` using custom ops.
+ """
+ # Parse arguments.
+ assert clamp is None or clamp >= 0
+ spec = activation_funcs[act]
+ alpha = float(alpha if alpha is not None else spec.def_alpha)
+ gain = float(gain if gain is not None else spec.def_gain)
+ clamp = float(clamp if clamp is not None else -1)
+
+ # Lookup from cache.
+ key = (dim, act, alpha, gain, clamp)
+ if key in _bias_act_cuda_cache:
+ return _bias_act_cuda_cache[key]
+
+ # Forward op.
+ class BiasActCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, b): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
+ x = x.contiguous(memory_format=ctx.memory_format)
+ b = b.contiguous() if b is not None else _null_tensor
+ y = x
+ if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
+ y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
+ y if 'y' in spec.ref else _null_tensor)
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ dy = dy.contiguous(memory_format=ctx.memory_format)
+ x, b, y = ctx.saved_tensors
+ dx = None
+ db = None
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ dx = dy
+ if act != 'linear' or gain != 1 or clamp >= 0:
+ dx = BiasActCudaGrad.apply(dy, x, b, y)
+
+ if ctx.needs_input_grad[1]:
+ db = dx.sum([i for i in range(dx.ndim) if i != dim])
+
+ return dx, db
+
+ # Backward op.
+ class BiasActCudaGrad(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
+ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
+ dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
+ ctx.save_for_backward(
+ dy if spec.has_2nd_grad else _null_tensor,
+ x, b, y)
+ return dx
+
+ @staticmethod
+ def backward(ctx, d_dx): # pylint: disable=arguments-differ
+ d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
+ dy, x, b, y = ctx.saved_tensors
+ d_dy = None
+ d_x = None
+ d_b = None
+ d_y = None
+
+ if ctx.needs_input_grad[0]:
+ d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
+
+ if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
+ d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
+
+ if spec.has_2nd_grad and ctx.needs_input_grad[2]:
+ d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
+
+ return d_dy, d_x, d_b, d_y
+
+ # Add to cache.
+ _bias_act_cuda_cache[key] = BiasActCuda
+ return BiasActCuda
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/conv2d_gradfix.py b/modules/eg3ds/torch_utils/ops/conv2d_gradfix.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a177cc1c0b6eabf16908cf9afaa4387e7716b72
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/conv2d_gradfix.py
@@ -0,0 +1,199 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""Custom replacement for `torch.nn.functional.conv2d` that supports
+arbitrarily high order gradients with zero performance penalty."""
+
+import contextlib
+import torch
+
+# pylint: disable=redefined-builtin
+# pylint: disable=arguments-differ
+# pylint: disable=protected-access
+
+#----------------------------------------------------------------------------
+
+enabled = False # Enable the custom op by setting this to true.
+weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights.
+
+@contextlib.contextmanager
+def no_weight_gradients(disable=True):
+ global weight_gradients_disabled
+ old = weight_gradients_disabled
+ if disable:
+ weight_gradients_disabled = True
+ yield
+ weight_gradients_disabled = old
+
+#----------------------------------------------------------------------------
+
+def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
+
+def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):
+ if _should_use_custom_op(input):
+ return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)
+
+#----------------------------------------------------------------------------
+
+def _should_use_custom_op(input):
+ assert isinstance(input, torch.Tensor)
+ if (not enabled) or (not torch.backends.cudnn.enabled):
+ return False
+ if input.device.type != 'cuda':
+ return False
+ return True
+
+def _tuple_of_ints(xs, ndim):
+ xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
+ assert len(xs) == ndim
+ assert all(isinstance(x, int) for x in xs)
+ return xs
+
+#----------------------------------------------------------------------------
+
+_conv2d_gradfix_cache = dict()
+_null_tensor = torch.empty([0])
+
+def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):
+ # Parse arguments.
+ ndim = 2
+ weight_shape = tuple(weight_shape)
+ stride = _tuple_of_ints(stride, ndim)
+ padding = _tuple_of_ints(padding, ndim)
+ output_padding = _tuple_of_ints(output_padding, ndim)
+ dilation = _tuple_of_ints(dilation, ndim)
+
+ # Lookup from cache.
+ key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
+ if key in _conv2d_gradfix_cache:
+ return _conv2d_gradfix_cache[key]
+
+ # Validate arguments.
+ assert groups >= 1
+ assert len(weight_shape) == ndim + 2
+ assert all(stride[i] >= 1 for i in range(ndim))
+ assert all(padding[i] >= 0 for i in range(ndim))
+ assert all(dilation[i] >= 0 for i in range(ndim))
+ if not transpose:
+ assert all(output_padding[i] == 0 for i in range(ndim))
+ else: # transpose
+ assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))
+
+ # Helpers.
+ common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)
+ def calc_output_padding(input_shape, output_shape):
+ if transpose:
+ return [0, 0]
+ return [
+ input_shape[i + 2]
+ - (output_shape[i + 2] - 1) * stride[i]
+ - (1 - 2 * padding[i])
+ - dilation[i] * (weight_shape[i + 2] - 1)
+ for i in range(ndim)
+ ]
+
+ # Forward & backward.
+ class Conv2d(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, weight, bias):
+ assert weight.shape == weight_shape
+ ctx.save_for_backward(
+ input if weight.requires_grad else _null_tensor,
+ weight if input.requires_grad else _null_tensor,
+ )
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):
+ a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)
+ c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)
+ c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)
+ c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ if transpose:
+ return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)
+ return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight = ctx.saved_tensors
+ input_shape = ctx.input_shape
+ grad_input = None
+ grad_weight = None
+ grad_bias = None
+
+ if ctx.needs_input_grad[0]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad_input = op.apply(grad_output, weight, None)
+ assert grad_input.shape == input_shape
+
+ if ctx.needs_input_grad[1] and not weight_gradients_disabled:
+ grad_weight = Conv2dGradWeight.apply(grad_output, input, weight)
+ assert grad_weight.shape == weight_shape
+
+ if ctx.needs_input_grad[2]:
+ grad_bias = grad_output.sum([0, 2, 3])
+
+ return grad_input, grad_weight, grad_bias
+
+ # Gradient with respect to the weights.
+ class Conv2dGradWeight(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, grad_output, input, weight):
+ ctx.save_for_backward(
+ grad_output if input.requires_grad else _null_tensor,
+ input if grad_output.requires_grad else _null_tensor,
+ )
+ ctx.grad_output_shape = grad_output.shape
+ ctx.input_shape = input.shape
+
+ # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).
+ if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):
+ a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)
+ c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)
+ return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))
+
+ # General case => cuDNN.
+ return torch.ops.aten.convolution_backward(grad_output=grad_output, input=input, weight=weight, bias_sizes=None, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[False, True, False])[1]
+
+
+ @staticmethod
+ def backward(ctx, grad2_grad_weight):
+ grad_output, input = ctx.saved_tensors
+ grad_output_shape = ctx.grad_output_shape
+ input_shape = ctx.input_shape
+ grad2_grad_output = None
+ grad2_input = None
+
+ if ctx.needs_input_grad[0]:
+ grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)
+ assert grad2_grad_output.shape == grad_output_shape
+
+ if ctx.needs_input_grad[1]:
+ p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)
+ op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)
+ grad2_input = op.apply(grad_output, grad2_grad_weight, None)
+ assert grad2_input.shape == input_shape
+
+ return grad2_grad_output, grad2_input
+
+ _conv2d_gradfix_cache[key] = Conv2d
+ return Conv2d
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/conv2d_resample.py b/modules/eg3ds/torch_utils/ops/conv2d_resample.py
new file mode 100644
index 0000000000000000000000000000000000000000..5daad2efadcd79513aaf8aee9ecb08a5ce04797e
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/conv2d_resample.py
@@ -0,0 +1,147 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+"""2D convolution with optional up/downsampling."""
+
+import torch
+
+from .. import misc
+from . import conv2d_gradfix
+from . import upfirdn2d
+from .upfirdn2d import _parse_padding
+from .upfirdn2d import _get_filter_size
+
+#----------------------------------------------------------------------------
+
+def _get_weight_shape(w):
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
+ shape = [int(sz) for sz in w.shape]
+ misc.assert_shape(w, shape)
+ return shape
+
+#----------------------------------------------------------------------------
+
+def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
+ """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
+ """
+ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)
+
+ # Flip weight if requested.
+ # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
+ if not flip_weight and (kw > 1 or kh > 1):
+ w = w.flip([2, 3])
+
+ # Execute using conv2d_gradfix.
+ op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d
+ return op(x, w, stride=stride, padding=padding, groups=groups)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
+ r"""2D convolution with optional up/downsampling.
+
+ Padding is performed only once at the beginning, not between the operations.
+
+ Args:
+ x: Input tensor of shape
+ `[batch_size, in_channels, in_height, in_width]`.
+ w: Weight tensor of shape
+ `[out_channels, in_channels//groups, kernel_height, kernel_width]`.
+ f: Low-pass filter for up/downsampling. Must be prepared beforehand by
+ calling upfirdn2d.setup_filter(). None = identity (default).
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ groups: Split input channels into N groups (default: 1).
+ flip_weight: False = convolution, True = correlation (default: True).
+ flip_filter: False = convolution, True = correlation (default: False).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ # Validate arguments.
+ assert isinstance(x, torch.Tensor) and (x.ndim == 4)
+ assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
+ if isinstance(f, torch.Tensor) and f.dtype == torch.float16:
+ f = f.float()
+ assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
+ assert isinstance(up, int) and (up >= 1)
+ assert isinstance(down, int) and (down >= 1)
+ assert isinstance(groups, int) and (groups >= 1)
+ out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
+ fw, fh = _get_filter_size(f)
+ px0, px1, py0, py1 = _parse_padding(padding)
+
+ # Adjust padding to account for up/downsampling.
+ if up > 1:
+ px0 += (fw + up - 1) // 2
+ px1 += (fw - up) // 2
+ py0 += (fh + up - 1) // 2
+ py1 += (fh - up) // 2
+ if down > 1:
+ px0 += (fw - down + 1) // 2
+ px1 += (fw - down) // 2
+ py0 += (fh - down + 1) // 2
+ py1 += (fh - down) // 2
+
+ # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
+ if kw == 1 and kh == 1 and (down > 1 and up == 1):
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
+ if kw == 1 and kh == 1 and (up > 1 and down == 1):
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ return x
+
+ # Fast path: downsampling only => use strided convolution.
+ if down > 1 and up == 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
+ return x
+
+ # Fast path: upsampling with optional downsampling => use transpose strided convolution.
+ if up > 1:
+ if groups == 1:
+ w = w.transpose(0, 1)
+ else:
+ w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
+ w = w.transpose(1, 2)
+ w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
+ px0 -= kw - 1
+ px1 -= kw - up
+ py0 -= kh - 1
+ py1 -= kh - up
+ pxt = max(min(-px0, -px1), 0)
+ pyt = max(min(-py0, -py1), 0)
+ x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))
+ x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+ # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
+ if up == 1 and down == 1:
+ if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
+ return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)
+
+ # Fallback: Generic reference implementation.
+ x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)
+ x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
+ if down > 1:
+ x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
+ return x
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.cpp b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..4f55466235a020b0f5e150350bfdcd8b2a1e579d
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cpp
@@ -0,0 +1,304 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include
+#include
+#include "filtered_lrelu.h"
+
+//------------------------------------------------------------------------
+
+static std::tuple filtered_lrelu(
+ torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,
+ int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device");
+ TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32");
+ TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32");
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2");
+ TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large");
+ TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large");
+ TORCH_CHECK(fu.numel() > 0, "fu is empty");
+ TORCH_CHECK(fd.numel() > 0, "fd is empty");
+ TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x");
+ TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1");
+
+ // Figure out how much shared memory is available on the device.
+ int maxSharedBytes = 0;
+ AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));
+ int sharedKB = maxSharedBytes >> 10;
+
+ // Populate enough launch parameters to check if a CUDA kernel exists.
+ filtered_lrelu_kernel_params p;
+ p.up = up;
+ p.down = down;
+ p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.
+ p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);
+ filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ if (!test_spec.exec)
+ {
+ // No kernel found - return empty tensors and indicate missing kernel with return code of -1.
+ return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);
+ }
+
+ // Input/output element size.
+ int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;
+
+ // Input sizes.
+ int64_t xw = (int)x.size(3);
+ int64_t xh = (int)x.size(2);
+ int64_t fut_w = (int)fu.size(-1) - 1;
+ int64_t fut_h = (int)fu.size(0) - 1;
+ int64_t fdt_w = (int)fd.size(-1) - 1;
+ int64_t fdt_h = (int)fd.size(0) - 1;
+
+ // Logical size of upsampled buffer.
+ int64_t cw = xw * up + (px0 + px1) - fut_w;
+ int64_t ch = xh * up + (py0 + py1) - fut_h;
+ TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter");
+ TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large");
+
+ // Compute output size and allocate.
+ int64_t yw = (cw - fdt_w + (down - 1)) / down;
+ int64_t yh = (ch - fdt_h + (down - 1)) / down;
+ TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1");
+ TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large");
+ torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());
+
+ // Allocate sign tensor.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ int64_t sw_active = 0; // Active width of sign tensor.
+ if (writeSigns)
+ {
+ sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements.
+ int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height.
+ int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16.
+ TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large");
+ s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+ else if (readSigns)
+ sw_active = s.size(3) << 2;
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large");
+ }
+
+ // Populate rest of CUDA kernel parameters.
+ p.x = x.data_ptr();
+ p.y = y.data_ptr();
+ p.b = b.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.fu = fu.data_ptr();
+ p.fd = fd.data_ptr();
+ p.pad0 = make_int2(px0, py0);
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.flip = (flip_filters) ? 1 : 0;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+ p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes.
+
+ // x, y, b strides are in bytes.
+ p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));
+ p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));
+ p.bStride = sz * b.stride(0);
+
+ // fu, fd strides are in elements.
+ p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);
+ p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);
+
+ // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.
+ bool index64b = false;
+ if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;
+ if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true;
+ if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;
+ if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true;
+ if (s.numel() > INT_MAX) index64b = true;
+
+ // Choose CUDA kernel.
+ filtered_lrelu_kernel_spec spec = { 0 };
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&]
+ {
+ if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.
+ {
+ // Choose kernel based on index type, datatype and sign read/write modes.
+ if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB);
+ }
+ });
+ TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists.
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = spec.numWarps * 32;
+ int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;
+ int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;
+ int gz = p.yShape.z * p.yShape.w;
+
+ // Repeat multiple horizontal tiles in a CTA?
+ if (spec.xrep)
+ {
+ p.tilesXrep = spec.xrep;
+ p.tilesXdim = gx;
+
+ gx = (gx + p.tilesXrep - 1) / p.tilesXrep;
+ std::swap(gx, gy);
+ }
+ else
+ {
+ p.tilesXrep = 0;
+ p.tilesXdim = 0;
+ }
+
+ // Launch filter setup kernel.
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));
+
+ // Copy kernels to constant memory.
+ if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+ else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream())));
+
+ // Set cache and shared memory configurations for main kernel.
+ AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));
+ if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?
+ AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));
+ AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));
+
+ // Launch main kernel.
+ const int maxSubGz = 65535; // CUDA maximum for block z dimension.
+ for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.
+ {
+ p.blockZofs = zofs;
+ int subGz = std::min(maxSubGz, gz - zofs);
+ AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));
+ }
+
+ // Done.
+ return std::make_tuple(y, so, 0);
+}
+
+//------------------------------------------------------------------------
+
+static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)
+{
+ // Set CUDA device.
+ TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
+
+ // Validate arguments.
+ TORCH_CHECK(x.dim() == 4, "x must be rank 4");
+ TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large");
+ TORCH_CHECK(x.numel() > 0, "x is empty");
+ TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64");
+
+ // Output signs if we don't have sign input.
+ torch::Tensor so;
+ torch::Tensor s = si;
+ bool readSigns = !!s.numel();
+ if (writeSigns)
+ {
+ int64_t sw = x.size(3);
+ sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.
+ s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);
+ }
+
+ // Validate sign tensor if in use.
+ if (readSigns || writeSigns)
+ {
+ TORCH_CHECK(s.is_contiguous(), "signs must be contiguous");
+ TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8");
+ TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x");
+ TORCH_CHECK(s.dim() == 4, "signs must be rank 4");
+ TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x");
+ TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large");
+ }
+
+ // Initialize CUDA kernel parameters.
+ filtered_lrelu_act_kernel_params p;
+ p.x = x.data_ptr();
+ p.s = (readSigns || writeSigns) ? s.data_ptr() : 0;
+ p.gain = gain;
+ p.slope = slope;
+ p.clamp = clamp;
+ p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));
+ p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));
+ p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.
+ p.sOfs = make_int2(sx, sy);
+
+ // Choose CUDA kernel.
+ void* func = 0;
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&]
+ {
+ if (writeSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else if (readSigns)
+ func = choose_filtered_lrelu_act_kernel();
+ else
+ func = choose_filtered_lrelu_act_kernel();
+ });
+ TORCH_CHECK(func, "internal error - CUDA kernel not found");
+
+ // Launch CUDA kernel.
+ void* args[] = {&p};
+ int bx = 128; // 4 warps per block.
+
+ // Logical size of launch = writeSigns ? p.s : p.x
+ uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;
+ uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;
+ uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.
+ gx = (gx - 1) / bx + 1;
+
+ // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.
+ const uint32_t gmax = 65535;
+ gy = std::min(gy, gmax);
+ gz = std::min(gz, gmax);
+
+ // Launch.
+ AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));
+ return so;
+}
+
+//------------------------------------------------------------------------
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
+{
+ m.def("filtered_lrelu", &filtered_lrelu); // The whole thing.
+ m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place.
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cu
new file mode 100644
index 0000000000000000000000000000000000000000..aaac95408365f023ffaa4cb89348d499d3b948f0
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.cu
@@ -0,0 +1,1288 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+#include "filtered_lrelu.h"
+#include
+
+//------------------------------------------------------------------------
+// Helpers.
+
+enum // Filter modes.
+{
+ MODE_SUSD = 0, // Separable upsampling, separable downsampling.
+ MODE_FUSD = 1, // Full upsampling, separable downsampling.
+ MODE_SUFD = 2, // Separable upsampling, full downsampling.
+ MODE_FUFD = 3, // Full upsampling, full downsampling.
+};
+
+template struct InternalType;
+template <> struct InternalType
+{
+ typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }
+ __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+template <> struct InternalType
+{
+ typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;
+ __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }
+ __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }
+ __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }
+};
+
+#define MIN(A, B) ((A) < (B) ? (A) : (B))
+#define MAX(A, B) ((A) > (B) ? (A) : (B))
+#define CEIL_DIV(A, B) (((B)==1) ? (A) : \
+ ((B)==2) ? ((int)((A)+1) >> 1) : \
+ ((B)==4) ? ((int)((A)+3) >> 2) : \
+ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))
+
+// This works only up to blocks of size 256 x 256 and for all N that are powers of two.
+template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)
+{
+ if ((N & (N-1)) && N <= 256)
+ y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.
+ else
+ y = i/N;
+
+ x = i - y*N;
+}
+
+// Type cast stride before reading it.
+template __device__ __forceinline__ T get_stride(const int64_t& x)
+{
+ return *reinterpret_cast(&x);
+}
+
+//------------------------------------------------------------------------
+// Filters, setup kernel, copying function.
+
+#define MAX_FILTER_SIZE 32
+
+// Combined up/down filter buffers so that transfer can be done with one copy.
+__device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.
+__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.
+
+// Accessors to combined buffers to index up/down filters individually.
+#define c_fu (c_fbuf)
+#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+#define g_fu (g_fbuf)
+#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)
+
+// Set up filters into global memory buffer.
+static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)
+{
+ for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)
+ {
+ int x, y;
+ fast_div_mod(x, y, idx);
+
+ int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);
+ int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);
+ if (p.fuShape.y > 0)
+ g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];
+ else
+ g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];
+
+ int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);
+ int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);
+ if (p.fdShape.y > 0)
+ g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];
+ else
+ g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];
+ }
+}
+
+// Host function to copy filters written by setup kernel into constant buffer for main kernel.
+template static cudaError_t copy_filters(cudaStream_t stream)
+{
+ void* src = 0;
+ cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);
+ if (err) return err;
+ return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);
+}
+
+//------------------------------------------------------------------------
+// Coordinate spaces:
+// - Relative to input tensor: inX, inY, tileInX, tileInY
+// - Relative to input tile: relInX, relInY, tileInW, tileInH
+// - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH
+// - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH
+// - Relative to output tensor: outX, outY, tileOutX, tileOutY
+//
+// Relationships between coordinate spaces:
+// - inX = tileInX + relInX
+// - inY = tileInY + relInY
+// - relUpX = relInX * up + phaseInX
+// - relUpY = relInY * up + phaseInY
+// - relUpX = relOutX * down
+// - relUpY = relOutY * down
+// - outX = tileOutX + relOutX
+// - outY = tileOutY + relOutY
+
+extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.
+
+template
+static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)
+{
+ // Check that we don't try to support non-existing filter modes.
+ static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported");
+ static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported");
+ static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor");
+ static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor");
+ static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor");
+ static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor");
+ static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE");
+ static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters");
+ static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters");
+ static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4");
+ static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4");
+
+ // Static definitions.
+ typedef typename InternalType::scalar_t scalar_t;
+ typedef typename InternalType::vec2_t vec2_t;
+ typedef typename InternalType::vec4_t vec4_t;
+ const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4.
+ const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height.
+ const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width.
+ const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height.
+ const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up.
+ const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4.
+
+ // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.
+ const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));
+
+ // Sizes of logical buffers.
+ const int szIn = tileInH_up * tileInW;
+ const int szUpX = tileInH_up * tileUpW;
+ const int szUpXY = downInline ? 0 : (tileUpH * tileUpW);
+ const int szDownX = tileUpH * tileOutW;
+
+ // Sizes for shared memory arrays.
+ const int s_buf0_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :
+ (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :
+ (filterMode == MODE_FUFD) ? szIn :
+ -1;
+ const int s_buf1_size_base =
+ (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :
+ (filterMode == MODE_FUSD) ? szUpXY :
+ (filterMode == MODE_SUFD) ? szUpX :
+ (filterMode == MODE_FUFD) ? szUpXY :
+ -1;
+
+ // Ensure U128 alignment.
+ const int s_buf0_size = (s_buf0_size_base + 3) & ~3;
+ const int s_buf1_size = (s_buf1_size_base + 3) & ~3;
+
+ // Check at compile time that we don't use too much shared memory.
+ static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow");
+
+ // Declare shared memory arrays.
+ scalar_t* s_buf0;
+ scalar_t* s_buf1;
+ if (sharedKB <= 48)
+ {
+ // Allocate shared memory arrays here.
+ __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.
+ s_buf0 = s_buf0_st;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+ else
+ {
+ // Use the dynamically allocated shared memory array.
+ s_buf0 = (scalar_t*)s_buf_raw;
+ s_buf1 = s_buf0 + s_buf0_size;
+ }
+
+ // Pointers to the buffers.
+ scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY]
+ scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX]
+ scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX]
+ scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX]
+ if (filterMode == MODE_SUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ s_tileDownX = s_buf1;
+ }
+ else if (filterMode == MODE_FUSD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ s_tileDownX = s_buf0;
+ }
+ else if (filterMode == MODE_SUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpX = s_buf1;
+ s_tileUpXY = s_buf0;
+ }
+ else if (filterMode == MODE_FUFD)
+ {
+ s_tileIn = s_buf0;
+ s_tileUpXY = s_buf1;
+ }
+
+ // Allow large grids in z direction via per-launch offset.
+ int channelIdx = blockIdx.z + p.blockZofs;
+ int batchIdx = channelIdx / p.yShape.z;
+ channelIdx -= batchIdx * p.yShape.z;
+
+ // Offset to output feature map. In bytes.
+ index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w);
+
+ // Sign shift amount.
+ uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;
+
+ // Inner tile loop.
+ #pragma unroll 1
+ for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)
+ {
+ // Locate output tile.
+ int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;
+ int tileOutX = tileX * tileOutW;
+ int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;
+
+ // Locate input tile.
+ int tmpX = tileOutX * down - p.pad0.x;
+ int tmpY = tileOutY * down - p.pad0.y;
+ int tileInX = CEIL_DIV(tmpX, up);
+ int tileInY = CEIL_DIV(tmpY, up);
+ const int phaseInX = tileInX * up - tmpX;
+ const int phaseInY = tileInY * up - tmpY;
+
+ // Extra sync if input and output buffers are the same and we are not on first tile.
+ if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))
+ __syncthreads();
+
+ // Load input tile & apply bias. Unrolled.
+ scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride)));
+ index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w);
+ int idx = threadIdx.x;
+ const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);
+ #pragma unroll
+ for (int loop = 0; loop < loopCountIN; loop++)
+ {
+ int relInX, relInY;
+ fast_div_mod(relInX, relInY, idx);
+ int inX = tileInX + relInX;
+ int inY = tileInY + relInY;
+ scalar_t v = 0;
+
+ if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)
+ v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b;
+
+ bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);
+ if (!skip)
+ s_tileIn[idx] = v;
+
+ idx += threadsPerBlock;
+ }
+
+ if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.
+ {
+ // Horizontal upsampling.
+ __syncthreads();
+ if (up == 4)
+ {
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ scalar_t a = s_tileIn[src0];
+ if (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInX == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInX == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ s_tileUpX[dst+2] = v.z;
+ s_tileUpX[dst+3] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ bool p0 = (phaseInX == 0);
+ for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)
+ {
+ int relUpX0, relInY;
+ fast_div_mod(relUpX0, relInY, idx);
+ int relInX0 = relUpX0 / up;
+ int src0 = relInX0 + tileInW * relInY;
+ int dst = relInY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ scalar_t a = s_tileIn[src0];
+ if (p0) // (phaseInX == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInX == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileIn[src0 + step + 1];
+ }
+ }
+ s_tileUpX[dst+0] = v.x;
+ s_tileUpX[dst+1] = v.y;
+ }
+ }
+
+ // Vertical upsampling & nonlinearity.
+
+ __syncthreads();
+ int groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.
+ if (up == 4)
+ {
+ minY -= 3; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec4_t v = InternalType::zero_vec4();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 3];
+ v.z += a * (scalar_t)c_fu[step * up + 2];
+ v.w += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else if (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.z += a * (scalar_t)c_fu[step * up + 3];
+ v.w += a * (scalar_t)c_fu[step * up + 2];
+ }
+ }
+ else if (phaseInY == 2)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 2];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ v.z += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.w += a * (scalar_t)c_fu[step * up + 3];
+ }
+ }
+ else // (phaseInY == 3)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 3];
+ v.y += a * (scalar_t)c_fu[step * up + 2];
+ v.z += a * (scalar_t)c_fu[step * up + 1];
+ v.w += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+ index_t si2 = si0 + p.sShape.x * 2;
+ index_t si3 = si0 + p.sShape.x * 3;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ int sz = __float_as_uint(v.z) >> 31 << 16;
+ int sw = __float_as_uint(v.w) >> 31 << 24;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (sz) v.z *= p.slope;
+ if (sw) v.w *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ // Combine signs.
+ uint32_t s = sx + sy + sw + sz;
+ s <<= (signX & 3) << 1;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }
+ if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ int ss = (signX & 3) << 1;
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }
+ if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[dst + 0 * tileUpW] = v.x;
+ if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;
+ if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;
+ if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;
+ }
+ }
+ else if (up == 2)
+ {
+ minY -= 1; // Adjust according to block height.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)
+ {
+ int relUpX, relInY0;
+ fast_div_mod(relUpX, relInY0, idx);
+ int relUpY0 = relInY0 * up;
+ int src0 = relInY0 * tileUpW + relUpX;
+ int dst = relUpY0 * tileUpW + relUpX;
+ vec2_t v = InternalType::zero_vec2();
+
+ scalar_t a = s_tileUpX[src0];
+ if (phaseInY == 0)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ v.y += a * (scalar_t)c_fu[step * up + 1];
+ }
+ }
+ else // (phaseInY == 1)
+ {
+ #pragma unroll
+ for (int step = 0; step < fuSize / up; step++)
+ {
+ v.x += a * (scalar_t)c_fu[step * up + 1];
+ v.y += a * (scalar_t)c_fu[step * up + 0];
+ a = s_tileUpX[src0 + (step + 1) * tileUpW];
+ }
+ }
+
+ int x = tileOutX * down + relUpX;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ index_t si1 = si0 + p.sShape.x;
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31 << 0;
+ int sy = __float_as_uint(v.y) >> 31 << 8;
+ if (sx) v.x *= p.slope;
+ if (sy) v.y *= p.slope;
+ if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); }
+
+ // Combine signs.
+ int s = sx + sy;
+ s <<= signXo;
+ s |= __shfl_xor_sync(groupMask, s, 1);
+ s |= __shfl_xor_sync(groupMask, s, 2);
+
+ // Write signs.
+ if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); }
+ if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); }
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read signs and apply.
+ {
+ if ((uint32_t)signXb < p.swLimit)
+ {
+ if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }
+ if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ }
+
+ if (!downInline)
+ {
+ // Write into temporary buffer.
+ s_tileUpXY[dst] = v.x;
+ if (relUpY0 < tileUpH - 1)
+ s_tileUpXY[dst + tileUpW] = v.y;
+ }
+ else
+ {
+ // Write directly into output buffer.
+ if ((uint32_t)x < p.yShape.x)
+ {
+ int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);
+ index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut;
+ if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);
+ if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+ }
+ }
+ else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)
+ {
+ // Full upsampling filter.
+
+ if (up == 2)
+ {
+ // 2 x 2-wide.
+ __syncthreads();
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.
+ for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);
+ int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);
+ int src0 = relInX0 + tileInW * relInY0;
+ int tap0y = (relInY0 * up + phaseInY - relUpY0);
+
+ #define X_LOOP(TAPY, PX) \
+ for (int sx = 0; sx < fuSize / up; sx++) \
+ { \
+ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \
+ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \
+ }
+
+ vec4_t v = InternalType::zero_vec4();
+ if (tap0y == 0 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 0) }
+ if (tap0y == 0 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(0, 1) }
+ if (tap0y == 1 && phaseInX == 0)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 0) }
+ if (tap0y == 1 && phaseInX == 1)
+ #pragma unroll
+ for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];
+ #pragma unroll
+ X_LOOP(1, 1) }
+
+ #undef X_LOOP
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+
+ v.x *= (scalar_t)((float)up * (float)up * p.gain);
+ v.y *= (scalar_t)((float)up * (float)up * p.gain);
+ v.z *= (scalar_t)((float)up * (float)up * p.gain);
+ v.w *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write signs.
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ }
+ else
+ {
+ // Determine and write signs.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ int sx = __float_as_uint(v.x) >> 31;
+ int sy = __float_as_uint(v.y) >> 31;
+ int sz = __float_as_uint(v.z) >> 31;
+ int sw = __float_as_uint(v.w) >> 31;
+ if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); }
+ if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); }
+ if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); }
+ if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); }
+
+ p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);
+ }
+ else
+ {
+ // Just compute the values.
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+ }
+ }
+ else if (signRead) // Read sign and apply.
+ {
+ if ((uint32_t)signY < p.sShape.y)
+ {
+ int s = 0;
+ if ((uint32_t)signXb < p.swLimit) s = p.s[si];
+ if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;
+ s >>= (signX & 3) << 1;
+ if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;
+ if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;
+ if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;
+ if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp);
+ if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp);
+ if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp);
+ if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp);
+ }
+
+ s_tileUpXY[idx + 0] = v.x;
+ s_tileUpXY[idx + 1] = v.y;
+ s_tileUpXY[idx + 2] = v.z;
+ s_tileUpXY[idx + 3] = v.w;
+ }
+ }
+ else if (up == 1)
+ {
+ __syncthreads();
+ uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);
+ int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.
+ for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)
+ {
+ int relUpX0, relUpY0;
+ fast_div_mod(relUpX0, relUpY0, idx);
+ scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.
+
+ int x = tileOutX * down + relUpX0;
+ int y = tileOutY * down + relUpY0;
+ int signX = x + p.sOfs.x;
+ int signY = y + p.sOfs.y;
+ int signZ = blockIdx.z + p.blockZofs;
+ int signXb = signX >> 2;
+ index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);
+ v *= (scalar_t)((float)up * (float)up * p.gain);
+
+ if (signWrite)
+ {
+ if (!enableWriteSkip)
+ {
+ // Determine and write sign.
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ }
+ else
+ {
+ // Determine and write sign.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)
+ {
+ uint32_t s = 0;
+ uint32_t signXbit = (1u << signXo);
+ if (v < 0.f)
+ {
+ s = signXbit;
+ v *= p.slope;
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ s = signXbit * 2;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ s += __shfl_xor_sync(groupMask, s, 1); // Coalesce.
+ s += __shfl_xor_sync(groupMask, s, 2); // Coalesce.
+ p.s[si] = s; // Write.
+ }
+ else
+ {
+ // Just compute the value.
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+ }
+ }
+ else if (signRead)
+ {
+ // Read sign and apply if within sign tensor bounds.
+ if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)
+ {
+ int s = p.s[si];
+ s >>= signXo;
+ if (s & 1) v *= p.slope;
+ if (s & 2) v = 0.f;
+ }
+ }
+ else // Forward pass with no sign write.
+ {
+ if (v < 0.f) v *= p.slope;
+ v = InternalType::clamp(v, p.clamp);
+ }
+
+ if (!downInline) // Write into temporary buffer.
+ s_tileUpXY[idx] = v;
+ else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer
+ *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);
+ }
+ }
+ }
+
+ // Downsampling.
+ if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)
+ {
+ // Horizontal downsampling.
+ __syncthreads();
+ if (down == 4 && tileOutW % 4 == 0)
+ {
+ // Calculate 4 pixels at a time.
+ for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec4_t v = InternalType::zero_vec4();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step];
+ v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step];
+ v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ s_tileDownX[idx+2] = v.z;
+ s_tileDownX[idx+3] = v.w;
+ }
+ }
+ else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))
+ {
+ // Calculate 2 pixels at a time.
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src0 = relUpY * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step];
+ v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];
+ }
+ s_tileDownX[idx+0] = v.x;
+ s_tileDownX[idx+1] = v.y;
+ }
+ }
+ else
+ {
+ // Calculate 1 pixel at a time.
+ for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)
+ {
+ int relOutX0, relUpY;
+ fast_div_mod(relOutX0, relUpY, idx);
+ int relUpX0 = relOutX0 * down;
+ int src = relUpY * tileUpW + relUpX0;
+ scalar_t v = 0.f;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];
+ s_tileDownX[idx] = v;
+ }
+ }
+
+ // Vertical downsampling & store output tile.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX, relOutY0;
+ fast_div_mod(relOutX, relOutY0, idx);
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileOutW + relOutX;
+ scalar_t v = 0;
+ #pragma unroll
+ for (int step = 0; step < fdSize; step++)
+ v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];
+
+ int outX = tileOutX + relOutX;
+ int outY = tileOutY + relOutY0;
+
+ if (outX < p.yShape.x & outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)
+ {
+ // Full downsampling filter.
+ if (down == 2)
+ {
+ // 2-wide.
+ __syncthreads();
+ for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ int relUpX0 = relOutX0 * down;
+ int relUpY0 = relOutY0 * down;
+ int src0 = relUpY0 * tileUpW + relUpX0;
+ vec2_t v = InternalType::zero_vec2();
+ #pragma unroll
+ for (int sy = 0; sy < fdSize; sy++)
+ #pragma unroll
+ for (int sx = 0; sx < fdSize; sx++)
+ {
+ v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];
+ }
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outY < p.yShape.y)
+ {
+ index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut;
+ if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;
+ if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y;
+ }
+ }
+ }
+ else if (down == 1 && !downInline)
+ {
+ // Thread per pixel.
+ __syncthreads();
+ for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)
+ {
+ int relOutX0, relOutY0;
+ fast_div_mod(relOutX0, relOutY0, idx);
+ scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.
+
+ int outX = tileOutX + relOutX0;
+ int outY = tileOutY + relOutY0;
+ if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)
+ *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v;
+ }
+ }
+ }
+
+ if (!enableXrep)
+ break;
+ }
+}
+
+//------------------------------------------------------------------------
+// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.
+// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.
+
+template
+static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)
+{
+ typedef typename InternalType::scalar_t scalar_t;
+
+ // Indexing.
+ int32_t x = threadIdx.x + blockIdx.x * blockDim.x;
+ int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;
+ int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.
+
+ // Loop to accommodate oversized tensors.
+ for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)
+ for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)
+ {
+ // Extract z and w (channel, minibatch index).
+ int32_t w = q / p.xShape.z;
+ int32_t z = q - w * p.xShape.z;
+
+ // Choose behavior based on sign read/write mode.
+ if (signWrite)
+ {
+ // Process value if in p.x.
+ uint32_t s = 0;
+ if (x < p.xShape.x && y < p.xShape.y)
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+
+ // Gain, LReLU, clamp.
+ v *= p.gain;
+ if (v < 0.f)
+ {
+ v *= p.slope;
+ s = 1; // Sign.
+ }
+ if (fabsf(v) > p.clamp)
+ {
+ v = InternalType::clamp(v, p.clamp);
+ s = 2; // Clamp.
+ }
+
+ *pv = (T)v; // Write value.
+ }
+
+ // Coalesce into threads 0 and 16 of warp.
+ uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;
+ s <<= ((threadIdx.x & 15) << 1); // Shift into place.
+ s |= __shfl_xor_sync(m, s, 1); // Distribute.
+ s |= __shfl_xor_sync(m, s, 2);
+ s |= __shfl_xor_sync(m, s, 4);
+ s |= __shfl_xor_sync(m, s, 8);
+
+ // Write signs if leader and in p.s.
+ if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.
+ {
+ uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.
+ ((uint32_t*)p.s)[is >> 4] = s;
+ }
+ }
+ else if (signRead)
+ {
+ // Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+
+ // Apply sign buffer offset.
+ uint32_t sx = x + p.sOfs.x;
+ uint32_t sy = y + p.sOfs.y;
+
+ // Read and apply signs if we land inside valid region of sign buffer.
+ if (sx < p.sShape.x && sy < p.sShape.y)
+ {
+ uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.
+ unsigned char s = p.s[is];
+ s >>= (sx & 3) << 1; // Shift into place.
+ if (s & 1) // Sign?
+ v *= p.slope;
+ if (s & 2) // Clamp?
+ v = 0.f;
+ }
+
+ *pv = (T)v; // Write value.
+ }
+ }
+ else
+ {
+ // Forward pass with no sign write. Process value if in p.x.
+ if (x < p.xShape.x) // y is always in.
+ {
+ int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;
+ T* pv = ((T*)p.x) + ix;
+ scalar_t v = (scalar_t)(*pv);
+ v *= p.gain;
+ if (v < 0.f)
+ v *= p.slope;
+ if (fabsf(v) > p.clamp)
+ v = InternalType::clamp(v, p.clamp);
+ *pv = (T)v; // Write value.
+ }
+ }
+ }
+}
+
+template void* choose_filtered_lrelu_act_kernel(void)
+{
+ return (void*)filtered_lrelu_act_kernel;
+}
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)
+{
+ filtered_lrelu_kernel_spec s = { 0 };
+
+ // Return the first matching kernel.
+#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \
+ if (sharedKB >= SH) \
+ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \
+ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \
+ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \
+ { \
+ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \
+ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \
+ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \
+ s.setup = (void*)setup_filters_kernel; \
+ s.exec = (void*)filtered_lrelu_kernel; \
+ s.tileOut = make_int2(TW, TH); \
+ s.numWarps = W; \
+ s.xrep = XR; \
+ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \
+ return s; \
+ }
+
+ // Launch parameters for various kernel specializations.
+ // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.
+ // Kernels that use more shared memory must be listed before those that use less, for the same reason.
+
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1
+ CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2
+ CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4
+ CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB
+ CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4
+
+ #undef CASE
+ return s; // No kernel found.
+}
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.h b/modules/eg3ds/torch_utils/ops/filtered_lrelu.h
new file mode 100644
index 0000000000000000000000000000000000000000..f2bfd1dd537909de9cd3b14765a482056391683b
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.h
@@ -0,0 +1,94 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include
+
+//------------------------------------------------------------------------
+// CUDA kernel parameters.
+
+struct filtered_lrelu_kernel_params
+{
+ // These parameters decide which kernel to use.
+ int up; // upsampling ratio (1, 2, 4)
+ int down; // downsampling ratio (1, 2, 4)
+ int2 fuShape; // [size, 1] | [size, size]
+ int2 fdShape; // [size, 1] | [size, size]
+
+ int _dummy; // Alignment.
+
+ // Rest of the parameters.
+ const void* x; // Input tensor.
+ void* y; // Output tensor.
+ const void* b; // Bias tensor.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+ const float* fu; // Upsampling filter.
+ const float* fd; // Downsampling filter.
+
+ int2 pad0; // Left/top padding.
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+ int flip; // Filter kernel flip for gradient computation.
+
+ int tilesXdim; // Original number of horizontal output tiles.
+ int tilesXrep; // Number of horizontal tiles per CTA.
+ int blockZofs; // Block z offset to support large minibatch, channel dimensions.
+
+ int4 xShape; // [width, height, channel, batch]
+ int4 yShape; // [width, height, channel, batch]
+ int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+ int swLimit; // Active width of sign tensor in bytes.
+
+ longlong4 xStride; // Strides of all tensors except signs, same component order as shapes.
+ longlong4 yStride; //
+ int64_t bStride; //
+ longlong3 fuStride; //
+ longlong3 fdStride; //
+};
+
+struct filtered_lrelu_act_kernel_params
+{
+ void* x; // Input/output, modified in-place.
+ unsigned char* s; // Sign tensor in/out. NULL if unused.
+
+ float gain; // Additional gain factor.
+ float slope; // Leaky ReLU slope on negative side.
+ float clamp; // Clamp after nonlinearity.
+
+ int4 xShape; // [width, height, channel, batch]
+ longlong4 xStride; // Input/output tensor strides, same order as in shape.
+ int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused.
+ int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel specialization.
+
+struct filtered_lrelu_kernel_spec
+{
+ void* setup; // Function for filter kernel setup.
+ void* exec; // Function for main operation.
+ int2 tileOut; // Width/height of launch tile.
+ int numWarps; // Number of warps per thread block, determines launch block size.
+ int xrep; // For processing multiple horizontal tiles per thread block.
+ int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants.
+};
+
+//------------------------------------------------------------------------
+// CUDA kernel selection.
+
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template void* choose_filtered_lrelu_act_kernel(void);
+template cudaError_t copy_filters(cudaStream_t stream);
+
+//------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu.py b/modules/eg3ds/torch_utils/ops/filtered_lrelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..2047b7e19320e8d03e444ca1cb03fe00d0c5e96e
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu.py
@@ -0,0 +1,276 @@
+# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+#
+# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+# property and proprietary rights in and to this material, related
+# documentation and any modifications thereto. Any use, reproduction,
+# disclosure or distribution of this material and related documentation
+# without an express license agreement from NVIDIA CORPORATION or
+# its affiliates is strictly prohibited.
+
+import os
+import numpy as np
+import torch
+import warnings
+
+from .. import custom_ops
+from .. import misc
+from . import upfirdn2d
+from . import bias_act
+
+#----------------------------------------------------------------------------
+
+_plugin = None
+
+def _init():
+ global _plugin
+ if _plugin is None:
+ _plugin = custom_ops.get_plugin(
+ module_name='filtered_lrelu_plugin',
+ sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],
+ headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],
+ source_dir=os.path.dirname(__file__),
+ extra_cuda_cflags=['--use_fast_math'],
+ )
+ return True
+
+def _get_filter_size(f):
+ if f is None:
+ return 1, 1
+ assert isinstance(f, torch.Tensor)
+ assert 1 <= f.ndim <= 2
+ return f.shape[-1], f.shape[0] # width, height
+
+def _parse_padding(padding):
+ if isinstance(padding, int):
+ padding = [padding, padding]
+ assert isinstance(padding, (list, tuple))
+ assert all(isinstance(x, (int, np.integer)) for x in padding)
+ padding = [int(x) for x in padding]
+ if len(padding) == 2:
+ px, py = padding
+ padding = [px, px, py, py]
+ px0, px1, py0, py1 = padding
+ return px0, px1, py0, py1
+
+#----------------------------------------------------------------------------
+
+def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):
+ r"""Filtered leaky ReLU for a batch of 2D images.
+
+ Performs the following sequence of operations for each channel:
+
+ 1. Add channel-specific bias if provided (`b`).
+
+ 2. Upsample the image by inserting N-1 zeros after each pixel (`up`).
+
+ 3. Pad the image with the specified number of zeros on each side (`padding`).
+ Negative padding corresponds to cropping the image.
+
+ 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it
+ so that the footprint of all output pixels lies within the input image.
+
+ 5. Multiply each value by the provided gain factor (`gain`).
+
+ 6. Apply leaky ReLU activation function to each value.
+
+ 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.
+
+ 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking
+ it so that the footprint of all output pixels lies within the input image.
+
+ 9. Downsample the image by keeping every Nth pixel (`down`).
+
+ The fused op is considerably more efficient than performing the same calculation
+ using standard PyTorch ops. It supports gradients of arbitrary order.
+
+ Args:
+ x: Float32/float16/float64 input tensor of the shape
+ `[batch_size, num_channels, in_height, in_width]`.
+ fu: Float32 upsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ fd: Float32 downsampling FIR filter of the shape
+ `[filter_height, filter_width]` (non-separable),
+ `[filter_taps]` (separable), or
+ `None` (identity).
+ b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
+ as `x`. The length of vector must must match the channel dimension of `x`.
+ up: Integer upsampling factor (default: 1).
+ down: Integer downsampling factor. (default: 1).
+ padding: Padding with respect to the upsampled image. Can be a single number
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
+ (default: 0).
+ gain: Overall scaling factor for signal magnitude (default: sqrt(2)).
+ slope: Slope on the negative side of leaky ReLU (default: 0.2).
+ clamp: Maximum magnitude for leaky ReLU output (default: None).
+ flip_filter: False = convolution, True = correlation (default: False).
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
+
+ Returns:
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
+ """
+ assert isinstance(x, torch.Tensor)
+ assert impl in ['ref', 'cuda']
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
+ return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)
+ return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)
+
+#----------------------------------------------------------------------------
+
+@misc.profiled_function
+def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using
+ existing `upfirdn2n()` and `bias_act()` ops.
+ """
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+ fu_w, fu_h = _get_filter_size(fu)
+ fd_w, fd_h = _get_filter_size(fd)
+ if b is not None:
+ assert isinstance(b, torch.Tensor) and b.dtype == x.dtype
+ misc.assert_shape(b, [x.shape[1]])
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ assert slope == float(slope) and slope >= 0
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+
+ # Calculate output size.
+ batch_size, channels, in_h, in_w = x.shape
+ in_dtype = x.dtype
+ out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down
+ out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down
+
+ # Compute using existing ops.
+ x = bias_act.bias_act(x=x, b=b) # Apply bias.
+ x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.
+ x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Check output shape & dtype.
+ misc.assert_shape(x, [batch_size, channels, out_h, out_w])
+ assert x.dtype == in_dtype
+ return x
+
+#----------------------------------------------------------------------------
+
+_filtered_lrelu_cuda_cache = dict()
+
+def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):
+ """Fast CUDA implementation of `filtered_lrelu()` using custom ops.
+ """
+ assert isinstance(up, int) and up >= 1
+ assert isinstance(down, int) and down >= 1
+ px0, px1, py0, py1 = _parse_padding(padding)
+ assert gain == float(gain) and gain > 0
+ gain = float(gain)
+ assert slope == float(slope) and slope >= 0
+ slope = float(slope)
+ assert clamp is None or (clamp == float(clamp) and clamp >= 0)
+ clamp = float(clamp if clamp is not None else 'inf')
+
+ # Lookup from cache.
+ key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)
+ if key in _filtered_lrelu_cuda_cache:
+ return _filtered_lrelu_cuda_cache[key]
+
+ # Forward op.
+ class FilteredLReluCuda(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
+
+ # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).
+ if fu is None:
+ fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ if fd is None:
+ fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)
+ assert 1 <= fu.ndim <= 2
+ assert 1 <= fd.ndim <= 2
+
+ # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.
+ if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:
+ fu = fu.square()[None]
+ if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:
+ fd = fd.square()[None]
+
+ # Missing sign input tensor.
+ if si is None:
+ si = torch.empty([0])
+
+ # Missing bias tensor.
+ if b is None:
+ b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)
+
+ # Construct internal sign tensor only if gradients are needed.
+ write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)
+
+ # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.
+ strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]
+ if any(a < b for a, b in zip(strides[:-1], strides[1:])):
+ warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning)
+
+ # Call C++/Cuda plugin if datatype is supported.
+ if x.dtype in [torch.float16, torch.float32]:
+ if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):
+ warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning)
+ y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)
+ else:
+ return_code = -1
+
+ # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because
+ # only the bit-packed sign tensor is retained for gradient computation.
+ if return_code < 0:
+ warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning)
+
+ y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.
+ y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.
+ so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.
+ y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.
+
+ # Prepare for gradient computation.
+ ctx.save_for_backward(fu, fd, (si if si.numel() else so))
+ ctx.x_shape = x.shape
+ ctx.y_shape = y.shape
+ ctx.s_ofs = sx, sy
+ return y
+
+ @staticmethod
+ def backward(ctx, dy): # pylint: disable=arguments-differ
+ fu, fd, si = ctx.saved_tensors
+ _, _, xh, xw = ctx.x_shape
+ _, _, yh, yw = ctx.y_shape
+ sx, sy = ctx.s_ofs
+ dx = None # 0
+ dfu = None; assert not ctx.needs_input_grad[1]
+ dfd = None; assert not ctx.needs_input_grad[2]
+ db = None # 3
+ dsi = None; assert not ctx.needs_input_grad[4]
+ dsx = None; assert not ctx.needs_input_grad[5]
+ dsy = None; assert not ctx.needs_input_grad[6]
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:
+ pp = [
+ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,
+ xw * up - yw * down + px0 - (up - 1),
+ (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,
+ xh * up - yh * down + py0 - (up - 1),
+ ]
+ gg = gain * (up ** 2) / (down ** 2)
+ ff = (not flip_filter)
+ sx = sx - (fu.shape[-1] - 1) + px0
+ sy = sy - (fu.shape[0] - 1) + py0
+ dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)
+
+ if ctx.needs_input_grad[3]:
+ db = dx.sum([0, 2, 3])
+
+ return dx, dfu, dfd, db, dsi, dsx, dsy
+
+ # Add to cache.
+ _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda
+ return FilteredLReluCuda
+
+#----------------------------------------------------------------------------
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu_ns.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu_ns.cu
new file mode 100644
index 0000000000000000000000000000000000000000..8a3eae46215c3babea2c54e3ae255b05f4d777af
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu_ns.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for no signs mode (no gradients required).
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+template void* choose_filtered_lrelu_act_kernel(void);
+
+// Copy filters to constant memory.
+template cudaError_t copy_filters(cudaStream_t stream);
diff --git a/modules/eg3ds/torch_utils/ops/filtered_lrelu_rd.cu b/modules/eg3ds/torch_utils/ops/filtered_lrelu_rd.cu
new file mode 100644
index 0000000000000000000000000000000000000000..3cd43ec0648d3db05e5808299fc0ee318e5ceaa6
--- /dev/null
+++ b/modules/eg3ds/torch_utils/ops/filtered_lrelu_rd.cu
@@ -0,0 +1,31 @@
+/*
+ * SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ * SPDX-License-Identifier: LicenseRef-NvidiaProprietary
+ *
+ * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
+ * property and proprietary rights in and to this material, related
+ * documentation and any modifications thereto. Any use, reproduction,
+ * disclosure or distribution of this material and related documentation
+ * without an express license agreement from NVIDIA CORPORATION or
+ * its affiliates is strictly prohibited.
+ */
+
+#include "filtered_lrelu.cu"
+
+// Template/kernel specializations for sign read mode.
+
+// Full op, 32-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Full op, 64-bit indexing.
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);
+
+// Activation/signs only for generic variant. 64-bit indexing.
+template void* choose_filtered_lrelu_act_kernel