diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..94c22a8ce9b689713cf92cdc9fea1025e3d3dc30 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,22 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/camera_control.png filter=lfs diff=lfs merge=lfs -text
+assets/cases/dog.png filter=lfs diff=lfs merge=lfs -text
+assets/cases/dog_pure_camera_motion_1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/cases/dog_pure_camera_motion_2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/cases/dog_pure_obj_motion.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/func_1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/logo.png filter=lfs diff=lfs merge=lfs -text
+assets/logo_generated2.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/logo_generated2_single.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/pose_files/complex_1.png filter=lfs diff=lfs merge=lfs -text
+assets/pose_files/complex_2.png filter=lfs diff=lfs merge=lfs -text
+assets/pose_files/complex_3.png filter=lfs diff=lfs merge=lfs -text
+assets/pose_files/complex_4.png filter=lfs diff=lfs merge=lfs -text
+assets/sea.jpg filter=lfs diff=lfs merge=lfs -text
+assets/syn_video_control1.mp4 filter=lfs diff=lfs merge=lfs -text
+assets/syn_video_control2.mp4 filter=lfs diff=lfs merge=lfs -text
+data/folders/007401_007450_1018898026/video.mp4 filter=lfs diff=lfs merge=lfs -text
+data/folders/046001_046050_1011035429/video.mp4 filter=lfs diff=lfs merge=lfs -text
+data/folders/188701_188750_1026109505/video.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/READEME.md b/READEME.md
new file mode 100644
index 0000000000000000000000000000000000000000..f1e33d4334326d6dc74d3f22f4969a03abe0793f
--- /dev/null
+++ b/READEME.md
@@ -0,0 +1,159 @@
+# MotionPro
+
+
+
+
+
+
+ 🖥️ GitHub    |    🌐 Project Page    |   🤗 Hugging Face   |    📑 Paper    |    📖 PDF   
+
+
+[**MotionPro: A Precise Motion Controller for Image-to-Video Generation**](https://zhw-zhang.github.io/MotionPro-page/)
+
+🔆 If you find MotionPro useful, please give a ⭐ for this repo, which is important to Open-Source projects. Thanks!
+
+In this repository, we introduce **MotionPro**, an image-to-video generation model built on SVD. MotionPro learns object and camera motion control from **in-the-wild** video datasets (e.g., WebVid-10M) without applying special data filtering. The model offers the following key features:
+
+- **User-friendly interaction.** Our model requires only simple conditional inputs, allowing users to achieve I2V motion control generation through brushing and dragging.
+- **Simultaneous control of object and camera motion.** Our trained MotionPro model supports simultaneous object and camera motion control. Moreover, our model can achieve precise camera control driven by pose without requiring training on a specific camera-pose paired dataset. [More Details](assets/camera_control.png)
+- **Synchronized video generation.** This is an extension of our model. By combining MotionPro and MotionPro-Dense, we can achieve synchronized video generation. [More Details](assets/README_syn.md)
+
+
+Additionally, our repository provides more tools to benefit the research community's development.:
+
+- **Memory optimization for training.** We provide a training framework based on PyTorch Lightning, optimized for memory efficiency, enabling SVD fine-tuning with a batch size of 8 per NVIDIA A100 GPU.
+- **Data construction tools.** We offer scripts for constructing training data. Additionally, we also provide code for loading datasets in two formats, supporting video input from both folders (Dataset) and tar files (WebDataset).
+- **MC-Bench and evaluation code.** We constructed MC-Bench with 1.1K user-annotated image-trajectory pairs, along with evaluation scripts for comprehensive assessments. All the images showcased on the project page can be found here.
+
+## Video Demos
+
+
+
+
Examples of different motion control types by our MotionPro.
+
+
+
+
+## 🔥 Updates
+- [x] **\[2025.03.26\]** Release inference and training code.
+- [ ] **\[2025.03.27\]** Upload gradio demo usage video.
+- [ ] **\[2025.03.29\]** Release MC-Bench and evaluation code.
+- [ ] **\[2025.03.30\]** Upload annotation tool for image-trajectory pair construction.
+
+## 🏃🏼 Inference
+
+Environment Requirement
+
+Clone the repo:
+```
+git clone https://github.com/HiDream-ai/MotionPro.git
+```
+
+Install dependencies:
+```
+conda create -n motionpro python=3.10.0
+conda activate motionpro
+pip install -r requirements.txt
+```
+
+
+
+Model Download
+
+
+| Models | Download Link | Notes |
+|-------------------|-------------------------------------------------------------------------------|--------------------------------------------|
+| MotionPro | 🤗[Huggingface](https://huggingface.co/zzwustc/MotionPro/blob/main/MotionPro-gs_16k.pt) | Supports both object and camera control. This is the default model mentioned in the paper. |
+| MotionPro-Dense | 🤗[Huggingface](https://huggingface.co/zzwustc/MotionPro/blob/main/MotionPro_Dense-gs_14k.pt) | Supports synchronized video generation when combined with MotionPro. MotionPro-Dense shares the same architecture as Motion, but the input conditions are modified to include: dense optical flow and per-frame visibility masks relative to the first frame. |
+
+
+Download the model from HuggingFace at high speeds (30-80MB/s):
+```
+cd tools/huggingface_down
+bash download_hfd.sh
+```
+
+
+
+
+Run Motion Control
+
+This section of the code supports simultaneous object motion and camera motion control. We provide a user-friendly Gradio demo interface that allows users to control motion with simple brushing and dragging operations. The instructional video can be found in `assets/demo.mp4` (please note the version of gradio).
+
+```
+python demo_sparse_flex_wh.py
+```
+When you expect all pixels to move (e.g., for camera control), you need to use the brush to fully cover the entire area. You can also test the demo using `assets/logo.png`.
+
+Additionally, users can also generate controllable image-to-video results using pre-defined camera trajectories. Note that our model has not been trained on a specific camera control dataset. Test the demo using `assets/sea.png`.
+
+```
+python demo_sparse_flex_wh_pure_camera.py
+```
+
+
+
+
+Run synchronized video generation and video recapture
+
+By combining MotionPro and MotionPro-Dense, we can achieve the following functionalities:
+- Synchronized video generation. We assume that two videos, `pure_obj_motion.mp4` and `pure_camera_motion.mp4`, have been generated using the respective demos. By combining their motion flows and using the result as a condition for MotionPro-Dense, we obtain `final_video`. By pairing the same object motion with different camera motions, we can generate `synchronized videos` where the object motion remains consistent while the camera motion varies. [More Details](assets/README_syn.md)
+
+Here, you need to first download the [model_weights](https://huggingface.co/zzwustc/MotionPro/tree/main/tools/co-tracker/checkpoints) of cotracker and place them in the `tools/co-tracker/checkpoints` directory.
+
+```
+python inference_dense.py --ori_video 'assets/cases/dog_pure_obj_motion.mp4' --camera_video 'assets/cases/dog_pure_camera_motion_1.mp4' --save_name 'syn_video.mp4' --ckpt_path 'MotionPro-Dense CKPT-PATH'
+```
+
+
+
+## 🚀 Training
+
+
+Data Prepare
+
+We have packaged several demo videos to help users debug the training code. Simply 🤗[download](https://huggingface.co/zzwustc/MotionPro/tree/main/data), extract the files, and place them in the `./data` directory.
+
+Additionally, `./data/dot_single_video` contains code for processing raw videos using [DOT](https://github.com/16lemoing/dot) to generate the necessary conditions for training, making it easier for the community to create training datasets.
+
+
+
+
+
+Train
+
+Simply run the following command to train MotionPro:
+```
+train_server_1.sh
+```
+In addition to loading video data from folders, we also support [WebDataset](https://rom1504.github.io/webdataset/), allowing videos to be read directly from tar files for training. This can be enabled by modifying the config file:
+```
+train_debug_from_folder.yaml -> train_debug_from_tar.yaml
+```
+
+Furthermore, to train the **MotionPro-Dense** model, simply modify the `train_debug_from_tar.yaml` file by changing `VidTar` to `VidTar_all_flow` and updating the `ckpt_path`.
+
+
+
+## 🌟 Star and Citation
+If you find our work helpful for your research, please consider giving a star⭐ on this repository and citing our work📝.
+```
+@inproceedings{2025motionpro,
+ title={MotionPro: A Precise Motion Controller for Image-to-Video Generation},
+ author={Zhongwei Zhang, Fuchen Long, Zhaofan Qiu, Yingwei Pan, Wu Liu, Ting Yao and Tao Mei},
+ booktitle={CVPR},
+ year={2025}
+}
+```
+
+
+## 💖 Acknowledgement
+
+
+Our code is inspired by several works, including [SVD](https://github.com/Stability-AI/generative-models), [DragNUWA](https://github.com/ProjectNUWA/DragNUWA), [DOT](https://github.com/16lemoing/dot), [Cotracker](https://github.com/facebookresearch/co-tracker). Thanks to all the contributors!
+
diff --git a/assets/README_syn.md b/assets/README_syn.md
new file mode 100644
index 0000000000000000000000000000000000000000..ffb3d37fe0c0bdeac545d8ddfe03e537a72d9f80
--- /dev/null
+++ b/assets/README_syn.md
@@ -0,0 +1,38 @@
+## MotionPro-Dense Video Generation Pipeline
+
+
+This document provides an introduction to the MotionPro-Dense video generation pipeline, detailing its functionality and workflow. The pipeline is illustrated in the diagram below.
+
+### Pipeline Description
+
+1. **Video Generation with Base Motion Control**
+ - First, MotionPro can be used to generate videos with controllable object motion and camera motion.
+
+2. **Optical Flow and Visibility Mask Extraction and Merging**
+ - The generated videos are processed using CoTracker, a tool for extracting optical flow and visibility masks for each frame.
+ - The extracted optical flows are accumulated through summation.
+ - The per-frame visibility masks from both sequences are intersected to obtain the final visibility mask.
+
+3. **Final Video Generation with Combined Motions**
+ - The aggregated motion conditions are used as input for **MotionPro-Dense**, which generates the final video with seamlessly integrated object and camera motions.
+
+
+
+
+
Figure 1: Illustration of video generation with combined motions.
+
+
+
+### Synchronized Video Generation
+
+Additionally, the pipeline enables the generation of **synchronized videos**, where a consistent object motion is paired with different camera motions.
+
+
+
+
Figure 2: Illustration of synchronized video generation.
+
+
+
+
diff --git a/assets/camera_control.png b/assets/camera_control.png
new file mode 100644
index 0000000000000000000000000000000000000000..17a204aa230202a9b5c0eb69bb4d1fbd715754ad
--- /dev/null
+++ b/assets/camera_control.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c3f1accf32a6a665324a33993433d16ae33ac74d7e8329b137c86d8e2e05b884
+size 853223
diff --git a/assets/cases/dog.png b/assets/cases/dog.png
new file mode 100644
index 0000000000000000000000000000000000000000..3138c5f2cf4b2e53337d3b04a94fef1f0a70d31b
--- /dev/null
+++ b/assets/cases/dog.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12475a897c6b7daed611edbf74cdfdef4149deab15aded2087e5f0782dd3df20
+size 187872
diff --git a/assets/cases/dog_pure_camera_motion_1.mp4 b/assets/cases/dog_pure_camera_motion_1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..aea58e0af372837886db0d87addae476597c4658
--- /dev/null
+++ b/assets/cases/dog_pure_camera_motion_1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8137d4babcd1ad94e988a0d6768320e8ca924d8754e1a42f2a2dfda9c0f5e761
+size 208964
diff --git a/assets/cases/dog_pure_camera_motion_2.mp4 b/assets/cases/dog_pure_camera_motion_2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..a116bc0fec38c5c515992fcfb49843105069d355
--- /dev/null
+++ b/assets/cases/dog_pure_camera_motion_2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ddde37b57e4d6b3d48d18ef7f0ef814e0993a5a51443189841045bc62fadc0c2
+size 454983
diff --git a/assets/cases/dog_pure_obj_motion.mp4 b/assets/cases/dog_pure_obj_motion.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d57b10afb0fe99af665138bb409a823eb4727b46
--- /dev/null
+++ b/assets/cases/dog_pure_obj_motion.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:7c59c97ac7a59b79ae5146ee90ff684f97fed220b30cde646321965e43ffc187
+size 136490
diff --git a/assets/func_1.mp4 b/assets/func_1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..8d1acffab25ac3bb48adc3479ff0d0137434f73b
--- /dev/null
+++ b/assets/func_1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:862c9a0b58b9536d8479122154bade05077e2ee0ed71a2fe6c8c8c87d553d207
+size 4606221
diff --git a/assets/logo.png b/assets/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f4502e14c60f470bff7849dc0d3f0dabc9e7496
--- /dev/null
+++ b/assets/logo.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e9b0c2525d869a606061f80a3f25fd7237bb61465a77cc7799a365fc63ffcf15
+size 2398622
diff --git a/assets/logo_generated2.mp4 b/assets/logo_generated2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..10a03c0d15c8157e439f58e00e81e4bc00ef79d8
--- /dev/null
+++ b/assets/logo_generated2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:30a5f7d9ad8a4ba8f6389a16acf024783d4be507d10578db542fba37e258837d
+size 333258
diff --git a/assets/logo_generated2_single.mp4 b/assets/logo_generated2_single.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..fdb3b697eb42f21464f763d28729a87948969987
--- /dev/null
+++ b/assets/logo_generated2_single.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2e58a7533f59b2c490833bbf264c3029d93350e5f7205ff6e25aead38276b95
+size 104022
diff --git a/assets/pose_files/0bf152ef84195293.png b/assets/pose_files/0bf152ef84195293.png
new file mode 100644
index 0000000000000000000000000000000000000000..fc50953854077793f19d90b8c69e141047f2f1e2
Binary files /dev/null and b/assets/pose_files/0bf152ef84195293.png differ
diff --git a/assets/pose_files/0bf152ef84195293.txt b/assets/pose_files/0bf152ef84195293.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e87e03c57fefdae8e277555762b056c0f701bb9b
--- /dev/null
+++ b/assets/pose_files/0bf152ef84195293.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=QShWPZxTDoE
+158692025 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.780003667 0.059620168 -0.622928321 0.726968666 -0.062449891 0.997897983 0.017311305 0.217967188 0.622651041 0.025398925 0.782087326 -1.002211444
+158958959 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.743836701 0.064830206 -0.665209770 0.951841944 -0.068305343 0.997446954 0.020830527 0.206496789 0.664861917 0.029942872 0.746365905 -1.084913992
+159225893 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.697046876 0.070604131 -0.713540971 1.208789672 -0.074218854 0.996899366 0.026138915 0.196421447 0.713174045 0.034738146 0.700125754 -1.130142078
+159526193 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.635762572 0.077846259 -0.767949164 1.465161122 -0.080595709 0.996158004 0.034256749 0.157107229 0.767665446 0.040114246 0.639594078 -1.136893070
+159793126 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.593250692 0.083153486 -0.800711632 1.635091834 -0.085384794 0.995539784 0.040124334 0.135863998 0.800476789 0.044564810 0.597704709 -1.166997229
+160093427 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.555486798 0.087166689 -0.826943994 1.803789619 -0.089439675 0.994984210 0.044799786 0.145490422 0.826701283 0.049075913 0.560496747 -1.243827350
+160360360 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.523399472 0.090266660 -0.847292721 1.945815368 -0.093254104 0.994468153 0.048340045 0.174777447 0.846969128 0.053712368 0.528921843 -1.336914479
+160660661 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.491546303 0.092127070 -0.865964711 2.093852892 -0.095617607 0.994085968 0.051482171 0.196702533 0.865586221 0.057495601 0.497448236 -1.439709380
+160927594 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.475284129 0.093297184 -0.874871790 2.200792438 -0.096743606 0.993874133 0.053430639 0.209217395 0.874497354 0.059243519 0.481398523 -1.547068315
+161227895 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.464444131 0.093880348 -0.880612373 2.324141986 -0.097857766 0.993716478 0.054326952 0.220651207 0.880179226 0.060942926 0.470712721 -1.712512928
+161494828 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.458157241 0.093640216 -0.883925021 2.443100890 -0.098046601 0.993691206 0.054448847 0.257385043 0.883447111 0.061719712 0.464447916 -1.885672329
+161795128 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.457354397 0.093508720 -0.884354591 2.543246338 -0.097820736 0.993711591 0.054482624 0.281562244 0.883888066 0.061590351 0.463625461 -2.094829165
+162062062 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.465170115 0.093944497 -0.880222261 2.606377358 -0.097235762 0.993758380 0.054675922 0.277376127 0.879864752 0.060155477 0.471401453 -2.299280675
+162362362 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.511845231 0.090872414 -0.854257941 2.576774100 -0.093636356 0.994366586 0.049672548 0.270516319 0.853959382 0.054564942 0.517470777 -2.624374352
+162629296 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.590568483 0.083218277 -0.802685261 2.398318316 -0.085610889 0.995516419 0.040222570 0.282138215 0.802433550 0.044964414 0.595045030 -3.012309268
+162929596 0.474812461 0.844111024 0.500000000 0.500000000 0.000000000 0.000000000 0.684302032 0.072693504 -0.725566208 2.086323553 -0.074529484 0.996780157 0.029575195 0.310959312 0.725379944 0.033837710 0.687516510 -3.456740526
diff --git a/assets/pose_files/0c11dbe781b1c11c.png b/assets/pose_files/0c11dbe781b1c11c.png
new file mode 100644
index 0000000000000000000000000000000000000000..ee9cf43d38f0d8d5ca88ae088f936679b892f470
Binary files /dev/null and b/assets/pose_files/0c11dbe781b1c11c.png differ
diff --git a/assets/pose_files/0c11dbe781b1c11c.txt b/assets/pose_files/0c11dbe781b1c11c.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ae116b187cb39c26269b8585cddbfc2de318405f
--- /dev/null
+++ b/assets/pose_files/0c11dbe781b1c11c.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=a-Unpcomk5k
+89889800 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.959632158 -0.051068146 0.276583046 0.339363991 0.046715312 0.998659134 0.022308502 0.111317310 -0.277351439 -0.008487292 0.960731030 -0.353512177
+90156733 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.939171016 -0.057914909 0.338531673 0.380727498 0.052699961 0.998307705 0.024584483 0.134404073 -0.339382589 -0.005248427 0.940633774 -0.477942109
+90423667 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.913449824 -0.063028678 0.402040780 0.393354042 0.056629892 0.998008251 0.027794635 0.151535333 -0.402991891 -0.002621480 0.915199816 -0.622810637
+90723967 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.879072070 -0.069992281 0.471522361 0.381271678 0.062575974 0.997545719 0.031412520 0.175549569 -0.472563744 0.001892101 0.881294429 -0.821022008
+90990900 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.846152365 -0.078372896 0.527146876 0.360267421 0.071291871 0.996883452 0.033775900 0.212440374 -0.528151155 0.009001731 0.849102676 -1.013792538
+91291200 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.806246638 -0.086898506 0.585162342 0.297888150 0.078344196 0.996124208 0.039983708 0.243578507 -0.586368918 0.013607344 0.809929788 -1.248063630
+91558133 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.771091938 -0.093814306 0.629774630 0.223948432 0.087357447 0.995320201 0.041307874 0.293608807 -0.630702674 0.023163332 0.775678813 -1.459775674
+91858433 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.737968326 -0.099363215 0.667480111 0.145501271 0.093257703 0.994626462 0.044957232 0.329381977 -0.668360531 0.029070651 0.743269205 -1.688460978
+92125367 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.716826320 -0.101809755 0.689778805 0.086545731 0.098867603 0.994127929 0.043986596 0.379651732 -0.690206647 0.036666028 0.722682774 -1.885393814
+92425667 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.703360021 -0.101482928 0.703552365 0.039205180 0.098760851 0.994108558 0.044659954 0.417778776 -0.703939617 0.038071405 0.709238708 -2.106152155
+92692600 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.699525177 -0.101035394 0.707429409 0.029387371 0.096523918 0.994241416 0.046552572 0.439027166 -0.708059072 0.035719164 0.705249250 -2.314481674
+92992900 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.698582709 -0.101620331 0.708276451 0.018437890 0.096638583 0.994193733 0.047326516 0.478349552 -0.708973348 0.035385344 0.704347014 -2.540820022
+93259833 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.704948425 -0.098988213 0.702316940 0.047566428 0.095107265 0.994462848 0.044701166 0.517456396 -0.702853024 0.035283424 0.710459530 -2.724204596
+93560133 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.714113414 -0.104848787 0.692133486 0.107161588 0.100486010 0.993833601 0.046875130 0.568063228 -0.692780316 0.036075566 0.720245779 -2.948379150
+93827067 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.717699587 -0.112323314 0.687234104 0.118765931 0.105546549 0.993049562 0.052081093 0.593900230 -0.688307464 0.035156611 0.724566638 -3.140363331
+94127367 0.485388169 0.862912326 0.500000000 0.500000000 0.000000000 0.000000000 0.715531290 -0.122954883 0.687675118 0.089455249 0.115526602 0.991661787 0.057100743 0.643643035 -0.688961923 0.038587399 0.723769605 -3.310401931
diff --git a/assets/pose_files/0c9b371cc6225682.png b/assets/pose_files/0c9b371cc6225682.png
new file mode 100644
index 0000000000000000000000000000000000000000..a45f42678035d0d0c8c906a3732b5e57ba9b7830
Binary files /dev/null and b/assets/pose_files/0c9b371cc6225682.png differ
diff --git a/assets/pose_files/0c9b371cc6225682.txt b/assets/pose_files/0c9b371cc6225682.txt
new file mode 100644
index 0000000000000000000000000000000000000000..bc2f45f0ec995e31e83d824f8fc212d99775cbb1
--- /dev/null
+++ b/assets/pose_files/0c9b371cc6225682.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=_ca03xP_KUU
+211244000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.984322786 0.006958477 -0.176239252 0.004217217 -0.005594095 0.999950409 0.008237306 -0.107944544 0.176287830 -0.007122268 0.984312892 -0.571743822
+211511000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981951714 0.008860772 -0.188924149 0.000856103 -0.007234470 0.999930620 0.009296093 -0.149397579 0.188993424 -0.007761548 0.981947660 -0.776566486
+211778000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981318414 0.010160952 -0.192122281 -0.005546933 -0.008323869 0.999911606 0.010366773 -0.170816348 0.192210630 -0.008573905 0.981316268 -0.981924227
+212078000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981108844 0.010863926 -0.193151161 0.019480142 -0.008781361 0.999893725 0.011634931 -0.185801323 0.193257034 -0.009719004 0.981100023 -1.207220396
+212345000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.981263518 0.010073495 -0.192407012 0.069708411 -0.008015377 0.999902070 0.011472094 -0.203594876 0.192503735 -0.009714933 0.981248140 -1.408936391
+212646000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980964184 0.009669405 -0.193947718 0.166020848 -0.007467276 0.999899149 0.012082115 -0.219176122 0.194044977 -0.010403861 0.980937481 -1.602649833
+212913000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980841637 0.009196524 -0.194589555 0.262465567 -0.006609587 0.999880970 0.013939449 -0.224018296 0.194694594 -0.012386235 0.980785728 -1.740759996
+213212000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980620921 0.008805701 -0.195716679 0.389752858 -0.006055873 0.999874413 0.014644019 -0.230312701 0.195821062 -0.013174997 0.980551124 -1.890949759
+213479000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980327129 0.009317402 -0.197159693 0.505632551 -0.006113928 0.999839306 0.016850581 -0.230702867 0.197285011 -0.015313662 0.980226576 -2.016199670
+213779000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980493963 0.009960363 -0.196296573 0.623893674 -0.006936011 0.999846518 0.016088497 -0.223079036 0.196426690 -0.014413159 0.980412602 -2.137999468
+214046000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.980032921 0.010318150 -0.198567480 0.754726451 -0.007264129 0.999843955 0.016102606 -0.222246314 0.198702648 -0.014338664 0.979954958 -2.230292399
+214347000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.976653159 0.010179597 -0.214580998 0.946523963 -0.006709154 0.999834776 0.016895246 -0.210005171 0.214717537 -0.015061138 0.976560056 -2.305666573
+214614000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.971478105 0.011535713 -0.236848563 1.096604956 -0.007706031 0.999824286 0.017088750 -0.192895049 0.237004071 -0.014776184 0.971396267 -2.365701917
+214914000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.965282261 0.014877280 -0.260785013 1.237534109 -0.014124592 0.999888897 0.004760279 -0.136261458 0.260826856 -0.000911531 0.965385139 -2.458136272
+215181000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.961933076 0.016891202 -0.272762626 1.331672110 -0.022902885 0.999559581 -0.018870916 -0.076291319 0.272323757 0.024399608 0.961896241 -2.579417067
+215481000 0.479272232 0.852039479 0.500000000 0.500000000 0.000000000 0.000000000 0.959357142 0.017509742 -0.281651050 1.417338469 -0.039949402 0.996448219 -0.074127860 0.083949011 0.279352754 0.082366876 0.956649244 -2.712094466
diff --git a/assets/pose_files/0f47577ab3441480.png b/assets/pose_files/0f47577ab3441480.png
new file mode 100644
index 0000000000000000000000000000000000000000..8bb402ea77289870f77d271f76fd0d54c5545a61
Binary files /dev/null and b/assets/pose_files/0f47577ab3441480.png differ
diff --git a/assets/pose_files/0f47577ab3441480.txt b/assets/pose_files/0f47577ab3441480.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8b3ff33ec400ca157b4a66c5983a79ac41ae8d6f
--- /dev/null
+++ b/assets/pose_files/0f47577ab3441480.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=in69BD2eZqg
+195562033 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999749303 -0.004518872 0.021929268 0.038810557 0.004613766 0.999980211 -0.004278630 0.328177052 -0.021909500 0.004378735 0.999750376 -0.278403591
+195828967 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999336481 -0.006239665 0.035883281 0.034735125 0.006456365 0.999961615 -0.005926326 0.417233500 -0.035844926 0.006154070 0.999338388 -0.270773664
+196095900 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998902142 -0.007417044 0.046254709 0.033849936 0.007582225 0.999965489 -0.003396692 0.504852301 -0.046227921 0.003743677 0.998923898 -0.256677740
+196396200 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998096347 -0.008753631 0.061049398 0.026475959 0.009088391 0.999945164 -0.005207890 0.583593760 -0.061000463 0.005752816 0.998121142 -0.236166024
+196663133 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997463286 -0.009214416 0.070583619 0.014842158 0.009590282 0.999941587 -0.004988078 0.634675512 -0.070533529 0.005652342 0.997493386 -0.198663134
+196963433 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996699810 -0.009558053 0.080611102 0.003250557 0.009986609 0.999938071 -0.004914839 0.670145924 -0.080559134 0.005703651 0.996733487 -0.141256339
+197230367 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.996102691 -0.010129508 0.087617576 -0.013035317 0.010638822 0.999929130 -0.005347892 0.673139255 -0.087557197 0.006259197 0.996139824 -0.073934910
+197530667 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995880842 -0.009925503 0.090126604 -0.036202423 0.010367444 0.999936402 -0.004436717 0.655632681 -0.090076834 0.005352824 0.995920420 0.017267095
+197797600 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995802402 -0.010077500 0.090972595 -0.060858524 0.010445373 0.999939084 -0.003568561 0.618604505 -0.090931088 0.004503824 0.995846987 0.133592270
+198097900 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995846093 -0.010148350 0.090484887 -0.077962281 0.010412642 0.999942780 -0.002449236 0.561822755 -0.090454854 0.003381249 0.995894790 0.274195378
+198364833 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.995989919 -0.009936163 0.088912196 -0.082315587 0.010200773 0.999944806 -0.002522171 0.520613290 -0.088882230 0.003419030 0.996036291 0.395169547
+198665133 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.997159958 -0.009351323 0.074730076 -0.068472873 0.009822783 0.999934077 -0.005943770 0.466061412 -0.074669570 0.006660947 0.997186065 0.549834051
+198932067 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998626053 -0.008290987 0.051742285 -0.037270541 0.008407482 0.999962568 -0.002034174 0.410440195 -0.051723484 0.002466401 0.998658419 0.690111645
+199232367 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.999980092 -0.004756952 0.004140501 -0.005957613 0.004773445 0.999980688 -0.003982662 0.354437092 -0.004121476 0.004002347 0.999983490 0.842797271
+199499300 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.998872638 0.001147069 -0.047456335 0.002603018 -0.001435435 0.999980688 -0.006042828 0.295339877 0.047448486 0.006104136 0.998855054 0.988644188
+199799600 0.507650910 0.902490531 0.500000000 0.500000000 0.000000000 0.000000000 0.992951691 0.008710741 -0.118199304 -0.030798243 -0.009495872 0.999936402 -0.006080875 0.208803899 0.118138820 0.007160421 0.992971301 1.161643267
diff --git a/assets/pose_files/0f68374b76390082.png b/assets/pose_files/0f68374b76390082.png
new file mode 100644
index 0000000000000000000000000000000000000000..632ae498002c03cfcb07fb407fa10c3da267c8eb
Binary files /dev/null and b/assets/pose_files/0f68374b76390082.png differ
diff --git a/assets/pose_files/0f68374b76390082.txt b/assets/pose_files/0f68374b76390082.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d6fc78bb4f4bc95aeb5e9a7b5fc2e1272e296637
--- /dev/null
+++ b/assets/pose_files/0f68374b76390082.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=-aldZQifF2U
+103736967 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.804089785 -0.073792785 0.589910388 -2.686968354 0.081914566 0.996554494 0.013005137 0.128970374 -0.588837504 0.037864953 0.807363987 -1.789505608
+104003900 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.772824645 -0.077280566 0.629896700 -2.856354365 0.084460691 0.996253133 0.018602582 0.115028772 -0.628974140 0.038824979 0.776456118 -1.799931844
+104270833 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.740043461 -0.078656308 0.667943776 -3.017167990 0.086847030 0.995998919 0.021066183 0.116867188 -0.666928232 0.042419042 0.743913531 -1.815074499
+104571133 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.696879685 -0.073477358 0.713414192 -3.221640235 0.086792909 0.996067226 0.017807571 0.133618379 -0.711916924 0.049509555 0.700516284 -1.784051774
+104838067 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.654997289 -0.066671766 0.752684176 -3.418233112 0.086666502 0.996154904 0.012819566 0.161623584 -0.750644684 0.056835718 0.658256948 -1.733288907
+105138367 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.603833497 -0.059696361 0.794871926 -3.619566170 0.087576874 0.996123314 0.008281946 0.184519895 -0.792284906 0.064611480 0.606720686 -1.643568460
+105405300 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.555575073 -0.055402864 0.829618514 -3.768244320 0.089813948 0.995938241 0.006363695 0.197587954 -0.826601386 0.070975810 0.558294415 -1.559717271
+105705600 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.501615226 -0.052979972 0.863467038 -3.914896511 0.093892507 0.995560884 0.006539768 0.201989601 -0.859980464 0.077792637 0.504362881 -1.476983336
+105972533 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.454045177 -0.052372806 0.889438093 -4.034987790 0.099656843 0.994991958 0.007714771 0.211683202 -0.885387778 0.085135736 0.456990600 -1.405070279
+106272833 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.397668689 -0.051785514 0.916066527 -4.178181130 0.105599925 0.994354606 0.010369749 0.208751884 -0.911431968 0.092612833 0.400892258 -1.295093582
+106539767 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.345666498 -0.052948993 0.936862350 -4.285116664 0.110631727 0.993743002 0.015344846 0.195070069 -0.931812882 0.098342501 0.349361509 -1.182773054
+106840067 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.284817457 -0.055293880 0.956985712 -4.392320606 0.115495987 0.993041575 0.023003323 0.168523273 -0.951598525 0.103976257 0.289221793 -1.053514096
+107107000 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.228878200 -0.056077410 0.971838534 -4.485196000 0.120451130 0.992298782 0.028890507 0.159180748 -0.965974271 0.110446639 0.233870149 -0.923927626
+107407300 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.162932962 -0.053445265 0.985188544 -4.601126217 0.124115810 0.991709769 0.033272449 0.152041098 -0.978799343 0.116856292 0.168215603 -0.758111250
+107674233 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.102818660 -0.051196381 0.993381739 -4.691710857 0.127722457 0.991087139 0.037858382 0.141352300 -0.986466050 0.122984610 0.108441174 -0.599244073
+107974533 0.474175212 0.842978122 0.500000000 0.500000000 0.000000000 0.000000000 0.034108389 -0.050325166 0.998150289 -4.758242879 0.132215530 0.990180492 0.045405328 0.118994547 -0.990633965 0.130422264 0.040427230 -0.433560831
diff --git a/assets/pose_files/2c80f9eb0d3b2bb4.png b/assets/pose_files/2c80f9eb0d3b2bb4.png
new file mode 100644
index 0000000000000000000000000000000000000000..258aef1a9006c4334a602608837d822fea73248c
Binary files /dev/null and b/assets/pose_files/2c80f9eb0d3b2bb4.png differ
diff --git a/assets/pose_files/2c80f9eb0d3b2bb4.txt b/assets/pose_files/2c80f9eb0d3b2bb4.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b7873f812fa8cd65bbfde387c6ebb7f34d44c94d
--- /dev/null
+++ b/assets/pose_files/2c80f9eb0d3b2bb4.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=sLIFyXD2ujI
+77444033 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.980310440 0.030424286 -0.195104495 -0.195846403 -0.034550700 0.999244750 -0.017780757 0.034309913 0.194416180 0.024171660 0.980621278 -0.178639121
+77610867 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.973806441 0.034138829 -0.224801064 -0.221452338 -0.039088678 0.999080658 -0.017603843 0.038706263 0.223993421 0.025929911 0.974245667 -0.219951444
+77777700 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.965910375 0.037603889 -0.256131083 -0.242696017 -0.043735024 0.998875856 -0.018281631 0.046505467 0.255155712 0.028860316 0.966469169 -0.265310453
+77944533 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.956261098 0.040829532 -0.289650917 -0.252766079 -0.048421524 0.998644531 -0.019089982 0.054620904 0.288478881 0.032280345 0.956941962 -0.321621308
+78144733 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.941536248 0.043692805 -0.334066480 -0.250198162 -0.053955212 0.998311937 -0.021497937 0.069548726 0.332563221 0.038265716 0.942304313 -0.401964240
+78311567 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.926658213 0.046350952 -0.373036444 -0.239336491 -0.058738846 0.998033047 -0.021904159 0.077439241 0.371287435 0.042209402 0.927558064 -0.474019461
+78478400 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.909880757 0.048629351 -0.412009954 -0.218247042 -0.063676558 0.997708619 -0.022863906 0.088967126 0.409954011 0.047038805 0.910892427 -0.543114491
+78645233 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.891359746 0.050869841 -0.450433195 -0.185763327 -0.067926541 0.997452736 -0.021771761 0.093745158 0.448178291 0.050002839 0.892544627 -0.611223637
+78845433 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.870108902 0.054302387 -0.489858925 -0.153515269 -0.074510135 0.996981323 -0.021829695 0.107765162 0.487194777 0.055493668 0.871528387 -0.691303250
+79012267 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.852772951 0.056338910 -0.519234240 -0.128052677 -0.078825951 0.996660411 -0.021319628 0.116291007 0.516299069 0.059109934 0.854366004 -0.760654136
+79179100 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.835254073 0.059146367 -0.546673834 -0.101344556 -0.084243484 0.996225357 -0.020929486 0.126763936 0.543372452 0.063535146 0.837083995 -0.832841061
+79345933 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.818755865 0.062443536 -0.570736051 -0.077325807 -0.089739971 0.995768547 -0.019791666 0.136091605 0.567085147 0.067422375 0.820895016 -0.908256727
+79546133 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.798365474 0.066207208 -0.598522484 -0.043774887 -0.096616283 0.995144248 -0.018795265 0.150808225 0.594371796 0.072832510 0.800885499 -0.994657638
+79712967 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.781648815 0.069040783 -0.619885862 -0.013285614 -0.101820730 0.994646847 -0.017611075 0.161173621 0.615351617 0.076882906 0.784494340 -1.070102980
+79879800 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.765168309 0.072694756 -0.639713168 0.019850080 -0.108554602 0.993946910 -0.016894773 0.177612448 0.634612799 0.082371153 0.768428028 -1.147576811
+80080000 0.483930168 0.860320329 0.500000000 0.500000000 0.000000000 0.000000000 0.745406330 0.077463314 -0.662094295 0.062107046 -0.117075674 0.993000031 -0.015629012 0.200140798 0.656248927 0.089165099 0.749257565 -1.238600776
diff --git a/assets/pose_files/2f25826f0d0ef09a.png b/assets/pose_files/2f25826f0d0ef09a.png
new file mode 100644
index 0000000000000000000000000000000000000000..b17a8f1369b58d514dcc718c5db486c7f3b85854
Binary files /dev/null and b/assets/pose_files/2f25826f0d0ef09a.png differ
diff --git a/assets/pose_files/2f25826f0d0ef09a.txt b/assets/pose_files/2f25826f0d0ef09a.txt
new file mode 100644
index 0000000000000000000000000000000000000000..316703a9f682933e388f5e2e52dd825dd9921a80
--- /dev/null
+++ b/assets/pose_files/2f25826f0d0ef09a.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=t-mlAKnESzQ
+167200000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991872013 -0.011311784 0.126735851 0.400533760 0.012037775 0.999915242 -0.004963919 -0.047488550 -0.126668960 0.006449190 0.991924107 -0.414499612
+167467000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.991945148 -0.011409644 0.126153216 0.506974565 0.012122569 0.999914587 -0.004884966 -0.069421149 -0.126086697 0.006374919 0.991998732 -0.517325825
+167734000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.992271781 -0.010751382 0.123616949 0.590358341 0.011312660 0.999928653 -0.003839425 -0.085158661 -0.123566844 0.005208189 0.992322564 -0.599035085
+168034000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993287027 -0.009973313 0.115245141 0.673577580 0.010455138 0.999938965 -0.003577147 -0.104263255 -0.115202427 0.004758038 0.993330657 -0.691557669
+168301000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.993988216 -0.009955904 0.109033749 0.753843765 0.010435819 0.999938190 -0.003831771 -0.106670354 -0.108988866 0.004946592 0.994030654 -0.805538867
+168602000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.994774222 -0.010583352 0.101549298 0.846176230 0.011122120 0.999926925 -0.004740742 -0.089426372 -0.101491705 0.005845411 0.994819224 -0.933449460
+168869000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.995415390 -0.010595482 0.095057681 0.913119395 0.011053002 0.999929726 -0.004287821 -0.072756893 -0.095005572 0.005318835 0.995462537 -1.037255409
+169169000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996029556 -0.009414902 0.088523701 0.977259045 0.009879347 0.999939620 -0.004809874 -0.042104006 -0.088473074 0.005665333 0.996062458 -1.127427189
+169436000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.996695757 -0.008890423 0.080737323 1.025351476 0.009221899 0.999950528 -0.003733651 -0.007486727 -0.080700137 0.004465866 0.996728420 -1.188659636
+169736000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997404695 -0.008404067 0.071506783 1.073562767 0.008649707 0.999957681 -0.003126226 0.054879890 -0.071477488 0.003736625 0.997435212 -1.216979926
+170003000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.997887254 -0.008228444 0.064446673 1.110116903 0.008409287 0.999961436 -0.002535321 0.124372514 -0.064423330 0.003071915 0.997917950 -1.231904045
+170303000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998332024 -0.007790270 0.057205603 1.136173895 0.007975516 0.999963641 -0.003010646 0.212542522 -0.057180069 0.003461868 0.998357892 -1.242942079
+170570000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998471320 -0.007715963 0.054730706 1.159189486 0.007868989 0.999965727 -0.002581036 0.310163907 -0.054708913 0.003007766 0.998497844 -1.245661417
+170871000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998552144 -0.007742116 0.053231847 1.173763753 0.007991423 0.999958038 -0.004472161 0.412779543 -0.053194992 0.004891084 0.998572171 -1.229165757
+171137000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998553872 -0.007909958 0.053175092 1.179029258 0.008138723 0.999958515 -0.004086919 0.509089997 -0.053140558 0.004513786 0.998576820 -1.196146494
+171438000 0.470983989 0.837304886 0.500000000 0.500000000 0.000000000 0.000000000 0.998469293 -0.008281939 0.054685175 1.181414517 0.008542870 0.999953210 -0.004539483 0.618089736 -0.054645021 0.004999703 0.998493314 -1.159911786
diff --git a/assets/pose_files/3c35b868a8ec3433.png b/assets/pose_files/3c35b868a8ec3433.png
new file mode 100644
index 0000000000000000000000000000000000000000..630b1531dc75b26ce41936b6a74172ac20d291c8
Binary files /dev/null and b/assets/pose_files/3c35b868a8ec3433.png differ
diff --git a/assets/pose_files/3c35b868a8ec3433.txt b/assets/pose_files/3c35b868a8ec3433.txt
new file mode 100644
index 0000000000000000000000000000000000000000..671dc924f15cfe4228560b9e72103f5d3ab1e2f5
--- /dev/null
+++ b/assets/pose_files/3c35b868a8ec3433.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=bJyPo9pESu0
+189622767 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.966956913 0.041186374 -0.251590967 0.235831829 -0.037132759 0.999092996 0.020840336 0.069818943 0.252221137 -0.010809440 0.967609227 -0.850289525
+189789600 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.967445135 0.041703269 -0.249621317 0.217678822 -0.037349533 0.999056637 0.022154763 0.078295447 0.250309765 -0.012110277 0.968090057 -0.818677483
+189956433 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.967587769 0.043305319 -0.248794302 0.196350216 -0.038503598 0.998966932 0.024136283 0.085749990 0.249582499 -0.013774496 0.968255579 -0.778043636
+190123267 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.967742383 0.044170257 -0.248039767 0.170234078 -0.039154600 0.998917341 0.025120445 0.090556068 0.248880804 -0.014598221 0.968424082 -0.733500964
+190323467 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.973553717 0.043272153 -0.224322766 0.120337922 -0.038196862 0.998907626 0.026917407 0.091227451 0.225242496 -0.017637115 0.974143088 -0.680520640
+190490300 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.984184802 0.039637893 -0.172653258 0.065019106 -0.035194401 0.998967648 0.028723357 0.090969669 0.173613548 -0.022192664 0.984563768 -0.638603728
+190657133 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.993078411 0.035358477 -0.112004772 0.011571313 -0.032207530 0.999036312 0.029818388 0.092482656 0.112951167 -0.026004599 0.993260205 -0.588118143
+190823967 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.997807920 0.031166473 -0.058378015 -0.027908508 -0.029339414 0.999060452 0.031897116 0.092538838 0.059317287 -0.030114418 0.997784853 -0.529325066
+191024167 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.999654651 0.026247790 -0.001263706 -0.064570799 -0.026190240 0.999087334 0.033742432 0.091922841 0.002148218 -0.033697683 0.999429762 -0.448626929
+191191000 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.998773992 0.022079065 0.044305529 -0.084478169 -0.023622099 0.999121666 0.034611158 0.087434649 -0.043502431 -0.035615314 0.998418272 -0.371306296
+191357833 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.995725632 0.017435640 0.090699598 -0.094868572 -0.020876031 0.999092638 0.037122324 0.082208324 -0.089970052 -0.038857099 0.995186150 -0.290596011
+191524667 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.989503622 0.013347236 0.143890470 -0.096537122 -0.019140780 0.999057651 0.038954727 0.079283141 -0.143234938 -0.041300017 0.988826632 -0.207477308
+191724867 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.975660741 0.006981443 0.219174415 -0.085240259 -0.016479453 0.999001026 0.041537181 0.072219148 -0.218665481 -0.044138070 0.974801123 -0.112100139
+191891700 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.955792487 -0.000511726 0.294041574 -0.064476318 -0.012924311 0.998958945 0.043749433 0.061688334 -0.293757826 -0.045615666 0.954790831 -0.034724173
+192058533 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.925362229 -0.009678445 0.378960580 -0.029417786 -0.008889219 0.998845160 0.047216032 0.058476640 -0.378979892 -0.047060598 0.924207509 0.042010383
+192258733 0.474122545 0.842884498 0.500000000 0.500000000 0.000000000 0.000000000 0.872846186 -0.021581186 0.487517983 0.038433307 -0.004890797 0.998584569 0.052961230 0.057516307 -0.487970918 -0.048611358 0.871505201 0.124675285
diff --git a/assets/pose_files/3f79dc32d575bcdc.png b/assets/pose_files/3f79dc32d575bcdc.png
new file mode 100644
index 0000000000000000000000000000000000000000..9dc6cc6eebe5029ee4cf3cc28c7e97030e2af38b
Binary files /dev/null and b/assets/pose_files/3f79dc32d575bcdc.png differ
diff --git a/assets/pose_files/3f79dc32d575bcdc.txt b/assets/pose_files/3f79dc32d575bcdc.txt
new file mode 100644
index 0000000000000000000000000000000000000000..9ece1affc7fd92eb5986758c84b2f19d3a8edefd
--- /dev/null
+++ b/assets/pose_files/3f79dc32d575bcdc.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=1qVpRlWxam4
+86319567 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999183893 0.038032386 -0.013605987 -0.249154748 -0.038085770 0.999267697 -0.003686040 0.047875167 0.013455833 0.004201226 0.999900639 -0.566803149
+86586500 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999392629 0.034676589 -0.003445767 -0.282371175 -0.034685481 0.999395013 -0.002555777 0.057086778 0.003355056 0.002673743 0.999990821 -0.624021456
+86853433 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999498725 0.028301919 0.014187563 -0.320995587 -0.028301118 0.999599397 -0.000257314 0.061367205 -0.014189162 -0.000144339 0.999899328 -0.706664680
+87153733 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999064565 0.022049030 0.037200645 -0.371910835 -0.022201553 0.999746680 0.003691827 0.063911726 -0.037109818 -0.004514286 0.999301016 -0.799748814
+87420667 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998171926 0.018552339 0.057520505 -0.440220060 -0.018887693 0.999807596 0.005291941 0.070160264 -0.057411261 -0.006368696 0.998330295 -0.853433007
+87720967 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997675776 0.016262729 0.066170901 -0.486385324 -0.016915560 0.999813497 0.009317505 0.069230577 -0.066007033 -0.010415167 0.997764826 -0.912234761
+87987900 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.998019218 0.015867118 0.060876362 -0.497549423 -0.016505934 0.999813735 0.010005167 0.076295227 -0.060706269 -0.010990170 0.998095155 -0.980435972
+88288200 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999152124 0.018131699 0.036962789 -0.468507446 -0.018461898 0.999792457 0.008611582 0.087696066 -0.036798976 -0.009286684 0.999279559 -1.074633197
+88555133 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999717414 0.022977378 0.006097841 -0.420528982 -0.023013741 0.999717355 0.005961678 0.101216630 -0.005959134 -0.006100327 0.999963641 -1.169004730
+88855433 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.999106526 0.030726369 -0.029017152 -0.374249594 -0.030677194 0.999527037 0.002138488 0.120936030 0.029069137 -0.001246413 0.999576628 -1.251082317
+89122367 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.997359693 0.039784521 -0.060752310 -0.335843098 -0.039773725 0.999207735 0.001387495 0.132824955 0.060759377 0.001032514 0.998151898 -1.312258423
+89422667 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.992973983 0.050480653 -0.107025139 -0.253623964 -0.050627887 0.998716712 0.001342622 0.144421611 0.106955573 0.004085269 0.994255424 -1.394020432
+89689600 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.986886561 0.059628733 -0.149997801 -0.173418608 -0.059660275 0.998209476 0.004293700 0.142984494 0.149985254 0.004711515 0.988677025 -1.462588413
+89989900 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.978200734 0.067550205 -0.196367815 -0.089199207 -0.067402542 0.997698128 0.007442682 0.141665403 0.196418539 0.005955252 0.980502069 -1.524381413
+90256833 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.967793405 0.073765829 -0.240695804 0.013635864 -0.073441446 0.997246027 0.010330606 0.134276795 0.240794986 0.007679154 0.970545650 -1.588498428
+90557133 0.487278048 0.866272132 0.500000000 0.500000000 0.000000000 0.000000000 0.953711152 0.081056722 -0.289594263 0.148156165 -0.081249826 0.996628821 0.011376631 0.129987979 0.289540142 0.012679463 0.957081914 -1.633951355
diff --git a/assets/pose_files/4a2d6753676df096.png b/assets/pose_files/4a2d6753676df096.png
new file mode 100644
index 0000000000000000000000000000000000000000..20481d7ea0c8e3de101ab6d30e87b984cce03547
Binary files /dev/null and b/assets/pose_files/4a2d6753676df096.png differ
diff --git a/assets/pose_files/4a2d6753676df096.txt b/assets/pose_files/4a2d6753676df096.txt
new file mode 100644
index 0000000000000000000000000000000000000000..74b9d4583c1ba540796caa2ad01edaf887b27d51
--- /dev/null
+++ b/assets/pose_files/4a2d6753676df096.txt
@@ -0,0 +1,17 @@
+https://www.youtube.com/watch?v=mGFQkgadzRQ
+123665000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.996869564 0.002875770 -0.079011612 -0.427841466 -0.002861131 0.999995887 0.000298484 -0.005788880 0.079012141 -0.000071487 0.996873677 0.132732609
+123999000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.993462563 0.003229393 -0.114112593 -0.472377562 -0.003208589 0.999994814 0.000365978 -0.005932507 0.114113182 0.000002555 0.993467748 0.123959606
+124332000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.988605380 0.003602870 -0.150487319 -0.517270184 -0.003599323 0.999993503 0.000295953 -0.005751638 0.150487408 0.000249071 0.988611877 0.113156366
+124708000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.981692851 0.004048047 -0.190427750 -0.566330350 -0.004096349 0.999991596 0.000139980 -0.007622665 0.190426722 0.000642641 0.981701195 0.098572887
+125041000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.974759340 0.004326052 -0.223216295 -0.606091424 -0.004403458 0.999990284 0.000150970 -0.009427620 0.223214790 0.000835764 0.974768937 0.084984909
+125417000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.965238512 0.004419941 -0.261333257 -0.651601078 -0.004571608 0.999989569 0.000027561 -0.007437027 0.261330664 0.001168111 0.965248644 0.068577736
+125750000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.953956902 0.004390486 -0.299911648 -0.697081969 -0.004806366 0.999988258 -0.000648964 -0.003676960 0.299905270 0.002060569 0.953966737 0.050264043
+126126000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.940579295 0.004839818 -0.339539677 -0.744385684 -0.005527717 0.999984145 -0.001058831 -0.001820489 0.339529186 0.002872794 0.940591156 0.028560147
+126459000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.928297341 0.004980532 -0.371805429 -0.781716025 -0.005848793 0.999982178 -0.001207554 -0.001832299 0.371792793 0.003295582 0.928309917 0.009470658
+126835000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.913324535 0.005156573 -0.407199889 -0.824074795 -0.006227055 0.999979734 -0.001303667 -0.001894351 0.407184929 0.003726327 0.913338125 -0.013179829
+127168000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.898822486 0.005400294 -0.438279599 -0.860775204 -0.006702366 0.999976516 -0.001423908 -0.001209170 0.438261628 0.004217350 0.898837566 -0.034594674
+127544000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.880397439 0.005455900 -0.474205226 -0.903308447 -0.007032821 0.999974072 -0.001551900 -0.000798134 0.474184483 0.004701289 0.880412936 -0.061250069
+127877000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.862660766 0.005402398 -0.505754173 -0.939888304 -0.007276668 0.999972045 -0.001730187 -0.000489221 0.505730629 0.005172769 0.862675905 -0.086411685
+128253000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.841714203 0.005442667 -0.539895892 -0.978630821 -0.007698633 0.999968529 -0.001921765 0.000975953 0.539868414 0.005774037 0.841729641 -0.115983579
+128587000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.823229551 0.005282366 -0.567684054 -1.010071242 -0.007977336 0.999965608 -0.002263572 0.002284809 0.567652583 0.006392045 0.823243380 -0.141444392
+128962000 0.591609280 1.051749871 0.500000000 0.500000000 0.000000000 0.000000000 0.802855015 0.005112482 -0.596152425 -1.042319682 -0.008217614 0.999963105 -0.002491409 0.003637235 0.596117735 0.006899191 0.802867413 -0.169369454
diff --git a/assets/pose_files/color_bar.png b/assets/pose_files/color_bar.png
new file mode 100644
index 0000000000000000000000000000000000000000..6c73f221640b043844ab2f7b9c504ddba25e5b0f
Binary files /dev/null and b/assets/pose_files/color_bar.png differ
diff --git a/assets/pose_files/complex_1.png b/assets/pose_files/complex_1.png
new file mode 100644
index 0000000000000000000000000000000000000000..3ae49c98f2eb8ceace2e86ff04dd6c7b4c712e84
--- /dev/null
+++ b/assets/pose_files/complex_1.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:12e6411790a6d0d21b9ee25cf6b540016d58ddc9dfe11943c02b23d33cb36c40
+size 184778
diff --git a/assets/pose_files/complex_1.pth b/assets/pose_files/complex_1.pth
new file mode 100644
index 0000000000000000000000000000000000000000..d421750a27a93d51f18b175578f6d93d8ab42087
--- /dev/null
+++ b/assets/pose_files/complex_1.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fa358d46a1d516f4a9186928b697af3167a57c8996a050bca61d3ae62316c2e2
+size 1521
diff --git a/assets/pose_files/complex_2.png b/assets/pose_files/complex_2.png
new file mode 100644
index 0000000000000000000000000000000000000000..09868a1dde24a3735648dc161d14a46332369e87
--- /dev/null
+++ b/assets/pose_files/complex_2.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c2efb207f514382d5d8927e853e94c3707727f576687284cc338dd6e3fdc58f1
+size 183569
diff --git a/assets/pose_files/complex_2.pth b/assets/pose_files/complex_2.pth
new file mode 100644
index 0000000000000000000000000000000000000000..0cd94f6ae58e09d2c923a8a30e8fd031e9274016
--- /dev/null
+++ b/assets/pose_files/complex_2.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4643626ae50680c8ff1c14b9507ffb840e6bbd21a73d7845bc45c454591073ba
+size 1521
diff --git a/assets/pose_files/complex_3.png b/assets/pose_files/complex_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..f40c9e800fb6aefd3e35444abca9f0c0de1ff926
--- /dev/null
+++ b/assets/pose_files/complex_3.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6342bd2532e81cc5257891667c0c76e1eec7d09bcfafee41098306169d8e2e63
+size 199701
diff --git a/assets/pose_files/complex_3.pth b/assets/pose_files/complex_3.pth
new file mode 100644
index 0000000000000000000000000000000000000000..4040a9cb930ba6c3935cea5bd220677b240c9ab1
--- /dev/null
+++ b/assets/pose_files/complex_3.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:82af5c5211af04db50521b4e5ca64196a4965ee8f3074ab18d52023c488834d0
+size 1521
diff --git a/assets/pose_files/complex_4.png b/assets/pose_files/complex_4.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab7a1594000f9cbef74ed0f7ca246d90b4131a02
--- /dev/null
+++ b/assets/pose_files/complex_4.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d36a39d30bdec7a43029683f34107ff9b645801eed0fe9b9d823689ff71227b6
+size 185156
diff --git a/assets/pose_files/complex_4.pth b/assets/pose_files/complex_4.pth
new file mode 100644
index 0000000000000000000000000000000000000000..77074ddac59dc6f2d02c9fe2fb77f8d2b5bd36f2
--- /dev/null
+++ b/assets/pose_files/complex_4.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c045a7176f6f3ee95c517706e8fabf6c548d1a6fe0d8cf139a4a46955196d50d
+size 1521
diff --git a/assets/sea.jpg b/assets/sea.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..deece29c1be91ec9d48d9ff1d8f3dc2b0ccc7dd6
--- /dev/null
+++ b/assets/sea.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:bef05aeada45bbd68ca6e39f267dcf29a9fd223deb6c23aca03ca4af063a74fd
+size 1500066
diff --git a/assets/syn_video_control1.mp4 b/assets/syn_video_control1.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..b2cd75f20fc87064f7274dc1161f4fa60f533cd7
--- /dev/null
+++ b/assets/syn_video_control1.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:62a819a54311b19c9187d0561f404ffe1a37bea0d1cea065ef742ec880a91dce
+size 597937
diff --git a/assets/syn_video_control2.mp4 b/assets/syn_video_control2.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d034c947c5abcdf41c5cd5ebcaa18d096e180af9
--- /dev/null
+++ b/assets/syn_video_control2.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:efee5c51d2db1bd3f075c4269ba62482f454d735b54760d91a34358893696cb0
+size 952736
diff --git a/data/dot_single_video/checkpoints/cvo_raft_patch_8.pth b/data/dot_single_video/checkpoints/cvo_raft_patch_8.pth
new file mode 100644
index 0000000000000000000000000000000000000000..c0f920d7c2fabb676de4325c77fd4cc7b5bd84b2
--- /dev/null
+++ b/data/dot_single_video/checkpoints/cvo_raft_patch_8.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c1422383b0e61f3c8543c4bedf6c64087675421cc37616a60faa580bbadddc51
+size 21093617
diff --git a/data/dot_single_video/checkpoints/movi_f_cotracker2_patch_4_wind_8.pth b/data/dot_single_video/checkpoints/movi_f_cotracker2_patch_4_wind_8.pth
new file mode 100644
index 0000000000000000000000000000000000000000..2113d57ff1b4ccae0645ee86f39f91f84ed65209
--- /dev/null
+++ b/data/dot_single_video/checkpoints/movi_f_cotracker2_patch_4_wind_8.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:89e7585d3e95d6e2bc4ff74ce072a98f70377047e669d44fa0b5c01311d4f54c
+size 204379466
diff --git a/data/dot_single_video/checkpoints/movi_f_cotracker_patch_4_wind_8.pth b/data/dot_single_video/checkpoints/movi_f_cotracker_patch_4_wind_8.pth
new file mode 100644
index 0000000000000000000000000000000000000000..2bd5a940c61fd3bb20c21fdb3942f4df658e3374
--- /dev/null
+++ b/data/dot_single_video/checkpoints/movi_f_cotracker_patch_4_wind_8.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:8788efe1d7c462757200605a2bcfd357a76d385ae57769be605016d7f9bbb1d5
+size 96650477
diff --git a/data/dot_single_video/checkpoints/movi_f_raft_patch_4_alpha.pth b/data/dot_single_video/checkpoints/movi_f_raft_patch_4_alpha.pth
new file mode 100644
index 0000000000000000000000000000000000000000..7841f89df64c6278eb5cfd57cc44e0a6bb80db02
--- /dev/null
+++ b/data/dot_single_video/checkpoints/movi_f_raft_patch_4_alpha.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ae2b478a99f1948384c237f0935eff2a61be4f7b736493c922b2fb5223117fb9
+size 23173365
diff --git a/data/dot_single_video/configs/cotracker2_patch_4_wind_8.json b/data/dot_single_video/configs/cotracker2_patch_4_wind_8.json
new file mode 100644
index 0000000000000000000000000000000000000000..c777e3bc25ee578ccb183bfef0c922b6f9e221aa
--- /dev/null
+++ b/data/dot_single_video/configs/cotracker2_patch_4_wind_8.json
@@ -0,0 +1,5 @@
+{
+ "name": "cotracker2",
+ "patch_size": 4,
+ "wind_size": 8
+}
\ No newline at end of file
diff --git a/data/dot_single_video/configs/cotracker_patch_4_wind_8.json b/data/dot_single_video/configs/cotracker_patch_4_wind_8.json
new file mode 100644
index 0000000000000000000000000000000000000000..2d24aabd72ca712de22263f0116c7900509ebdde
--- /dev/null
+++ b/data/dot_single_video/configs/cotracker_patch_4_wind_8.json
@@ -0,0 +1,5 @@
+{
+ "name": "cotracker",
+ "patch_size": 4,
+ "wind_size": 8
+}
\ No newline at end of file
diff --git a/data/dot_single_video/configs/dot_single_video_1105.yaml b/data/dot_single_video/configs/dot_single_video_1105.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0185107f11f1f2a0533692694b8c233c2678c8d7
--- /dev/null
+++ b/data/dot_single_video/configs/dot_single_video_1105.yaml
@@ -0,0 +1,23 @@
+dot_model:
+ height: 320
+ width: 512
+ tracker_config: data/dot_single_video/configs/cotracker2_patch_4_wind_8.json
+ tracker_path: data/dot_single_video/checkpoints/movi_f_cotracker2_patch_4_wind_8.pth
+ estimator_config: data/dot_single_video/configs/raft_patch_8.json
+ estimator_path: data/dot_single_video/checkpoints/cvo_raft_patch_8.pth
+ refiner_config: data/dot_single_video/configs/raft_patch_4_alpha.json
+ refiner_path: data/dot_single_video/checkpoints/movi_f_raft_patch_4_alpha.pth
+
+inference_config:
+ mode: tracks_from_first_to_every_other_frame
+ return_flow: true # ! important prams
+
+ num_tracks: 8192
+ sim_tracks: 2048
+ sample_mode: all
+
+ is_train: false
+ interpolation_version: torch3d
+ alpha_thresh: 0.8
+
+
diff --git a/data/dot_single_video/configs/raft_patch_4_alpha.json b/data/dot_single_video/configs/raft_patch_4_alpha.json
new file mode 100644
index 0000000000000000000000000000000000000000..13706299bc40b4b73b4c9fc93690bc8cf1fa488e
--- /dev/null
+++ b/data/dot_single_video/configs/raft_patch_4_alpha.json
@@ -0,0 +1,8 @@
+{
+ "name": "raft",
+ "patch_size": 4,
+ "num_iter": 4,
+ "refine_alpha": true,
+ "norm_fnet": "instance",
+ "norm_cnet": "instance"
+}
\ No newline at end of file
diff --git a/data/dot_single_video/configs/raft_patch_8.json b/data/dot_single_video/configs/raft_patch_8.json
new file mode 100644
index 0000000000000000000000000000000000000000..5fc0080c4af486337b916e772b9cbb93c967eedf
--- /dev/null
+++ b/data/dot_single_video/configs/raft_patch_8.json
@@ -0,0 +1,8 @@
+{
+ "name": "raft",
+ "patch_size": 8,
+ "num_iter": 12,
+ "refine_alpha": false,
+ "norm_fnet": "instance",
+ "norm_cnet": "batch"
+}
\ No newline at end of file
diff --git a/data/dot_single_video/dot/__init__.py b/data/dot_single_video/dot/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/dot_single_video/dot/data/cvo_dataset.py b/data/dot_single_video/dot/data/cvo_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..58e1fc8fe359eb67093734bf79026f1791eabc64
--- /dev/null
+++ b/data/dot_single_video/dot/data/cvo_dataset.py
@@ -0,0 +1,129 @@
+import os
+import os.path as osp
+from collections import OrderedDict
+
+import lmdb
+import torch
+import numpy as np
+import pickle as pkl
+from einops import rearrange
+from torch.utils.data import Dataset, DataLoader
+
+from dot.utils.torch import get_alpha_consistency
+
+
+class CVO_sampler_lmdb:
+ """Data sampling"""
+
+ all_keys = ["imgs", "imgs_blur", "fflows", "bflows", "delta_fflows", "delta_bflows"]
+
+ def __init__(self, data_root, keys=None, split=None):
+ if split == "extended":
+ self.db_path = osp.join(data_root, "cvo_test_extended.lmdb")
+ else:
+ self.db_path = osp.join(data_root, "cvo_test.lmdb")
+ self.split = split
+
+ self.env = lmdb.open(
+ self.db_path,
+ subdir=os.path.isdir(self.db_path),
+ readonly=True,
+ lock=False,
+ readahead=False,
+ meminit=False,
+ )
+ with self.env.begin(write=False) as txn:
+ self.samples = pkl.loads(txn.get(b"__samples__"))
+ self.length = len(self.samples)
+
+ self.keys = self.all_keys if keys is None else [x.lower() for x in keys]
+ self._check_keys(self.keys)
+
+ def _check_keys(self, keys):
+ # check keys are supported:
+ for k in keys:
+ assert k in self.all_keys, f"Invalid key value: {k}"
+
+ def __len__(self):
+ return self.length
+
+ def sample(self, index):
+ sample = OrderedDict()
+ with self.env.begin(write=False) as txn:
+ for k in self.keys:
+ key = "{:05d}_{:s}".format(index, k)
+ value = pkl.loads(txn.get(key.encode()))
+ if "flow" in key and self.split in ["clean", "final"]: # Convert Int to Floating
+ value = value.astype(np.float32)
+ value = (value - 2 ** 15) / 128.0
+ if "imgs" in k:
+ k = "imgs"
+ sample[k] = value
+ return sample
+
+
+class CVO(Dataset):
+ all_keys = ["fflows", "bflows"]
+
+ def __init__(self, data_root, keys=None, split="clean", crop_size=256):
+ keys = self.all_keys if keys is None else [x.lower() for x in keys]
+ self._check_keys(keys)
+ if split == "final":
+ keys.append("imgs_blur")
+ else:
+ keys.append("imgs")
+ self.split = split
+ self.sampler = CVO_sampler_lmdb(data_root, keys, split)
+
+ def __getitem__(self, index):
+ sample = self.sampler.sample(index)
+
+ video = torch.from_numpy(sample["imgs"].copy())
+ video = video / 255.0
+ video = rearrange(video, "h w (t c) -> t c h w", c=3)
+
+ fflow = torch.from_numpy(sample["fflows"].copy())
+ fflow = rearrange(fflow, "h w (t c) -> t h w c", c=2)[-1]
+
+ bflow = torch.from_numpy(sample["bflows"].copy())
+ bflow = rearrange(bflow, "h w (t c) -> t h w c", c=2)[-1]
+
+ if self.split in ["clean", "final"]:
+ thresh_1 = 0.01
+ thresh_2 = 0.5
+ elif self.split == "extended":
+ thresh_1 = 0.1
+ thresh_2 = 0.5
+ else:
+ raise ValueError(f"Unknown split {self.split}")
+
+ alpha = get_alpha_consistency(bflow[None], fflow[None], thresh_1=thresh_1, thresh_2=thresh_2)[0]
+
+ data = {
+ "video": video,
+ "alpha": alpha,
+ "flow": bflow
+ }
+
+ return data
+
+ def _check_keys(self, keys):
+ # check keys are supported:
+ for k in keys:
+ assert k in self.all_keys, f"Invalid key value: {k}"
+
+ def __len__(self):
+ return len(self.sampler)
+
+
+def create_optical_flow_dataset(args):
+ dataset = CVO(args.data_root, split=args.split)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ shuffle=False,
+ num_workers=0,
+ drop_last=False,
+ )
+ return dataloader
\ No newline at end of file
diff --git a/data/dot_single_video/dot/data/movi_f_dataset.py b/data/dot_single_video/dot/data/movi_f_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a7e3a3ad17a62dfde3ee49b04eb7e3747e5503d
--- /dev/null
+++ b/data/dot_single_video/dot/data/movi_f_dataset.py
@@ -0,0 +1,124 @@
+import os
+from glob import glob
+import random
+import numpy as np
+import torch
+from torch.utils import data
+
+from dot.utils.io import read_video, read_tracks
+
+
+def create_point_tracking_dataset(args, batch_size=1, split="train", num_workers=None, verbose=False):
+ dataset = Dataset(args, split, verbose)
+ dataloader = DataLoader(args, dataset, batch_size, split, num_workers)
+ return dataloader
+
+
+class DataLoader:
+ def __init__(self, args, dataset, batch_size=1, split="train", num_workers=None):
+ num_workers = args.num_workers if num_workers is None else num_workers
+ is_train = split == "train"
+ self.sampler = data.distributed.DistributedSampler(dataset, args.world_size, args.rank) if is_train else None
+ self.loader = data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ sampler=self.sampler,
+ )
+ self.epoch = -1
+ self.reinit()
+
+ def reinit(self):
+ self.epoch += 1
+ if self.sampler:
+ self.sampler.set_epoch(self.epoch)
+ self.iter = iter(self.loader)
+
+ def next(self):
+ try:
+ return next(self.iter)
+ except StopIteration:
+ self.reinit()
+ return next(self.iter)
+
+
+def get_correspondences(track_path, src_step, tgt_step, num_tracks, height, width, vis_src_only):
+ H, W = height, width
+ tracks = torch.from_numpy(read_tracks(track_path))
+ tracks[..., 0] = tracks[..., 0] / (W - 1)
+ tracks[..., 1] = tracks[..., 1] / (H - 1)
+ src_points = tracks[:, src_step]
+ tgt_points = tracks[:, tgt_step]
+ if vis_src_only:
+ src_alpha = src_points[..., 2]
+ vis_idx = torch.nonzero(src_alpha, as_tuple=True)[0]
+ num_vis = vis_idx.shape[0]
+ if num_vis == 0:
+ return False, None
+ samples = np.random.choice(num_vis, num_tracks, replace=num_tracks > num_vis)
+ idx = vis_idx[samples]
+ else:
+ idx = np.random.choice(tracks.size(0), num_tracks, replace=num_tracks > tracks.size(0))
+ return True, (src_points[idx], tgt_points[idx])
+
+
+class Dataset(data.Dataset):
+ def __init__(self, args, split="train", verbose=False):
+ super().__init__()
+ self.video_folder = os.path.join(args.data_root, "video")
+ self.in_track_folder = os.path.join(args.data_root, args.in_track_name)
+ self.out_track_folder = os.path.join(args.data_root, args.out_track_name)
+ self.num_in_tracks = args.num_in_tracks
+ self.num_out_tracks = args.num_out_tracks
+ num_videos = len(glob(os.path.join(self.video_folder, "*")))
+ self.video_steps = [
+ len(glob(os.path.join(self.video_folder, str(video_idx), "*"))) for video_idx in range(num_videos)
+ ]
+ video_indices = list(range(num_videos))
+ if split == "valid":
+ video_indices = video_indices[:int(num_videos * args.valid_ratio)]
+ elif split == "train":
+ video_indices = video_indices[int(num_videos * args.valid_ratio):]
+ self.video_indices = video_indices
+ self.num_videos = len(video_indices)
+ if verbose:
+ print(f"Created {split} dataset of length {self.num_videos}")
+
+ def __len__(self):
+ return self.num_videos
+
+ def __getitem__(self, idx):
+ idx = idx % self.num_videos
+ video_idx = self.video_indices[idx]
+ time_steps = self.video_steps[video_idx]
+ src_step = random.randrange(time_steps)
+ tgt_step = random.randrange(time_steps - 1)
+ tgt_step = (src_step + tgt_step) % time_steps
+
+ video_path = os.path.join(self.video_folder, str(video_idx))
+ src_frame = read_video(video_path, start_step=src_step, time_steps=1)[0]
+ tgt_frame = read_video(video_path, start_step=tgt_step, time_steps=1)[0]
+ _, H, W = src_frame.shape
+
+ in_track_path = os.path.join(self.in_track_folder, f"{video_idx}.npy")
+ out_track_path = os.path.join(self.out_track_folder, f"{video_idx}.npy")
+ vis_src_only = False
+ _, corr = get_correspondences(in_track_path, src_step, tgt_step, self.num_in_tracks, H, W, vis_src_only)
+ src_points, tgt_points = corr
+
+ vis_src_only = True
+ success, corr = get_correspondences(out_track_path, src_step, tgt_step, self.num_out_tracks, H, W, vis_src_only)
+ if not success:
+ return self[idx + 1]
+ out_src_points, out_tgt_points = corr
+
+ data = {
+ "src_frame": src_frame,
+ "tgt_frame": tgt_frame,
+ "src_points": src_points,
+ "tgt_points": tgt_points,
+ "out_src_points": out_src_points,
+ "out_tgt_points": out_tgt_points,
+ }
+
+ return data
diff --git a/data/dot_single_video/dot/data/movi_f_tf_dataset.py b/data/dot_single_video/dot/data/movi_f_tf_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2347e6daffb5daab282a72a6f885b0a5c6702cd4
--- /dev/null
+++ b/data/dot_single_video/dot/data/movi_f_tf_dataset.py
@@ -0,0 +1,1005 @@
+# Copyright 2023 The Kubric Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Kubric dataset with point tracking."""
+
+import functools
+import itertools
+
+import matplotlib.pyplot as plt
+import mediapy as media
+import numpy as np
+import tensorflow.compat.v1 as tf
+import tensorflow_datasets as tfds
+from tensorflow_graphics.geometry.transformation import rotation_matrix_3d
+
+
+def project_point(cam, point3d, num_frames):
+ """Compute the image space coordinates [0, 1] for a set of points.
+
+ Args:
+ cam: The camera parameters, as returned by kubric. 'matrix_world' and
+ 'intrinsics' have a leading axis num_frames.
+ point3d: Points in 3D world coordinates. it has shape [num_frames,
+ num_points, 3].
+ num_frames: The number of frames in the video.
+
+ Returns:
+ Image coordinates in 2D. The last coordinate is an indicator of whether
+ the point is behind the camera.
+ """
+
+ homo_transform = tf.linalg.inv(cam['matrix_world'])
+ homo_intrinsics = tf.zeros((num_frames, 3, 1), dtype=tf.float32)
+ homo_intrinsics = tf.concat([cam['intrinsics'], homo_intrinsics], axis=2)
+
+ point4d = tf.concat([point3d, tf.ones_like(point3d[:, :, 0:1])], axis=2)
+ projected = tf.matmul(point4d, tf.transpose(homo_transform, (0, 2, 1)))
+ projected = tf.matmul(projected, tf.transpose(homo_intrinsics, (0, 2, 1)))
+ image_coords = projected / projected[:, :, 2:3]
+ image_coords = tf.concat(
+ [image_coords[:, :, :2],
+ tf.sign(projected[:, :, 2:])], axis=2)
+ return image_coords
+
+
+def unproject(coord, cam, depth):
+ """Unproject points.
+
+ Args:
+ coord: Points in 2D coordinates. it has shape [num_points, 2]. Coord is in
+ integer (y,x) because of the way meshgrid happens.
+ cam: The camera parameters, as returned by kubric. 'matrix_world' and
+ 'intrinsics' have a leading axis num_frames.
+ depth: Depth map for the scene.
+
+ Returns:
+ Image coordinates in 3D.
+ """
+ shp = tf.convert_to_tensor(tf.shape(depth))
+ idx = coord[:, 0] * shp[1] + coord[:, 1]
+ coord = tf.cast(coord[..., ::-1], tf.float32)
+ shp = tf.cast(shp[1::-1], tf.float32)[tf.newaxis, ...]
+
+ # Need to convert from pixel to raster coordinate.
+ projected_pt = (coord + 0.5) / shp
+
+ projected_pt = tf.concat(
+ [
+ projected_pt,
+ tf.ones_like(projected_pt[:, -1:]),
+ ],
+ axis=-1,
+ )
+
+ camera_plane = projected_pt @ tf.linalg.inv(tf.transpose(cam['intrinsics']))
+ camera_ball = camera_plane / tf.sqrt(
+ tf.reduce_sum(
+ tf.square(camera_plane),
+ axis=1,
+ keepdims=True,
+ ), )
+ camera_ball *= tf.gather(tf.reshape(depth, [-1]), idx)[:, tf.newaxis]
+
+ camera_ball = tf.concat(
+ [
+ camera_ball,
+ tf.ones_like(camera_plane[:, 2:]),
+ ],
+ axis=1,
+ )
+ points_3d = camera_ball @ tf.transpose(cam['matrix_world'])
+ return points_3d[:, :3] / points_3d[:, 3:]
+
+
+def reproject(coords, camera, camera_pos, num_frames, bbox=None):
+ """Reconstruct points in 3D and reproject them to pixels.
+
+ Args:
+ coords: Points in 3D. It has shape [num_points, 3]. If bbox is specified,
+ these are assumed to be in local box coordinates (as specified by kubric),
+ and bbox will be used to put them into world coordinates; otherwise they
+ are assumed to be in world coordinates.
+ camera: the camera intrinsic parameters, as returned by kubric.
+ 'matrix_world' and 'intrinsics' have a leading axis num_frames.
+ camera_pos: the camera positions. It has shape [num_frames, 3]
+ num_frames: the number of frames in the video.
+ bbox: The kubric bounding box for the object. Its first axis is num_frames.
+
+ Returns:
+ Image coordinates in 2D and their respective depths. For the points,
+ the last coordinate is an indicator of whether the point is behind the
+ camera. They are of shape [num_points, num_frames, 3] and
+ [num_points, num_frames] respectively.
+ """
+ # First, reconstruct points in the local object coordinate system.
+ if bbox is not None:
+ coord_box = list(itertools.product([-.5, .5], [-.5, .5], [-.5, .5]))
+ coord_box = np.array([np.array(x) for x in coord_box])
+ coord_box = np.concatenate(
+ [coord_box, np.ones_like(coord_box[:, 0:1])], axis=1)
+ coord_box = tf.tile(coord_box[tf.newaxis, ...], [num_frames, 1, 1])
+ bbox_homo = tf.concat([bbox, tf.ones_like(bbox[:, :, 0:1])], axis=2)
+
+ local_to_world = tf.linalg.lstsq(tf.cast(coord_box, tf.float32), bbox_homo)
+ world_coords = tf.matmul(
+ tf.cast(
+ tf.concat([coords, tf.ones_like(coords[:, 0:1])], axis=1),
+ tf.float32)[tf.newaxis, :, :], local_to_world)
+ world_coords = world_coords[:, :, 0:3] / world_coords[:, :, 3:]
+ else:
+ world_coords = tf.tile(coords[tf.newaxis, :, :], [num_frames, 1, 1])
+
+ # Compute depths by taking the distance between the points and the camera
+ # center.
+ depths = tf.sqrt(
+ tf.reduce_sum(
+ tf.square(world_coords - camera_pos[:, np.newaxis, :]),
+ axis=2,
+ ), )
+
+ # Project each point back to the image using the camera.
+ projections = project_point(camera, world_coords, num_frames)
+
+ return (
+ tf.transpose(projections, (1, 0, 2)),
+ tf.transpose(depths),
+ tf.transpose(world_coords, (1, 0, 2)),
+ )
+
+
+def estimate_occlusion_by_depth_and_segment(
+ data,
+ segments,
+ x,
+ y,
+ num_frames,
+ thresh,
+ seg_id,
+):
+ """Estimate depth at a (floating point) x,y position.
+
+ We prefer overestimating depth at the point, so we take the max over the 4
+ neightoring pixels.
+
+ Args:
+ data: depth map. First axis is num_frames.
+ segments: segmentation map. First axis is num_frames.
+ x: x coordinate. First axis is num_frames.
+ y: y coordinate. First axis is num_frames.
+ num_frames: number of frames.
+ thresh: Depth threshold at which we consider the point occluded.
+ seg_id: Original segment id. Assume occlusion if there's a mismatch.
+
+ Returns:
+ Depth for each point.
+ """
+
+ # need to convert from raster to pixel coordinates
+ x = x - 0.5
+ y = y - 0.5
+
+ x0 = tf.cast(tf.floor(x), tf.int32)
+ x1 = x0 + 1
+ y0 = tf.cast(tf.floor(y), tf.int32)
+ y1 = y0 + 1
+
+ shp = tf.shape(data)
+ assert len(data.shape) == 3
+ x0 = tf.clip_by_value(x0, 0, shp[2] - 1)
+ x1 = tf.clip_by_value(x1, 0, shp[2] - 1)
+ y0 = tf.clip_by_value(y0, 0, shp[1] - 1)
+ y1 = tf.clip_by_value(y1, 0, shp[1] - 1)
+
+ data = tf.reshape(data, [-1])
+ rng = tf.range(num_frames)[:, tf.newaxis]
+ i1 = tf.gather(data, rng * shp[1] * shp[2] + y0 * shp[2] + x0)
+ i2 = tf.gather(data, rng * shp[1] * shp[2] + y1 * shp[2] + x0)
+ i3 = tf.gather(data, rng * shp[1] * shp[2] + y0 * shp[2] + x1)
+ i4 = tf.gather(data, rng * shp[1] * shp[2] + y1 * shp[2] + x1)
+
+ depth = tf.maximum(tf.maximum(tf.maximum(i1, i2), i3), i4)
+
+ segments = tf.reshape(segments, [-1])
+ i1 = tf.gather(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x0)
+ i2 = tf.gather(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x0)
+ i3 = tf.gather(segments, rng * shp[1] * shp[2] + y0 * shp[2] + x1)
+ i4 = tf.gather(segments, rng * shp[1] * shp[2] + y1 * shp[2] + x1)
+
+ depth_occluded = tf.less(tf.transpose(depth), thresh)
+ seg_occluded = True
+ for i in [i1, i2, i3, i4]:
+ i = tf.cast(i, tf.int32)
+ seg_occluded = tf.logical_and(seg_occluded, tf.not_equal(seg_id, i))
+
+ return tf.logical_or(depth_occluded, tf.transpose(seg_occluded))
+
+
+def get_camera_matrices(
+ cam_focal_length,
+ cam_positions,
+ cam_quaternions,
+ cam_sensor_width,
+ input_size,
+ num_frames=None,
+):
+ """Tf function that converts camera positions into projection matrices."""
+ intrinsics = []
+ matrix_world = []
+ assert cam_quaternions.shape[0] == num_frames
+ for frame_idx in range(cam_quaternions.shape[0]):
+ focal_length = tf.cast(cam_focal_length, tf.float32)
+ sensor_width = tf.cast(cam_sensor_width, tf.float32)
+ f_x = focal_length / sensor_width
+ f_y = focal_length / sensor_width * input_size[0] / input_size[1]
+ p_x = 0.5
+ p_y = 0.5
+ intrinsics.append(
+ tf.stack([
+ tf.stack([f_x, 0., -p_x]),
+ tf.stack([0., -f_y, -p_y]),
+ tf.stack([0., 0., -1.]),
+ ]))
+
+ position = cam_positions[frame_idx]
+ quat = cam_quaternions[frame_idx]
+ rotation_matrix = rotation_matrix_3d.from_quaternion(
+ tf.concat([quat[1:], quat[0:1]], axis=0))
+ transformation = tf.concat(
+ [rotation_matrix, position[:, tf.newaxis]],
+ axis=1,
+ )
+ transformation = tf.concat(
+ [transformation,
+ tf.constant([0.0, 0.0, 0.0, 1.0])[tf.newaxis, :]],
+ axis=0,
+ )
+ matrix_world.append(transformation)
+
+ return (
+ tf.cast(tf.stack(intrinsics), tf.float32),
+ tf.cast(tf.stack(matrix_world), tf.float32),
+ )
+
+
+def quat2rot(quats):
+ """Convert a list of quaternions to rotation matrices."""
+ rotation_matrices = []
+ for frame_idx in range(quats.shape[0]):
+ quat = quats[frame_idx]
+ rotation_matrix = rotation_matrix_3d.from_quaternion(
+ tf.concat([quat[1:], quat[0:1]], axis=0))
+ rotation_matrices.append(rotation_matrix)
+ return tf.cast(tf.stack(rotation_matrices), tf.float32)
+
+
+def rotate_surface_normals(
+ world_frame_normals,
+ point_3d,
+ cam_pos,
+ obj_rot_mats,
+ frame_for_query,
+):
+ """Points are occluded if the surface normal points away from the camera."""
+ query_obj_rot_mat = tf.gather(obj_rot_mats, frame_for_query)
+ obj_frame_normals = tf.einsum(
+ 'boi,bi->bo',
+ tf.linalg.inv(query_obj_rot_mat),
+ world_frame_normals,
+ )
+ world_frame_normals_frames = tf.einsum(
+ 'foi,bi->bfo',
+ obj_rot_mats,
+ obj_frame_normals,
+ )
+ cam_to_pt = point_3d - cam_pos[tf.newaxis, :, :]
+ dots = tf.reduce_sum(world_frame_normals_frames * cam_to_pt, axis=-1)
+ faces_away = dots > 0
+
+ # If the query point also faces away, it's probably a bug in the meshes, so
+ # ignore the result of the test.
+ faces_away_query = tf.reduce_sum(
+ tf.cast(faces_away, tf.int32)
+ * tf.one_hot(frame_for_query, tf.shape(faces_away)[1], dtype=tf.int32),
+ axis=1,
+ keepdims=True,
+ )
+ faces_away = tf.logical_and(faces_away, tf.logical_not(faces_away_query > 0))
+ return faces_away
+
+
+def single_object_reproject(
+ bbox_3d=None,
+ pt=None,
+ pt_segments=None,
+ camera=None,
+ cam_positions=None,
+ num_frames=None,
+ depth_map=None,
+ segments=None,
+ window=None,
+ input_size=None,
+ quat=None,
+ normals=None,
+ frame_for_pt=None,
+ trust_normals=None,
+):
+ """Reproject points for a single object.
+
+ Args:
+ bbox_3d: The object bounding box from Kubric. If none, assume it's
+ background.
+ pt: The set of points in 3D, with shape [num_points, 3]
+ pt_segments: The segment each point came from, with shape [num_points]
+ camera: Camera intrinsic parameters
+ cam_positions: Camera positions, with shape [num_frames, 3]
+ num_frames: Number of frames
+ depth_map: Depth map video for the camera
+ segments: Segmentation map video for the camera
+ window: the window inside which we're sampling points
+ input_size: [height, width] of the input images.
+ quat: Object quaternion [num_frames, 4]
+ normals: Point normals on the query frame [num_points, 3]
+ frame_for_pt: Integer frame where the query point came from [num_points]
+ trust_normals: Boolean flag for whether the surface normals for each query
+ are trustworthy [num_points]
+
+ Returns:
+ Position for each point, of shape [num_points, num_frames, 2], in pixel
+ coordinates, and an occlusion flag for each point, of shape
+ [num_points, num_frames]. These are respect to the image frame, not the
+ window.
+
+ """
+ # Finally, reproject
+ reproj, depth_proj, world_pos = reproject(
+ pt,
+ camera,
+ cam_positions,
+ num_frames,
+ bbox=bbox_3d,
+ )
+
+ occluded = tf.less(reproj[:, :, 2], 0)
+ reproj = reproj[:, :, 0:2] * np.array(input_size[::-1])[np.newaxis,
+ np.newaxis, :]
+ occluded = tf.logical_or(
+ occluded,
+ estimate_occlusion_by_depth_and_segment(
+ depth_map[:, :, :, 0],
+ segments[:, :, :, 0],
+ tf.transpose(reproj[:, :, 0]),
+ tf.transpose(reproj[:, :, 1]),
+ num_frames,
+ depth_proj * .99,
+ pt_segments,
+ ),
+ )
+ obj_occ = occluded
+ obj_reproj = reproj
+
+ obj_occ = tf.logical_or(obj_occ, tf.less(obj_reproj[:, :, 1], window[0]))
+ obj_occ = tf.logical_or(obj_occ, tf.less(obj_reproj[:, :, 0], window[1]))
+ obj_occ = tf.logical_or(obj_occ, tf.greater(obj_reproj[:, :, 1], window[2]))
+ obj_occ = tf.logical_or(obj_occ, tf.greater(obj_reproj[:, :, 0], window[3]))
+
+ if quat is not None:
+ faces_away = rotate_surface_normals(
+ normals,
+ world_pos,
+ cam_positions,
+ quat2rot(quat),
+ frame_for_pt,
+ )
+ faces_away = tf.logical_and(faces_away, trust_normals)
+ else:
+ # world is convex; can't face away from cam.
+ faces_away = tf.zeros([tf.shape(pt)[0], num_frames], dtype=tf.bool)
+
+ return obj_reproj, tf.logical_or(faces_away, obj_occ)
+
+
+def get_num_to_sample(counts, max_seg_id, max_sampled_frac, tracks_to_sample):
+ """Computes the number of points to sample for each object.
+
+ Args:
+ counts: The number of points available per object. An int array of length
+ n, where n is the number of objects.
+ max_seg_id: The maximum number of segment id's in the video.
+ max_sampled_frac: The maximum fraction of points to sample from each
+ object, out of all points that lie on the sampling grid.
+ tracks_to_sample: Total number of tracks to sample per video.
+
+ Returns:
+ The number of points to sample for each object. An int array of length n.
+ """
+ seg_order = tf.argsort(counts)
+ sorted_counts = tf.gather(counts, seg_order)
+ initializer = (0, tracks_to_sample, 0)
+
+ def scan_fn(prev_output, count_seg):
+ index = prev_output[0]
+ remaining_needed = prev_output[1]
+ desired_frac = 1 / (tf.shape(seg_order)[0] - index)
+ want_to_sample = (
+ tf.cast(remaining_needed, tf.float32) *
+ tf.cast(desired_frac, tf.float32))
+ want_to_sample = tf.cast(tf.round(want_to_sample), tf.int32)
+ max_to_sample = (
+ tf.cast(count_seg, tf.float32) * tf.cast(max_sampled_frac, tf.float32))
+ max_to_sample = tf.cast(tf.round(max_to_sample), tf.int32)
+ num_to_sample = tf.minimum(want_to_sample, max_to_sample)
+
+ remaining_needed = remaining_needed - num_to_sample
+ return (index + 1, remaining_needed, num_to_sample)
+
+ # outputs 0 and 1 are just bookkeeping; output 2 is the actual number of
+ # points to sample per object.
+ res = tf.scan(scan_fn, sorted_counts, initializer)[2]
+ invert = tf.argsort(seg_order)
+ num_to_sample = tf.gather(res, invert)
+ num_to_sample = tf.concat(
+ [
+ num_to_sample,
+ tf.zeros([max_seg_id - tf.shape(num_to_sample)[0]], dtype=tf.int32),
+ ],
+ axis=0,
+ )
+ return num_to_sample
+
+
+# pylint: disable=cell-var-from-loop
+
+
+def track_points(
+ object_coordinates,
+ depth,
+ depth_range,
+ segmentations,
+ surface_normals,
+ bboxes_3d,
+ obj_quat,
+ cam_focal_length,
+ cam_positions,
+ cam_quaternions,
+ cam_sensor_width,
+ window,
+ tracks_to_sample=256,
+ sampling_stride=4,
+ max_seg_id=25,
+ max_sampled_frac=0.1,
+):
+ """Track points in 2D using Kubric data.
+
+ Args:
+ object_coordinates: Video of coordinates for each pixel in the object's
+ local coordinate frame. Shape [num_frames, height, width, 3]
+ depth: uint16 depth video from Kubric. Shape [num_frames, height, width]
+ depth_range: Values needed to normalize Kubric's int16 depth values into
+ metric depth.
+ segmentations: Integer object id for each pixel. Shape
+ [num_frames, height, width]
+ surface_normals: uint16 surface normal map. Shape
+ [num_frames, height, width, 3]
+ bboxes_3d: The set of all object bounding boxes from Kubric
+ obj_quat: Quaternion rotation for each object. Shape
+ [num_objects, num_frames, 4]
+ cam_focal_length: Camera focal length
+ cam_positions: Camera positions, with shape [num_frames, 3]
+ cam_quaternions: Camera orientations, with shape [num_frames, 4]
+ cam_sensor_width: Camera sensor width parameter
+ window: the window inside which we're sampling points. Integer valued
+ in the format [x_min, y_min, x_max, y_max], where min is inclusive and
+ max is exclusive.
+ tracks_to_sample: Total number of tracks to sample per video.
+ sampling_stride: For efficiency, query points are sampled from a random grid
+ of this stride.
+ max_seg_id: The maxium segment id in the video.
+ max_sampled_frac: The maximum fraction of points to sample from each
+ object, out of all points that lie on the sampling grid.
+
+ Returns:
+ A set of queries, randomly sampled from the video (with a bias toward
+ objects), of shape [num_points, 3]. Each point is [t, y, x], where
+ t is time. All points are in pixel/frame coordinates.
+ The trajectory for each query point, of shape [num_points, num_frames, 3].
+ Each point is [x, y]. Points are in pixel coordinates
+ Occlusion flag for each point, of shape [num_points, num_frames]. This is
+ a boolean, where True means the point is occluded.
+
+ """
+ chosen_points = []
+ all_reproj = []
+ all_occ = []
+
+ # Convert to metric depth
+
+ depth_range_f32 = tf.cast(depth_range, tf.float32)
+ depth_min = depth_range_f32[0]
+ depth_max = depth_range_f32[1]
+ depth_f32 = tf.cast(depth, tf.float32)
+ depth_map = depth_min + depth_f32 * (depth_max - depth_min) / 65535
+
+ surface_normal_map = surface_normals / 65535 * 2. - 1.
+
+ input_size = object_coordinates.shape.as_list()[1:3]
+ num_frames = object_coordinates.shape.as_list()[0]
+
+ # We first sample query points within the given window. That means first
+ # extracting the window from the segmentation tensor, because we want to have
+ # a bias toward moving objects.
+ # Note: for speed we sample points on a grid. The grid start position is
+ # randomized within the window.
+ start_vec = [
+ tf.random.uniform([], minval=0, maxval=sampling_stride, dtype=tf.int32)
+ for _ in range(3)
+ ]
+ start_vec[1] += window[0]
+ start_vec[2] += window[1]
+ end_vec = [num_frames, window[2], window[3]]
+
+ def extract_box(x):
+ x = x[start_vec[0]::sampling_stride, start_vec[1]:window[2]:sampling_stride,
+ start_vec[2]:window[3]:sampling_stride]
+ return x
+
+ segmentations_box = extract_box(segmentations)
+ object_coordinates_box = extract_box(object_coordinates)
+
+ # Next, get the number of points to sample from each object. First count
+ # how many points are available for each object.
+
+ cnt = tf.math.bincount(tf.cast(tf.reshape(segmentations_box, [-1]), tf.int32))
+ num_to_sample = get_num_to_sample(
+ cnt,
+ max_seg_id,
+ max_sampled_frac,
+ tracks_to_sample,
+ )
+ num_to_sample.set_shape([max_seg_id])
+ intrinsics, matrix_world = get_camera_matrices(
+ cam_focal_length,
+ cam_positions,
+ cam_quaternions,
+ cam_sensor_width,
+ input_size,
+ num_frames=num_frames,
+ )
+
+ # If the normal map is very rough, it's often because they come from a normal
+ # map rather than the mesh. These aren't trustworthy, and the normal test
+ # may fail (i.e. the normal is pointing away from the camera even though the
+ # point is still visible). So don't use the normal test when inferring
+ # occlusion.
+ trust_sn = True
+ sn_pad = tf.pad(surface_normal_map, [(0, 0), (1, 1), (1, 1), (0, 0)])
+ shp = surface_normal_map.shape
+ sum_thresh = 0
+ for i in [0, 2]:
+ for j in [0, 2]:
+ diff = sn_pad[:, i: shp[1] + i, j: shp[2] + j, :] - surface_normal_map
+ diff = tf.reduce_sum(tf.square(diff), axis=-1)
+ sum_thresh += tf.cast(diff > 0.05 * 0.05, tf.int32)
+ trust_sn = tf.logical_and(trust_sn, (sum_thresh <= 2))[..., tf.newaxis]
+ surface_normals_box = extract_box(surface_normal_map)
+ trust_sn_box = extract_box(trust_sn)
+
+ def get_camera(fr=None):
+ if fr is None:
+ return {'intrinsics': intrinsics, 'matrix_world': matrix_world}
+ return {'intrinsics': intrinsics[fr], 'matrix_world': matrix_world[fr]}
+
+ # Construct pixel coordinates for each pixel within the window.
+ window = tf.cast(window, tf.float32)
+ z, y, x = tf.meshgrid(
+ *[
+ tf.range(st, ed, sampling_stride)
+ for st, ed in zip(start_vec, end_vec)
+ ],
+ indexing='ij')
+ pix_coords = tf.reshape(tf.stack([z, y, x], axis=-1), [-1, 3])
+
+ for i in range(max_seg_id):
+ # sample points on object i in the first frame. obj_id is the position
+ # within the object_coordinates array, which is one lower than the value
+ # in the segmentation mask (0 in the segmentation mask is the background
+ # object, which has no bounding box).
+ obj_id = i - 1
+ mask = tf.equal(tf.reshape(segmentations_box, [-1]), i)
+ pt = tf.boolean_mask(tf.reshape(object_coordinates_box, [-1, 3]), mask)
+ normals = tf.boolean_mask(tf.reshape(surface_normals_box, [-1, 3]), mask)
+ trust_sn_mask = tf.boolean_mask(tf.reshape(trust_sn_box, [-1, 1]), mask)
+ idx = tf.cond(
+ tf.shape(pt)[0] > 0,
+ lambda: tf.multinomial( # pylint: disable=g-long-lambda
+ tf.zeros(tf.shape(pt)[0:1])[tf.newaxis, :],
+ tf.gather(num_to_sample, i))[0],
+ lambda: tf.zeros([0], dtype=tf.int64))
+ # note: pt_coords is pixel coordinates, not raster coordinates.
+ pt_coords = tf.gather(tf.boolean_mask(pix_coords, mask), idx)
+ normals = tf.gather(normals, idx)
+ trust_sn_gather = tf.gather(trust_sn_mask, idx)
+
+ pixel_to_raster = tf.constant([0.0, 0.5, 0.5])[tf.newaxis, :]
+
+ if obj_id == -1:
+ # For the background object, no bounding box is available. However,
+ # this doesn't move, so we use the depth map to backproject these points
+ # into 3D and use those positions throughout the video.
+ pt_3d = []
+ pt_coords_reorder = []
+ for fr in range(num_frames):
+ # We need to loop over frames because we need to use the correct depth
+ # map for each frame.
+ pt_coords_chunk = tf.boolean_mask(pt_coords,
+ tf.equal(pt_coords[:, 0], fr))
+ pt_coords_reorder.append(pt_coords_chunk)
+
+ pt_3d.append(
+ unproject(pt_coords_chunk[:, 1:], get_camera(fr), depth_map[fr]))
+ pt = tf.concat(pt_3d, axis=0)
+ chosen_points.append(
+ tf.cast(tf.concat(pt_coords_reorder, axis=0), tf.float32) +
+ pixel_to_raster)
+ bbox = None
+ quat = None
+ frame_for_pt = None
+ else:
+ # For any other object, we just use the point coordinates supplied by
+ # kubric.
+ pt = tf.gather(pt, idx)
+ pt = pt / np.iinfo(np.uint16).max - .5
+ chosen_points.append(tf.cast(pt_coords, tf.float32) + pixel_to_raster)
+ # if obj_id>num_objects, then we won't have a box. We also won't have
+ # points, so just use a dummy to prevent tf from crashing.
+ bbox = tf.cond(obj_id >= tf.shape(bboxes_3d)[0], lambda: bboxes_3d[0, :],
+ lambda: bboxes_3d[obj_id, :])
+ quat = tf.cond(obj_id >= tf.shape(obj_quat)[0], lambda: obj_quat[0, :],
+ lambda: obj_quat[obj_id, :])
+ frame_for_pt = pt_coords[..., 0]
+
+ # Finally, compute the reprojections for this particular object.
+ obj_reproj, obj_occ = tf.cond(
+ tf.shape(pt)[0] > 0,
+ functools.partial(
+ single_object_reproject,
+ bbox_3d=bbox,
+ pt=pt,
+ pt_segments=i,
+ camera=get_camera(),
+ cam_positions=cam_positions,
+ num_frames=num_frames,
+ depth_map=depth_map,
+ segments=segmentations,
+ window=window,
+ input_size=input_size,
+ quat=quat,
+ normals=normals,
+ frame_for_pt=frame_for_pt,
+ trust_normals=trust_sn_gather,
+ ),
+ lambda: # pylint: disable=g-long-lambda
+ (tf.zeros([0, num_frames, 2], dtype=tf.float32),
+ tf.zeros([0, num_frames], dtype=tf.bool)))
+ all_reproj.append(obj_reproj)
+ all_occ.append(obj_occ)
+
+ # Points are currently in pixel coordinates of the original video. We now
+ # convert them to coordinates within the window frame, and rescale to
+ # pixel coordinates. Note that this produces the pixel coordinates after
+ # the window gets cropped and rescaled to the full image size.
+ wd = tf.concat(
+ [np.array([0.0]), window[0:2],
+ np.array([num_frames]), window[2:4]],
+ axis=0)
+ wd = wd[tf.newaxis, tf.newaxis, :]
+ coord_multiplier = [num_frames, input_size[0], input_size[1]]
+ all_reproj = tf.concat(all_reproj, axis=0)
+ # We need to extract x,y, but the format of the window is [t1,y1,x1,t2,y2,x2]
+ window_size = wd[:, :, 5:3:-1] - wd[:, :, 2:0:-1]
+ window_top_left = wd[:, :, 2:0:-1]
+ all_reproj = (all_reproj - window_top_left) / window_size
+ all_reproj = all_reproj * coord_multiplier[2:0:-1]
+ all_occ = tf.concat(all_occ, axis=0)
+
+ # chosen_points is [num_points, (z,y,x)]
+ chosen_points = tf.concat(chosen_points, axis=0)
+
+ chosen_points = tf.cast(chosen_points, tf.float32)
+
+ # renormalize so the box corners are at [-1,1]
+ chosen_points = (chosen_points - wd[:, 0, :3]) / (wd[:, 0, 3:] - wd[:, 0, :3])
+ chosen_points = chosen_points * coord_multiplier
+ # Note: all_reproj is in (x,y) format, but chosen_points is in (z,y,x) format
+
+ return tf.cast(chosen_points, tf.float32), tf.cast(all_reproj,
+ tf.float32), all_occ
+
+
+def _get_distorted_bounding_box(
+ jpeg_shape,
+ bbox,
+ min_object_covered,
+ aspect_ratio_range,
+ area_range,
+ max_attempts,
+):
+ """Sample a crop window to be used for cropping."""
+ bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(
+ jpeg_shape,
+ bounding_boxes=bbox,
+ min_object_covered=min_object_covered,
+ aspect_ratio_range=aspect_ratio_range,
+ area_range=area_range,
+ max_attempts=max_attempts,
+ use_image_if_no_bounding_boxes=True)
+
+ # Crop the image to the specified bounding box.
+ offset_y, offset_x, _ = tf.unstack(bbox_begin)
+ target_height, target_width, _ = tf.unstack(bbox_size)
+ crop_window = tf.stack(
+ [offset_y, offset_x, offset_y + target_height, offset_x + target_width])
+ return crop_window
+
+
+def add_tracks(data,
+ train_size=(256, 256),
+ vflip=False,
+ random_crop=True,
+ tracks_to_sample=256,
+ sampling_stride=4,
+ max_seg_id=25,
+ max_sampled_frac=0.1):
+ """Track points in 2D using Kubric data.
+
+ Args:
+ data: Kubric data, including RGB/depth/object coordinate/segmentation
+ videos and camera parameters.
+ train_size: Cropped output will be at this resolution. Ignored if
+ random_crop is False.
+ vflip: whether to vertically flip images and tracks (to test generalization)
+ random_crop: Whether to randomly crop videos
+ tracks_to_sample: Total number of tracks to sample per video.
+ sampling_stride: For efficiency, query points are sampled from a random grid
+ of this stride.
+ max_seg_id: The maxium segment id in the video.
+ max_sampled_frac: The maximum fraction of points to sample from each
+ object, out of all points that lie on the sampling grid.
+
+ Returns:
+ A dict with the following keys:
+ query_points:
+ A set of queries, randomly sampled from the video (with a bias toward
+ objects), of shape [num_points, 3]. Each point is [t, y, x], where
+ t is time. Points are in pixel/frame coordinates.
+ [num_frames, height, width].
+ target_points:
+ The trajectory for each query point, of shape [num_points, num_frames, 3].
+ Each point is [x, y]. Points are in pixel/frame coordinates.
+ occlusion:
+ Occlusion flag for each point, of shape [num_points, num_frames]. This is
+ a boolean, where True means the point is occluded.
+ video:
+ The cropped video, normalized into the range [-1, 1]
+
+ """
+ shp = data['video'].shape.as_list()
+ num_frames = shp[0]
+ if any([s % sampling_stride != 0 for s in shp[:-1]]):
+ raise ValueError('All video dims must be a multiple of sampling_stride.')
+
+ bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
+ min_area = 0.3
+ max_area = 1.0
+ min_aspect_ratio = 0.5
+ max_aspect_ratio = 2.0
+ if random_crop:
+ crop_window = _get_distorted_bounding_box(
+ jpeg_shape=shp[1:4],
+ bbox=bbox,
+ min_object_covered=min_area,
+ aspect_ratio_range=(min_aspect_ratio, max_aspect_ratio),
+ area_range=(min_area, max_area),
+ max_attempts=20)
+ else:
+ crop_window = tf.constant([0, 0, shp[1], shp[2]],
+ dtype=tf.int32,
+ shape=[4])
+
+ query_points, target_points, occluded = track_points(
+ data['object_coordinates'], data['depth'],
+ data['metadata']['depth_range'], data['segmentations'],
+ data['normal'],
+ data['instances']['bboxes_3d'], data['instances']['quaternions'],
+ data['camera']['focal_length'],
+ data['camera']['positions'], data['camera']['quaternions'],
+ data['camera']['sensor_width'], crop_window, tracks_to_sample,
+ sampling_stride, max_seg_id, max_sampled_frac)
+ video = data['video']
+
+ shp = video.shape.as_list()
+ query_points.set_shape([tracks_to_sample, 3])
+ target_points.set_shape([tracks_to_sample, num_frames, 2])
+ occluded.set_shape([tracks_to_sample, num_frames])
+
+ # Crop the video to the sampled window, in a way which matches the coordinate
+ # frame produced the track_points functions.
+ crop_window = crop_window / (
+ np.array(shp[1:3] + shp[1:3]).astype(np.float32) - 1)
+ crop_window = tf.tile(crop_window[tf.newaxis, :], [num_frames, 1])
+ video = tf.image.crop_and_resize(
+ video,
+ tf.cast(crop_window, tf.float32),
+ tf.range(num_frames),
+ train_size,
+ )
+ if vflip:
+ video = video[:, ::-1, :, :]
+ target_points = target_points * np.array([1, -1])
+ query_points = query_points * np.array([1, -1, 1])
+ res = {
+ 'query_points': query_points,
+ 'target_points': target_points,
+ 'occluded': occluded,
+ 'video': video / (255. / 2.) - 1.,
+ }
+ return res
+
+
+def create_point_tracking_dataset(
+ data_dir="gs://kubric-public/tfds",
+ train_size=(512, 512),
+ shuffle=True,
+ shuffle_buffer_size=None,
+ split='train',
+ batch_dims=tuple(),
+ repeat=True,
+ vflip=False,
+ random_crop=True,
+ tracks_to_sample=2048,
+ sampling_stride=4,
+ max_seg_id=25,
+ max_sampled_frac=0.1,
+ num_parallel_point_extraction_calls=16,
+ **kwargs):
+ """Construct a dataset for point tracking using Kubric.
+
+ Args:
+ train_size: Tuple of 2 ints. Cropped output will be at this resolution
+ shuffle_buffer_size: Int. Size of the shuffle buffer
+ split: Which split to construct from Kubric. Can be 'train' or
+ 'validation'.
+ batch_dims: Sequence of ints. Add multiple examples into a batch of this
+ shape.
+ repeat: Bool. whether to repeat the dataset.
+ vflip: Bool. whether to vertically flip the dataset to test generalization.
+ random_crop: Bool. whether to randomly crop videos
+ tracks_to_sample: Int. Total number of tracks to sample per video.
+ sampling_stride: Int. For efficiency, query points are sampled from a
+ random grid of this stride.
+ max_seg_id: Int. The maxium segment id in the video. Note the size of
+ the to graph is proportional to this number, so prefer small values.
+ max_sampled_frac: Float. The maximum fraction of points to sample from each
+ object, out of all points that lie on the sampling grid.
+ num_parallel_point_extraction_calls: Int. The num_parallel_calls for the
+ map function for point extraction.
+ **kwargs: additional args to pass to tfds.load.
+
+ Returns:
+ The dataset generator.
+ """
+ ds = tfds.load(
+ 'movi_f/512x512',
+ data_dir=data_dir,
+ shuffle_files=shuffle,
+ **kwargs)
+
+ ds = ds[split]
+ if repeat:
+ ds = ds.repeat()
+ ds = ds.map(
+ functools.partial(
+ add_tracks,
+ train_size=train_size,
+ vflip=vflip,
+ random_crop=random_crop,
+ tracks_to_sample=tracks_to_sample,
+ sampling_stride=sampling_stride,
+ max_seg_id=max_seg_id,
+ max_sampled_frac=max_sampled_frac),
+ num_parallel_calls=num_parallel_point_extraction_calls)
+ if shuffle_buffer_size is not None:
+ ds = ds.shuffle(shuffle_buffer_size)
+
+ for bs in batch_dims[::-1]:
+ ds = ds.batch(bs)
+
+ return ds
+
+
+def plot_tracks(rgb, points, occluded, trackgroup=None):
+ """Plot tracks with matplotlib."""
+ disp = []
+ cmap = plt.cm.hsv
+
+ z_list = np.arange(
+ points.shape[0]) if trackgroup is None else np.array(trackgroup)
+ # random permutation of the colors so nearby points in the list can get
+ # different colors
+ z_list = np.random.permutation(np.max(z_list) + 1)[z_list]
+ colors = cmap(z_list / (np.max(z_list) + 1))
+ figure_dpi = 64
+
+ for i in range(rgb.shape[0]):
+ fig = plt.figure(
+ figsize=(256 / figure_dpi, 256 / figure_dpi),
+ dpi=figure_dpi,
+ frameon=False,
+ facecolor='w')
+ ax = fig.add_subplot()
+ ax.axis('off')
+ ax.imshow(rgb[i])
+
+ valid = points[:, i, 0] > 0
+ valid = np.logical_and(valid, points[:, i, 0] < rgb.shape[2] - 1)
+ valid = np.logical_and(valid, points[:, i, 1] > 0)
+ valid = np.logical_and(valid, points[:, i, 1] < rgb.shape[1] - 1)
+
+ colalpha = np.concatenate([colors[:, :-1], 1 - occluded[:, i:i + 1]],
+ axis=1)
+ # Note: matplotlib uses pixel corrdinates, not raster.
+ plt.scatter(
+ points[valid, i, 0] - 0.5,
+ points[valid, i, 1] - 0.5,
+ s=3,
+ c=colalpha[valid],
+ )
+
+ occ2 = occluded[:, i:i + 1]
+
+ colalpha = np.concatenate([colors[:, :-1], occ2], axis=1)
+
+ plt.scatter(
+ points[valid, i, 0],
+ points[valid, i, 1],
+ s=20,
+ facecolors='none',
+ edgecolors=colalpha[valid],
+ )
+
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+ plt.margins(0, 0)
+ fig.canvas.draw()
+ width, height = fig.get_size_inches() * fig.get_dpi()
+ img = np.frombuffer(
+ fig.canvas.tostring_rgb(),
+ dtype='uint8').reshape(int(height), int(width), 3)
+ disp.append(np.copy(img))
+ plt.close(fig)
+
+ return np.stack(disp, axis=0)
+
+
+def main():
+ ds = tfds.as_numpy(create_point_tracking_dataset(shuffle_buffer_size=None))
+ for i, data in enumerate(ds):
+ disp = plot_tracks(data['video'] * .5 + .5, data['target_points'],
+ data['occluded'])
+ media.write_video(f'{i}.mp4', disp, fps=10)
+ if i > 10:
+ break
+
+
+if __name__ == '__main__':
+ main()
diff --git a/data/dot_single_video/dot/data/tap_dataset.py b/data/dot_single_video/dot/data/tap_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f6f3188128615125439a607b9f1a85706a8d8a2
--- /dev/null
+++ b/data/dot_single_video/dot/data/tap_dataset.py
@@ -0,0 +1,230 @@
+# Copyright 2023 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import io
+import glob
+import torch
+import pickle as pkl
+import numpy as np
+import os.path as osp
+import mediapy as media
+from torch.utils.data import Dataset, DataLoader
+
+from PIL import Image
+from typing import Mapping, Tuple
+
+
+def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
+ """Resize a video to output_size."""
+ # If you have a GPU, consider replacing this with a GPU-enabled resize op,
+ # such as a jitted jax.image.resize. It will make things faster.
+ return media.resize_video(video, output_size)
+
+
+def sample_queries_first(
+ target_occluded: np.ndarray,
+ target_points: np.ndarray,
+ frames: np.ndarray,
+) -> Mapping[str, np.ndarray]:
+ """Package a set of frames and tracks for use in TAPNet evaluations.
+ Given a set of frames and tracks with no query points, use the first
+ visible point in each track as the query.
+ Args:
+ target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
+ where True indicates occluded.
+ target_points: Position, of shape [n_tracks, n_frames, 2], where each point
+ is [x,y] scaled between 0 and 1.
+ frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
+ -1 and 1.
+ Returns:
+ A dict with the keys:
+ video: Video tensor of shape [1, n_frames, height, width, 3]
+ query_points: Query points of shape [1, n_queries, 3] where
+ each point is [t, y, x] scaled to the range [-1, 1]
+ target_points: Target points of shape [1, n_queries, n_frames, 2] where
+ each point is [x, y] scaled to the range [-1, 1]
+ """
+ valid = np.sum(~target_occluded, axis=1) > 0
+ target_points = target_points[valid, :]
+ target_occluded = target_occluded[valid, :]
+
+ query_points = []
+ for i in range(target_points.shape[0]):
+ index = np.where(target_occluded[i] == 0)[0][0]
+ x, y = target_points[i, index, 0], target_points[i, index, 1]
+ query_points.append(np.array([index, y, x])) # [t, y, x]
+ query_points = np.stack(query_points, axis=0)
+
+ return {
+ "video": frames[np.newaxis, ...],
+ "query_points": query_points[np.newaxis, ...],
+ "target_points": target_points[np.newaxis, ...],
+ "occluded": target_occluded[np.newaxis, ...],
+ }
+
+
+def sample_queries_strided(
+ target_occluded: np.ndarray,
+ target_points: np.ndarray,
+ frames: np.ndarray,
+ query_stride: int = 5,
+) -> Mapping[str, np.ndarray]:
+ """Package a set of frames and tracks for use in TAPNet evaluations.
+
+ Given a set of frames and tracks with no query points, sample queries
+ strided every query_stride frames, ignoring points that are not visible
+ at the selected frames.
+
+ Args:
+ target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
+ where True indicates occluded.
+ target_points: Position, of shape [n_tracks, n_frames, 2], where each point
+ is [x,y] scaled between 0 and 1.
+ frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
+ -1 and 1.
+ query_stride: When sampling query points, search for un-occluded points
+ every query_stride frames and convert each one into a query.
+
+ Returns:
+ A dict with the keys:
+ video: Video tensor of shape [1, n_frames, height, width, 3]. The video
+ has floats scaled to the range [-1, 1].
+ query_points: Query points of shape [1, n_queries, 3] where
+ each point is [t, y, x] scaled to the range [-1, 1].
+ target_points: Target points of shape [1, n_queries, n_frames, 2] where
+ each point is [x, y] scaled to the range [-1, 1].
+ trackgroup: Index of the original track that each query point was
+ sampled from. This is useful for visualization.
+ """
+ tracks = []
+ occs = []
+ queries = []
+ trackgroups = []
+ total = 0
+ trackgroup = np.arange(target_occluded.shape[0])
+ for i in range(0, target_occluded.shape[1], query_stride):
+ mask = target_occluded[:, i] == 0
+ query = np.stack(
+ [
+ i * np.ones(target_occluded.shape[0:1]),
+ target_points[:, i, 1],
+ target_points[:, i, 0],
+ ],
+ axis=-1,
+ )
+ queries.append(query[mask])
+ tracks.append(target_points[mask])
+ occs.append(target_occluded[mask])
+ trackgroups.append(trackgroup[mask])
+ total += np.array(np.sum(target_occluded[:, i] == 0))
+
+ return {
+ "video": frames[np.newaxis, ...],
+ "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
+ "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
+ "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
+ "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
+ }
+
+
+class TapVid(Dataset):
+ def __init__(
+ self,
+ data_root,
+ split="davis",
+ query_mode="first",
+ resize_to_256=True
+ ):
+ self.split = split
+ self.resize_to_256 = resize_to_256
+ self.query_mode = query_mode
+ if self.split == "kinetics":
+ all_paths = glob.glob(osp.join(data_root, "*_of_0010.pkl"))
+ points_dataset = []
+ for pickle_path in all_paths:
+ with open(pickle_path, "rb") as f:
+ data = pkl.load(f)
+ points_dataset = points_dataset + data
+ self.points_dataset = points_dataset
+ else:
+ with open(data_root, "rb") as f:
+ self.points_dataset = pkl.load(f)
+ if self.split == "davis":
+ self.video_names = list(self.points_dataset.keys())
+ print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
+
+ def __getitem__(self, index):
+ if self.split == "davis":
+ video_name = self.video_names[index]
+ else:
+ video_name = index
+ video = self.points_dataset[video_name]
+ frames = video["video"]
+
+ if isinstance(frames[0], bytes):
+ # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
+ def decode(frame):
+ byteio = io.BytesIO(frame)
+ img = Image.open(byteio)
+ return np.array(img)
+
+ frames = np.array([decode(frame) for frame in frames])
+
+ target_points = self.points_dataset[video_name]["points"]
+ if self.resize_to_256:
+ frames = resize_video(frames, [256, 256])
+ target_points *= np.array([256, 256])
+ else:
+ target_points *= np.array([frames.shape[2], frames.shape[1]])
+
+ target_occ = self.points_dataset[video_name]["occluded"]
+ if self.query_mode == "first":
+ converted = sample_queries_first(target_occ, target_points, frames)
+ else:
+ converted = sample_queries_strided(target_occ, target_points, frames)
+ assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
+
+ trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
+
+ rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() / 255.
+ visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(1, 0) # T, N
+ query_points = torch.from_numpy(converted["query_points"])[0].float() # T, N
+ tracks = torch.cat([trajs, visibles[..., None]], dim=-1)
+
+ data = {
+ "video": rgbs,
+ "query_points": query_points,
+ "tracks": tracks
+ }
+
+ return data
+
+ def __len__(self):
+ return len(self.points_dataset)
+
+
+def create_point_tracking_dataset(args):
+ data_root = osp.join(args.data_root, f"tapvid_{args.split}")
+ if args.split in ["davis", "rgb_stacking"]:
+ data_root = osp.join(data_root, f"tapvid_{args.split}.pkl")
+ dataset = TapVid(data_root, args.split, args.query_mode)
+ dataloader = DataLoader(
+ dataset,
+ batch_size=args.batch_size,
+ pin_memory=True,
+ shuffle=False,
+ num_workers=0,
+ drop_last=False,
+ )
+ return dataloader
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/__init__.py b/data/dot_single_video/dot/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..258216371af34cbf47fbbe36b49b566f858ecbd7
--- /dev/null
+++ b/data/dot_single_video/dot/models/__init__.py
@@ -0,0 +1,42 @@
+from .dense_optical_tracking import DenseOpticalTracker
+from .optical_flow import OpticalFlow
+from .point_tracking import PointTracker
+
+def create_model(args):
+ if args.model == "dot":
+ model = DenseOpticalTracker(
+ height=args.height,
+ width=args.width,
+ tracker_config=args.tracker_config,
+ tracker_path=args.tracker_path,
+ estimator_config=args.estimator_config,
+ estimator_path=args.estimator_path,
+ refiner_config=args.refiner_config,
+ refiner_path=args.refiner_path,
+ )
+ elif args.model == "pt":
+ model = PointTracker(
+ height=args.height,
+ width=args.width,
+ tracker_config=args.tracker_config,
+ tracker_path=args.tracker_path,
+ estimator_config=args.estimator_config,
+ estimator_path=args.estimator_path,
+ )
+ elif args.model == "ofe":
+ model = OpticalFlow(
+ height=args.height,
+ width=args.width,
+ config=args.estimator_config,
+ load_path=args.estimator_path,
+ )
+ elif args.model == "ofr":
+ model = OpticalFlow(
+ height=args.height,
+ width=args.width,
+ config=args.refiner_config,
+ load_path=args.refiner_path,
+ )
+ else:
+ raise ValueError(f"Unknown model name {args.model}")
+ return model
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/dense_optical_tracking.py b/data/dot_single_video/dot/models/dense_optical_tracking.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fce799ea0e2ea45830be36ee9809165596aade3
--- /dev/null
+++ b/data/dot_single_video/dot/models/dense_optical_tracking.py
@@ -0,0 +1,241 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from tqdm import tqdm
+from einops import rearrange, repeat
+
+from .optical_flow import OpticalFlow
+from .point_tracking import PointTracker
+from dot.utils.torch import get_grid
+
+
+class DenseOpticalTracker(nn.Module):
+ def __init__(self,
+ height=512,
+ width=512,
+ tracker_config="configs/cotracker2_patch_4_wind_8.json",
+ tracker_path="checkpoints/movi_f_cotracker2_patch_4_wind_8.pth",
+ estimator_config="configs/raft_patch_8.json",
+ estimator_path="checkpoints/cvo_raft_patch_8.pth",
+ refiner_config="configs/raft_patch_4_alpha.json",
+ refiner_path="checkpoints/movi_f_raft_patch_4_alpha.pth"):
+ super().__init__()
+ self.point_tracker = PointTracker(height, width, tracker_config, tracker_path, estimator_config, estimator_path)
+ self.optical_flow_refiner = OpticalFlow(height, width, refiner_config, refiner_path)
+ self.name = self.point_tracker.name + "_" + self.optical_flow_refiner.name
+ self.resolution = [height, width]
+
+ def forward(self, data, mode, **kwargs):
+ if mode == "flow_from_last_to_first_frame":
+ return self.get_flow_from_last_to_first_frame(data, **kwargs)
+ elif mode == "tracks_for_queries":
+ return self.get_tracks_for_queries(data, **kwargs)
+ elif mode == "tracks_from_first_to_every_other_frame":
+ return self.get_tracks_from_first_to_every_other_frame(data, **kwargs)
+ elif mode == "tracks_from_every_cell_in_every_frame":
+ return self.get_tracks_from_every_cell_in_every_frame(data, **kwargs)
+ else:
+ raise ValueError(f"Unknown mode {mode}")
+
+ def get_flow_from_last_to_first_frame(self, data, **kwargs):
+ B, T, C, h, w = data["video"].shape
+ init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
+ init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
+ data = {
+ "src_frame": data["video"][:, -1],
+ "tgt_frame": data["video"][:, 0],
+ "src_points": init[:, -1],
+ "tgt_points": init[:, 0]
+ }
+ pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
+ pred["src_points"] = data["src_points"]
+ pred["tgt_points"] = data["tgt_points"]
+ return pred
+
+ def get_tracks_for_queries(self, data, **kwargs):
+ time_steps = data["video"].size(1)
+ query_points = data["query_points"]
+ video = data["video"]
+ S = query_points.size(1)
+ B, T, C, h, w = video.shape
+ H, W = self.resolution
+
+ init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
+ init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
+
+ if h != H or w != W:
+ video = video.reshape(B * T, C, h, w)
+ video = F.interpolate(video, size=(H, W), mode="bilinear")
+ video = video.reshape(B, T, C, H, W)
+
+ feats = self.optical_flow_refiner({"video": video}, mode="feats", **kwargs)["feats"]
+
+ grid = get_grid(H, W, device=video.device)
+ src_steps = [int(v) for v in torch.unique(query_points[..., 0])]
+ tracks = torch.zeros(B, T, S, 3, device=video.device)
+ for src_step in tqdm(src_steps, desc="Refine source step", leave=False):
+ src_points = init[:, src_step]
+ src_feats = feats[:, src_step]
+ tracks_from_src = []
+ for tgt_step in tqdm(range(time_steps), desc="Refine target step", leave=False):
+ if src_step == tgt_step:
+ flow = torch.zeros(B, H, W, 2, device=video.device)
+ alpha = torch.ones(B, H, W, device=video.device)
+ else:
+ tgt_points = init[:, tgt_step]
+ tgt_feats = feats[:, tgt_step]
+ data = {
+ "src_feats": src_feats,
+ "tgt_feats": tgt_feats,
+ "src_points": src_points,
+ "tgt_points": tgt_points
+ }
+ pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
+ flow, alpha = pred["flow"], pred["alpha"]
+ flow[..., 0] = flow[..., 0] / (W - 1)
+ flow[..., 1] = flow[..., 1] / (H - 1)
+ tracks_from_src.append(torch.cat([flow + grid, alpha[..., None]], dim=-1))
+ tracks_from_src = torch.stack(tracks_from_src, dim=1)
+ for b in range(B):
+ cur = query_points[b, :, 0] == src_step
+ if torch.any(cur):
+ cur_points = query_points[b, cur]
+ cur_x = cur_points[..., 2] / (w - 1)
+ cur_y = cur_points[..., 1] / (h - 1)
+ cur_tracks = dense_to_sparse_tracks(cur_x, cur_y, tracks_from_src[b], h, w)
+ tracks[b, :, cur] = cur_tracks
+ return {"tracks": tracks}
+
+ def get_tracks_from_first_to_every_other_frame(self, data, return_flow=False, **kwargs):
+ video = data["video"]
+ B, T, C, h, w = video.shape
+ H, W = self.resolution
+
+ if h != H or w != W:
+ video = video.reshape(B * T, C, h, w)
+ video = F.interpolate(video, size=(H, W), mode="bilinear")
+ video = video.reshape(B, T, C, H, W)
+
+ init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
+ init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
+
+ grid = get_grid(H, W, device=video.device)
+ grid[..., 0] *= (W - 1)
+ grid[..., 1] *= (H - 1)
+ src_step = 0
+ src_points = init[:, src_step]
+ src_frame = video[:, src_step]
+ tracks = []
+ for tgt_step in tqdm(range(T), desc="Refine target step", leave=False):
+ if src_step == tgt_step:
+ flow = torch.zeros(B, H, W, 2, device=video.device)
+ alpha = torch.ones(B, H, W, device=video.device)
+ else:
+ tgt_points = init[:, tgt_step]
+ tgt_frame = video[:, tgt_step]
+ data = {
+ "src_frame": src_frame,
+ "tgt_frame": tgt_frame,
+ "src_points": src_points,
+ "tgt_points": tgt_points
+ }
+ pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
+ flow, alpha = pred["flow"], pred["alpha"]
+ if return_flow:
+ tracks.append(torch.cat([flow, alpha[..., None]], dim=-1))
+ else:
+ tracks.append(torch.cat([flow + grid, alpha[..., None]], dim=-1)) # flow means: 1->i pixel moving values, grid is the fisrt frame pixel ori cood, alpha is confidence
+ tracks = torch.stack(tracks, dim=1)
+ return {"tracks": tracks}
+
+ def get_tracks_from_every_cell_in_every_frame(self, data, cell_size=1, cell_time_steps=20, **kwargs):
+ video = data["video"]
+ B, T, C, h, w = video.shape
+ H, W = self.resolution
+ ch, cw, ct = h // cell_size, w // cell_size, min(T, cell_time_steps)
+
+ if h != H or w != W:
+ video = video.reshape(B * T, C, h, w)
+ video = F.interpolate(video, size=(H, W), mode="bilinear")
+ video = video.reshape(B, T, C, H, W)
+
+ init = self.point_tracker(data, mode="tracks_at_motion_boundaries", **kwargs)["tracks"]
+ init = torch.stack([init[..., 0] / (w - 1), init[..., 1] / (h - 1), init[..., 2]], dim=-1)
+
+ feats = self.optical_flow_refiner({"video": video}, mode="feats", **kwargs)["feats"]
+
+ grid = get_grid(H, W, device=video.device)
+ visited_cells = torch.zeros(B, T, ch, cw, device=video.device)
+ src_steps = torch.linspace(0, T - 1, T // ct).long()
+ tracks = [[] for _ in range(B)]
+ for k, src_step in enumerate(tqdm(src_steps, desc="Refine source step", leave=False)):
+ if visited_cells[:, src_step].all():
+ continue
+ src_points = init[:, src_step]
+ src_feats = feats[:, src_step]
+ tracks_from_src = []
+ for tgt_step in tqdm(range(T), desc="Refine target step", leave=False):
+ if src_step == tgt_step:
+ flow = torch.zeros(B, H, W, 2, device=video.device)
+ alpha = torch.ones(B, H, W, device=video.device)
+ else:
+ tgt_points = init[:, tgt_step]
+ tgt_feats = feats[:, tgt_step]
+ data = {
+ "src_feats": src_feats,
+ "tgt_feats": tgt_feats,
+ "src_points": src_points,
+ "tgt_points": tgt_points
+ }
+ pred = self.optical_flow_refiner(data, mode="flow_with_tracks_init", **kwargs)
+ flow, alpha = pred["flow"], pred["alpha"]
+ flow[..., 0] = flow[..., 0] / (W - 1)
+ flow[..., 1] = flow[..., 1] / (H - 1)
+ tracks_from_src.append(torch.cat([flow + grid, alpha[..., None]], dim=-1))
+ tracks_from_src = torch.stack(tracks_from_src, dim=1)
+ for b in range(B):
+ src_cell = visited_cells[b, src_step]
+ if src_cell.all():
+ continue
+ cur_y, cur_x = (1 - src_cell).nonzero(as_tuple=True)
+ cur_x = (cur_x + 0.5) / cw
+ cur_y = (cur_y + 0.5) / ch
+ cur_tracks = dense_to_sparse_tracks(cur_x, cur_y, tracks_from_src[b], h, w)
+ visited_cells[b] = update_visited(visited_cells[b], cur_tracks, h, w, ch, cw)
+ tracks[b].append(cur_tracks)
+ tracks = [torch.cat(t, dim=1) for t in tracks]
+ return {"tracks": tracks}
+
+def dense_to_sparse_tracks(x, y, tracks, height, width):
+ h, w = height, width
+ T = tracks.size(0)
+ grid = torch.stack([x, y], dim=-1) * 2 - 1
+ grid = repeat(grid, "s c -> t s r c", t=T, r=1)
+ tracks = rearrange(tracks, "t h w c -> t c h w")
+ tracks = F.grid_sample(tracks, grid, align_corners=True, mode="bilinear")
+ tracks = rearrange(tracks[..., 0], "t c s -> t s c")
+ tracks[..., 0] = tracks[..., 0] * (w - 1)
+ tracks[..., 1] = tracks[..., 1] * (h - 1)
+ tracks[..., 2] = (tracks[..., 2] > 0).float()
+ return tracks
+
+def update_visited(visited_cells, tracks, height, width, cell_height, cell_width):
+ T = tracks.size(0)
+ h, w = height, width
+ ch, cw = cell_height, cell_width
+ for tgt_step in range(T):
+ tgt_points = tracks[tgt_step]
+ tgt_vis = tgt_points[:, 2]
+ visited = tgt_points[tgt_vis.bool()]
+ if len(visited) > 0:
+ visited_x, visited_y = visited[:, 0], visited[:, 1]
+ visited_x = (visited_x / (w - 1) * cw).floor().long()
+ visited_y = (visited_y / (h - 1) * ch).floor().long()
+ valid = (visited_x >= 0) & (visited_x < cw) & (visited_y >= 0) & (visited_y < ch)
+ visited_x = visited_x[valid]
+ visited_y = visited_y[valid]
+ tgt_cell = visited_cells[tgt_step].view(-1)
+ tgt_cell[visited_y * cw + visited_x] = 1.
+ tgt_cell = tgt_cell.view_as(visited_cells[tgt_step])
+ visited_cells[tgt_step] = tgt_cell
+ return visited_cells
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/interpolation.py b/data/dot_single_video/dot/models/interpolation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2edb78a31e98d6cb050033d2b0da53dc9c89cf72
--- /dev/null
+++ b/data/dot_single_video/dot/models/interpolation.py
@@ -0,0 +1,53 @@
+import warnings
+import torch
+
+try:
+ from dot.utils import torch3d
+except ModuleNotFoundError:
+ torch3d = None
+
+if torch3d:
+ TORCH3D_AVAILABLE = True
+else:
+ TORCH3D_AVAILABLE = False
+
+
+def interpolate(src_points, tgt_points, grid, version="torch3d"):
+ B, S, _ = src_points.shape
+ H, W, _ = grid.shape
+
+ # For each point in a regular grid, find indices of nearest visible source point
+ grid = grid.view(1, H * W, 2).expand(B, -1, -1) # B HW 2
+ src_pos, src_alpha = src_points[..., :2], src_points[..., 2]
+ if version == "torch" or (version == "torch3d" and not TORCH3D_AVAILABLE):
+ if version == "torch3d":
+ warnings.warn(
+ "Torch3D is not available. For optimal speed and memory consumption, consider setting it up.",
+ stacklevel=2,
+ )
+ dis = (grid ** 2).sum(-1)[:, None] + (src_pos ** 2).sum(-1)[:, :, None] - 2 * src_pos @ grid.permute(0, 2, 1)
+ dis[src_alpha == 0] = float('inf')
+ _, idx = dis.min(dim=1)
+ idx = idx.view(B, H * W, 1)
+ elif version == "torch3d":
+ src_pos_packed = src_pos[src_alpha.bool()]
+ tgt_points_packed = tgt_points[src_alpha.bool()]
+ lengths = src_alpha.sum(dim=1).long()
+ max_length = int(lengths.max())
+ cum_lengths = lengths.cumsum(dim=0)
+ cum_lengths = torch.cat([torch.zeros_like(cum_lengths[:1]), cum_lengths[:-1]])
+ src_pos = torch3d.packed_to_padded(src_pos_packed, cum_lengths, max_length)
+ tgt_points = torch3d.packed_to_padded(tgt_points_packed, cum_lengths, max_length)
+ _, idx, _ = torch3d.knn_points(grid, src_pos, lengths2=lengths, return_nn=False)
+ idx = idx.view(B, H * W, 1)
+
+ # Use correspondences between source and target points to initialize the flow
+ tgt_pos, tgt_alpha = tgt_points[..., :2], tgt_points[..., 2]
+ flow = tgt_pos - src_pos
+ flow = torch.cat([flow, tgt_alpha[..., None]], dim=-1) # B S 3
+ flow = flow.gather(dim=1, index=idx.expand(-1, -1, flow.size(-1)))
+ flow = flow.view(B, H, W, -1)
+ flow, alpha = flow[..., :2], flow[..., 2]
+ flow[..., 0] = flow[..., 0] * (W - 1)
+ flow[..., 1] = flow[..., 1] * (H - 1)
+ return flow, alpha
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/optical_flow.py b/data/dot_single_video/dot/models/optical_flow.py
new file mode 100644
index 0000000000000000000000000000000000000000..69f7e186bf45daa98b28e16349fd85116d6125ff
--- /dev/null
+++ b/data/dot_single_video/dot/models/optical_flow.py
@@ -0,0 +1,91 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from tqdm import tqdm
+
+from .shelf import RAFT
+from .interpolation import interpolate
+from dot.utils.io import read_config
+from dot.utils.torch import get_grid, get_sobel_kernel
+
+
+class OpticalFlow(nn.Module):
+ def __init__(self, height, width, config, load_path):
+ super().__init__()
+ model_args = read_config(config)
+ model_dict = {"raft": RAFT}
+ self.model = model_dict[model_args.name](model_args)
+ self.name = model_args.name
+ if load_path is not None:
+ device = next(self.model.parameters()).device
+ self.model.load_state_dict(torch.load(load_path, map_location=device))
+ coarse_height, coarse_width = height // model_args.patch_size, width // model_args.patch_size
+ self.register_buffer("coarse_grid", get_grid(coarse_height, coarse_width))
+
+ def forward(self, data, mode, **kwargs):
+ if mode == "flow_with_tracks_init":
+ return self.get_flow_with_tracks_init(data, **kwargs)
+ elif mode == "motion_boundaries":
+ return self.get_motion_boundaries(data, **kwargs)
+ elif mode == "feats":
+ return self.get_feats(data, **kwargs)
+ elif mode == "tracks_for_queries":
+ return self.get_tracks_for_queries(data, **kwargs)
+ elif mode == "tracks_from_first_to_every_other_frame":
+ return self.get_tracks_from_first_to_every_other_frame(data, **kwargs)
+ elif mode == "flow_from_last_to_first_frame":
+ return self.get_flow_from_last_to_first_frame(data, **kwargs)
+ else:
+ raise ValueError(f"Unknown mode {mode}")
+
+ def get_motion_boundaries(self, data, boundaries_size=1, boundaries_dilation=4, boundaries_thresh=0.025, **kwargs):
+ eps = 1e-12
+ src_frame, tgt_frame = data["src_frame"], data["tgt_frame"]
+ K = boundaries_size * 2 + 1
+ D = boundaries_dilation
+ B, _, H, W = src_frame.shape
+ reflect = torch.nn.ReflectionPad2d(K // 2)
+ sobel_kernel = get_sobel_kernel(K).to(src_frame.device)
+ flow, _ = self.model(src_frame, tgt_frame)
+ norm_flow = torch.stack([flow[..., 0] / (W - 1), flow[..., 1] / (H - 1)], dim=-1)
+ norm_flow = norm_flow.permute(0, 3, 1, 2).reshape(-1, 1, H, W)
+ boundaries = F.conv2d(reflect(norm_flow), sobel_kernel)
+ boundaries = ((boundaries ** 2).sum(dim=1, keepdim=True) + eps).sqrt()
+ boundaries = boundaries.view(-1, 2, H, W).mean(dim=1, keepdim=True)
+ if boundaries_dilation > 1:
+ boundaries = torch.nn.functional.max_pool2d(boundaries, kernel_size=D * 2, stride=1, padding=D)
+ boundaries = boundaries[:, :, -H:, -W:]
+ boundaries = boundaries[:, 0]
+ boundaries = boundaries - boundaries.reshape(B, -1).min(dim=1)[0].reshape(B, 1, 1)
+ boundaries = boundaries / boundaries.reshape(B, -1).max(dim=1)[0].reshape(B, 1, 1)
+ boundaries = boundaries > boundaries_thresh
+ return {"motion_boundaries": boundaries, "flow": flow}
+
+ def get_feats(self, data, **kwargs):
+ video = data["video"]
+ feats = []
+ for step in tqdm(range(video.size(1)), desc="Extract feats for frame", leave=False):
+ feats.append(self.model.encode(video[:, step]))
+ feats = torch.stack(feats, dim=1)
+ return {"feats": feats}
+
+ def get_flow_with_tracks_init(self, data, is_train=False, interpolation_version="torch3d", alpha_thresh=0.8, **kwargs):
+ coarse_flow, coarse_alpha = interpolate(data["src_points"], data["tgt_points"], self.coarse_grid,
+ version=interpolation_version)
+ flow, alpha = self.model(src_frame=data["src_frame"] if "src_feats" not in data else None,
+ tgt_frame=data["tgt_frame"] if "tgt_feats" not in data else None,
+ src_feats=data["src_feats"] if "src_feats" in data else None,
+ tgt_feats=data["tgt_feats"] if "tgt_feats" in data else None,
+ coarse_flow=coarse_flow,
+ coarse_alpha=coarse_alpha,
+ is_train=is_train)
+ if not is_train:
+ alpha = (alpha > alpha_thresh).float()
+ return {"flow": flow, "alpha": alpha, "coarse_flow": coarse_flow, "coarse_alpha": coarse_alpha}
+
+ def get_tracks_for_queries(self, data, **kwargs):
+ raise NotImplementedError
+
+
+
+
diff --git a/data/dot_single_video/dot/models/point_tracking.py b/data/dot_single_video/dot/models/point_tracking.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea48ebd771ebd264c8e0897186bb8bbff4f4379f
--- /dev/null
+++ b/data/dot_single_video/dot/models/point_tracking.py
@@ -0,0 +1,132 @@
+from tqdm import tqdm
+import torch
+from torch import nn
+
+from .optical_flow import OpticalFlow
+from .shelf import CoTracker, CoTracker2, Tapir
+from dot.utils.io import read_config
+from dot.utils.torch import sample_points, sample_mask_points, get_grid
+
+
+class PointTracker(nn.Module):
+ def __init__(self, height, width, tracker_config, tracker_path, estimator_config, estimator_path):
+ super().__init__()
+ model_args = read_config(tracker_config)
+ model_dict = {
+ "cotracker": CoTracker,
+ "cotracker2": CoTracker2,
+ "tapir": Tapir,
+ "bootstapir": Tapir
+ }
+ self.name = model_args.name
+ self.model = model_dict[model_args.name](model_args)
+ if tracker_path is not None:
+ device = next(self.model.parameters()).device
+ self.model.load_state_dict(torch.load(tracker_path, map_location=device), strict=False)
+ self.optical_flow_estimator = OpticalFlow(height, width, estimator_config, estimator_path)
+
+ def forward(self, data, mode, **kwargs):
+ if mode == "tracks_at_motion_boundaries":
+ return self.get_tracks_at_motion_boundaries(data, **kwargs)
+ elif mode == "flow_from_last_to_first_frame":
+ return self.get_flow_from_last_to_first_frame(data, **kwargs)
+ else:
+ raise ValueError(f"Unknown mode {mode}")
+
+ def get_tracks_at_motion_boundaries(self, data, num_tracks=8192, sim_tracks=2048, sample_mode="all", **kwargs):
+ video = data["video"]
+ N, S = num_tracks, sim_tracks
+ B, T, _, H, W = video.shape
+ assert N % S == 0
+
+ # Define sampling strategy
+ if sample_mode == "all":
+ samples_per_step = [S // T for _ in range(T)]
+ samples_per_step[0] += S - sum(samples_per_step)
+ backward_tracking = True
+ flip = False
+ elif sample_mode == "first":
+ samples_per_step = [0 for _ in range(T)]
+ samples_per_step[0] += S
+ backward_tracking = False
+ flip = False
+ elif sample_mode == "last":
+ samples_per_step = [0 for _ in range(T)]
+ samples_per_step[0] += S
+ backward_tracking = False
+ flip = True
+ else:
+ raise ValueError(f"Unknown sample mode {sample_mode}")
+
+ if flip:
+ video = video.flip(dims=[1])
+
+ # Track batches of points
+ tracks = []
+ motion_boundaries = {}
+ cache_features = True
+ for _ in tqdm(range(N // S), desc="Track batch of points", leave=False):
+ src_points = []
+ for src_step, src_samples in enumerate(samples_per_step):
+ if src_samples == 0:
+ continue
+ if not src_step in motion_boundaries:
+ tgt_step = src_step - 1 if src_step > 0 else src_step + 1
+ data = {"src_frame": video[:, src_step], "tgt_frame": video[:, tgt_step]}
+ pred = self.optical_flow_estimator(data, mode="motion_boundaries", **kwargs)
+ motion_boundaries[src_step] = pred["motion_boundaries"]
+ src_boundaries = motion_boundaries[src_step]
+ src_points.append(sample_points(src_step, src_boundaries, src_samples))
+ src_points = torch.cat(src_points, dim=1)
+ traj, vis = self.model(video, src_points, backward_tracking, cache_features)
+ tracks.append(torch.cat([traj, vis[..., None]], dim=-1))
+ cache_features = False
+ tracks = torch.cat(tracks, dim=2)
+
+ if flip:
+ tracks = tracks.flip(dims=[1])
+
+ return {"tracks": tracks}
+
+ def get_flow_from_last_to_first_frame(self, data, sim_tracks=2048, **kwargs):
+ video = data["video"]
+ video = video.flip(dims=[1])
+ src_step = 0 # We have flipped video over temporal axis so src_step is 0
+ B, T, C, H, W = video.shape
+ S = sim_tracks
+ backward_tracking = False
+ cache_features = True
+ flow = get_grid(H, W, shape=[B]).cuda()
+ flow[..., 0] = flow[..., 0] * (W - 1)
+ flow[..., 1] = flow[..., 1] * (H - 1)
+ alpha = torch.zeros(B, H, W).cuda()
+ mask = torch.ones(H, W)
+ pbar = tqdm(total=H * W // S, desc="Track batch of points", leave=False)
+ while torch.any(mask):
+ points, (i, j) = sample_mask_points(src_step, mask, S)
+ idx = i * W + j
+ points = points.cuda()[None].expand(B, -1, -1)
+
+ traj, vis = self.model(video, points, backward_tracking, cache_features)
+ traj = traj[:, -1]
+ vis = vis[:, -1].float()
+
+ # Update mask
+ mask = mask.view(-1)
+ mask[idx] = 0
+ mask = mask.view(H, W)
+
+ # Update flow
+ flow = flow.view(B, -1, 2)
+ flow[:, idx] = traj - flow[:, idx]
+ flow = flow.view(B, H, W, 2)
+
+ # Update alpha
+ alpha = alpha.view(B, -1)
+ alpha[:, idx] = vis
+ alpha = alpha.view(B, H, W)
+
+ cache_features = False
+ pbar.update(1)
+ pbar.close()
+ return {"flow": flow, "alpha": alpha}
diff --git a/data/dot_single_video/dot/models/shelf/__init__.py b/data/dot_single_video/dot/models/shelf/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce20921a243fec9437682fc8ec34ee3955dc87f
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/__init__.py
@@ -0,0 +1,4 @@
+from .raft import RAFT
+from .cotracker import CoTracker
+from .cotracker2 import CoTracker2
+from .tapir import Tapir
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/shelf/cotracker.py b/data/dot_single_video/dot/models/shelf/cotracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e0a6d25a4ac32cbd2214cd37b955e2ab4cef99e
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker.py
@@ -0,0 +1,12 @@
+from torch import nn
+
+from .cotracker_utils.predictor import CoTrackerPredictor
+
+
+class CoTracker(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.model = CoTrackerPredictor(args.patch_size, args.wind_size)
+
+ def forward(self, video, queries, backward_tracking, cache_features=False):
+ return self.model(video, queries=queries, backward_tracking=backward_tracking, cache_features=cache_features)
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2.py b/data/dot_single_video/dot/models/shelf/cotracker2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0547bfdf4e44f016ec77d60e5f13102a4f0505ad
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2.py
@@ -0,0 +1,14 @@
+from torch import nn
+
+from .cotracker2_utils.predictor import CoTrackerPredictor
+
+
+class CoTracker2(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.model = CoTrackerPredictor(args.patch_size, args.wind_size)
+
+ def forward(self, video, queries, backward_tracking, cache_features=False):
+ return self.model(video, queries=queries, backward_tracking=backward_tracking, cache_features=cache_features)
+
+
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/LICENSE.md b/data/dot_single_video/dot/models/shelf/cotracker2_utils/LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..ba959871dca0f9b6775570410879e637de44d7b4
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/LICENSE.md
@@ -0,0 +1,399 @@
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/build_cotracker.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/build_cotracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..b26aa4b91d7b9e8ad1822f8f4d12a065ee7b7157
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/build_cotracker.py
@@ -0,0 +1,14 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from dot.models.shelf.cotracker2_utils.models.core.cotracker.cotracker import CoTracker2
+
+
+def build_cotracker(patch_size, wind_size):
+ cotracker = CoTracker2(stride=patch_size, window_len=wind_size, add_space_attn=True)
+ return cotracker
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/blocks.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..f64f945522b8f52876c7e86b4934f2fb1a949439
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/blocks.py
@@ -0,0 +1,367 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+from dot.models.shelf.cotracker2_utils.models.core.model_utils import bilinear_sampler
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class BasicEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
+ super(BasicEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = "instance"
+ self.in_planes = output_dim // 2
+
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+
+ self.conv1 = nn.Conv2d(
+ input_dim,
+ self.in_planes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ padding_mode="zeros",
+ )
+ self.relu1 = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
+ self.layer3 = self._make_layer(output_dim, stride=2)
+ self.layer4 = self._make_layer(output_dim, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ output_dim * 3 + output_dim // 4,
+ output_dim * 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.InstanceNorm2d)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ d = self.layer4(c)
+
+ def _bilinear_intepolate(x):
+ return F.interpolate(
+ x,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ a = _bilinear_intepolate(a)
+ b = _bilinear_intepolate(b)
+ c = _bilinear_intepolate(c)
+ d = _bilinear_intepolate(d)
+
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+ return x
+
+
+class CorrBlock:
+ def __init__(
+ self,
+ fmaps,
+ num_levels=4,
+ radius=4,
+ multiple_track_feats=False,
+ padding_mode="zeros",
+ ):
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.padding_mode = padding_mode
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+ self.multiple_track_feats = multiple_track_feats
+
+ self.fmaps_pyramid.append(fmaps)
+ for i in range(self.num_levels - 1):
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ *_, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
+
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corrs = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W),
+ coords_lvl,
+ padding_mode=self.padding_mode,
+ )
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
+ return out
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ if self.multiple_track_feats:
+ targets_split = targets.split(C // self.num_levels, dim=-1)
+ B, S, N, C = targets_split[0].shape
+
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ *_, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
+ if self.multiple_track_feats:
+ fmap1 = targets_split[i]
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ self.corrs_pyramid.append(corrs)
+
+
+class Attention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
+ super().__init__()
+ inner_dim = dim_head * num_heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head**-0.5
+ self.heads = num_heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
+ self.to_out = nn.Linear(inner_dim, query_dim)
+
+ def forward(self, x, context=None, attn_bias=None):
+ B, N1, C = x.shape
+ h = self.heads
+
+ q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ N2 = context.shape[1]
+ k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
+ v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
+
+ sim = (q @ k.transpose(-2, -1)) * self.scale
+
+ if attn_bias is not None:
+ sim = sim + attn_bias
+ attn = sim.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
+ return self.to_out(x)
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = Attention,
+ mlp_ratio=4.0,
+ **block_kwargs
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, mask=None):
+ attn_bias = mask
+ if mask is not None:
+ mask = (
+ (mask[:, None] * mask[:, :, None])
+ .unsqueeze(1)
+ .expand(-1, self.attn.num_heads, -1, -1)
+ )
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/cotracker.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/cotracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f06ff96fb8814d93480d0a653e1fd157fadcfea
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/cotracker.py
@@ -0,0 +1,507 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from dot.models.shelf.cotracker2_utils.models.core.model_utils import sample_features4d, sample_features5d
+from dot.models.shelf.cotracker2_utils.models.core.embeddings import (
+ get_2d_embedding,
+ get_1d_sincos_pos_embed_from_grid,
+ get_2d_sincos_pos_embed,
+)
+
+from dot.models.shelf.cotracker2_utils.models.core.cotracker.blocks import (
+ Mlp,
+ BasicEncoder,
+ AttnBlock,
+ CorrBlock,
+ Attention,
+)
+
+torch.manual_seed(0)
+
+
+class CoTracker2(nn.Module):
+ def __init__(
+ self,
+ window_len=8,
+ stride=4,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ model_resolution=(384, 512),
+ ):
+ super(CoTracker2, self).__init__()
+ self.window_len = window_len
+ self.stride = stride
+ self.hidden_dim = 256
+ self.latent_dim = 128
+ self.add_space_attn = add_space_attn
+ self.fnet = BasicEncoder(output_dim=self.latent_dim)
+ self.num_virtual_tracks = num_virtual_tracks
+ self.model_resolution = model_resolution
+ self.input_dim = 456
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=6,
+ time_depth=6,
+ input_dim=self.input_dim,
+ hidden_size=384,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=add_space_attn,
+ num_virtual_tracks=num_virtual_tracks,
+ )
+
+ time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1)
+
+ self.register_buffer(
+ "time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
+ )
+
+ self.register_buffer(
+ "pos_emb",
+ get_2d_sincos_pos_embed(
+ embed_dim=self.input_dim,
+ grid_size=(
+ model_resolution[0] // stride,
+ model_resolution[1] // stride,
+ ),
+ ),
+ )
+ self.norm = nn.GroupNorm(1, self.latent_dim)
+ self.track_feat_updater = nn.Sequential(
+ nn.Linear(self.latent_dim, self.latent_dim),
+ nn.GELU(),
+ )
+ self.vis_predictor = nn.Sequential(
+ nn.Linear(self.latent_dim, 1),
+ )
+
+ def forward_window(
+ self,
+ fmaps,
+ coords,
+ track_feat=None,
+ vis=None,
+ track_mask=None,
+ attention_mask=None,
+ iters=4,
+ ):
+ # B = batch size
+ # S = number of frames in the window)
+ # N = number of tracks
+ # C = channels of a point feature vector
+ # E = positional embedding size
+ # LRR = local receptive field radius
+ # D = dimension of the transformer input tokens
+
+ # track_feat = B S N C
+ # vis = B S N 1
+ # track_mask = B S N 1
+ # attention_mask = B S N
+
+ B, S_init, N, __ = track_mask.shape
+ B, S, *_ = fmaps.shape
+
+ track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
+ track_mask_vis = (
+ torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+ )
+
+ corr_block = CorrBlock(
+ fmaps,
+ num_levels=4,
+ radius=3,
+ padding_mode="border",
+ )
+
+ sampled_pos_emb = (
+ sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
+ .reshape(B * N, self.input_dim)
+ .unsqueeze(1)
+ ) # B E N -> (B N) 1 E
+
+ coord_preds = []
+ for __ in range(iters):
+ coords = coords.detach() # B S N 2
+ corr_block.corr(track_feat)
+
+ # Sample correlation features around each point
+ fcorrs = corr_block.sample(coords) # (B N) S LRR
+
+ # Get the flow embeddings
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+ flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
+
+ track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2)
+ x = transformer_input + sampled_pos_emb + self.time_emb
+ x = x.view(B, N, S, -1) # (B N) S D -> B N S D
+
+ delta = self.updateformer(
+ x,
+ attention_mask.reshape(B * S, N), # B S N -> (B S) N
+ )
+
+ delta_coords = delta[..., :2].permute(0, 2, 1, 3)
+ coords = coords + delta_coords
+ coord_preds.append(coords * self.stride)
+
+ delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
+ track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
+ track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
+ track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
+ 0, 2, 1, 3
+ ) # (B N S) C -> B S N C
+
+ vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
+ return coord_preds, vis_pred
+
+ def get_track_feat(self, fmaps, queried_frames, queried_coords):
+ sample_frames = queried_frames[:, None, :, None]
+ sample_coords = torch.cat(
+ [
+ sample_frames,
+ queried_coords[:, None],
+ ],
+ dim=-1,
+ )
+ sample_track_feats = sample_features5d(fmaps, sample_coords)
+ return sample_track_feats
+
+ def init_video_online_processing(self):
+ self.online_ind = 0
+ self.online_track_feat = None
+ self.online_coords_predicted = None
+ self.online_vis_predicted = None
+
+ def forward(self, video, queries, iters=4, cached_feat=None, is_train=False, is_online=False):
+ """Predict tracks
+
+ Args:
+ video (FloatTensor[B, T, 3]): input videos.
+ queries (FloatTensor[B, N, 3]): point queries.
+ iters (int, optional): number of updates. Defaults to 4.
+ is_train (bool, optional): enables training mode. Defaults to False.
+ is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
+
+ Returns:
+ - coords_predicted (FloatTensor[B, T, N, 2]):
+ - vis_predicted (FloatTensor[B, T, N]):
+ - train_data: `None` if `is_train` is false, otherwise:
+ - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
+ - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
+ - mask (BoolTensor[B, T, N]):
+ """
+ B, T, C, H, W = video.shape
+ B, N, __ = queries.shape
+ S = self.window_len
+ device = queries.device
+
+ # B = batch size
+ # S = number of frames in the window of the padded video
+ # S_trimmed = actual number of frames in the window
+ # N = number of tracks
+ # C = color channels (3 for RGB)
+ # E = positional embedding size
+ # LRR = local receptive field radius
+ # D = dimension of the transformer input tokens
+
+ # video = B T C H W
+ # queries = B N 3
+ # coords_init = B S N 2
+ # vis_init = B S N 1
+
+ assert S >= 2 # A tracker needs at least two frames to track something
+ if is_online:
+ assert T <= S, "Online mode: video chunk must be <= window size."
+ assert self.online_ind is not None, "Call model.init_video_online_processing() first."
+ assert not is_train, "Training not supported in online mode."
+ step = S // 2 # How much the sliding window moves at every step
+ video = 2 * video - 1.0
+
+ # The first channel is the frame number
+ # The rest are the coordinates of points we want to track
+ queried_frames = queries[:, :, 0].long()
+
+ queried_coords = queries[..., 1:]
+ queried_coords = queried_coords / self.stride
+
+ # We store our predictions here
+ coords_predicted = torch.zeros((B, T, N, 2), device=device)
+ vis_predicted = torch.zeros((B, T, N), device=device)
+ if is_online:
+ if self.online_coords_predicted is None:
+ # Init online predictions with zeros
+ self.online_coords_predicted = coords_predicted
+ self.online_vis_predicted = vis_predicted
+ else:
+ # Pad online predictions with zeros for the current window
+ pad = min(step, T - step)
+ coords_predicted = F.pad(
+ self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
+ )
+ vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant")
+ all_coords_predictions, all_vis_predictions = [], []
+
+ # Pad the video so that an integer number of sliding windows fit into it
+ # TODO: we may drop this requirement because the transformer should not care
+ # TODO: pad the features instead of the video
+ pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
+ video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
+ B, -1, C, H, W
+ )
+
+ # Compute convolutional features for the video or for the current chunk in case of online mode
+ if cached_feat is None:
+ fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
+ B, -1, self.latent_dim, H // self.stride, W // self.stride
+ )
+ else:
+ _, _, c, h, w = cached_feat.shape
+ fmaps = F.pad(cached_feat.reshape(B, 1, T, c * h * w), (0, 0, 0, pad), "replicate").reshape(B, -1, c, h, w)
+
+ # We compute track features
+ track_feat = self.get_track_feat(
+ fmaps,
+ queried_frames - self.online_ind if is_online else queried_frames,
+ queried_coords,
+ ).repeat(1, S, 1, 1)
+ if is_online:
+ # We update track features for the current window
+ sample_frames = queried_frames[:, None, :, None] # B 1 N 1
+ left = 0 if self.online_ind == 0 else self.online_ind + step
+ right = self.online_ind + S
+ sample_mask = (sample_frames >= left) & (sample_frames < right)
+ if self.online_track_feat is None:
+ self.online_track_feat = torch.zeros_like(track_feat, device=device)
+ self.online_track_feat += track_feat * sample_mask
+ track_feat = self.online_track_feat.clone()
+ # We process ((num_windows - 1) * step + S) frames in total, so there are
+ # (ceil((T - S) / step) + 1) windows
+ num_windows = (T - S + step - 1) // step + 1
+ # We process only the current video chunk in the online mode
+ indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
+
+ coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
+ vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
+ for ind in indices:
+ # We copy over coords and vis for tracks that are queried
+ # by the end of the previous window, which is ind + overlap
+ if ind > 0:
+ overlap = S - step
+ copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
+ coords_prev = torch.nn.functional.pad(
+ coords_predicted[:, ind : ind + overlap] / self.stride,
+ (0, 0, 0, 0, 0, step),
+ "replicate",
+ ) # B S N 2
+ vis_prev = torch.nn.functional.pad(
+ vis_predicted[:, ind : ind + overlap, :, None].clone(),
+ (0, 0, 0, 0, 0, step),
+ "replicate",
+ ) # B S N 1
+ coords_init = torch.where(
+ copy_over.expand_as(coords_init), coords_prev, coords_init
+ )
+ vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
+
+ # The attention mask is 1 for the spatio-temporal points within
+ # a track which is updated in the current window
+ attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
+
+ # The track mask is 1 for the spatio-temporal points that actually
+ # need updating: only after begin queried, and not if contained
+ # in a previous window
+ track_mask = (
+ queried_frames[:, None, :, None]
+ <= torch.arange(ind, ind + S, device=device)[None, :, None, None]
+ ).contiguous() # B S N 1
+
+ if ind > 0:
+ track_mask[:, :overlap, :, :] = False
+
+ # Predict the coordinates and visibility for the current window
+ coords, vis = self.forward_window(
+ fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
+ coords=coords_init,
+ track_feat=attention_mask.unsqueeze(-1) * track_feat,
+ vis=vis_init,
+ track_mask=track_mask,
+ attention_mask=attention_mask,
+ iters=iters,
+ )
+
+ S_trimmed = T if is_online else min(T - ind, S) # accounts for last window duration
+ coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
+ vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
+ if is_train:
+ all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords])
+ all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
+
+ if is_online:
+ self.online_ind += step
+ self.online_coords_predicted = coords_predicted
+ self.online_vis_predicted = vis_predicted
+ vis_predicted = torch.sigmoid(vis_predicted)
+
+ if is_train:
+ mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None]
+ train_data = (all_coords_predictions, all_vis_predictions, mask)
+ else:
+ train_data = None
+
+ return coords_predicted, vis_predicted, train_data
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(space_depth)
+ ]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ tokens = self.input_transform(input_tensor)
+ B, _, T, _ = tokens.shape
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (
+ i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
+ ):
+ space_tokens = (
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
+ ) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](
+ virtual_tokens, point_tokens, mask=mask
+ )
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](
+ point_tokens, virtual_tokens, mask=mask
+ )
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+ flow = self.flow_head(tokens)
+ return flow
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.cross_attn = Attention(
+ hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, context, mask=None):
+ if mask is not None:
+ if mask.shape[1] == x.shape[1]:
+ mask = mask[:, None, :, None].expand(
+ -1, self.cross_attn.heads, -1, context.shape[1]
+ )
+ else:
+ mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1)
+
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.cross_attn(
+ self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
+ )
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/losses.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddfd8cbd4fea982b546ed82758ab75485316da43
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/cotracker/losses.py
@@ -0,0 +1,61 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from dot.models.shelf.cotracker2_utils.models.core.model_utils import reduce_masked_mean
+
+EPS = 1e-6
+
+
+def balanced_ce_loss(pred, gt, valid=None):
+ total_balanced_loss = 0.0
+ for j in range(len(gt)):
+ B, S, N = gt[j].shape
+ # pred and gt are the same shape
+ for (a, b) in zip(pred[j].size(), gt[j].size()):
+ assert a == b # some shape mismatch!
+ # if valid is not None:
+ for (a, b) in zip(pred[j].size(), valid[j].size()):
+ assert a == b # some shape mismatch!
+
+ pos = (gt[j] > 0.95).float()
+ neg = (gt[j] < 0.05).float()
+
+ label = pos * 2.0 - 1.0
+ a = -label * pred[j]
+ b = F.relu(a)
+ loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
+
+ pos_loss = reduce_masked_mean(loss, pos * valid[j])
+ neg_loss = reduce_masked_mean(loss, neg * valid[j])
+
+ balanced_loss = pos_loss + neg_loss
+ total_balanced_loss += balanced_loss / float(N)
+ return total_balanced_loss
+
+
+def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
+ """Loss function defined over sequence of flow predictions"""
+ total_flow_loss = 0.0
+ for j in range(len(flow_gt)):
+ B, S, N, D = flow_gt[j].shape
+ assert D == 2
+ B, S1, N = vis[j].shape
+ B, S2, N = valids[j].shape
+ assert S == S1
+ assert S == S2
+ n_predictions = len(flow_preds[j])
+ flow_loss = 0.0
+ for i in range(n_predictions):
+ i_weight = gamma ** (n_predictions - i - 1)
+ flow_pred = flow_preds[j][i]
+ i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
+ flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
+ flow_loss = flow_loss / n_predictions
+ total_flow_loss += flow_loss / float(N)
+ return total_flow_loss
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/embeddings.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ee4aeedeb68ef69d10667f4dc5b73e90335c34a
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/embeddings.py
@@ -0,0 +1,120 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Tuple, Union
+import torch
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim: int, grid_size: Union[int, Tuple[int, int]]
+) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(
+ embed_dim: int, grid: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(
+ embed_dim: int, pos: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/model_utils.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..12afd4e5e143f548a42350616456b614a88dfb2c
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/core/model_utils.py
@@ -0,0 +1,256 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from typing import Optional, Tuple
+
+EPS = 1e-6
+
+
+def smart_cat(tensor1, tensor2, dim):
+ if tensor1 is None:
+ return tensor2
+ return torch.cat([tensor1, tensor2], dim=dim)
+
+
+def get_points_on_a_grid(
+ size: int,
+ extent: Tuple[float, ...],
+ center: Optional[Tuple[float, ...]] = None,
+ device: Optional[torch.device] = torch.device("cpu"),
+):
+ r"""Get a grid of points covering a rectangular region
+
+ `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
+ :attr:`size` grid fo points distributed to cover a rectangular area
+ specified by `extent`.
+
+ The `extent` is a pair of integer :math:`(H,W)` specifying the height
+ and width of the rectangle.
+
+ Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
+ specifying the vertical and horizontal center coordinates. The center
+ defaults to the middle of the extent.
+
+ Points are distributed uniformly within the rectangle leaving a margin
+ :math:`m=W/64` from the border.
+
+ It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
+ points :math:`P_{ij}=(x_i, y_i)` where
+
+ .. math::
+ P_{ij} = \left(
+ c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
+ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
+ \right)
+
+ Points are returned in row-major order.
+
+ Args:
+ size (int): grid size.
+ extent (tuple): height and with of the grid extent.
+ center (tuple, optional): grid center.
+ device (str, optional): Defaults to `"cpu"`.
+
+ Returns:
+ Tensor: grid.
+ """
+ if size == 1:
+ return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
+
+ if center is None:
+ center = [extent[0] / 2, extent[1] / 2]
+
+ margin = extent[1] / 64
+ range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
+ range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(*range_y, size, device=device),
+ torch.linspace(*range_x, size, device=device),
+ indexing="ij",
+ )
+ return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
+
+
+def reduce_masked_mean(input, mask, dim=None, keepdim=False):
+ r"""Masked mean
+
+ `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
+ over a mask :attr:`mask`, returning
+
+ .. math::
+ \text{output} =
+ \frac
+ {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
+ {\epsilon + \sum_{i=1}^N \text{mask}_i}
+
+ where :math:`N` is the number of elements in :attr:`input` and
+ :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
+ division by zero.
+
+ `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
+ :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
+ Optionally, the dimension can be kept in the output by setting
+ :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
+ the same dimension as :attr:`input`.
+
+ The interface is similar to `torch.mean()`.
+
+ Args:
+ inout (Tensor): input tensor.
+ mask (Tensor): mask.
+ dim (int, optional): Dimension to sum over. Defaults to None.
+ keepdim (bool, optional): Keep the summed dimension. Defaults to False.
+
+ Returns:
+ Tensor: mean tensor.
+ """
+
+ mask = mask.expand_as(input)
+
+ prod = input * mask
+
+ if dim is None:
+ numer = torch.sum(prod)
+ denom = torch.sum(mask)
+ else:
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
+ denom = torch.sum(mask, dim=dim, keepdim=keepdim)
+
+ mean = numer / (EPS + denom)
+ return mean
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ coords = coords * torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
+ )
+ else:
+ coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
+
+ coords -= 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(
+ B, -1, feats.shape[1] * feats.shape[3]
+ ) # B C R 1 -> B R C
+
+
+def sample_features5d(input, coords):
+ r"""Sample spatio-temporal features
+
+ `sample_features5d(input, coords)` works in the same way as
+ :func:`sample_features4d` but for spatio-temporal features and points:
+ :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
+ a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
+ x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
+
+ Args:
+ input (Tensor): spatio-temporal features.
+ coords (Tensor): spatio-temporal points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, T, _, _, _ = input.shape
+
+ # B T C H W -> B C T H W
+ input = input.permute(0, 2, 1, 3, 4)
+
+ # B R1 R2 3 -> B R1 R2 1 3
+ coords = coords.unsqueeze(3)
+
+ # B C R1 R2 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 3, 1, 4).view(
+ B, feats.shape[2], feats.shape[3], feats.shape[1]
+ ) # B C R1 R2 1 -> B R1 R2 C
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/evaluation_predictor.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/evaluation_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..752dc013206c3ca220ccd86e2594a4ba652cb121
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/models/evaluation_predictor.py
@@ -0,0 +1,104 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from typing import Tuple
+
+from dot.models.shelf.cotracker2_utils.models.core.cotracker.cotracker import CoTracker2
+from dot.models.shelf.cotracker2_utils.models.core.model_utils import get_points_on_a_grid
+
+
+class EvaluationPredictor(torch.nn.Module):
+ def __init__(
+ self,
+ cotracker_model: CoTracker2,
+ interp_shape: Tuple[int, int] = (384, 512),
+ grid_size: int = 5,
+ local_grid_size: int = 8,
+ single_point: bool = True,
+ n_iters: int = 6,
+ ) -> None:
+ super(EvaluationPredictor, self).__init__()
+ self.grid_size = grid_size
+ self.local_grid_size = local_grid_size
+ self.single_point = single_point
+ self.interp_shape = interp_shape
+ self.n_iters = n_iters
+
+ self.model = cotracker_model
+ self.model.eval()
+
+ def forward(self, video, queries):
+ queries = queries.clone()
+ B, T, C, H, W = video.shape
+ B, N, D = queries.shape
+
+ assert D == 3
+
+ video = video.reshape(B * T, C, H, W)
+ video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
+ video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ device = video.device
+
+ queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
+ queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
+
+ if self.single_point:
+ traj_e = torch.zeros((B, T, N, 2), device=device)
+ vis_e = torch.zeros((B, T, N), device=device)
+ for pind in range((N)):
+ query = queries[:, pind : pind + 1]
+
+ t = query[0, 0, 0].long()
+
+ traj_e_pind, vis_e_pind = self._process_one_point(video, query)
+ traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
+ vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
+ else:
+ if self.grid_size > 0:
+ xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
+ xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
+ queries = torch.cat([queries, xy], dim=1) #
+
+ traj_e, vis_e, __ = self.model(
+ video=video,
+ queries=queries,
+ iters=self.n_iters,
+ )
+
+ traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
+ traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
+ return traj_e, vis_e
+
+ def _process_one_point(self, video, query):
+ t = query[0, 0, 0].long()
+
+ device = query.device
+ if self.local_grid_size > 0:
+ xy_target = get_points_on_a_grid(
+ self.local_grid_size,
+ (50, 50),
+ [query[0, 0, 2].item(), query[0, 0, 1].item()],
+ )
+
+ xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
+ device
+ ) #
+ query = torch.cat([query, xy_target], dim=1) #
+
+ if self.grid_size > 0:
+ xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
+ xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
+ query = torch.cat([query, xy], dim=1) #
+ # crop the video to start from the queried frame
+ query[0, 0, 0] = 0
+ traj_e_pind, vis_e_pind, __ = self.model(
+ video=video[:, t:], queries=query, iters=self.n_iters
+ )
+
+ return traj_e_pind, vis_e_pind
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/predictor.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..a24c5e52bd8bb25ff662fe6b29a936759b958b5a
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/predictor.py
@@ -0,0 +1,284 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from .models.core.model_utils import smart_cat, get_points_on_a_grid
+from .models.build_cotracker import build_cotracker
+
+
+class CoTrackerPredictor(torch.nn.Module):
+ def __init__(self, patch_size, wind_size):
+ super().__init__()
+ self.support_grid_size = 6
+ model = build_cotracker(patch_size, wind_size)
+ self.interp_shape = model.model_resolution
+ self.model = model
+ self.model.eval()
+ self.cached_feat = None
+
+ @torch.no_grad()
+ def forward(
+ self,
+ video, # (1, T, 3, H, W)
+ # input prompt types:
+ # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
+ # *backward_tracking=True* will compute tracks in both directions.
+ # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
+ # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
+ # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
+ queries: torch.Tensor = None,
+ segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
+ grid_size: int = 0,
+ grid_query_frame: int = 0, # only for dense and regular grid tracks
+ backward_tracking: bool = False,
+ cache_features: bool = False,
+ ):
+ if queries is None and grid_size == 0:
+ tracks, visibilities = self._compute_dense_tracks(
+ video,
+ grid_query_frame=grid_query_frame,
+ backward_tracking=backward_tracking,
+ )
+ else:
+ tracks, visibilities = self._compute_sparse_tracks(
+ video,
+ queries,
+ segm_mask,
+ grid_size,
+ add_support_grid=(grid_size == 0 or segm_mask is not None),
+ grid_query_frame=grid_query_frame,
+ backward_tracking=backward_tracking,
+ cache_features=cache_features
+ )
+
+ return tracks, visibilities
+
+ def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
+ *_, H, W = video.shape
+ grid_step = W // grid_size
+ grid_width = W // grid_step
+ grid_height = H // grid_step
+ tracks = visibilities = None
+ grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
+ grid_pts[0, :, 0] = grid_query_frame
+ for offset in range(grid_step * grid_step):
+ print(f"step {offset} / {grid_step * grid_step}")
+ ox = offset % grid_step
+ oy = offset // grid_step
+ grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
+ grid_pts[0, :, 2] = (
+ torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
+ )
+ tracks_step, visibilities_step = self._compute_sparse_tracks(
+ video=video,
+ queries=grid_pts,
+ backward_tracking=backward_tracking,
+ )
+ tracks = smart_cat(tracks, tracks_step, dim=2)
+ visibilities = smart_cat(visibilities, visibilities_step, dim=2)
+
+ return tracks, visibilities
+
+ def _compute_sparse_tracks(
+ self,
+ video,
+ queries,
+ segm_mask=None,
+ grid_size=0,
+ add_support_grid=False,
+ grid_query_frame=0,
+ backward_tracking=False,
+ cache_features=False,
+ ):
+ B, T, C, H, W = video.shape
+
+ video = video.reshape(B * T, C, H, W)
+ video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
+ video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ if cache_features:
+ h, w = self.interp_shape[0], self.interp_shape[1]
+ video_ = video.reshape(B * T, C, h, w)
+ video_ = 2 * video_ - 1.0
+ fmaps_ = self.model.fnet(video_)
+ fmaps_ = fmaps_.reshape(B, T, self.model.latent_dim, h // self.model.stride, w // self.model.stride)
+ self.cached_feat = fmaps_
+
+ if queries is not None:
+ B, N, D = queries.shape
+ assert D == 3
+ queries = queries.clone()
+ queries[:, :, 1:] *= queries.new_tensor(
+ [
+ (self.interp_shape[1] - 1) / (W - 1),
+ (self.interp_shape[0] - 1) / (H - 1),
+ ]
+ )
+ elif grid_size > 0:
+ grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
+ if segm_mask is not None:
+ segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
+ point_mask = segm_mask[0, 0][
+ (grid_pts[0, :, 1]).round().long().cpu(),
+ (grid_pts[0, :, 0]).round().long().cpu(),
+ ].bool()
+ grid_pts = grid_pts[:, point_mask]
+
+ queries = torch.cat(
+ [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
+ dim=2,
+ )
+
+ if add_support_grid:
+ grid_pts = get_points_on_a_grid(
+ self.support_grid_size, self.interp_shape, device=video.device
+ )
+ grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
+ grid_pts = grid_pts.repeat(B, 1, 1)
+ queries = torch.cat([queries, grid_pts], dim=1)
+
+ tracks, visibilities, __ = self.model.forward(
+ video=video,
+ queries=queries,
+ iters=6,
+ cached_feat=self.cached_feat
+ )
+
+ if backward_tracking:
+ tracks, visibilities = self._compute_backward_tracks(
+ video, queries, tracks, visibilities
+ )
+ if add_support_grid:
+ queries[:, -self.support_grid_size**2 :, 0] = T - 1
+ if add_support_grid:
+ tracks = tracks[:, :, : -self.support_grid_size**2]
+ visibilities = visibilities[:, :, : -self.support_grid_size**2]
+ thr = 0.9
+ visibilities = visibilities > thr
+
+ # correct query-point predictions
+ # see https://github.com/facebookresearch/co-tracker/issues/28
+
+ # TODO: batchify
+ for i in range(len(queries)):
+ queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
+ arange = torch.arange(0, len(queries_t))
+
+ # overwrite the predictions with the query points
+ tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
+
+ # correct visibilities, the query points should be visible
+ visibilities[i, queries_t, arange] = True
+
+ tracks *= tracks.new_tensor(
+ [(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
+ )
+ return tracks, visibilities
+
+ def _compute_backward_tracks(self, video, queries, tracks, visibilities):
+ inv_video = video.flip(1).clone()
+ inv_queries = queries.clone()
+ inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
+
+ if self.cached_feat is not None:
+ inv_feat = self.cached_feat.flip(1)
+ else:
+ inv_feat = None
+
+ inv_tracks, inv_visibilities, __ = self.model(
+ video=inv_video,
+ queries=inv_queries,
+ iters=6,
+ cached_feat=inv_feat
+ )
+
+ inv_tracks = inv_tracks.flip(1)
+ inv_visibilities = inv_visibilities.flip(1)
+ arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
+
+ mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
+
+ tracks[mask] = inv_tracks[mask]
+ visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
+ return tracks, visibilities
+
+
+class CoTrackerOnlinePredictor(torch.nn.Module):
+ def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
+ super().__init__()
+ self.support_grid_size = 6
+ model = build_cotracker(checkpoint)
+ self.interp_shape = model.model_resolution
+ self.step = model.window_len // 2
+ self.model = model
+ self.model.eval()
+
+ @torch.no_grad()
+ def forward(
+ self,
+ video_chunk,
+ is_first_step: bool = False,
+ queries: torch.Tensor = None,
+ grid_size: int = 10,
+ grid_query_frame: int = 0,
+ add_support_grid=False,
+ ):
+ # Initialize online video processing and save queried points
+ # This needs to be done before processing *each new video*
+ if is_first_step:
+ self.model.init_video_online_processing()
+ if queries is not None:
+ B, N, D = queries.shape
+ assert D == 3
+ queries = queries.clone()
+ queries[:, :, 1:] *= queries.new_tensor(
+ [
+ (self.interp_shape[1] - 1) / (W - 1),
+ (self.interp_shape[0] - 1) / (H - 1),
+ ]
+ )
+ elif grid_size > 0:
+ grid_pts = get_points_on_a_grid(
+ grid_size, self.interp_shape, device=video_chunk.device
+ )
+ queries = torch.cat(
+ [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
+ dim=2,
+ )
+ if add_support_grid:
+ grid_pts = get_points_on_a_grid(
+ self.support_grid_size, self.interp_shape, device=video_chunk.device
+ )
+ grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
+ queries = torch.cat([queries, grid_pts], dim=1)
+ self.queries = queries
+ return (None, None)
+ B, T, C, H, W = video_chunk.shape
+ video_chunk = video_chunk.reshape(B * T, C, H, W)
+ video_chunk = F.interpolate(
+ video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
+ )
+ video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ tracks, visibilities, __ = self.model(
+ video=video_chunk,
+ queries=self.queries,
+ iters=6,
+ is_online=True,
+ )
+ thr = 0.9
+ return (
+ tracks
+ * tracks.new_tensor(
+ [
+ (W - 1) / (self.interp_shape[1] - 1),
+ (H - 1) / (self.interp_shape[0] - 1),
+ ]
+ ),
+ visibilities > thr,
+ )
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/utils/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker2_utils/utils/visualizer.py b/data/dot_single_video/dot/models/shelf/cotracker2_utils/utils/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..22ba43afe038e0829b9e2fd17cc670d0231c510c
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker2_utils/utils/visualizer.py
@@ -0,0 +1,343 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import numpy as np
+import imageio
+import torch
+
+from matplotlib import cm
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import matplotlib.pyplot as plt
+from PIL import Image, ImageDraw
+
+
+def read_video_from_path(path):
+ try:
+ reader = imageio.get_reader(path)
+ except Exception as e:
+ print("Error opening video file: ", e)
+ return None
+ frames = []
+ for i, im in enumerate(reader):
+ frames.append(np.array(im))
+ return np.stack(frames)
+
+
+def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
+ # Create a draw object
+ draw = ImageDraw.Draw(rgb)
+ # Calculate the bounding box of the circle
+ left_up_point = (coord[0] - radius, coord[1] - radius)
+ right_down_point = (coord[0] + radius, coord[1] + radius)
+ # Draw the circle
+ draw.ellipse(
+ [left_up_point, right_down_point],
+ fill=tuple(color) if visible else None,
+ outline=tuple(color),
+ )
+ return rgb
+
+
+def draw_line(rgb, coord_y, coord_x, color, linewidth):
+ draw = ImageDraw.Draw(rgb)
+ draw.line(
+ (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
+ fill=tuple(color),
+ width=linewidth,
+ )
+ return rgb
+
+
+def add_weighted(rgb, alpha, original, beta, gamma):
+ return (rgb * alpha + original * beta + gamma).astype("uint8")
+
+
+class Visualizer:
+ def __init__(
+ self,
+ save_dir: str = "./results",
+ grayscale: bool = False,
+ pad_value: int = 0,
+ fps: int = 10,
+ mode: str = "rainbow", # 'cool', 'optical_flow'
+ linewidth: int = 2,
+ show_first_frame: int = 10,
+ tracks_leave_trace: int = 0, # -1 for infinite
+ ):
+ self.mode = mode
+ self.save_dir = save_dir
+ if mode == "rainbow":
+ self.color_map = cm.get_cmap("gist_rainbow")
+ elif mode == "cool":
+ self.color_map = cm.get_cmap(mode)
+ self.show_first_frame = show_first_frame
+ self.grayscale = grayscale
+ self.tracks_leave_trace = tracks_leave_trace
+ self.pad_value = pad_value
+ self.linewidth = linewidth
+ self.fps = fps
+
+ def visualize(
+ self,
+ video: torch.Tensor, # (B,T,C,H,W)
+ tracks: torch.Tensor, # (B,T,N,2)
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
+ filename: str = "video",
+ writer=None, # tensorboard Summary Writer, used for visualization during training
+ step: int = 0,
+ query_frame: int = 0,
+ save_video: bool = True,
+ compensate_for_camera_motion: bool = False,
+ ):
+ if compensate_for_camera_motion:
+ assert segm_mask is not None
+ if segm_mask is not None:
+ coords = tracks[0, query_frame].round().long()
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
+
+ video = F.pad(
+ video,
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
+ "constant",
+ 255,
+ )
+ tracks = tracks + self.pad_value
+
+ if self.grayscale:
+ transform = transforms.Grayscale()
+ video = transform(video)
+ video = video.repeat(1, 1, 3, 1, 1)
+
+ res_video = self.draw_tracks_on_video(
+ video=video,
+ tracks=tracks,
+ visibility=visibility,
+ segm_mask=segm_mask,
+ gt_tracks=gt_tracks,
+ query_frame=query_frame,
+ compensate_for_camera_motion=compensate_for_camera_motion,
+ )
+ if save_video:
+ self.save_video(res_video, filename=filename, writer=writer, step=step)
+ return res_video
+
+ def save_video(self, video, filename, writer=None, step=0):
+ if writer is not None:
+ writer.add_video(
+ filename,
+ video.to(torch.uint8),
+ global_step=step,
+ fps=self.fps,
+ )
+ else:
+ os.makedirs(self.save_dir, exist_ok=True)
+ wide_list = list(video.unbind(1))
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
+
+ # Prepare the video file path
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
+
+ # Create a writer object
+ video_writer = imageio.get_writer(save_path, fps=self.fps)
+
+ # Write frames to the video file
+ for frame in wide_list[2:-1]:
+ video_writer.append_data(frame)
+
+ video_writer.close()
+
+ print(f"Video saved to {save_path}")
+
+ def draw_tracks_on_video(
+ self,
+ video: torch.Tensor,
+ tracks: torch.Tensor,
+ visibility: torch.Tensor = None,
+ segm_mask: torch.Tensor = None,
+ gt_tracks=None,
+ query_frame: int = 0,
+ compensate_for_camera_motion=False,
+ ):
+ B, T, C, H, W = video.shape
+ _, _, N, D = tracks.shape
+
+ assert D == 2
+ assert C == 3
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
+ tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
+ if gt_tracks is not None:
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
+
+ res_video = []
+
+ # process input video
+ for rgb in video:
+ res_video.append(rgb.copy())
+ vector_colors = np.zeros((T, N, 3))
+
+ if self.mode == "optical_flow":
+ import flow_vis
+
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
+ elif segm_mask is None:
+ if self.mode == "rainbow":
+ y_min, y_max = (
+ tracks[query_frame, :, 1].min(),
+ tracks[query_frame, :, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ color = self.color_map(norm(tracks[query_frame, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+ else:
+ # color changes with time
+ for t in range(T):
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
+ vector_colors[t] = np.repeat(color, N, axis=0)
+ else:
+ if self.mode == "rainbow":
+ vector_colors[:, segm_mask <= 0, :] = 255
+
+ y_min, y_max = (
+ tracks[0, segm_mask > 0, 1].min(),
+ tracks[0, segm_mask > 0, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ if segm_mask[n] > 0:
+ color = self.color_map(norm(tracks[0, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+
+ else:
+ # color changes with segm class
+ segm_mask = segm_mask.cpu()
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
+ vector_colors = np.repeat(color[None], T, axis=0)
+
+ # draw tracks
+ if self.tracks_leave_trace != 0:
+ for t in range(query_frame + 1, T):
+ first_ind = (
+ max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
+ )
+ curr_tracks = tracks[first_ind : t + 1]
+ curr_colors = vector_colors[first_ind : t + 1]
+ if compensate_for_camera_motion:
+ diff = (
+ tracks[first_ind : t + 1, segm_mask <= 0]
+ - tracks[t : t + 1, segm_mask <= 0]
+ ).mean(1)[:, None]
+
+ curr_tracks = curr_tracks - diff
+ curr_tracks = curr_tracks[:, segm_mask > 0]
+ curr_colors = curr_colors[:, segm_mask > 0]
+
+ res_video[t] = self._draw_pred_tracks(
+ res_video[t],
+ curr_tracks,
+ curr_colors,
+ )
+ if gt_tracks is not None:
+ res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
+
+ # draw points
+ for t in range(query_frame, T):
+ img = Image.fromarray(np.uint8(res_video[t]))
+ for i in range(N):
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
+ visibile = True
+ if visibility is not None:
+ visibile = visibility[0, t, i]
+ if coord[0] != 0 and coord[1] != 0:
+ if not compensate_for_camera_motion or (
+ compensate_for_camera_motion and segm_mask[i] > 0
+ ):
+ img = draw_circle(
+ img,
+ coord=coord,
+ radius=int(self.linewidth * 2),
+ color=vector_colors[t, i].astype(int),
+ visible=visibile,
+ )
+ res_video[t] = np.array(img)
+
+ # construct the final rgb sequence
+ if self.show_first_frame > 0:
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
+
+ def _draw_pred_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3
+ tracks: np.ndarray, # T x 2
+ vector_colors: np.ndarray,
+ alpha: float = 0.5,
+ ):
+ T, N, _ = tracks.shape
+ rgb = Image.fromarray(np.uint8(rgb))
+ for s in range(T - 1):
+ vector_color = vector_colors[s]
+ original = rgb.copy()
+ alpha = (s / T) ** 2
+ for i in range(N):
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
+ if coord_y[0] != 0 and coord_y[1] != 0:
+ rgb = draw_line(
+ rgb,
+ coord_y,
+ coord_x,
+ vector_color[i].astype(int),
+ self.linewidth,
+ )
+ if self.tracks_leave_trace > 0:
+ rgb = Image.fromarray(
+ np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
+ )
+ rgb = np.array(rgb)
+ return rgb
+
+ def _draw_gt_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3,
+ gt_tracks: np.ndarray, # T x 2
+ ):
+ T, N, _ = gt_tracks.shape
+ color = np.array((211, 0, 0))
+ rgb = Image.fromarray(np.uint8(rgb))
+ for t in range(T):
+ for i in range(N):
+ gt_tracks = gt_tracks[t][i]
+ # draw a red cross
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
+ length = self.linewidth * 3
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
+ rgb = draw_line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ )
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
+ rgb = draw_line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ )
+ rgb = np.array(rgb)
+ return rgb
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/LICENSE.md b/data/dot_single_video/dot/models/shelf/cotracker_utils/LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..ba959871dca0f9b6775570410879e637de44d7b4
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/LICENSE.md
@@ -0,0 +1,399 @@
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More_considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/build_cotracker.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/build_cotracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..530d07d1df9ad602cf46744aaae35ae71455af00
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/build_cotracker.py
@@ -0,0 +1,70 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from dot.models.shelf.cotracker_utils.models.core.cotracker.cotracker import CoTracker
+
+
+def build_cotracker(
+ patch_size: int,
+ wind_size: int,
+):
+ if patch_size == 4 and wind_size == 8:
+ return build_cotracker_stride_4_wind_8()
+ elif patch_size == 4 and wind_size == 12:
+ return build_cotracker_stride_4_wind_12()
+ elif patch_size == 8 and wind_size == 16:
+ return build_cotracker_stride_8_wind_16()
+ else:
+ raise ValueError(f"Unknown model for patch size {patch_size} and window size {window_size}")
+
+
+# model used to produce the results in the paper
+def build_cotracker_stride_4_wind_8(checkpoint=None):
+ return _build_cotracker(
+ stride=4,
+ sequence_len=8,
+ checkpoint=checkpoint,
+ )
+
+
+def build_cotracker_stride_4_wind_12(checkpoint=None):
+ return _build_cotracker(
+ stride=4,
+ sequence_len=12,
+ checkpoint=checkpoint,
+ )
+
+
+# the fastest model
+def build_cotracker_stride_8_wind_16(checkpoint=None):
+ return _build_cotracker(
+ stride=8,
+ sequence_len=16,
+ checkpoint=checkpoint,
+ )
+
+
+def _build_cotracker(
+ stride,
+ sequence_len,
+ checkpoint=None,
+):
+ cotracker = CoTracker(
+ stride=stride,
+ S=sequence_len,
+ add_space_attn=True,
+ space_depth=6,
+ time_depth=6,
+ )
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f, map_location="cpu")
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+ cotracker.load_state_dict(state_dict)
+ return cotracker
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/blocks.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8880b679aae33325222339fa1e618fd010ea84c4
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/blocks.py
@@ -0,0 +1,400 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+from timm.models.vision_transformer import Attention, Mlp
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
+ )
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class BasicEncoder(nn.Module):
+ def __init__(
+ self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
+ ):
+ super(BasicEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = norm_fn
+ self.in_planes = 64
+
+ if self.norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
+ self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
+
+ elif self.norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(self.in_planes)
+ self.norm2 = nn.BatchNorm2d(output_dim * 2)
+
+ elif self.norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+
+ elif self.norm_fn == "none":
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(
+ input_dim,
+ self.in_planes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ padding_mode="zeros",
+ )
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.shallow = False
+ if self.shallow:
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=2)
+ self.layer3 = self._make_layer(128, stride=2)
+ self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
+ else:
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=2)
+ self.layer3 = self._make_layer(128, stride=2)
+ self.layer4 = self._make_layer(128, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ 128 + 128 + 96 + 64,
+ output_dim * 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ if self.shallow:
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ a = F.interpolate(
+ a,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ b = F.interpolate(
+ b,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ c = F.interpolate(
+ c,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ x = self.conv2(torch.cat([a, b, c], dim=1))
+ else:
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ d = self.layer4(c)
+ a = F.interpolate(
+ a,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ b = F.interpolate(
+ b,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ c = F.interpolate(
+ c,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ d = F.interpolate(
+ d,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+ return x
+
+
+class AttnBlock(nn.Module):
+ """
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
+ """
+
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(
+ hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x):
+ x = x + self.attn(self.norm1(x))
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+def bilinear_sampler(img, coords, mode="bilinear", mask=False):
+ """Wrapper for grid_sample, uses pixel coordinates"""
+ H, W = img.shape[-2:]
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
+ # go to 0,1 then 0,2 then -1,1
+ xgrid = 2 * xgrid / (W - 1) - 1
+ ygrid = 2 * ygrid / (H - 1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, align_corners=True)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.float()
+
+ return img
+
+
+class CorrBlock:
+ def __init__(self, fmaps, num_levels=4, radius=4):
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+
+ self.fmaps_pyramid.append(fmaps)
+ for i in range(self.num_levels - 1):
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ _, _, _, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
+ coords.device
+ )
+
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ return out.contiguous().float()
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for fmaps in self.fmaps_pyramid:
+ _, _, _, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W)
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W)
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ self.corrs_pyramid.append(corrs)
+
+
+class UpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=12,
+ time_depth=12,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ ):
+ super().__init__()
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_blocks = nn.ModuleList(
+ [
+ AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(space_depth)
+ ]
+ )
+ assert len(self.time_blocks) >= len(self.space_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor):
+ x = self.input_transform(input_tensor)
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ B, N, T, _ = x.shape
+ x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
+ x_time = self.time_blocks[i](x_time)
+
+ x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
+ if self.add_space_attn and (
+ i % (len(self.time_blocks) // len(self.space_blocks)) == 0
+ ):
+ x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
+ x_space = self.space_blocks[j](x_space)
+ x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
+ j += 1
+
+ flow = self.flow_head(x)
+ return flow
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/cotracker.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/cotracker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ebea54d619fea0d0638cd7a8cc1d3071278325
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/cotracker.py
@@ -0,0 +1,360 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from dot.models.shelf.cotracker_utils.models.core.cotracker.blocks import (
+ BasicEncoder,
+ CorrBlock,
+ UpdateFormer,
+)
+
+from dot.models.shelf.cotracker_utils.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
+from dot.models.shelf.cotracker_utils.models.core.embeddings import (
+ get_2d_embedding,
+ get_1d_sincos_pos_embed_from_grid,
+ get_2d_sincos_pos_embed,
+)
+
+
+torch.manual_seed(0)
+
+
+def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cpu"):
+ if grid_size == 1:
+ return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
+ None, None
+ ]
+
+ grid_y, grid_x = meshgrid2d(
+ 1, grid_size, grid_size, stack=False, norm=False, device=device
+ )
+ step = interp_shape[1] // 64
+ if grid_center[0] != 0 or grid_center[1] != 0:
+ grid_y = grid_y - grid_size / 2.0
+ grid_x = grid_x - grid_size / 2.0
+ grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
+ interp_shape[0] - step * 2
+ )
+ grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
+ interp_shape[1] - step * 2
+ )
+
+ grid_y = grid_y + grid_center[0]
+ grid_x = grid_x + grid_center[1]
+ xy = torch.stack([grid_x, grid_y], dim=-1).to(device)
+ return xy
+
+
+def sample_pos_embed(grid_size, embed_dim, coords):
+ pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
+ pos_embed = (
+ torch.from_numpy(pos_embed)
+ .reshape(grid_size[0], grid_size[1], embed_dim)
+ .float()
+ .unsqueeze(0)
+ .to(coords.device)
+ )
+ sampled_pos_embed = bilinear_sample2d(
+ pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
+ )
+ return sampled_pos_embed
+
+
+class CoTracker(nn.Module):
+ def __init__(
+ self,
+ S=8,
+ stride=8,
+ add_space_attn=True,
+ num_heads=8,
+ hidden_size=384,
+ space_depth=12,
+ time_depth=12,
+ ):
+ super(CoTracker, self).__init__()
+ self.S = S
+ self.stride = stride
+ self.hidden_dim = 256
+ self.latent_dim = latent_dim = 128
+ self.corr_levels = 4
+ self.corr_radius = 3
+ self.add_space_attn = add_space_attn
+ self.fnet = BasicEncoder(
+ output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
+ )
+
+ self.updateformer = UpdateFormer(
+ space_depth=space_depth,
+ time_depth=time_depth,
+ input_dim=456,
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ output_dim=latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=add_space_attn,
+ )
+
+ self.norm = nn.GroupNorm(1, self.latent_dim)
+ self.ffeat_updater = nn.Sequential(
+ nn.Linear(self.latent_dim, self.latent_dim),
+ nn.GELU(),
+ )
+ self.vis_predictor = nn.Sequential(
+ nn.Linear(self.latent_dim, 1),
+ )
+
+ def forward_iteration(
+ self,
+ fmaps,
+ coords_init,
+ feat_init=None,
+ vis_init=None,
+ track_mask=None,
+ iters=4,
+ ):
+ B, S_init, N, D = coords_init.shape
+ assert D == 2
+ assert B == 1
+
+ B, S, __, H8, W8 = fmaps.shape
+
+ device = fmaps.device
+
+ if S_init < S:
+ coords = torch.cat(
+ [coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
+ )
+ vis_init = torch.cat(
+ [vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
+ )
+ else:
+ coords = coords_init.clone()
+
+ fcorr_fn = CorrBlock(
+ fmaps, num_levels=self.corr_levels, radius=self.corr_radius
+ )
+
+ ffeats = feat_init.clone()
+
+ times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
+
+ pos_embed = sample_pos_embed(
+ grid_size=(H8, W8),
+ embed_dim=456,
+ coords=coords,
+ )
+ pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
+ times_embed = (
+ torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
+ .repeat(B, 1, 1)
+ .float()
+ .to(device)
+ )
+ coord_predictions = []
+
+ for __ in range(iters):
+ coords = coords.detach()
+ fcorr_fn.corr(ffeats)
+
+ fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
+ LRR = fcorrs.shape[3]
+
+ fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
+ flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+
+ flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
+ ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ if track_mask.shape[1] < vis_init.shape[1]:
+ track_mask = torch.cat(
+ [
+ track_mask,
+ torch.zeros_like(track_mask[:, 0]).repeat(
+ 1, vis_init.shape[1] - track_mask.shape[1], 1, 1
+ ),
+ ],
+ dim=1,
+ )
+ concat = (
+ torch.cat([track_mask, vis_init], dim=2)
+ .permute(0, 2, 1, 3)
+ .reshape(B * N, S, 2)
+ )
+
+ transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
+ x = transformer_input + pos_embed + times_embed
+
+ x = rearrange(x, "(b n) t d -> b n t d", b=B)
+
+ delta = self.updateformer(x)
+
+ delta = rearrange(delta, " b n t d -> (b n) t d")
+
+ delta_coords_ = delta[:, :, :2]
+ delta_feats_ = delta[:, :, 2:]
+
+ delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
+ ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
+
+ ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
+
+ ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
+ 0, 2, 1, 3
+ ) # B,S,N,C
+
+ coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
+ coord_predictions.append(coords * self.stride)
+
+ vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
+ B, S, N
+ )
+ return coord_predictions, vis_e, feat_init
+
+ def forward(self, rgbs, queries, iters=4, cached_feat=None, feat_init=None, is_train=False):
+ B, T, C, H, W = rgbs.shape
+ B, N, __ = queries.shape
+
+ device = rgbs.device
+ assert B == 1
+ # INIT for the first sequence
+ # We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
+ first_positive_inds = queries[:, :, 0].long()
+
+ __, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
+ inv_sort_inds = torch.argsort(sort_inds, dim=0)
+ first_positive_sorted_inds = first_positive_inds[0][sort_inds]
+
+ assert torch.allclose(
+ first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
+ )
+
+ coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
+ 1, self.S, 1, 1
+ ) / float(self.stride)
+
+ rgbs = 2 * rgbs - 1.0
+
+ traj_e = torch.zeros((B, T, N, 2), device=device)
+ vis_e = torch.zeros((B, T, N), device=device)
+
+ ind_array = torch.arange(T, device=device)
+ ind_array = ind_array[None, :, None].repeat(B, 1, N)
+
+ track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
+ # these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
+ vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
+
+ ind = 0
+
+ track_mask_ = track_mask[:, :, sort_inds].clone()
+ coords_init_ = coords_init[:, :, sort_inds].clone()
+ vis_init_ = vis_init[:, :, sort_inds].clone()
+
+ prev_wind_idx = 0
+ fmaps_ = None
+ vis_predictions = []
+ coord_predictions = []
+ wind_inds = []
+ while ind < T - self.S // 2:
+ rgbs_seq = rgbs[:, ind : ind + self.S]
+
+ S = S_local = rgbs_seq.shape[1]
+
+ if cached_feat is None:
+ if S < self.S:
+ rgbs_seq = torch.cat(
+ [rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
+ dim=1,
+ )
+ S = rgbs_seq.shape[1]
+ rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
+
+ if fmaps_ is None:
+ fmaps_ = self.fnet(rgbs_)
+ else:
+ fmaps_ = torch.cat(
+ [fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
+ )
+ fmaps = fmaps_.reshape(
+ B, S, self.latent_dim, H // self.stride, W // self.stride
+ )
+ else:
+ fmaps = cached_feat[:, ind : ind + self.S]
+ if S < self.S:
+ fmaps = torch.cat(
+ [fmaps, fmaps[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
+ dim=1,
+ )
+
+ curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
+ if curr_wind_points.shape[0] == 0:
+ ind = ind + self.S // 2
+ continue
+ wind_idx = curr_wind_points[-1] + 1
+
+ if wind_idx - prev_wind_idx > 0:
+ fmaps_sample = fmaps[
+ :, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
+ ]
+
+ feat_init_ = bilinear_sample2d(
+ fmaps_sample,
+ coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
+ coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
+ ).permute(0, 2, 1)
+
+ feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
+ feat_init = smart_cat(feat_init, feat_init_, dim=2)
+
+ if prev_wind_idx > 0:
+ new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
+
+ coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
+ coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
+ :, -1
+ ].repeat(1, self.S // 2, 1, 1)
+
+ new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
+ vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
+ vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
+ 1, self.S // 2, 1, 1
+ )
+
+ coords, vis, __ = self.forward_iteration(
+ fmaps=fmaps,
+ coords_init=coords_init_[:, :, :wind_idx],
+ feat_init=feat_init[:, :, :wind_idx],
+ vis_init=vis_init_[:, :, :wind_idx],
+ track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
+ iters=iters,
+ )
+ if is_train:
+ vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
+ coord_predictions.append([coord[:, :S_local] for coord in coords])
+ wind_inds.append(wind_idx)
+
+ traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
+ vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
+
+ track_mask_[:, : ind + self.S, :wind_idx] = 0.0
+ ind = ind + self.S // 2
+
+ prev_wind_idx = wind_idx
+
+ traj_e = traj_e[:, :, inv_sort_inds]
+ vis_e = vis_e[:, :, inv_sort_inds]
+
+ vis_e = torch.sigmoid(vis_e)
+
+ train_data = (
+ (vis_predictions, coord_predictions, wind_inds, sort_inds)
+ if is_train
+ else None
+ )
+ return traj_e, feat_init, vis_e, train_data
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/losses.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..41e381576a76b86e2a2f40de9dd11fca1750d199
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/cotracker/losses.py
@@ -0,0 +1,61 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from dot.models.shelf.cotracker_utils.models.core.model_utils import reduce_masked_mean
+
+EPS = 1e-6
+
+
+def balanced_ce_loss(pred, gt, valid=None):
+ total_balanced_loss = 0.0
+ for j in range(len(gt)):
+ B, S, N = gt[j].shape
+ # pred and gt are the same shape
+ for (a, b) in zip(pred[j].size(), gt[j].size()):
+ assert a == b # some shape mismatch!
+ # if valid is not None:
+ for (a, b) in zip(pred[j].size(), valid[j].size()):
+ assert a == b # some shape mismatch!
+
+ pos = (gt[j] > 0.95).float()
+ neg = (gt[j] < 0.05).float()
+
+ label = pos * 2.0 - 1.0
+ a = -label * pred[j]
+ b = F.relu(a)
+ loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
+
+ pos_loss = reduce_masked_mean(loss, pos * valid[j])
+ neg_loss = reduce_masked_mean(loss, neg * valid[j])
+
+ balanced_loss = pos_loss + neg_loss
+ total_balanced_loss += balanced_loss / float(N)
+ return total_balanced_loss
+
+
+def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
+ """Loss function defined over sequence of flow predictions"""
+ total_flow_loss = 0.0
+ for j in range(len(flow_gt)):
+ B, S, N, D = flow_gt[j].shape
+ assert D == 2
+ B, S1, N = vis[j].shape
+ B, S2, N = valids[j].shape
+ assert S == S1
+ assert S == S2
+ n_predictions = len(flow_preds[j])
+ flow_loss = 0.0
+ for i in range(n_predictions):
+ i_weight = gamma ** (n_predictions - i - 1)
+ flow_pred = flow_preds[j][i]
+ i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
+ flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
+ flow_loss = flow_loss / n_predictions
+ total_flow_loss += flow_loss / float(N)
+ return total_flow_loss
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/embeddings.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbcd86a55bda603b1729638de9ddb339cac42f84
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/embeddings.py
@@ -0,0 +1,154 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import numpy as np
+
+
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = np.arange(grid_size_h, dtype=np.float32)
+ grid_w = np.arange(grid_size_w, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token and extra_tokens > 0:
+ pos_embed = np.concatenate(
+ [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
+ )
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000 ** omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+def get_2d_embedding(xy, C, cat_coords=True):
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
+ return pe
+
+
+def get_3d_embedding(xyz, C, cat_coords=True):
+ B, N, D = xyz.shape
+ assert D == 3
+
+ x = xyz[:, :, 0:1]
+ y = xyz[:, :, 1:2]
+ z = xyz[:, :, 2:3]
+ div_term = (
+ torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
+ pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe_z[:, :, 0::2] = torch.sin(z * div_term)
+ pe_z[:, :, 1::2] = torch.cos(z * div_term)
+
+ pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
+ if cat_coords:
+ pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
+ return pe
+
+
+def get_4d_embedding(xyzw, C, cat_coords=True):
+ B, N, D = xyzw.shape
+ assert D == 4
+
+ x = xyzw[:, :, 0:1]
+ y = xyzw[:, :, 1:2]
+ z = xyzw[:, :, 2:3]
+ w = xyzw[:, :, 3:4]
+ div_term = (
+ torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+ pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+ pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe_z[:, :, 0::2] = torch.sin(z * div_term)
+ pe_z[:, :, 1::2] = torch.cos(z * div_term)
+
+ pe_w[:, :, 0::2] = torch.sin(w * div_term)
+ pe_w[:, :, 1::2] = torch.cos(w * div_term)
+
+ pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
+ if cat_coords:
+ pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
+ return pe
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/model_utils.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e875f96a4bd3232707303d8c7fd8cff9368477d0
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/core/model_utils.py
@@ -0,0 +1,169 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+EPS = 1e-6
+
+
+def smart_cat(tensor1, tensor2, dim):
+ if tensor1 is None:
+ return tensor2
+ return torch.cat([tensor1, tensor2], dim=dim)
+
+
+def normalize_single(d):
+ # d is a whatever shape torch tensor
+ dmin = torch.min(d)
+ dmax = torch.max(d)
+ d = (d - dmin) / (EPS + (dmax - dmin))
+ return d
+
+
+def normalize(d):
+ # d is B x whatever. normalize within each element of the batch
+ out = torch.zeros(d.size())
+ if d.is_cuda:
+ out = out.cuda()
+ B = list(d.size())[0]
+ for b in list(range(B)):
+ out[b] = normalize_single(d[b])
+ return out
+
+
+def meshgrid2d(B, Y, X, stack=False, norm=False, device="cpu"):
+ # returns a meshgrid sized B x Y x X
+
+ grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
+ grid_y = torch.reshape(grid_y, [1, Y, 1])
+ grid_y = grid_y.repeat(B, 1, X)
+
+ grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
+ grid_x = torch.reshape(grid_x, [1, 1, X])
+ grid_x = grid_x.repeat(B, Y, 1)
+
+ if stack:
+ # note we stack in xy order
+ # (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
+ grid = torch.stack([grid_x, grid_y], dim=-1)
+ return grid
+ else:
+ return grid_y, grid_x
+
+
+def reduce_masked_mean(x, mask, dim=None, keepdim=False):
+ # x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
+ # returns shape-1
+ # axis can be a list of axes
+ for (a, b) in zip(x.size(), mask.size()):
+ assert a == b # some shape mismatch!
+ prod = x * mask
+ if dim is None:
+ numer = torch.sum(prod)
+ denom = EPS + torch.sum(mask)
+ else:
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
+ denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
+
+ mean = numer / denom
+ return mean
+
+
+def bilinear_sample2d(im, x, y, return_inbounds=False):
+ # x and y are each B, N
+ # output is B, C, N
+ if len(im.shape) == 5:
+ B, N, C, H, W = list(im.shape)
+ else:
+ B, C, H, W = list(im.shape)
+ N = list(x.shape)[1]
+
+ x = x.float()
+ y = y.float()
+ H_f = torch.tensor(H, dtype=torch.float32)
+ W_f = torch.tensor(W, dtype=torch.float32)
+
+ # inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x -0.5).byte() & (x < float(W_f - 0.5)).byte()
+ y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
+ inbounds = (x_valid & y_valid).float()
+ inbounds = inbounds.reshape(
+ B, N
+ ) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
+ return output, inbounds
+
+ return output # B, C, N
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/models/evaluation_predictor.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/evaluation_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bcc7fa4b9cb6553745e03dff0b1030db630f469
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/models/evaluation_predictor.py
@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from typing import Tuple
+
+from dot.models.shelf.cotracker_utils.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
+
+
+class EvaluationPredictor(torch.nn.Module):
+ def __init__(
+ self,
+ cotracker_model: CoTracker,
+ interp_shape: Tuple[int, int] = (384, 512),
+ grid_size: int = 6,
+ local_grid_size: int = 6,
+ single_point: bool = True,
+ n_iters: int = 6,
+ ) -> None:
+ super(EvaluationPredictor, self).__init__()
+ self.grid_size = grid_size
+ self.local_grid_size = local_grid_size
+ self.single_point = single_point
+ self.interp_shape = interp_shape
+ self.n_iters = n_iters
+
+ self.model = cotracker_model
+ self.model.eval()
+
+ def forward(self, video, queries):
+ queries = queries.clone()
+ B, T, C, H, W = video.shape
+ B, N, D = queries.shape
+
+ assert D == 3
+ assert B == 1
+
+ rgbs = video.reshape(B * T, C, H, W)
+ rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
+ rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ device = rgbs.device
+
+ queries[:, :, 1] *= self.interp_shape[1] / W
+ queries[:, :, 2] *= self.interp_shape[0] / H
+
+ if self.single_point:
+ traj_e = torch.zeros((B, T, N, 2), device=device)
+ vis_e = torch.zeros((B, T, N), device=device)
+ for pind in range((N)):
+ query = queries[:, pind : pind + 1]
+
+ t = query[0, 0, 0].long()
+
+ traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
+ traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
+ vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
+ else:
+ if self.grid_size > 0:
+ xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
+ xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
+ device
+ ) #
+ queries = torch.cat([queries, xy], dim=1) #
+
+ traj_e, __, vis_e, __ = self.model(
+ rgbs=rgbs,
+ queries=queries,
+ iters=self.n_iters,
+ )
+
+ traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
+ traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
+ return traj_e, vis_e
+
+ def _process_one_point(self, rgbs, query):
+ t = query[0, 0, 0].long()
+
+ device = rgbs.device
+ if self.local_grid_size > 0:
+ xy_target = get_points_on_a_grid(
+ self.local_grid_size,
+ (50, 50),
+ [query[0, 0, 2], query[0, 0, 1]],
+ )
+
+ xy_target = torch.cat(
+ [torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
+ ) #
+ query = torch.cat([query, xy_target], dim=1).to(device) #
+
+ if self.grid_size > 0:
+ xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
+ xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
+ query = torch.cat([query, xy], dim=1).to(device) #
+ # crop the video to start from the queried frame
+ query[0, 0, 0] = 0
+ traj_e_pind, __, vis_e_pind, __ = self.model(
+ rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
+ )
+
+ return traj_e_pind, vis_e_pind
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/predictor.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..574a58d70b261dbdbf2d884da42aabed52f5b91a
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/predictor.py
@@ -0,0 +1,203 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from tqdm import tqdm
+from .models.core.cotracker.cotracker import get_points_on_a_grid
+from .models.core.model_utils import smart_cat
+from .models.build_cotracker import build_cotracker
+
+
+class CoTrackerPredictor(torch.nn.Module):
+ def __init__(self, patch_size, wind_size):
+ super().__init__()
+ self.interp_shape = (384, 512)
+ self.support_grid_size = 6
+ model = build_cotracker(patch_size, wind_size)
+
+ self.model = model
+ self.model.eval()
+ self.cached_feat = None
+
+ @torch.no_grad()
+ def forward(
+ self,
+ video, # (1, T, 3, H, W)
+ # input prompt types:
+ # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
+ # *backward_tracking=True* will compute tracks in both directions.
+ # - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
+ # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
+ # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
+ queries: torch.Tensor = None,
+ segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
+ grid_size: int = 0,
+ grid_query_frame: int = 0, # only for dense and regular grid tracks
+ backward_tracking: bool = False,
+ cache_features: bool = False,
+ ):
+
+ if queries is None and grid_size == 0:
+ tracks, visibilities = self._compute_dense_tracks(
+ video,
+ grid_query_frame=grid_query_frame,
+ backward_tracking=backward_tracking,
+ )
+ else:
+ tracks, visibilities = self._compute_sparse_tracks(
+ video,
+ queries,
+ segm_mask,
+ grid_size,
+ add_support_grid=(grid_size == 0 or segm_mask is not None),
+ grid_query_frame=grid_query_frame,
+ backward_tracking=backward_tracking,
+ cache_features=cache_features,
+ )
+
+ return tracks, visibilities
+
+ def _compute_dense_tracks(
+ self, video, grid_query_frame, grid_size=30, backward_tracking=False
+ ):
+ *_, H, W = video.shape
+ grid_step = W // grid_size
+ grid_width = W // grid_step
+ grid_height = H // grid_step
+ tracks = visibilities = None
+ grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
+ grid_pts[0, :, 0] = grid_query_frame
+ for offset in tqdm(range(grid_step * grid_step)):
+ ox = offset % grid_step
+ oy = offset // grid_step
+ grid_pts[0, :, 1] = (
+ torch.arange(grid_width).repeat(grid_height) * grid_step + ox
+ )
+ grid_pts[0, :, 2] = (
+ torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
+ )
+ tracks_step, visibilities_step = self._compute_sparse_tracks(
+ video=video,
+ queries=grid_pts,
+ backward_tracking=backward_tracking,
+ )
+ tracks = smart_cat(tracks, tracks_step, dim=2)
+ visibilities = smart_cat(visibilities, visibilities_step, dim=2)
+
+ return tracks, visibilities
+
+ def _compute_sparse_tracks(
+ self,
+ video,
+ queries,
+ segm_mask=None,
+ grid_size=0,
+ add_support_grid=False,
+ grid_query_frame=0,
+ backward_tracking=False,
+ cache_features=False,
+ ):
+ B, T, C, H, W = video.shape
+ assert B == 1
+
+ video = video.reshape(B * T, C, H, W)
+ video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
+ video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ if cache_features:
+ h, w = self.interp_shape[0], self.interp_shape[1]
+ video_ = video.reshape(B * T, C, h, w)
+ video_ = 2 * video_ - 1.0
+ fmaps_ = self.model.fnet(video_)
+ fmaps_ = fmaps_.reshape(B, T, self.model.latent_dim, h // self.model.stride, w // self.model.stride)
+ self.cached_feat = fmaps_
+
+ if queries is not None:
+ queries = queries.clone()
+ B, N, D = queries.shape
+ assert D == 3
+ queries[:, :, 1] *= self.interp_shape[1] / W
+ queries[:, :, 2] *= self.interp_shape[0] / H
+ elif grid_size > 0:
+ grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
+ if segm_mask is not None:
+ segm_mask = F.interpolate(
+ segm_mask, tuple(self.interp_shape), mode="nearest"
+ )
+ point_mask = segm_mask[0, 0][
+ (grid_pts[0, :, 1]).round().long().cpu(),
+ (grid_pts[0, :, 0]).round().long().cpu(),
+ ].bool()
+ grid_pts = grid_pts[:, point_mask]
+
+ queries = torch.cat(
+ [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
+ dim=2,
+ )
+
+ if add_support_grid:
+ grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device)
+ grid_pts = torch.cat(
+ [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
+ )
+ queries = torch.cat([queries, grid_pts], dim=1)
+
+ tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6, cached_feat=self.cached_feat)
+
+ if backward_tracking:
+ tracks, visibilities = self._compute_backward_tracks(
+ video, queries, tracks, visibilities
+ )
+ if add_support_grid:
+ queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
+ if add_support_grid:
+ tracks = tracks[:, :, : -self.support_grid_size ** 2]
+ visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
+ thr = 0.9
+ visibilities = visibilities > thr
+
+ # correct query-point predictions
+ # see https://github.com/facebookresearch/co-tracker/issues/28
+
+ # TODO: batchify
+ for i in range(len(queries)):
+ queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
+ arange = torch.arange(0, len(queries_t))
+
+ # overwrite the predictions with the query points
+ tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
+
+ # correct visibilities, the query points should be visible
+ visibilities[i, queries_t, arange] = True
+
+ tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
+ tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
+ return tracks, visibilities
+
+ def _compute_backward_tracks(self, video, queries, tracks, visibilities):
+ inv_video = video.flip(1).clone()
+ inv_queries = queries.clone()
+ inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
+
+ if self.cached_feat is not None:
+ inv_feat = self.cached_feat.flip(1)
+ else:
+ inv_feat = None
+
+ inv_tracks, __, inv_visibilities, __ = self.model(
+ rgbs=inv_video, queries=inv_queries, iters=6, cached_feat=inv_feat
+ )
+
+ inv_tracks = inv_tracks.flip(1)
+ inv_visibilities = inv_visibilities.flip(1)
+
+ mask = tracks == 0
+
+ tracks[mask] = inv_tracks[mask]
+ visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
+ return tracks, visibilities
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/utils/__init__.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4547e070da2f3ddc5bf2f466cb2242e6135c7dc3
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/data/dot_single_video/dot/models/shelf/cotracker_utils/utils/visualizer.py b/data/dot_single_video/dot/models/shelf/cotracker_utils/utils/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..17d565911fcd79f475944c0d3d34da7ee35edb11
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/cotracker_utils/utils/visualizer.py
@@ -0,0 +1,314 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import numpy as np
+import cv2
+import torch
+import flow_vis
+
+from matplotlib import cm
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+from moviepy.editor import ImageSequenceClip
+import matplotlib.pyplot as plt
+
+
+def read_video_from_path(path):
+ cap = cv2.VideoCapture(path)
+ if not cap.isOpened():
+ print("Error opening video file")
+ else:
+ frames = []
+ while cap.isOpened():
+ ret, frame = cap.read()
+ if ret == True:
+ frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
+ else:
+ break
+ cap.release()
+ return np.stack(frames)
+
+
+class Visualizer:
+ def __init__(
+ self,
+ save_dir: str = "./results",
+ grayscale: bool = False,
+ pad_value: int = 0,
+ fps: int = 10,
+ mode: str = "rainbow", # 'cool', 'optical_flow'
+ linewidth: int = 2,
+ show_first_frame: int = 10,
+ tracks_leave_trace: int = 0, # -1 for infinite
+ ):
+ self.mode = mode
+ self.save_dir = save_dir
+ if mode == "rainbow":
+ self.color_map = cm.get_cmap("gist_rainbow")
+ elif mode == "cool":
+ self.color_map = cm.get_cmap(mode)
+ self.show_first_frame = show_first_frame
+ self.grayscale = grayscale
+ self.tracks_leave_trace = tracks_leave_trace
+ self.pad_value = pad_value
+ self.linewidth = linewidth
+ self.fps = fps
+
+ def visualize(
+ self,
+ video: torch.Tensor, # (B,T,C,H,W)
+ tracks: torch.Tensor, # (B,T,N,2)
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
+ filename: str = "video",
+ writer=None, # tensorboard Summary Writer, used for visualization during training
+ step: int = 0,
+ query_frame: int = 0,
+ save_video: bool = True,
+ compensate_for_camera_motion: bool = False,
+ ):
+ if compensate_for_camera_motion:
+ assert segm_mask is not None
+ if segm_mask is not None:
+ coords = tracks[0, query_frame].round().long()
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
+
+ video = F.pad(
+ video,
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
+ "constant",
+ 255,
+ )
+ tracks = tracks + self.pad_value
+
+ if self.grayscale:
+ transform = transforms.Grayscale()
+ video = transform(video)
+ video = video.repeat(1, 1, 3, 1, 1)
+
+ res_video = self.draw_tracks_on_video(
+ video=video,
+ tracks=tracks,
+ visibility=visibility,
+ segm_mask=segm_mask,
+ gt_tracks=gt_tracks,
+ query_frame=query_frame,
+ compensate_for_camera_motion=compensate_for_camera_motion,
+ )
+ if save_video:
+ self.save_video(res_video, filename=filename, writer=writer, step=step)
+ return res_video
+
+ def save_video(self, video, filename, writer=None, step=0):
+ if writer is not None:
+ writer.add_video(
+ f"{filename}_pred_track",
+ video.to(torch.uint8),
+ global_step=step,
+ fps=self.fps,
+ )
+ else:
+ os.makedirs(self.save_dir, exist_ok=True)
+ wide_list = list(video.unbind(1))
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
+ clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
+
+ # Write the video file
+ save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
+ clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
+
+ print(f"Video saved to {save_path}")
+
+ def draw_tracks_on_video(
+ self,
+ video: torch.Tensor,
+ tracks: torch.Tensor,
+ visibility: torch.Tensor = None,
+ segm_mask: torch.Tensor = None,
+ gt_tracks=None,
+ query_frame: int = 0,
+ compensate_for_camera_motion=False,
+ ):
+ B, T, C, H, W = video.shape
+ _, _, N, D = tracks.shape
+
+ assert D == 2
+ assert C == 3
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
+ tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
+ if gt_tracks is not None:
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
+
+ res_video = []
+
+ # process input video
+ for rgb in video:
+ res_video.append(rgb.copy())
+
+ vector_colors = np.zeros((T, N, 3))
+ if self.mode == "optical_flow":
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
+ elif segm_mask is None:
+ if self.mode == "rainbow":
+ y_min, y_max = (
+ tracks[query_frame, :, 1].min(),
+ tracks[query_frame, :, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ color = self.color_map(norm(tracks[query_frame, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+ else:
+ # color changes with time
+ for t in range(T):
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
+ vector_colors[t] = np.repeat(color, N, axis=0)
+ else:
+ if self.mode == "rainbow":
+ vector_colors[:, segm_mask <= 0, :] = 255
+
+ y_min, y_max = (
+ tracks[0, segm_mask > 0, 1].min(),
+ tracks[0, segm_mask > 0, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ if segm_mask[n] > 0:
+ color = self.color_map(norm(tracks[0, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+
+ else:
+ # color changes with segm class
+ segm_mask = segm_mask.cpu()
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
+ vector_colors = np.repeat(color[None], T, axis=0)
+
+ # draw tracks
+ if self.tracks_leave_trace != 0:
+ for t in range(1, T):
+ first_ind = (
+ max(0, t - self.tracks_leave_trace)
+ if self.tracks_leave_trace >= 0
+ else 0
+ )
+ curr_tracks = tracks[first_ind : t + 1]
+ curr_colors = vector_colors[first_ind : t + 1]
+ if compensate_for_camera_motion:
+ diff = (
+ tracks[first_ind : t + 1, segm_mask <= 0]
+ - tracks[t : t + 1, segm_mask <= 0]
+ ).mean(1)[:, None]
+
+ curr_tracks = curr_tracks - diff
+ curr_tracks = curr_tracks[:, segm_mask > 0]
+ curr_colors = curr_colors[:, segm_mask > 0]
+
+ res_video[t] = self._draw_pred_tracks(
+ res_video[t],
+ curr_tracks,
+ curr_colors,
+ )
+ if gt_tracks is not None:
+ res_video[t] = self._draw_gt_tracks(
+ res_video[t], gt_tracks[first_ind : t + 1]
+ )
+
+ # draw points
+ for t in range(T):
+ for i in range(N):
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
+ visibile = True
+ if visibility is not None:
+ visibile = visibility[0, t, i]
+ if coord[0] != 0 and coord[1] != 0:
+ if not compensate_for_camera_motion or (
+ compensate_for_camera_motion and segm_mask[i] > 0
+ ):
+
+ cv2.circle(
+ res_video[t],
+ coord,
+ int(self.linewidth * 2),
+ vector_colors[t, i].tolist(),
+ thickness=-1 if visibile else 2
+ -1,
+ )
+
+ # construct the final rgb sequence
+ if self.show_first_frame > 0:
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
+
+ def _draw_pred_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3
+ tracks: np.ndarray, # T x 2
+ vector_colors: np.ndarray,
+ alpha: float = 0.5,
+ ):
+ T, N, _ = tracks.shape
+
+ for s in range(T - 1):
+ vector_color = vector_colors[s]
+ original = rgb.copy()
+ alpha = (s / T) ** 2
+ for i in range(N):
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
+ if coord_y[0] != 0 and coord_y[1] != 0:
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ vector_color[i].tolist(),
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ if self.tracks_leave_trace > 0:
+ rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
+ return rgb
+
+ def _draw_gt_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3,
+ gt_tracks: np.ndarray, # T x 2
+ ):
+ T, N, _ = gt_tracks.shape
+ color = np.array((211.0, 0.0, 0.0))
+
+ for t in range(T):
+ for i in range(N):
+ gt_tracks = gt_tracks[t][i]
+ # draw a red cross
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
+ length = self.linewidth * 3
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
+ cv2.line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ cv2.LINE_AA,
+ )
+ return rgb
diff --git a/data/dot_single_video/dot/models/shelf/raft.py b/data/dot_single_video/dot/models/shelf/raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..afb82d33f4b2fd93bc85325ddaef254e3eb6188b
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/raft.py
@@ -0,0 +1,139 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+from .raft_utils.update import BasicUpdateBlock
+from .raft_utils.extractor import BasicEncoder
+from .raft_utils.corr import CorrBlock
+from .raft_utils.utils import coords_grid
+
+
+class RAFT(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.fnet = BasicEncoder(output_dim=256, norm_fn=args.norm_fnet, dropout=0, patch_size=args.patch_size)
+ self.cnet = BasicEncoder(output_dim=256, norm_fn=args.norm_cnet, dropout=0, patch_size=args.patch_size)
+ self.update_block = BasicUpdateBlock(hidden_dim=128, patch_size=args.patch_size, refine_alpha=args.refine_alpha)
+ self.refine_alpha = args.refine_alpha
+ self.patch_size = args.patch_size
+ self.num_iter = args.num_iter
+
+ def encode(self, frame):
+ frame = frame * 2 - 1
+ fmap = self.fnet(frame)
+ cmap = self.cnet(frame)
+ feats = torch.cat([fmap, cmap], dim=1)
+ return feats.float()
+
+ def initialize_feats(self, feats, frame):
+ if feats is None:
+ feats = self.encode(frame)
+ fmap, cmap = feats.split([256, 256], dim=1)
+ return fmap, cmap
+
+ def initialize_flow(self, fmap, coarse_flow):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, _, h, w = fmap.shape
+ src_pts = coords_grid(N, h, w, device=fmap.device)
+
+ if coarse_flow is not None:
+ coarse_flow = coarse_flow.permute(0, 3, 1, 2)
+ # coarse_flow = torch.stack([coarse_flow[:, 0] * (w - 1), coarse_flow[:, 1] * (h - 1)], dim=1)
+ tgt_pts = src_pts + coarse_flow
+ else:
+ tgt_pts = src_pts
+
+ return src_pts, tgt_pts
+
+ def initialize_alpha(self, fmap, coarse_alpha):
+ N, _, h, w = fmap.shape
+ if coarse_alpha is None:
+ alpha = torch.ones(N, 1, h, w, device=fmap.device)
+ else:
+ alpha = coarse_alpha[:, None]
+ return alpha.logit(eps=1e-5)
+
+ def postprocess_alpha(self, alpha):
+ alpha = alpha[:, 0]
+ return alpha.sigmoid()
+
+ def postprocess_flow(self, flow):
+ # N, C, H, W = flow.shape
+ # flow = torch.stack([flow[:, 0] / (W - 1), flow[:, 1] / (H - 1)], dim=1)
+ flow = flow.permute(0, 2, 3, 1)
+ return flow
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/P, W/P, 2] -> [H, W, 2] using convex combination """
+ N, _, H, W = flow.shape
+ mask = mask.view(N, 1, 9, self.patch_size, self.patch_size, H, W)
+ mask = torch.softmax(mask, dim=2)
+
+ up_flow = F.unfold(self.patch_size * flow, [3, 3], padding=1)
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, 2, self.patch_size * H, self.patch_size * W)
+
+ def upsample_alpha(self, alpha, mask):
+ """ Upsample alpha field [H/P, W/P, 1] -> [H, W, 1] using convex combination """
+ N, _, H, W = alpha.shape
+ mask = mask.view(N, 1, 9, self.patch_size, self.patch_size, H, W)
+ mask = torch.softmax(mask, dim=2)
+
+ up_alpha = F.unfold(alpha, [3, 3], padding=1)
+ up_alpha = up_alpha.view(N, 1, 9, 1, 1, H, W)
+
+ up_alpha = torch.sum(mask * up_alpha, dim=2)
+ up_alpha = up_alpha.permute(0, 1, 4, 2, 5, 3)
+ return up_alpha.reshape(N, 1, self.patch_size * H, self.patch_size * W)
+
+ def forward(self, src_frame=None, tgt_frame=None, src_feats=None, tgt_feats=None, coarse_flow=None, coarse_alpha=None,
+ is_train=False):
+ src_fmap, src_cmap = self.initialize_feats(src_feats, src_frame)
+ tgt_fmap, _ = self.initialize_feats(tgt_feats, tgt_frame)
+
+ corr_fn = CorrBlock(src_fmap, tgt_fmap)
+
+ net, inp = torch.split(src_cmap, [128, 128], dim=1)
+ net = torch.tanh(net)
+ inp = torch.relu(inp)
+
+ src_pts, tgt_pts = self.initialize_flow(src_fmap, coarse_flow)
+ alpha = self.initialize_alpha(src_fmap, coarse_alpha) if self.refine_alpha else None
+
+ flows_up = []
+ alphas_up = []
+ for itr in range(self.num_iter):
+ tgt_pts = tgt_pts.detach()
+ if self.refine_alpha:
+ alpha = alpha.detach()
+
+ corr = corr_fn(tgt_pts)
+
+ flow = tgt_pts - src_pts
+ net, up_mask, delta_flow, up_mask_alpha, delta_alpha = self.update_block(net, inp, corr, flow, alpha)
+
+ # F(t+1) = F(t) + \Delta(t)
+ tgt_pts = tgt_pts + delta_flow
+ if self.refine_alpha:
+ alpha = alpha + delta_alpha
+
+ # upsample predictions
+ flow_up = self.upsample_flow(tgt_pts - src_pts, up_mask)
+ if self.refine_alpha:
+ alpha_up = self.upsample_alpha(alpha, up_mask_alpha)
+
+ if is_train or (itr == self.num_iter - 1):
+ flows_up.append(self.postprocess_flow(flow_up))
+ if self.refine_alpha:
+ alphas_up.append(self.postprocess_alpha(alpha_up))
+
+ flows_up = torch.stack(flows_up, dim=1)
+ alphas_up = torch.stack(alphas_up, dim=1) if self.refine_alpha else None
+ if not is_train:
+ flows_up = flows_up[:, 0]
+ alphas_up = alphas_up[:, 0] if self.refine_alpha else None
+ return flows_up, alphas_up
diff --git a/data/dot_single_video/dot/models/shelf/raft_utils/LICENSE b/data/dot_single_video/dot/models/shelf/raft_utils/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..bbbd9fc645b0d5235f0d937515868a9f020c72bf
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/raft_utils/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2020, princeton-vl
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/data/dot_single_video/dot/models/shelf/raft_utils/__init__.py b/data/dot_single_video/dot/models/shelf/raft_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/dot_single_video/dot/models/shelf/raft_utils/corr.py b/data/dot_single_video/dot/models/shelf/raft_utils/corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..504ceef7ee4271cbee1792bcf622ec9b40b6490d
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/raft_utils/corr.py
@@ -0,0 +1,55 @@
+import torch
+import torch.nn.functional as F
+
+from .utils import bilinear_sampler
+
+
+class CorrBlock:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+ self.corr_pyramid = []
+
+ # all pairs correlation
+ corr = CorrBlock.corr(fmap1, fmap2)
+
+ batch, h1, w1, dim, h2, w2 = corr.shape
+ corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
+
+ self.corr_pyramid.append(corr)
+ for i in range(self.num_levels - 1):
+ corr = F.avg_pool2d(corr, 2, stride=2)
+ self.corr_pyramid.append(corr)
+
+ def __call__(self, coords):
+ r = self.radius
+ coords = coords.permute(0, 2, 3, 1)
+ batch, h1, w1, _ = coords.shape
+
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corr = self.corr_pyramid[i]
+ dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1)
+
+ centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corr = bilinear_sampler(corr, coords_lvl)
+ corr = corr.view(batch, h1, w1, -1)
+ out_pyramid.append(corr)
+
+ out = torch.cat(out_pyramid, dim=-1)
+ return out.permute(0, 3, 1, 2).contiguous().float()
+
+ @staticmethod
+ def corr(fmap1, fmap2):
+ batch, dim, ht, wd = fmap1.shape
+ fmap1 = fmap1.view(batch, dim, ht * wd)
+ fmap2 = fmap2.view(batch, dim, ht * wd)
+
+ corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
+ return corr / torch.sqrt(torch.tensor(dim).float())
diff --git a/data/dot_single_video/dot/models/shelf/raft_utils/extractor.py b/data/dot_single_video/dot/models/shelf/raft_utils/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b88cbee6381bfbc570ac14ae0fb5de5476b7302
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/raft_utils/extractor.py
@@ -0,0 +1,194 @@
+import torch
+import torch.nn as nn
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), norm3)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class BottleneckBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes // 4)
+ self.norm2 = nn.BatchNorm2d(planes // 4)
+ self.norm3 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes // 4)
+ self.norm2 = nn.InstanceNorm2d(planes // 4)
+ self.norm3 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ self.norm3 = nn.Sequential()
+ if not stride == 1:
+ self.norm4 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+ y = self.relu(self.norm3(self.conv3(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class BasicEncoder(nn.Module):
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, patch_size=8):
+ super().__init__()
+ assert patch_size in [4, 8]
+ if patch_size == 4:
+ stride1, stride2, stride3 = 1, 2, 2
+ else:
+ stride1, stride2, stride3 = 2, 2, 2
+
+ self.norm_fn = norm_fn
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(64)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(64)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=stride1, padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 64
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=stride2)
+ self.layer3 = self._make_layer(128, stride=stride3)
+
+ # output convolution
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+
+ # if input is list, combine batch dimension
+ is_list = isinstance(x, tuple) or isinstance(x, list)
+ if is_list:
+ batch_dim = x[0].shape[0]
+ x = torch.cat(x, dim=0)
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ x = self.conv2(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+
+ if is_list:
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+ return x
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/shelf/raft_utils/update.py b/data/dot_single_video/dot/models/shelf/raft_utils/update.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66358d4e9ed0cbfabdee098c0a451c330344833
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/raft_utils/update.py
@@ -0,0 +1,116 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256):
+ super().__init__()
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.conv2(self.relu(self.conv1(x)))
+
+
+class AlphaHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256):
+ super().__init__()
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(hidden_dim, 1, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.conv2(self.relu(self.conv1(x)))
+
+
+class SepConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
+ super().__init__()
+ self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
+ self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
+ self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
+
+ self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
+ self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
+ self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
+
+ def forward(self, h, x):
+ # horizontal
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz1(hx))
+ r = torch.sigmoid(self.convr1(hx))
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
+ h = (1 - z) * h + z * q
+
+ # vertical
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz2(hx))
+ r = torch.sigmoid(self.convr2(hx))
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
+ h = (1 - z) * h + z * q
+
+ return h
+
+
+class BasicMotionEncoder(nn.Module):
+ def __init__(self, refine_alpha, corr_levels=4, corr_radius=4):
+ super().__init__()
+ in_dim = 2 + (3 if refine_alpha else 0)
+ cor_planes = corr_levels * (2 * corr_radius + 1) ** 2
+ self.refine_alpha = refine_alpha
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
+ self.convf1 = nn.Conv2d(in_dim, 128, 7, padding=3)
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
+ self.conv = nn.Conv2d(64 + 192, 128 - in_dim, 3, padding=1)
+
+ def forward(self, flow, alpha, corr):
+ if self.refine_alpha:
+ flow = torch.cat([flow, alpha, torch.zeros_like(flow)], dim=1)
+ cor = F.relu(self.convc1(corr))
+ cor = F.relu(self.convc2(cor))
+ feat = F.relu(self.convf1(flow))
+ feat = F.relu(self.convf2(feat))
+ feat = torch.cat([cor, feat], dim=1)
+ feat = F.relu(self.conv(feat))
+ return torch.cat([feat, flow], dim=1)
+
+
+class BasicUpdateBlock(nn.Module):
+ def __init__(self, hidden_dim=128, patch_size=8, refine_alpha=False):
+ super().__init__()
+ self.refine_alpha = refine_alpha
+ self.encoder = BasicMotionEncoder(refine_alpha)
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
+
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
+ self.mask = nn.Sequential(
+ nn.Conv2d(128, 256, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, patch_size * patch_size * 9, 1, padding=0)
+ )
+
+ if refine_alpha:
+ self.alpha_head = AlphaHead(hidden_dim, hidden_dim=256)
+ self.alpha_mask = nn.Sequential(
+ nn.Conv2d(128, 256, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, patch_size * patch_size * 9, 1, padding=0)
+ )
+
+ def forward(self, net, inp, corr, flow, alpha):
+ mot = self.encoder(flow, alpha, corr)
+ inp = torch.cat([inp, mot], dim=1)
+ net = self.gru(net, inp)
+
+ delta_flow = self.flow_head(net)
+ mask = .25 * self.mask(net)
+
+ delta_alpha, mask_alpha = None, None
+ if self.refine_alpha:
+ delta_alpha = self.alpha_head(net)
+ mask_alpha = .25 * self.alpha_mask(net)
+
+ return net, mask, delta_flow, mask_alpha, delta_alpha
diff --git a/data/dot_single_video/dot/models/shelf/raft_utils/utils.py b/data/dot_single_video/dot/models/shelf/raft_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..74bd51a442cce371fc493baa22520e8b8f67a477
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/raft_utils/utils.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy import interpolate
+
+
+class InputPadder:
+ """ Pads images such that dimensions are divisible by 8 """
+ def __init__(self, dims, mode='sintel'):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
+ if mode == 'sintel':
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
+ else:
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+ def unpad(self,x):
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
+ return x[..., c[0]:c[1], c[2]:c[3]]
+
+def forward_interpolate(flow):
+ flow = flow.detach().cpu().numpy()
+ dx, dy = flow[0], flow[1]
+
+ ht, wd = dx.shape
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht), indexing="ij")
+
+ x1 = x0 + dx
+ y1 = y0 + dy
+
+ x1 = x1.reshape(-1)
+ y1 = y1.reshape(-1)
+ dx = dx.reshape(-1)
+ dy = dy.reshape(-1)
+
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
+ x1 = x1[valid]
+ y1 = y1[valid]
+ dx = dx[valid]
+ dy = dy[valid]
+
+ flow_x = interpolate.griddata((x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
+
+ flow_y = interpolate.griddata((x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
+
+ flow = np.stack([flow_x, flow_y], axis=0)
+ return torch.from_numpy(flow).float()
+
+
+def bilinear_sampler(img, coords, mode='bilinear', mask=False):
+ """ Wrapper for grid_sample, uses pixel coordinates """
+ H, W = img.shape[-2:]
+ xgrid, ygrid = coords.split([1,1], dim=-1)
+ xgrid = 2*xgrid/(W-1) - 1
+ ygrid = 2*ygrid/(H-1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, align_corners=True)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.float()
+
+ return img
+
+
+def coords_grid(batch, ht, wd, device):
+ coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing="ij")
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+
+def upflow8(flow, mode='bilinear'):
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
diff --git a/data/dot_single_video/dot/models/shelf/tapir.py b/data/dot_single_video/dot/models/shelf/tapir.py
new file mode 100644
index 0000000000000000000000000000000000000000..22c8ed74a26ef8864cb039553c53887a8c629258
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/tapir.py
@@ -0,0 +1,33 @@
+from torch import nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .tapir_utils.tapir_model import TAPIR
+
+class Tapir(nn.Module):
+ def __init__(self, args):
+ super().__init__()
+ self.model = TAPIR(pyramid_level=args.pyramid_level,
+ softmax_temperature=args.softmax_temperature,
+ extra_convs=args.extra_convs)
+
+ def forward(self, video, queries, backward_tracking, cache_features=False):
+ # Preprocess video
+ video = video * 2 - 1 # conversion from [0, 1] to [-1, 1]
+ video = rearrange(video, "b t c h w -> b t h w c")
+
+ # Preprocess queries
+ queries = queries[..., [0, 2, 1]]
+
+ # Inference
+ outputs = self.model(video, queries, cache_features=cache_features)
+ tracks, occlusions, expected_dist = outputs['tracks'], outputs['occlusion'], outputs['expected_dist']
+
+ # Postprocess tracks
+ tracks = rearrange(tracks, "b s t c -> b t s c")
+
+ # Postprocess visibility
+ visibles = (1 - F.sigmoid(occlusions)) * (1 - F.sigmoid(expected_dist)) > 0.5
+ visibles = rearrange(visibles, "b s t -> b t s")
+
+ return tracks, visibles
\ No newline at end of file
diff --git a/data/dot_single_video/dot/models/shelf/tapir_utils/LICENSE b/data/dot_single_video/dot/models/shelf/tapir_utils/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..75b52484ea471f882c29e02693b4f02dba175b5e
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/tapir_utils/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/data/dot_single_video/dot/models/shelf/tapir_utils/nets.py b/data/dot_single_video/dot/models/shelf/tapir_utils/nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..67bb0dbdb89705ba4bfd18b6d74a696836605ec4
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/tapir_utils/nets.py
@@ -0,0 +1,382 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Pytorch neural network definitions."""
+
+from typing import Sequence, Union
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+class ExtraConvBlock(nn.Module):
+ """Additional convolution block."""
+
+ def __init__(
+ self,
+ channel_dim,
+ channel_multiplier,
+ ):
+ super().__init__()
+ self.channel_dim = channel_dim
+ self.channel_multiplier = channel_multiplier
+
+ self.layer_norm = nn.LayerNorm(
+ normalized_shape=channel_dim, elementwise_affine=True#, bias=True
+ )
+ self.conv = nn.Conv2d(
+ self.channel_dim * 3,
+ self.channel_dim * self.channel_multiplier,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ self.conv_1 = nn.Conv2d(
+ self.channel_dim * self.channel_multiplier,
+ self.channel_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ def forward(self, x):
+ x = self.layer_norm(x)
+ x = x.permute(0, 3, 1, 2)
+ prev_frame = torch.cat([x[0:1], x[:-1]], dim=0)
+ next_frame = torch.cat([x[1:], x[-1:]], dim=0)
+ resid = torch.cat([x, prev_frame, next_frame], axis=1)
+ resid = self.conv(resid)
+ resid = F.gelu(resid, approximate='tanh')
+ x += self.conv_1(resid)
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class ExtraConvs(nn.Module):
+ """Additional CNN."""
+
+ def __init__(
+ self,
+ num_layers=5,
+ channel_dim=256,
+ channel_multiplier=4,
+ ):
+ super().__init__()
+ self.num_layers = num_layers
+ self.channel_dim = channel_dim
+ self.channel_multiplier = channel_multiplier
+
+ self.blocks = nn.ModuleList()
+ for _ in range(self.num_layers):
+ self.blocks.append(
+ ExtraConvBlock(self.channel_dim, self.channel_multiplier)
+ )
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class ConvChannelsMixer(nn.Module):
+ """Linear activation block for PIPs's MLP Mixer."""
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.mlp2_up = nn.Linear(in_channels, in_channels * 4)
+ self.mlp2_down = nn.Linear(in_channels * 4, in_channels)
+
+ def forward(self, x):
+ x = self.mlp2_up(x)
+ x = F.gelu(x, approximate='tanh')
+ x = self.mlp2_down(x)
+ return x
+
+
+class PIPsConvBlock(nn.Module):
+ """Convolutional block for PIPs's MLP Mixer."""
+
+ def __init__(self, in_channels, kernel_shape=3):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(
+ normalized_shape=in_channels, elementwise_affine=True#, bias=False
+ )
+ self.mlp1_up = nn.Conv1d(
+ in_channels, in_channels * 4, kernel_shape, 1, 1, groups=in_channels
+ )
+ self.mlp1_up_1 = nn.Conv1d(
+ in_channels * 4,
+ in_channels * 4,
+ kernel_shape,
+ 1,
+ 1,
+ groups=in_channels * 4,
+ )
+ self.layer_norm_1 = nn.LayerNorm(
+ normalized_shape=in_channels, elementwise_affine=True#, bias=False
+ )
+ self.conv_channels_mixer = ConvChannelsMixer(in_channels)
+
+ def forward(self, x):
+ to_skip = x
+ x = self.layer_norm(x)
+
+ x = x.permute(0, 2, 1)
+ x = self.mlp1_up(x)
+ x = F.gelu(x, approximate='tanh')
+ x = self.mlp1_up_1(x)
+ x = x.permute(0, 2, 1)
+ x = x[..., 0::4] + x[..., 1::4] + x[..., 2::4] + x[..., 3::4]
+
+ x = x + to_skip
+ to_skip = x
+ x = self.layer_norm_1(x)
+ x = self.conv_channels_mixer(x)
+
+ x = x + to_skip
+ return x
+
+
+class PIPSMLPMixer(nn.Module):
+ """Depthwise-conv version of PIPs's MLP Mixer."""
+
+ def __init__(
+ self,
+ input_channels: int,
+ output_channels: int,
+ hidden_dim: int = 512,
+ num_blocks: int = 12,
+ kernel_shape: int = 3,
+ ):
+ """Inits Mixer module.
+
+ A depthwise-convolutional version of a MLP Mixer for processing images.
+
+ Args:
+ input_channels (int): The number of input channels.
+ output_channels (int): The number of output channels.
+ hidden_dim (int, optional): The dimension of the hidden layer. Defaults
+ to 512.
+ num_blocks (int, optional): The number of convolution blocks in the
+ mixer. Defaults to 12.
+ kernel_shape (int, optional): The size of the kernel in the convolution
+ blocks. Defaults to 3.
+ """
+
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_blocks = num_blocks
+ self.linear = nn.Linear(input_channels, self.hidden_dim)
+ self.layer_norm = nn.LayerNorm(
+ normalized_shape=hidden_dim, elementwise_affine=True#, bias=False
+ )
+ self.linear_1 = nn.Linear(hidden_dim, output_channels)
+ self.blocks = nn.ModuleList([
+ PIPsConvBlock(hidden_dim, kernel_shape) for _ in range(num_blocks)
+ ])
+
+ def forward(self, x):
+ x = self.linear(x)
+ for block in self.blocks:
+ x = block(x)
+
+ x = self.layer_norm(x)
+ x = self.linear_1(x)
+ return x
+
+
+class BlockV2(nn.Module):
+ """ResNet V2 block."""
+
+ def __init__(
+ self,
+ channels_in: int,
+ channels_out: int,
+ stride: Union[int, Sequence[int]],
+ use_projection: bool,
+ ):
+ super().__init__()
+ self.padding = (1, 1, 1, 1)
+ # Handle assymetric padding created by padding="SAME" in JAX/LAX
+ if stride == 1:
+ self.padding = (1, 1, 1, 1)
+ elif stride == 2:
+ self.padding = (0, 2, 0, 2)
+ else:
+ raise ValueError(
+ 'Check correct padding using padtype_to_padsin jax._src.lax.lax'
+ )
+
+ self.use_projection = use_projection
+ if self.use_projection:
+ self.proj_conv = nn.Conv2d(
+ in_channels=channels_in,
+ out_channels=channels_out,
+ kernel_size=1,
+ stride=stride,
+ padding=0,
+ bias=False,
+ )
+
+ self.bn_0 = nn.InstanceNorm2d(
+ num_features=channels_in,
+ eps=1e-05,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=False,
+ )
+ self.conv_0 = nn.Conv2d(
+ in_channels=channels_in,
+ out_channels=channels_out,
+ kernel_size=3,
+ stride=stride,
+ padding=0,
+ bias=False,
+ )
+
+ self.conv_1 = nn.Conv2d(
+ in_channels=channels_out,
+ out_channels=channels_out,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False,
+ )
+ self.bn_1 = nn.InstanceNorm2d(
+ num_features=channels_out,
+ eps=1e-05,
+ momentum=0.1,
+ affine=True,
+ track_running_stats=False,
+ )
+
+ def forward(self, inputs):
+ x = shortcut = inputs
+
+ x = self.bn_0(x)
+ x = torch.relu(x)
+ if self.use_projection:
+ shortcut = self.proj_conv(x)
+
+ x = self.conv_0(F.pad(x, self.padding))
+
+ x = self.bn_1(x)
+ x = torch.relu(x)
+ # no issues with padding here as this layer always has stride 1
+ x = self.conv_1(x)
+
+ return x + shortcut
+
+
+class BlockGroup(nn.Module):
+ """Higher level block for ResNet implementation."""
+
+ def __init__(
+ self,
+ channels_in: int,
+ channels_out: int,
+ num_blocks: int,
+ stride: Union[int, Sequence[int]],
+ use_projection: bool,
+ ):
+ super().__init__()
+ blocks = []
+ for i in range(num_blocks):
+ blocks.append(
+ BlockV2(
+ channels_in=channels_in if i == 0 else channels_out,
+ channels_out=channels_out,
+ stride=(1 if i else stride),
+ use_projection=(i == 0 and use_projection),
+ )
+ )
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, inputs):
+ out = inputs
+ for block in self.blocks:
+ out = block(out)
+ return out
+
+
+class ResNet(nn.Module):
+ """ResNet model."""
+
+ def __init__(
+ self,
+ blocks_per_group: Sequence[int],
+ channels_per_group: Sequence[int] = (64, 128, 256, 512),
+ use_projection: Sequence[bool] = (True, True, True, True),
+ strides: Sequence[int] = (1, 2, 2, 2),
+ ):
+ """Initializes a ResNet model with customizable layers and configurations.
+
+ This constructor allows defining the architecture of a ResNet model by
+ setting the number of blocks, channels, projection usage, and strides for
+ each group of blocks within the network. It provides flexibility in
+ creating various ResNet configurations.
+
+ Args:
+ blocks_per_group: A sequence of 4 integers, each indicating the number
+ of residual blocks in each group.
+ channels_per_group: A sequence of 4 integers, each specifying the number
+ of output channels for the blocks in each group. Defaults to (64, 128,
+ 256, 512).
+ use_projection: A sequence of 4 booleans, each indicating whether to use
+ a projection shortcut (True) or an identity shortcut (False) in each
+ group. Defaults to (True, True, True, True).
+ strides: A sequence of 4 integers, each specifying the stride size for
+ the convolutions in each group. Defaults to (1, 2, 2, 2).
+
+ The ResNet model created will have 4 groups, with each group's
+ architecture defined by the corresponding elements in these sequences.
+ """
+ super().__init__()
+
+ self.initial_conv = nn.Conv2d(
+ in_channels=3,
+ out_channels=channels_per_group[0],
+ kernel_size=(7, 7),
+ stride=2,
+ padding=0,
+ bias=False,
+ )
+
+ block_groups = []
+ for i, _ in enumerate(strides):
+ block_groups.append(
+ BlockGroup(
+ channels_in=channels_per_group[i - 1] if i > 0 else 64,
+ channels_out=channels_per_group[i],
+ num_blocks=blocks_per_group[i],
+ stride=strides[i],
+ use_projection=use_projection[i],
+ )
+ )
+ self.block_groups = nn.ModuleList(block_groups)
+
+ def forward(self, inputs):
+ result = {}
+ out = inputs
+ out = self.initial_conv(F.pad(out, (2, 4, 2, 4)))
+ result['initial_conv'] = out
+
+ for block_id, block_group in enumerate(self.block_groups):
+ out = block_group(out)
+ result[f'resnet_unit_{block_id}'] = out
+
+ return result
diff --git a/data/dot_single_video/dot/models/shelf/tapir_utils/tapir_model.py b/data/dot_single_video/dot/models/shelf/tapir_utils/tapir_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..84d0cba28a73dcbd7dd28b836f51271f113d9511
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/tapir_utils/tapir_model.py
@@ -0,0 +1,712 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""TAPIR models definition."""
+
+import functools
+from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from . import nets
+from . import utils
+
+
+class FeatureGrids(NamedTuple):
+ """Feature grids for a video, used to compute trajectories.
+
+ These are per-frame outputs of the encoding resnet.
+
+ Attributes:
+ lowres: Low-resolution features, one for each resolution; 256 channels.
+ hires: High-resolution features, one for each resolution; 64 channels.
+ resolutions: Resolutions used for trajectory computation. There will be one
+ entry for the initialization, and then an entry for each PIPs refinement
+ resolution.
+ """
+
+ lowres: Sequence[torch.Tensor]
+ hires: Sequence[torch.Tensor]
+ resolutions: Sequence[Tuple[int, int]]
+
+
+class QueryFeatures(NamedTuple):
+ """Query features used to compute trajectories.
+
+ These are sampled from the query frames and are a full descriptor of the
+ tracked points. They can be acquired from a query image and then reused in a
+ separate video.
+
+ Attributes:
+ lowres: Low-resolution features, one for each resolution; each has shape
+ [batch, num_query_points, 256]
+ hires: High-resolution features, one for each resolution; each has shape
+ [batch, num_query_points, 64]
+ resolutions: Resolutions used for trajectory computation. There will be one
+ entry for the initialization, and then an entry for each PIPs refinement
+ resolution.
+ """
+
+ lowres: Sequence[torch.Tensor]
+ hires: Sequence[torch.Tensor]
+ resolutions: Sequence[Tuple[int, int]]
+
+
+class TAPIR(nn.Module):
+ """TAPIR model."""
+
+ def __init__(
+ self,
+ bilinear_interp_with_depthwise_conv: bool = False,
+ num_pips_iter: int = 4,
+ pyramid_level: int = 1,
+ mixer_hidden_dim: int = 512,
+ num_mixer_blocks: int = 12,
+ mixer_kernel_shape: int = 3,
+ patch_size: int = 7,
+ softmax_temperature: float = 20.0,
+ parallelize_query_extraction: bool = False,
+ initial_resolution: Tuple[int, int] = (256, 256),
+ blocks_per_group: Sequence[int] = (2, 2, 2, 2),
+ feature_extractor_chunk_size: int = 10,
+ extra_convs: bool = True,
+ ):
+ super().__init__()
+
+ self.highres_dim = 128
+ self.lowres_dim = 256
+ self.bilinear_interp_with_depthwise_conv = (
+ bilinear_interp_with_depthwise_conv
+ )
+ self.parallelize_query_extraction = parallelize_query_extraction
+
+ self.num_pips_iter = num_pips_iter
+ self.pyramid_level = pyramid_level
+ self.patch_size = patch_size
+ self.softmax_temperature = softmax_temperature
+ self.initial_resolution = tuple(initial_resolution)
+ self.feature_extractor_chunk_size = feature_extractor_chunk_size
+
+ highres_dim = 128
+ lowres_dim = 256
+ strides = (1, 2, 2, 1)
+ blocks_per_group = (2, 2, 2, 2)
+ channels_per_group = (64, highres_dim, 256, lowres_dim)
+ use_projection = (True, True, True, True)
+
+ self.resnet_torch = nets.ResNet(
+ blocks_per_group=blocks_per_group,
+ channels_per_group=channels_per_group,
+ use_projection=use_projection,
+ strides=strides,
+ )
+ self.torch_cost_volume_track_mods = nn.ModuleDict({
+ 'hid1': torch.nn.Conv2d(1, 16, 3, 1, 1),
+ 'hid2': torch.nn.Conv2d(16, 1, 3, 1, 1),
+ 'hid3': torch.nn.Conv2d(16, 32, 3, 2, 0),
+ 'hid4': torch.nn.Linear(32, 16),
+ 'occ_out': torch.nn.Linear(16, 2),
+ })
+ dim = 4 + self.highres_dim + self.lowres_dim
+ input_dim = dim + (self.pyramid_level + 2) * 49
+ self.torch_pips_mixer = nets.PIPSMLPMixer(input_dim, dim)
+
+ if extra_convs:
+ self.extra_convs = nets.ExtraConvs()
+ else:
+ self.extra_convs = None
+
+ self.cached_feats = None
+
+ def forward(
+ self,
+ video: torch.Tensor,
+ query_points: torch.Tensor,
+ is_training: bool = False,
+ query_chunk_size: Optional[int] = 512,
+ get_query_feats: bool = False,
+ refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
+ cache_features: bool = False,
+ ) -> Mapping[str, torch.Tensor]:
+ """Runs a forward pass of the model.
+
+ Args:
+ video: A 5-D tensor representing a batch of sequences of images.
+ query_points: The query points for which we compute tracks.
+ is_training: Whether we are training.
+ query_chunk_size: When computing cost volumes, break the queries into
+ chunks of this size to save memory.
+ get_query_feats: Return query features for other losses like contrastive.
+ Not supported in the current version.
+ refinement_resolutions: A list of (height, width) tuples. Refinement will
+ be repeated at each specified resolution, in order to achieve high
+ accuracy on resolutions higher than what TAPIR was trained on. If None,
+ reasonable refinement resolutions will be inferred from the input video
+ size.
+
+ Returns:
+ A dict of outputs, including:
+ occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
+ where higher indicates more likely to be occluded.
+ tracks: predicted point locations, of shape
+ [batch, num_queries, num_frames, 2], where each point is [x, y]
+ in raster coordinates
+ expected_dist: uncertainty estimate logits, of shape
+ [batch, num_queries, num_frames], where higher indicates more likely
+ to be far from the correct answer.
+ """
+ if get_query_feats:
+ raise ValueError('Get query feats not supported in TAPIR.')
+
+ if self.cached_feats is None or cache_features:
+ feature_grids = self.get_feature_grids(
+ video,
+ is_training,
+ refinement_resolutions,
+ )
+ else:
+ feature_grids = self.cached_feats
+
+ if cache_features:
+ self.cached_feats = feature_grids
+
+ query_features = self.get_query_features(
+ video,
+ is_training,
+ query_points,
+ feature_grids,
+ refinement_resolutions,
+ )
+
+ trajectories = self.estimate_trajectories(
+ video.shape[-3:-1],
+ is_training,
+ feature_grids,
+ query_features,
+ query_points,
+ query_chunk_size,
+ )
+
+ p = self.num_pips_iter
+ out = dict(
+ occlusion=torch.mean(
+ torch.stack(trajectories['occlusion'][p::p]), dim=0
+ ),
+ tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
+ expected_dist=torch.mean(
+ torch.stack(trajectories['expected_dist'][p::p]), dim=0
+ ),
+ unrefined_occlusion=trajectories['occlusion'][:-1],
+ unrefined_tracks=trajectories['tracks'][:-1],
+ unrefined_expected_dist=trajectories['expected_dist'][:-1],
+ )
+
+ return out
+
+ def get_query_features(
+ self,
+ video: torch.Tensor,
+ is_training: bool,
+ query_points: torch.Tensor,
+ feature_grids: Optional[FeatureGrids] = None,
+ refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
+ ) -> QueryFeatures:
+ """Computes query features, which can be used for estimate_trajectories.
+
+ Args:
+ video: A 5-D tensor representing a batch of sequences of images.
+ is_training: Whether we are training.
+ query_points: The query points for which we compute tracks.
+ feature_grids: If passed, we'll use these feature grids rather than
+ computing new ones.
+ refinement_resolutions: A list of (height, width) tuples. Refinement will
+ be repeated at each specified resolution, in order to achieve high
+ accuracy on resolutions higher than what TAPIR was trained on. If None,
+ reasonable refinement resolutions will be inferred from the input video
+ size.
+
+ Returns:
+ A QueryFeatures object which contains the required features for every
+ required resolution.
+ """
+
+ if feature_grids is None:
+ feature_grids = self.get_feature_grids(
+ video,
+ is_training=is_training,
+ refinement_resolutions=refinement_resolutions,
+ )
+
+ feature_grid = feature_grids.lowres
+ hires_feats = feature_grids.hires
+ resize_im_shape = feature_grids.resolutions
+
+ shape = video.shape
+ # shape is [batch_size, time, height, width, channels]; conversion needs
+ # [time, width, height]
+ curr_resolution = (-1, -1)
+ query_feats = []
+ hires_query_feats = []
+ for i, resolution in enumerate(resize_im_shape):
+ if utils.is_same_res(curr_resolution, resolution):
+ query_feats.append(query_feats[-1])
+ hires_query_feats.append(hires_query_feats[-1])
+ continue
+ position_in_grid = utils.convert_grid_coordinates(
+ query_points,
+ shape[1:4],
+ feature_grid[i].shape[1:4],
+ coordinate_format='tyx',
+ )
+ position_in_grid_hires = utils.convert_grid_coordinates(
+ query_points,
+ shape[1:4],
+ hires_feats[i].shape[1:4],
+ coordinate_format='tyx',
+ )
+
+ interp_features = utils.map_coordinates_3d(
+ feature_grid[i], position_in_grid
+ )
+ hires_interp = utils.map_coordinates_3d(
+ hires_feats[i], position_in_grid_hires
+ )
+
+ hires_query_feats.append(hires_interp)
+ query_feats.append(interp_features)
+
+ return QueryFeatures(
+ tuple(query_feats), tuple(hires_query_feats), tuple(resize_im_shape)
+ )
+
+ def get_feature_grids(
+ self,
+ video: torch.Tensor,
+ is_training: bool,
+ refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
+ ) -> FeatureGrids:
+ """Computes feature grids.
+
+ Args:
+ video: A 5-D tensor representing a batch of sequences of images.
+ is_training: Whether we are training.
+ refinement_resolutions: A list of (height, width) tuples. Refinement will
+ be repeated at each specified resolution, to achieve high accuracy on
+ resolutions higher than what TAPIR was trained on. If None, reasonable
+ refinement resolutions will be inferred from the input video size.
+
+ Returns:
+ A FeatureGrids object containing the required features for every
+ required resolution. Note that there will be one more feature grid
+ than there are refinement_resolutions, because there is always a
+ feature grid computed for TAP-Net initialization.
+ """
+ del is_training
+ if refinement_resolutions is None:
+ refinement_resolutions = utils.generate_default_resolutions(
+ video.shape[2:4], self.initial_resolution
+ )
+
+ all_required_resolutions = [self.initial_resolution]
+ all_required_resolutions.extend(refinement_resolutions)
+
+ feature_grid = []
+ hires_feats = []
+ resize_im_shape = []
+ curr_resolution = (-1, -1)
+
+ latent = None
+ hires = None
+ video_resize = None
+ for resolution in all_required_resolutions:
+ if resolution[0] % 8 != 0 or resolution[1] % 8 != 0:
+ raise ValueError('Image resolution must be a multiple of 8.')
+
+ if not utils.is_same_res(curr_resolution, resolution):
+ if utils.is_same_res(curr_resolution, video.shape[-3:-1]):
+ video_resize = video
+ else:
+ video_resize = utils.bilinear(video, resolution)
+
+ curr_resolution = resolution
+ n, f, h, w, c = video_resize.shape
+ video_resize = video_resize.view(n * f, h, w, c).permute(0, 3, 1, 2)
+
+ if self.feature_extractor_chunk_size > 0:
+ latent_list = []
+ hires_list = []
+ chunk_size = self.feature_extractor_chunk_size
+ for start_idx in range(0, video_resize.shape[0], chunk_size):
+ video_chunk = video_resize[start_idx:start_idx + chunk_size]
+ resnet_out = self.resnet_torch(video_chunk)
+
+ u3 = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1).detach()
+ latent_list.append(u3)
+ u1 = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1).detach()
+ hires_list.append(u1)
+
+ latent = torch.cat(latent_list, dim=0)
+ hires = torch.cat(hires_list, dim=0)
+
+ else:
+ resnet_out = self.resnet_torch(video_resize)
+ latent = resnet_out['resnet_unit_3'].permute(0, 2, 3, 1).detach()
+ hires = resnet_out['resnet_unit_1'].permute(0, 2, 3, 1).detach()
+
+ if self.extra_convs:
+ latent = self.extra_convs(latent)
+
+ latent = latent / torch.sqrt(
+ torch.maximum(
+ torch.sum(torch.square(latent), axis=-1, keepdims=True),
+ torch.tensor(1e-12, device=latent.device),
+ )
+ )
+ hires = hires / torch.sqrt(
+ torch.maximum(
+ torch.sum(torch.square(hires), axis=-1, keepdims=True),
+ torch.tensor(1e-12, device=hires.device),
+ )
+ )
+
+ feature_grid.append(latent[None, ...])
+ hires_feats.append(hires[None, ...])
+ resize_im_shape.append(video_resize.shape[2:4])
+
+ return FeatureGrids(
+ tuple(feature_grid), tuple(hires_feats), tuple(resize_im_shape)
+ )
+
+ def estimate_trajectories(
+ self,
+ video_size: Tuple[int, int],
+ is_training: bool,
+ feature_grids: FeatureGrids,
+ query_features: QueryFeatures,
+ query_points_in_video: Optional[torch.Tensor],
+ query_chunk_size: Optional[int] = None,
+ ) -> Mapping[str, Any]:
+ """Estimates trajectories given features for a video and query features.
+
+ Args:
+ video_size: A 2-tuple containing the original [height, width] of the
+ video. Predictions will be scaled with respect to this resolution.
+ is_training: Whether we are training.
+ feature_grids: a FeatureGrids object computed for the given video.
+ query_features: a QueryFeatures object computed for the query points.
+ query_points_in_video: If provided, assume that the query points come from
+ the same video as feature_grids, and therefore constrain the resulting
+ trajectories to (approximately) pass through them.
+ query_chunk_size: When computing cost volumes, break the queries into
+ chunks of this size to save memory.
+
+ Returns:
+ A dict of outputs, including:
+ occlusion: Occlusion logits, of shape [batch, num_queries, num_frames]
+ where higher indicates more likely to be occluded.
+ tracks: predicted point locations, of shape
+ [batch, num_queries, num_frames, 2], where each point is [x, y]
+ in raster coordinates
+ expected_dist: uncertainty estimate logits, of shape
+ [batch, num_queries, num_frames], where higher indicates more likely
+ to be far from the correct answer.
+ """
+ del is_training
+
+ def train2orig(x):
+ return utils.convert_grid_coordinates(
+ x,
+ self.initial_resolution[::-1],
+ video_size[::-1],
+ coordinate_format='xy',
+ )
+
+ occ_iters = []
+ pts_iters = []
+ expd_iters = []
+ num_iters = self.num_pips_iter * (len(feature_grids.lowres) - 1)
+ for _ in range(num_iters + 1):
+ occ_iters.append([])
+ pts_iters.append([])
+ expd_iters.append([])
+
+ infer = functools.partial(
+ self.tracks_from_cost_volume,
+ im_shp=feature_grids.lowres[0].shape[0:2]
+ + self.initial_resolution
+ + (3,),
+ )
+
+ num_queries = query_features.lowres[0].shape[1]
+ perm = torch.randperm(num_queries)
+ inv_perm = torch.zeros_like(perm)
+ inv_perm[perm] = torch.arange(num_queries)
+
+ for ch in range(0, num_queries, query_chunk_size):
+ perm_chunk = perm[ch: ch + query_chunk_size]
+ chunk = query_features.lowres[0][:, perm_chunk]
+
+ if query_points_in_video is not None:
+ infer_query_points = query_points_in_video[
+ :, perm[ch: ch + query_chunk_size]
+ ]
+ num_frames = feature_grids.lowres[0].shape[1]
+ infer_query_points = utils.convert_grid_coordinates(
+ infer_query_points,
+ (num_frames,) + video_size,
+ (num_frames,) + self.initial_resolution,
+ coordinate_format='tyx',
+ )
+ else:
+ infer_query_points = None
+
+ points, occlusion, expected_dist = infer(
+ chunk,
+ feature_grids.lowres[0],
+ infer_query_points,
+ )
+ pts_iters[0].append(train2orig(points))
+ occ_iters[0].append(occlusion)
+ expd_iters[0].append(expected_dist)
+
+ mixer_feats = None
+ for i in range(num_iters):
+ feature_level = i // self.num_pips_iter + 1
+ queries = [
+ query_features.hires[feature_level][:, perm_chunk],
+ query_features.lowres[feature_level][:, perm_chunk],
+ ]
+ for _ in range(self.pyramid_level):
+ queries.append(queries[-1])
+ pyramid = [
+ feature_grids.hires[feature_level],
+ feature_grids.lowres[feature_level],
+ ]
+ for _ in range(self.pyramid_level):
+ pyramid.append(
+ F.avg_pool3d(
+ pyramid[-1],
+ kernel_size=(2, 2, 1),
+ stride=(2, 2, 1),
+ padding=0,
+ )
+ )
+
+ refined = self.refine_pips(
+ queries,
+ None,
+ pyramid,
+ points,
+ occlusion,
+ expected_dist,
+ orig_hw=self.initial_resolution,
+ last_iter=mixer_feats,
+ mixer_iter=i,
+ resize_hw=feature_grids.resolutions[feature_level],
+ )
+ points, occlusion, expected_dist, mixer_feats = refined
+ pts_iters[i + 1].append(train2orig(points))
+ occ_iters[i + 1].append(occlusion)
+ expd_iters[i + 1].append(expected_dist)
+ if (i + 1) % self.num_pips_iter == 0:
+ mixer_feats = None
+ expected_dist = expd_iters[0][-1]
+ occlusion = occ_iters[0][-1]
+
+ occlusion = []
+ points = []
+ expd = []
+ for i, _ in enumerate(occ_iters):
+ occlusion.append(torch.cat(occ_iters[i], dim=1)[:, inv_perm])
+ points.append(torch.cat(pts_iters[i], dim=1)[:, inv_perm])
+ expd.append(torch.cat(expd_iters[i], dim=1)[:, inv_perm])
+
+ out = dict(
+ occlusion=occlusion,
+ tracks=points,
+ expected_dist=expd,
+ )
+ return out
+
+ def refine_pips(
+ self,
+ target_feature,
+ frame_features,
+ pyramid,
+ pos_guess,
+ occ_guess,
+ expd_guess,
+ orig_hw,
+ last_iter=None,
+ mixer_iter=0.0,
+ resize_hw=None,
+ ):
+ del frame_features
+ del mixer_iter
+ orig_h, orig_w = orig_hw
+ resized_h, resized_w = resize_hw
+ corrs_pyr = []
+ assert len(target_feature) == len(pyramid)
+ for pyridx, (query, grid) in enumerate(zip(target_feature, pyramid)):
+ # note: interp needs [y,x]
+ coords = utils.convert_grid_coordinates(
+ pos_guess, (orig_w, orig_h), grid.shape[-2:-4:-1]
+ )
+ coords = torch.flip(coords, dims=(-1,))
+ last_iter_query = None
+ if last_iter is not None:
+ if pyridx == 0:
+ last_iter_query = last_iter[..., : self.highres_dim]
+ else:
+ last_iter_query = last_iter[..., self.highres_dim:]
+
+ ctxy, ctxx = torch.meshgrid(
+ torch.arange(-3, 4), torch.arange(-3, 4), indexing='ij'
+ )
+ ctx = torch.stack([ctxy, ctxx], dim=-1)
+ ctx = ctx.reshape(-1, 2).to(coords.device)
+ coords2 = coords.unsqueeze(3) + ctx.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+ neighborhood = utils.map_coordinates_2d(grid, coords2)
+
+ # s is spatial context size
+ if last_iter_query is None:
+ patches = torch.einsum('bnfsc,bnc->bnfs', neighborhood, query)
+ else:
+ patches = torch.einsum(
+ 'bnfsc,bnfc->bnfs', neighborhood, last_iter_query
+ )
+
+ corrs_pyr.append(patches)
+ corrs_pyr = torch.concatenate(corrs_pyr, dim=-1)
+
+ corrs_chunked = corrs_pyr
+ pos_guess_input = pos_guess
+ occ_guess_input = occ_guess[..., None]
+ expd_guess_input = expd_guess[..., None]
+
+ # mlp_input is batch, num_points, num_chunks, frames_per_chunk, channels
+ if last_iter is None:
+ both_feature = torch.cat([target_feature[0], target_feature[1]], axis=-1)
+ mlp_input_features = torch.tile(
+ both_feature.unsqueeze(2), (1, 1, corrs_chunked.shape[-2], 1)
+ )
+ else:
+ mlp_input_features = last_iter
+
+ pos_guess_input = torch.zeros_like(pos_guess_input)
+
+ mlp_input = torch.cat(
+ [
+ pos_guess_input,
+ occ_guess_input,
+ expd_guess_input,
+ mlp_input_features,
+ corrs_chunked,
+ ],
+ axis=-1,
+ )
+ x = utils.einshape('bnfc->(bn)fc', mlp_input)
+ res = self.torch_pips_mixer(x.float())
+ res = utils.einshape('(bn)fc->bnfc', res, b=mlp_input.shape[0])
+
+ pos_update = utils.convert_grid_coordinates(
+ res[..., :2].detach(),
+ (resized_w, resized_h),
+ (orig_w, orig_h),
+ )
+ return (
+ pos_update + pos_guess,
+ res[..., 2] + occ_guess,
+ res[..., 3] + expd_guess,
+ res[..., 4:] + (mlp_input_features if last_iter is None else last_iter),
+ )
+
+ def tracks_from_cost_volume(
+ self,
+ interp_feature: torch.Tensor,
+ feature_grid: torch.Tensor,
+ query_points: Optional[torch.Tensor],
+ im_shp=None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Converts features into tracks by computing a cost volume.
+
+ The computed cost volume will have shape
+ [batch, num_queries, time, height, width], which can be very
+ memory intensive.
+
+ Args:
+ interp_feature: A tensor of features for each query point, of shape
+ [batch, num_queries, channels, heads].
+ feature_grid: A tensor of features for the video, of shape [batch, time,
+ height, width, channels, heads].
+ query_points: When computing tracks, we assume these points are given as
+ ground truth and we reproduce them exactly. This is a set of points of
+ shape [batch, num_points, 3], where each entry is [t, y, x] in frame/
+ raster coordinates.
+ im_shp: The shape of the original image, i.e., [batch, num_frames, time,
+ height, width, 3].
+
+ Returns:
+ A 2-tuple of the inferred points (of shape
+ [batch, num_points, num_frames, 2] where each point is [x, y]) and
+ inferred occlusion (of shape [batch, num_points, num_frames], where
+ each is a logit where higher means occluded)
+ """
+
+ mods = self.torch_cost_volume_track_mods
+ cost_volume = torch.einsum(
+ 'bnc,bthwc->tbnhw',
+ interp_feature,
+ feature_grid,
+ )
+
+ shape = cost_volume.shape
+ batch_size, num_points = cost_volume.shape[1:3]
+ cost_volume = utils.einshape('tbnhw->(tbn)hw1', cost_volume)
+
+ cost_volume = cost_volume.permute(0, 3, 1, 2)
+ occlusion = mods['hid1'](cost_volume)
+ occlusion = torch.nn.functional.relu(occlusion)
+
+ pos = mods['hid2'](occlusion)
+ pos = pos.permute(0, 2, 3, 1)
+ pos_rshp = utils.einshape('(tb)hw1->t(b)hw1', pos, t=shape[0])
+
+ pos = utils.einshape(
+ 't(bn)hw1->bnthw', pos_rshp, b=batch_size, n=num_points
+ )
+ pos_sm = pos.reshape(pos.size(0), pos.size(1), pos.size(2), -1)
+ softmaxed = F.softmax(pos_sm * self.softmax_temperature, dim=-1)
+ pos = softmaxed.view_as(pos)
+
+ points = utils.heatmaps_to_points(pos, im_shp, query_points=query_points)
+
+ occlusion = torch.nn.functional.pad(occlusion, (0, 2, 0, 2))
+ occlusion = mods['hid3'](occlusion)
+ occlusion = torch.nn.functional.relu(occlusion)
+ occlusion = torch.mean(occlusion, dim=(-1, -2))
+ occlusion = mods['hid4'](occlusion)
+ occlusion = torch.nn.functional.relu(occlusion)
+ occlusion = mods['occ_out'](occlusion)
+
+ expected_dist = utils.einshape(
+ '(tbn)1->bnt', occlusion[..., 1:2], n=shape[2], t=shape[0]
+ )
+ occlusion = utils.einshape(
+ '(tbn)1->bnt', occlusion[..., 0:1], n=shape[2], t=shape[0]
+ )
+ return points, occlusion, expected_dist
diff --git a/data/dot_single_video/dot/models/shelf/tapir_utils/utils.py b/data/dot_single_video/dot/models/shelf/tapir_utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1b68db3a9dc35c02ec10d691c8ce598016a446
--- /dev/null
+++ b/data/dot_single_video/dot/models/shelf/tapir_utils/utils.py
@@ -0,0 +1,317 @@
+# Copyright 2024 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Pytorch model utilities."""
+
+from typing import Any, Sequence, Union
+from einshape.src import abstract_ops
+from einshape.src import backend
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
+ """Resizes a 5D tensor using bilinear interpolation.
+
+ Args:
+ x: A 5D tensor of shape (B, T, W, H, C) where B is batch size, T is
+ time, W is width, H is height, and C is the number of channels.
+ resolution: The target resolution as a tuple (new_width, new_height).
+
+ Returns:
+ The resized tensor.
+ """
+ b, t, h, w, c = x.size()
+ x = x.permute(0, 1, 4, 2, 3).reshape(b, t * c, h, w)
+ x = F.interpolate(x, size=resolution, mode='bilinear', align_corners=False)
+ b, _, h, w = x.size()
+ x = x.reshape(b, t, c, h, w).permute(0, 1, 3, 4, 2)
+ return x
+
+
+def map_coordinates_3d(
+ feats: torch.Tensor, coordinates: torch.Tensor
+) -> torch.Tensor:
+ """Maps 3D coordinates to corresponding features using bilinear interpolation.
+
+ Args:
+ feats: A 5D tensor of features with shape (B, W, H, D, C), where B is batch
+ size, W is width, H is height, D is depth, and C is the number of
+ channels.
+ coordinates: A 3D tensor of coordinates with shape (B, N, 3), where N is the
+ number of coordinates and the last dimension represents (W, H, D)
+ coordinates.
+
+ Returns:
+ The mapped features tensor.
+ """
+ x = feats.permute(0, 4, 1, 2, 3)
+ y = coordinates[:, :, None, None, :].float()
+ y[..., 0] += 0.5
+ y = 2 * (y / torch.tensor(x.shape[2:], device=y.device)) - 1
+ y = torch.flip(y, dims=(-1,))
+ out = (
+ F.grid_sample(
+ x, y, mode='bilinear', align_corners=False, padding_mode='border'
+ )
+ .squeeze(dim=(3, 4))
+ .permute(0, 2, 1)
+ )
+ return out
+
+
+def map_coordinates_2d(
+ feats: torch.Tensor, coordinates: torch.Tensor
+) -> torch.Tensor:
+ """Maps 2D coordinates to feature maps using bilinear interpolation.
+
+ The function performs bilinear interpolation on the feature maps (`feats`)
+ at the specified `coordinates`. The coordinates are normalized between
+ -1 and 1 The result is a tensor of sampled features corresponding
+ to these coordinates.
+
+ Args:
+ feats (Tensor): A 5D tensor of shape (N, T, H, W, C) representing feature
+ maps, where N is the batch size, T is the number of frames, H and W are
+ height and width, and C is the number of channels.
+ coordinates (Tensor): A 5D tensor of shape (N, P, T, S, XY) representing
+ coordinates, where N is the batch size, P is the number of points, T is
+ the number of frames, S is the number of samples, and XY represents the 2D
+ coordinates.
+
+ Returns:
+ Tensor: A 5D tensor of the sampled features corresponding to the
+ given coordinates, of shape (N, P, T, S, C).
+ """
+ n, t, h, w, c = feats.shape
+ x = feats.permute(0, 1, 4, 2, 3).view(n * t, c, h, w)
+
+ n, p, t, s, xy = coordinates.shape
+ y = coordinates.permute(0, 2, 1, 3, 4).view(n * t, p, s, xy)
+ y = 2 * (y / h) - 1
+ y = torch.flip(y, dims=(-1,)).float()
+
+ out = F.grid_sample(
+ x, y, mode='bilinear', align_corners=False, padding_mode='zeros'
+ )
+ _, c, _, _ = out.shape
+ out = out.permute(0, 2, 3, 1).view(n, t, p, s, c).permute(0, 2, 1, 3, 4)
+
+ return out
+
+
+def soft_argmax_heatmap_batched(softmax_val, threshold=5):
+ """Test if two image resolutions are the same."""
+ b, h, w, d1, d2 = softmax_val.shape
+ y, x = torch.meshgrid(
+ torch.arange(d1, device=softmax_val.device),
+ torch.arange(d2, device=softmax_val.device),
+ indexing='ij',
+ )
+ coords = torch.stack([x + 0.5, y + 0.5], dim=-1).to(softmax_val.device)
+ softmax_val_flat = softmax_val.reshape(b, h, w, -1)
+ argmax_pos = torch.argmax(softmax_val_flat, dim=-1)
+
+ pos = coords.reshape(-1, 2)[argmax_pos]
+ valid = (
+ torch.sum(
+ torch.square(
+ coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :]
+ ),
+ dim=-1,
+ keepdims=True,
+ )
+ < threshold**2
+ )
+
+ weighted_sum = torch.sum(
+ coords[None, None, None, :, :, :]
+ * valid
+ * softmax_val[:, :, :, :, :, None],
+ dim=(3, 4),
+ )
+ sum_of_weights = torch.maximum(
+ torch.sum(valid * softmax_val[:, :, :, :, :, None], dim=(3, 4)),
+ torch.tensor(1e-12, device=softmax_val.device),
+ )
+ return weighted_sum / sum_of_weights
+
+
+def heatmaps_to_points(
+ all_pairs_softmax,
+ image_shape,
+ threshold=5,
+ query_points=None,
+):
+ """Convert heatmaps to points using soft argmax."""
+
+ out_points = soft_argmax_heatmap_batched(all_pairs_softmax, threshold)
+ feature_grid_shape = all_pairs_softmax.shape[1:]
+ # Note: out_points is now [x, y]; we need to divide by [width, height].
+ # image_shape[3] is width and image_shape[2] is height.
+ out_points = convert_grid_coordinates(
+ out_points.detach(),
+ feature_grid_shape[3:1:-1],
+ image_shape[3:1:-1],
+ )
+ assert feature_grid_shape[1] == image_shape[1]
+ if query_points is not None:
+ # The [..., 0:1] is because we only care about the frame index.
+ query_frame = convert_grid_coordinates(
+ query_points.detach(),
+ image_shape[1:4],
+ feature_grid_shape[1:4],
+ coordinate_format='tyx',
+ )[..., 0:1]
+
+ query_frame = torch.round(query_frame)
+ frame_indices = torch.arange(image_shape[1], device=query_frame.device)[
+ None, None, :
+ ]
+ is_query_point = query_frame == frame_indices
+
+ is_query_point = is_query_point[:, :, :, None]
+ out_points = (
+ out_points * ~is_query_point
+ + torch.flip(query_points[:, :, None], dims=(-1,))[..., 0:2]
+ * is_query_point
+ )
+
+ return out_points
+
+
+def is_same_res(r1, r2):
+ """Test if two image resolutions are the same."""
+ return all([x == y for x, y in zip(r1, r2)])
+
+
+def convert_grid_coordinates(
+ coords: torch.Tensor,
+ input_grid_size: Sequence[int],
+ output_grid_size: Sequence[int],
+ coordinate_format: str = 'xy',
+) -> torch.Tensor:
+ """Convert grid coordinates to correct format."""
+ if isinstance(input_grid_size, tuple):
+ input_grid_size = torch.tensor(input_grid_size, device=coords.device)
+ if isinstance(output_grid_size, tuple):
+ output_grid_size = torch.tensor(output_grid_size, device=coords.device)
+
+ if coordinate_format == 'xy':
+ if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
+ raise ValueError(
+ 'If coordinate_format is xy, the shapes must be length 2.'
+ )
+ elif coordinate_format == 'tyx':
+ if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
+ raise ValueError(
+ 'If coordinate_format is tyx, the shapes must be length 3.'
+ )
+ if input_grid_size[0] != output_grid_size[0]:
+ raise ValueError('converting frame count is not supported.')
+ else:
+ raise ValueError('Recognized coordinate formats are xy and tyx.')
+
+ position_in_grid = coords
+ position_in_grid = position_in_grid * output_grid_size / input_grid_size
+
+ return position_in_grid
+
+
+class _JaxBackend(backend.Backend[torch.Tensor]):
+ """Einshape implementation for PyTorch."""
+
+ # https://github.com/vacancy/einshape/blob/main/einshape/src/pytorch/pytorch_ops.py
+
+ def reshape(self, x: torch.Tensor, op: abstract_ops.Reshape) -> torch.Tensor:
+ return x.reshape(op.shape)
+
+ def transpose(
+ self, x: torch.Tensor, op: abstract_ops.Transpose
+ ) -> torch.Tensor:
+ return x.permute(op.perm)
+
+ def broadcast(
+ self, x: torch.Tensor, op: abstract_ops.Broadcast
+ ) -> torch.Tensor:
+ shape = op.transform_shape(x.shape)
+ for axis_position in sorted(op.axis_sizes.keys()):
+ x = x.unsqueeze(axis_position)
+ return x.expand(shape)
+
+
+def einshape(
+ equation: str, value: Union[torch.Tensor, Any], **index_sizes: int
+) -> torch.Tensor:
+ """Reshapes `value` according to the given Shape Equation.
+
+ Args:
+ equation: The Shape Equation specifying the index regrouping and reordering.
+ value: Input tensor, or tensor-like object.
+ **index_sizes: Sizes of indices, where they cannot be inferred from
+ `input_shape`.
+
+ Returns:
+ Tensor derived from `value` by reshaping as specified by `equation`.
+ """
+ if not isinstance(value, torch.Tensor):
+ value = torch.tensor(value)
+ return _JaxBackend().exec(equation, value, value.shape, **index_sizes)
+
+
+def generate_default_resolutions(full_size, train_size, num_levels=None):
+ """Generate a list of logarithmically-spaced resolutions.
+
+ Generated resolutions are between train_size and full_size, inclusive, with
+ num_levels different resolutions total. Useful for generating the input to
+ refinement_resolutions in PIPs.
+
+ Args:
+ full_size: 2-tuple of ints. The full image size desired.
+ train_size: 2-tuple of ints. The smallest refinement level. Should
+ typically match the training resolution, which is (256, 256) for TAPIR.
+ num_levels: number of levels. Typically each resolution should be less than
+ twice the size of prior resolutions.
+
+ Returns:
+ A list of resolutions.
+ """
+ if all([x == y for x, y in zip(train_size, full_size)]):
+ return [train_size]
+
+ if num_levels is None:
+ size_ratio = np.array(full_size) / np.array(train_size)
+ num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1)
+
+ if num_levels <= 1:
+ return [train_size]
+
+ h, w = full_size[0:2]
+ if h % 8 != 0 or w % 8 != 0:
+ print(
+ 'Warning: output size is not a multiple of 8. Final layer '
+ + 'will round size down.'
+ )
+ ll_h, ll_w = train_size[0:2]
+
+ sizes = []
+ for i in range(num_levels):
+ size = (
+ int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8,
+ int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8,
+ )
+ sizes.append(size)
+ return sizes
diff --git a/data/dot_single_video/dot/utils/__init__.py b/data/dot_single_video/dot/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/data/dot_single_video/dot/utils/io.py b/data/dot_single_video/dot/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d3ee176861043fcee5e7f52aa2ea0c6b4c75c3f
--- /dev/null
+++ b/data/dot_single_video/dot/utils/io.py
@@ -0,0 +1,129 @@
+import os
+import argparse
+from PIL import Image
+from glob import glob
+import numpy as np
+import json
+import torch
+import torchvision
+from torch.nn import functional as F
+
+
+def create_folder(path, verbose=False, exist_ok=True, safe=True):
+ if os.path.exists(path) and not exist_ok:
+ if not safe:
+ raise OSError
+ return False
+ try:
+ os.makedirs(path)
+ except:
+ if not safe:
+ raise OSError
+ return False
+ if verbose:
+ print(f"Created folder: {path}")
+ return True
+
+
+def read_video(path, start_step=0, time_steps=None, channels="first", exts=("jpg", "png"), resolution=None):
+ if path.endswith(".mp4"):
+ video = read_video_from_file(path, start_step, time_steps, channels, resolution)
+ else:
+ video = read_video_from_folder(path, start_step, time_steps, channels, resolution, exts)
+ return video
+
+
+def read_video_from_file(path, start_step, time_steps, channels, resolution):
+ video, _, _ = torchvision.io.read_video(path, output_format="TCHW", pts_unit="sec")
+ if time_steps is None:
+ time_steps = len(video) - start_step
+ video = video[start_step: start_step + time_steps]
+ if resolution is not None:
+ video = F.interpolate(video, size=resolution, mode="bilinear")
+ if channels == "last":
+ video = video.permute(0, 2, 3, 1)
+ video = video / 255.
+ return video
+
+
+def read_video_from_folder(path, start_step, time_steps, channels, resolution, exts):
+ paths = []
+ for ext in exts:
+ paths += glob(os.path.join(path, f"*.{ext}"))
+ paths = sorted(paths)
+ if time_steps is None:
+ time_steps = len(paths) - start_step
+ video = []
+ for step in range(start_step, start_step + time_steps):
+ frame = read_frame(paths[step], resolution, channels)
+ video.append(frame)
+ video = torch.stack(video)
+ return video
+
+
+def read_frame(path, resolution=None, channels="first"):
+ frame = Image.open(path).convert('RGB')
+ frame = np.array(frame)
+ frame = frame.astype(np.float32)
+ frame = frame / 255
+ frame = torch.from_numpy(frame)
+ frame = frame.permute(2, 0, 1)
+ if resolution is not None:
+ frame = F.interpolate(frame[None], size=resolution, mode="bilinear")[0]
+ if channels == "last":
+ frame = frame.permute(1, 2, 0)
+ return frame
+
+
+def write_video(video, path, channels="first", zero_padded=True, ext="png", dtype="torch"):
+ if dtype == "numpy":
+ video = torch.from_numpy(video)
+ if path.endswith(".mp4"):
+ write_video_to_file(video, path, channels)
+ else:
+ write_video_to_folder(video, path, channels, zero_padded, ext)
+
+
+def write_video_to_file(video, path, channels):
+ create_folder(os.path.dirname(path))
+ if channels == "first":
+ video = video.permute(0, 2, 3, 1)
+ video = (video.cpu() * 255.).to(torch.uint8)
+ torchvision.io.write_video(path, video, 24, "h264", options={"pix_fmt": "yuv420p", "crf": "23"})
+ return video
+
+
+def write_video_to_folder(video, path, channels, zero_padded, ext):
+ create_folder(path)
+ time_steps = video.shape[0]
+ for step in range(time_steps):
+ pad = "0" * (len(str(time_steps)) - len(str(step))) if zero_padded else ""
+ frame_path = os.path.join(path, f"{pad}{step}.{ext}")
+ write_frame(video[step], frame_path, channels)
+
+
+def write_frame(frame, path, channels="first"):
+ create_folder(os.path.dirname(path))
+ frame = frame.cpu().numpy()
+ if channels == "first":
+ frame = np.transpose(frame, (1, 2, 0))
+ frame = np.clip(np.round(frame * 255), 0, 255).astype(np.uint8)
+ frame = Image.fromarray(frame)
+ frame.save(path)
+
+
+def read_tracks(path):
+ return np.load(path)
+
+
+def write_tracks(tracks, path):
+ np.save(path, tracks)
+
+
+def read_config(path):
+ with open(path, 'r') as f:
+ config = json.load(f)
+ args = argparse.Namespace(**config)
+ return args
+
+
diff --git a/data/dot_single_video/dot/utils/log.py b/data/dot_single_video/dot/utils/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..12de66c0fbbbcabe34049ca8988aaaa6a335cb3a
--- /dev/null
+++ b/data/dot_single_video/dot/utils/log.py
@@ -0,0 +1,57 @@
+import os
+os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
+
+import torch
+from torchvision.utils import make_grid
+from torch.utils.tensorboard import SummaryWriter
+
+from dot.utils.plot import to_rgb
+
+
+def detach(tensor):
+ if isinstance(tensor, torch.Tensor):
+ return tensor.detach().cpu()
+ return tensor
+
+
+def number(tensor):
+ if isinstance(tensor, torch.Tensor) and tensor.isnan().any():
+ return torch.zeros_like(tensor)
+ return tensor
+
+
+class Logger():
+ def __init__(self, args):
+ self.writer = SummaryWriter(args.log_path)
+ self.factor = args.log_factor
+ self.world_size = args.world_size
+
+ def log_scalar(self, name, scalar, global_iter):
+ if scalar is not None:
+ if type(scalar) == list:
+ for i, x in enumerate(scalar):
+ self.log_scalar(f"{name}_{i}", x, global_iter)
+ else:
+ self.writer.add_scalar(name, number(detach(scalar)), global_iter)
+
+ def log_scalars(self, name, scalars, global_iter):
+ for s in scalars:
+ self.log_scalar(f"{name}/{s}", scalars[s], global_iter)
+
+ def log_image(self, name, tensor, mode, nrow, global_iter, pos=None, occ=None):
+ tensor = detach(tensor)
+ tensor = to_rgb(tensor, mode, pos, occ)
+ grid = make_grid(tensor, nrow=nrow, normalize=False, value_range=[0, 1], pad_value=0)
+ grid = torch.nn.functional.interpolate(grid[None], scale_factor=self.factor)[0]
+ self.writer.add_image(name, grid, global_iter)
+
+ def log_video(self, name, tensor, mode, nrow, global_iter, fps=4, pos=None, occ=None):
+ tensor = detach(tensor)
+ tensor = to_rgb(tensor, mode, pos, occ, is_video=True)
+ grid = []
+ for i in range(tensor.shape[1]):
+ grid.append(make_grid(tensor[:, i], nrow=nrow, normalize=False, value_range=[0, 1], pad_value=0))
+ grid = torch.stack(grid, dim=0)
+ grid = torch.nn.functional.interpolate(grid, scale_factor=self.factor)
+ grid = grid[None]
+ self.writer.add_video(name, grid, global_iter, fps=fps)
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/metrics/__init__.py b/data/dot_single_video/dot/utils/metrics/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c20d4718a6848cd4444fc59cacb223a1aec7eb0c
--- /dev/null
+++ b/data/dot_single_video/dot/utils/metrics/__init__.py
@@ -0,0 +1,7 @@
+def save_metrics(metrics, path):
+ names = list(metrics.keys())
+ num_values = len(metrics[names[0]])
+ with open(path, "w") as f:
+ f.write(",".join(names) + "\n")
+ for i in range(num_values):
+ f.write(",".join([str(metrics[name][i]) for name in names]) + "\n")
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/metrics/cvo_metrics.py b/data/dot_single_video/dot/utils/metrics/cvo_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..e177c9a94159d3377cd278fee4725be83620f2ba
--- /dev/null
+++ b/data/dot_single_video/dot/utils/metrics/cvo_metrics.py
@@ -0,0 +1,32 @@
+import torch
+
+
+def compute_metrics(gt, pred, time):
+ epe_all, epe_occ, epe_vis = get_epe(pred["flow"], gt["flow"], gt["alpha"])
+ iou = get_iou(gt["alpha"], pred["alpha"])
+ metrics = {
+ "epe_all": epe_all.cpu().numpy(),
+ "epe_occ": epe_occ.cpu().numpy(),
+ "epe_vis": epe_vis.cpu().numpy(),
+ "iou": iou.cpu().numpy(),
+ "time": time
+ }
+ return metrics
+
+
+def get_epe(pred, label, vis):
+ diff = torch.norm(pred - label, p=2, dim=-1, keepdim=True)
+ epe_all = torch.mean(diff, dim=(1, 2, 3))
+ vis = vis[..., None]
+ epe_occ = torch.sum(diff * (1 - vis), dim=(1, 2, 3)) / torch.sum((1 - vis), dim=(1, 2, 3))
+ epe_vis = torch.sum((diff * vis), dim=(1, 2, 3)) / torch.sum(vis, dim=(1, 2, 3))
+ return epe_all, epe_occ, epe_vis
+
+
+def get_iou(vis1, vis2):
+ occ1 = (1 - vis1).bool()
+ occ2 = (1 - vis2).bool()
+ intersection = (occ1 & occ2).float().sum(dim=[1, 2])
+ union = (occ1 | occ2).float().sum(dim=[1, 2])
+ iou = intersection / union
+ return iou
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/metrics/tap_metrics.py b/data/dot_single_video/dot/utils/metrics/tap_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..d296459b0bfc9ea616ef7143c50da8088c77c326
--- /dev/null
+++ b/data/dot_single_video/dot/utils/metrics/tap_metrics.py
@@ -0,0 +1,152 @@
+import numpy as np
+from typing import Mapping
+
+
+def compute_metrics(gt, pred, time, query_mode):
+ query_points = gt["query_points"].cpu().numpy()
+ gt_tracks = gt["tracks"][..., :2].permute(0, 2, 1, 3).cpu().numpy()
+ gt_occluded = (1 - gt["tracks"][..., 2]).permute(0, 2, 1).cpu().numpy()
+ pred_tracks = pred["tracks"][..., :2].permute(0, 2, 1, 3).cpu().numpy()
+ pred_occluded = (1 - pred["tracks"][..., 2]).permute(0, 2, 1).cpu().numpy()
+
+ metrics = compute_tapvid_metrics(
+ query_points,
+ gt_occluded,
+ gt_tracks,
+ pred_occluded,
+ pred_tracks,
+ query_mode=query_mode
+ )
+
+ metrics["time"] = time
+
+ return metrics
+
+
+def compute_tapvid_metrics(
+ query_points: np.ndarray,
+ gt_occluded: np.ndarray,
+ gt_tracks: np.ndarray,
+ pred_occluded: np.ndarray,
+ pred_tracks: np.ndarray,
+ query_mode: str,
+) -> Mapping[str, np.ndarray]:
+ """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
+ See the TAP-Vid paper for details on the metric computation. All inputs are
+ given in raster coordinates. The first three arguments should be the direct
+ outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
+ The paper metrics assume these are scaled relative to 256x256 images.
+ pred_occluded and pred_tracks are your algorithm's predictions.
+ This function takes a batch of inputs, and computes metrics separately for
+ each video. The metrics for the full benchmark are a simple mean of the
+ metrics across the full set of videos. These numbers are between 0 and 1,
+ but the paper multiplies them by 100 to ease reading.
+ Args:
+ query_points: The query points, an in the format [t, y, x]. Its size is
+ [b, n, 3], where b is the batch size and n is the number of queries
+ gt_occluded: A boolean array of shape [b, n, t], where t is the number
+ of frames. True indicates that the point is occluded.
+ gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
+ in the format [x, y]
+ pred_occluded: A boolean array of predicted occlusions, in the same
+ format as gt_occluded.
+ pred_tracks: An array of track predictions from your algorithm, in the
+ same format as gt_tracks.
+ query_mode: Either 'first' or 'strided', depending on how queries are
+ sampled. If 'first', we assume the prior knowledge that all points
+ before the query point are occluded, and these are removed from the
+ evaluation.
+ Returns:
+ A dict with the following keys:
+ occlusion_accuracy: Accuracy at predicting occlusion.
+ pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
+ predicted to be within the given pixel threshold, ignoring occlusion
+ prediction.
+ jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
+ threshold
+ average_pts_within_thresh: average across pts_within_{x}
+ average_jaccard: average across jaccard_{x}
+ """
+
+ metrics = {}
+ # Fixed bug is described in:
+ # https://github.com/facebookresearch/co-tracker/issues/20
+ eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
+
+ if query_mode == "first":
+ # evaluate frames after the query frame
+ query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
+ elif query_mode == "strided":
+ # evaluate all frames except the query frame
+ query_frame_to_eval_frames = 1 - eye
+ else:
+ raise ValueError("Unknown query mode " + query_mode)
+
+ query_frame = query_points[..., 0]
+ query_frame = np.round(query_frame).astype(np.int32)
+ evaluation_points = query_frame_to_eval_frames[query_frame] > 0
+
+ # Occlusion accuracy is simply how often the predicted occlusion equals the
+ # ground truth.
+ occ_acc = np.sum(
+ np.equal(pred_occluded, gt_occluded) & evaluation_points,
+ axis=(1, 2),
+ ) / np.sum(evaluation_points)
+ metrics["occlusion_accuracy"] = occ_acc
+
+ # Next, convert the predictions and ground truth positions into pixel
+ # coordinates.
+ visible = np.logical_not(gt_occluded)
+ pred_visible = np.logical_not(pred_occluded)
+ all_frac_within = []
+ all_jaccard = []
+ for thresh in [1, 2, 4, 8, 16]:
+ # True positives are points that are within the threshold and where both
+ # the prediction and the ground truth are listed as visible.
+ within_dist = np.sum(
+ np.square(pred_tracks - gt_tracks),
+ axis=-1,
+ ) < np.square(thresh)
+ is_correct = np.logical_and(within_dist, visible)
+
+ # Compute the frac_within_threshold, which is the fraction of points
+ # within the threshold among points that are visible in the ground truth,
+ # ignoring whether they're predicted to be visible.
+ count_correct = np.sum(
+ is_correct & evaluation_points,
+ axis=(1, 2),
+ )
+ count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
+ frac_correct = count_correct / count_visible_points
+ metrics["pts_within_" + str(thresh)] = frac_correct
+ all_frac_within.append(frac_correct)
+
+ true_positives = np.sum(
+ is_correct & pred_visible & evaluation_points, axis=(1, 2)
+ )
+
+ # The denominator of the jaccard metric is the true positives plus
+ # false positives plus false negatives. However, note that true positives
+ # plus false negatives is simply the number of points in the ground truth
+ # which is easier to compute than trying to compute all three quantities.
+ # Thus we just add the number of points in the ground truth to the number
+ # of false positives.
+ #
+ # False positives are simply points that are predicted to be visible,
+ # but the ground truth is not visible or too far from the prediction.
+ gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
+ false_positives = (~visible) & pred_visible
+ false_positives = false_positives | ((~within_dist) & pred_visible)
+ false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
+ jaccard = true_positives / (gt_positives + false_positives)
+ metrics["jaccard_" + str(thresh)] = jaccard
+ all_jaccard.append(jaccard)
+ metrics["average_jaccard"] = np.mean(
+ np.stack(all_jaccard, axis=1),
+ axis=1,
+ )
+ metrics["average_pts_within_thresh"] = np.mean(
+ np.stack(all_frac_within, axis=1),
+ axis=1,
+ )
+ return metrics
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/options/base_options.py b/data/dot_single_video/dot/utils/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0e5ef6e5559dfe0284418170050525c32ac8126
--- /dev/null
+++ b/data/dot_single_video/dot/utils/options/base_options.py
@@ -0,0 +1,73 @@
+import argparse
+import random
+from datetime import datetime
+
+
+def str2bool(v):
+ if isinstance(v, bool):
+ return v
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
+ return True
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+ return False
+ else:
+ raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+class BaseOptions:
+ def initialize(self, parser):
+ parser.add_argument("--name", type=str)
+ parser.add_argument("--model", type=str, default="dot", choices=["dot", "of", "pt"])
+ parser.add_argument("--datetime", type=str, default=None)
+ parser.add_argument("--data_root", type=str)
+ parser.add_argument("--height", type=int, default=512)
+ parser.add_argument("--width", type=int, default=512)
+ parser.add_argument("--aspect_ratio", type=float, default=1)
+ parser.add_argument("--batch_size", type=int)
+ parser.add_argument("--num_tracks", type=int, default=2048)
+ parser.add_argument("--sim_tracks", type=int, default=2048)
+ parser.add_argument("--alpha_thresh", type=float, default=0.8)
+ parser.add_argument("--is_train", type=str2bool, nargs='?', const=True, default=False)
+
+ # Parallelization
+ parser.add_argument('--worker_idx', type=int, default=0)
+ parser.add_argument("--num_workers", type=int, default=2)
+
+ # Optical flow estimator
+ parser.add_argument("--estimator_config", type=str, default="configs/raft_patch_8.json")
+ parser.add_argument("--estimator_path", type=str, default="checkpoints/cvo_raft_patch_8.pth")
+ parser.add_argument("--flow_mode", type=str, default="direct", choices=["direct", "chain", "warm_start"])
+
+ # Optical flow refiner
+ parser.add_argument("--refiner_config", type=str, default="configs/raft_patch_4_alpha.json")
+ parser.add_argument("--refiner_path", type=str, default="checkpoints/movi_f_raft_patch_4_alpha.pth")
+
+ # Point tracker
+ parser.add_argument("--tracker_config", type=str, default="configs/cotracker2_patch_4_wind_8.json")
+ parser.add_argument("--tracker_path", type=str, default="checkpoints/movi_f_cotracker2_patch_4_wind_8.pth")
+ parser.add_argument("--sample_mode", type=str, default="all", choices=["all", "first", "last"])
+
+ # Dense optical tracker
+ parser.add_argument("--cell_size", type=int, default=1)
+ parser.add_argument("--cell_time_steps", type=int, default=20)
+
+ # Interpolation
+ parser.add_argument("--interpolation_version", type=str, default="torch3d", choices=["torch3d", "torch"])
+ return parser
+
+ def parse_args(self):
+ parser = argparse.ArgumentParser()
+ parser = self.initialize(parser)
+ args = parser.parse_args()
+ if args.datetime is None:
+ args.datetime = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
+ name = f"{args.datetime}_{args.name}_{args.model}"
+ if hasattr(args, 'split'):
+ name += f"_{args.split}"
+ args.checkpoint_path = f"checkpoints/{name}"
+ args.log_path = f"logs/{name}"
+ args.result_path = f"results/{name}"
+ if hasattr(args, 'world_size'):
+ args.batch_size = args.batch_size // args.world_size
+ args.master_port = f'{10000 + random.randrange(1, 10000)}'
+ return args
diff --git a/data/dot_single_video/dot/utils/options/demo_options.py b/data/dot_single_video/dot/utils/options/demo_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..07f01eb06db9ef288fb408d376c0941d78d1baf8
--- /dev/null
+++ b/data/dot_single_video/dot/utils/options/demo_options.py
@@ -0,0 +1,23 @@
+from .base_options import BaseOptions, str2bool
+
+
+class DemoOptions(BaseOptions):
+ def initialize(self, parser):
+ BaseOptions.initialize(self, parser)
+ parser.add_argument("--inference_mode", type=str, default="tracks_from_first_to_every_other_frame")
+ parser.add_argument("--visualization_modes", type=str, nargs="+", default=["overlay", "spaghetti_last_static"])
+ parser.add_argument("--video_path", type=str, default="orange.mp4")
+ parser.add_argument("--mask_path", type=str, default="orange.png")
+ parser.add_argument("--save_tracks", type=str2bool, nargs='?', const=True, default=False)
+ parser.add_argument("--recompute_tracks", type=str2bool, nargs='?', const=True, default=False)
+ parser.add_argument("--overlay_factor", type=float, default=0.75)
+ parser.add_argument("--rainbow_mode", type=str, default="left_right", choices=["left_right", "up_down"])
+ parser.add_argument("--save_mode", type=str, default="video", choices=["image", "video"])
+ parser.add_argument("--spaghetti_radius", type=float, default=1.5)
+ parser.add_argument("--spaghetti_length", type=int, default=40)
+ parser.add_argument("--spaghetti_grid", type=int, default=30)
+ parser.add_argument("--spaghetti_scale", type=float, default=2)
+ parser.add_argument("--spaghetti_every", type=int, default=10)
+ parser.add_argument("--spaghetti_dropout", type=float, default=0)
+ parser.set_defaults(data_root="datasets/demo", name="demo", batch_size=1, height=480, width=856, num_tracks=8192)
+ return parser
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/options/preprocess_options.py b/data/dot_single_video/dot/utils/options/preprocess_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bf90cb5f65b887a0b8e4a2a0f53cf269cff6062
--- /dev/null
+++ b/data/dot_single_video/dot/utils/options/preprocess_options.py
@@ -0,0 +1,13 @@
+from .base_options import BaseOptions, str2bool
+
+
+class PreprocessOptions(BaseOptions):
+ def initialize(self, parser):
+ BaseOptions.initialize(self, parser)
+ parser.add_argument("--extract_movi_f", type=str2bool, nargs='?', const=True, default=False)
+ parser.add_argument("--save_tracks", type=str2bool, nargs='?', const=True, default=False)
+ parser.add_argument('--download_path', type=str, default="gs://kubric-public/tfds")
+ parser.add_argument('--num_videos', type=int, default=11000)
+ parser.set_defaults(data_root="datasets/kubric/movi_f", name="preprocess", num_workers=2, num_tracks=2048,
+ model="pt")
+ return parser
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/options/test_cvo_options.py b/data/dot_single_video/dot/utils/options/test_cvo_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a1d634c40ebaaf6c28a7513028652c0da4d75be
--- /dev/null
+++ b/data/dot_single_video/dot/utils/options/test_cvo_options.py
@@ -0,0 +1,15 @@
+from .base_options import BaseOptions, str2bool
+
+
+class TestOptions(BaseOptions):
+ def initialize(self, parser):
+ BaseOptions.initialize(self, parser)
+ parser.add_argument("--split", type=str, choices=["clean", "final", "extended"], default="clean")
+ parser.add_argument("--filter", type=str2bool, nargs='?', const=True, default=True)
+ parser.add_argument('--filter_indices', type=int, nargs="+",
+ default=[70, 77, 93, 96, 140, 143, 162, 172, 174, 179, 187, 215, 236, 284, 285, 293, 330,
+ 358, 368, 402, 415, 458, 483, 495, 534])
+ parser.add_argument('--plot_indices', type=int, nargs="+", default=[])
+ parser.set_defaults(data_root="datasets/kubric/cvo", name="test_cvo", batch_size=1, num_workers=0,
+ sample_mode="last")
+ return parser
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/options/test_tap_options.py b/data/dot_single_video/dot/utils/options/test_tap_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cc3ca5a82f752e77496721ef51a2f66f02ddc0b
--- /dev/null
+++ b/data/dot_single_video/dot/utils/options/test_tap_options.py
@@ -0,0 +1,11 @@
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+ def initialize(self, parser):
+ BaseOptions.initialize(self, parser)
+ parser.add_argument("--split", type=str, choices=["davis", "rgb_stacking", "kinetics"], default="davis")
+ parser.add_argument("--query_mode", type=str, default="first", choices=["first", "strided"])
+ parser.add_argument('--plot_indices', type=int, nargs="+", default=[])
+ parser.set_defaults(data_root="datasets/tap", name="test_tap", batch_size=1, num_workers=0, num_tracks=8192)
+ return parser
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/options/train_options.py b/data/dot_single_video/dot/utils/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..215b491f5f35faddd9737e24b0597f59c8257661
--- /dev/null
+++ b/data/dot_single_video/dot/utils/options/train_options.py
@@ -0,0 +1,27 @@
+from .base_options import BaseOptions
+
+
+class TrainOptions(BaseOptions):
+ def initialize(self, parser):
+ BaseOptions.initialize(self, parser)
+ parser.add_argument("--in_track_name", type=str, default="cotracker")
+ parser.add_argument("--out_track_name", type=str, default="ground_truth")
+ parser.add_argument("--num_in_tracks", type=int, default=2048)
+ parser.add_argument("--num_out_tracks", type=int, default=2048)
+ parser.add_argument("--batch_size_valid", type=int, default=4)
+ parser.add_argument("--train_iter", type=int, default=1000000)
+ parser.add_argument("--log_iter", type=int, default=10000)
+ parser.add_argument("--log_factor", type=float, default=1.)
+ parser.add_argument("--print_iter", type=int, default=100)
+ parser.add_argument("--valid_iter", type=int, default=10000)
+ parser.add_argument("--num_valid_batches", type=int, default=24)
+ parser.add_argument("--save_iter", type=int, default=1000)
+ parser.add_argument("--lr", type=float, default=0.0001)
+ parser.add_argument("--world_size", type=int, default=1)
+ parser.add_argument("--valid_ratio", type=float, default=0.01)
+ parser.add_argument("--lambda_motion_loss", type=float, default=1.)
+ parser.add_argument("--lambda_visibility_loss", type=float, default=1.)
+ parser.add_argument("--optimizer_path", type=str, default=None)
+ parser.set_defaults(data_root="datasets/kubric/movi_f", name="train", batch_size=8, refiner_path=None,
+ is_train=True, model="ofr")
+ return parser
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/plot.py b/data/dot_single_video/dot/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..de035d7a02e49662dd0918fdd10e5e52cb087960
--- /dev/null
+++ b/data/dot_single_video/dot/utils/plot.py
@@ -0,0 +1,197 @@
+import os.path as osp
+import matplotlib
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from dot.utils.io import create_folder
+
+
+def to_rgb(tensor, mode, tracks=None, is_video=False, to_torch=True, reshape_as_video=False):
+ if isinstance(tensor, list):
+ tensor = torch.stack(tensor)
+ tensor = tensor.cpu().numpy()
+ if is_video:
+ batch_size, time_steps = tensor.shape[:2]
+ if mode == "flow":
+ height, width = tensor.shape[-3: -1]
+ tensor = np.reshape(tensor, (-1, height, width, 2))
+ tensor = flow_to_rgb(tensor)
+ elif mode == "mask":
+ height, width = tensor.shape[-2:]
+ tensor = np.reshape(tensor, (-1, 1, height, width))
+ tensor = np.repeat(tensor, 3, axis=1)
+ else:
+ height, width = tensor.shape[-2:]
+ tensor = np.reshape(tensor, (-1, 3, height, width))
+ if tracks is not None:
+ samples = tracks.size(-2)
+ tracks = tracks.cpu().numpy()
+ tracks = np.reshape(tracks, (-1, samples, 3))
+ traj, occ = tracks[..., :2], 1 - tracks[..., 2]
+ if is_video:
+ tensor = np.reshape(tensor, (-1, time_steps, 3, height, width))
+ traj = np.reshape(traj, (-1, time_steps, samples, 2))
+ occ = np.reshape(occ, (-1, time_steps, samples))
+ new_tensor = []
+ for t in range(time_steps):
+ pos_t = traj[:, t]
+ occ_t = occ[:, t]
+ new_tensor.append(plot_tracks(tensor[:, t], pos_t, occ_t, tracks=traj[:, :t + 1]))
+ tensor = np.stack(new_tensor, axis=1)
+ else:
+ tensor = plot_tracks(tensor, traj, occ)
+ if is_video and reshape_as_video:
+ tensor = np.reshape(tensor, (batch_size, time_steps, 3, height, width))
+ else:
+ tensor = np.reshape(tensor, (-1, 3, height, width))
+ if to_torch:
+ tensor = torch.from_numpy(tensor)
+ return tensor
+
+
+def flow_to_rgb(flow, transparent=False):
+ flow = np.copy(flow)
+ H, W = flow.shape[-3: -1]
+ mul = 20.
+ scaling = mul / (H ** 2 + W ** 2) ** 0.5
+ direction = (np.arctan2(flow[..., 0], flow[..., 1]) + np.pi) / (2 * np.pi)
+ norm = np.linalg.norm(flow, axis=-1)
+ magnitude = np.clip(norm * scaling, 0., 1.)
+ saturation = np.ones_like(direction)
+ if transparent:
+ hsv = np.stack([direction, saturation, np.ones_like(magnitude)], axis=-1)
+ else:
+ hsv = np.stack([direction, saturation, magnitude], axis=-1)
+ rgb = matplotlib.colors.hsv_to_rgb(hsv)
+ rgb = np.moveaxis(rgb, -1, -3)
+ if transparent:
+ return np.concatenate([rgb, np.expand_dims(magnitude, axis=-3)], axis=-3)
+ return rgb
+
+
+def plot_tracks(rgb, points, occluded, tracks=None, trackgroup=None):
+ """Plot tracks with matplotlib.
+ Adapted from: https://github.com/google-research/kubric/blob/main/challenges/point_tracking/dataset.py"""
+ rgb = rgb.transpose(0, 2, 3, 1)
+ _, height, width, _ = rgb.shape
+ points = points.transpose(1, 0, 2).copy() # clone, otherwise it updates points array
+ # points[..., 0] *= (width - 1)
+ # points[..., 1] *= (height - 1)
+ if tracks is not None:
+ tracks = tracks.copy()
+ # tracks[..., 0] *= (width - 1)
+ # tracks[..., 1] *= (height - 1)
+ if occluded is not None:
+ occluded = occluded.transpose(1, 0)
+ disp = []
+ cmap = plt.cm.hsv
+
+ z_list = np.arange(points.shape[0]) if trackgroup is None else np.array(trackgroup)
+ # random permutation of the colors so nearby points in the list can get different colors
+ np.random.seed(0)
+ z_list = np.random.permutation(np.max(z_list) + 1)[z_list]
+ colors = cmap(z_list / (np.max(z_list) + 1))
+ figure_dpi = 64
+
+ for i in range(rgb.shape[0]):
+ fig = plt.figure(
+ figsize=(width / figure_dpi, height / figure_dpi),
+ dpi=figure_dpi,
+ frameon=False,
+ facecolor='w')
+ ax = fig.add_subplot()
+ ax.axis('off')
+ ax.imshow(rgb[i])
+
+ valid = points[:, i, 0] > 0
+ valid = np.logical_and(valid, points[:, i, 0] < rgb.shape[2] - 1)
+ valid = np.logical_and(valid, points[:, i, 1] > 0)
+ valid = np.logical_and(valid, points[:, i, 1] < rgb.shape[1] - 1)
+
+ if occluded is not None:
+ colalpha = np.concatenate([colors[:, :-1], 1 - occluded[:, i:i + 1]], axis=1)
+ else:
+ colalpha = colors[:, :-1]
+ # Note: matplotlib uses pixel coordinates, not raster.
+ ax.scatter(
+ points[valid, i, 0] - 0.5,
+ points[valid, i, 1] - 0.5,
+ s=3,
+ c=colalpha[valid],
+ )
+
+ if tracks is not None:
+ for j in range(tracks.shape[2]):
+ track_color = colors[j] # Use a different color for each track
+ x = tracks[i, :, j, 0]
+ y = tracks[i, :, j, 1]
+ valid_track = x > 0
+ valid_track = np.logical_and(valid_track, x < rgb.shape[2] - 1)
+ valid_track = np.logical_and(valid_track, y > 0)
+ valid_track = np.logical_and(valid_track, y < rgb.shape[1] - 1)
+ ax.plot(x[valid_track] - 0.5, y[valid_track] - 0.5, color=track_color, marker=None)
+
+ if occluded is not None:
+ occ2 = occluded[:, i:i + 1]
+
+ colalpha = np.concatenate([colors[:, :-1], occ2], axis=1)
+
+ ax.scatter(
+ points[valid, i, 0],
+ points[valid, i, 1],
+ s=20,
+ facecolors='none',
+ edgecolors=colalpha[valid],
+ )
+
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
+ plt.margins(0, 0)
+ fig.canvas.draw()
+ width, height = fig.get_size_inches() * fig.get_dpi()
+ img = np.frombuffer(
+ fig.canvas.tostring_rgb(),
+ dtype='uint8').reshape(int(height), int(width), 3)
+ disp.append(np.copy(img))
+ plt.close("all")
+
+ return np.stack(disp, axis=0).astype(float).transpose(0, 3, 1, 2) / 255 # TODO : inconsistent
+
+
+def plot_points(src_frame, tgt_frame, src_points, tgt_points, save_path, max_points=256):
+ _, H, W = src_frame.shape
+ src_frame = src_frame.permute(1, 2, 0).cpu().numpy()
+ tgt_frame = tgt_frame.permute(1, 2, 0).cpu().numpy()
+ src_points = src_points.cpu().numpy()
+ tgt_points = tgt_points.cpu().numpy()
+ src_pos, src_alpha = src_points[..., :2], src_points[..., 2]
+ tgt_pos, tgt_alpha = tgt_points[..., :2], tgt_points[..., 2]
+ src_pos = np.stack([src_pos[..., 0] * (W - 1), src_pos[..., 1] * (H - 1)], axis=-1)
+ tgt_pos = np.stack([tgt_pos[..., 0] * (W - 1), tgt_pos[..., 1] * (H - 1)], axis=-1)
+
+ plt.figure()
+ ax = plt.gca()
+ P = 10
+ plt.imshow(np.concatenate((src_frame, np.ones_like(src_frame[:, :P]), tgt_frame), axis=1))
+ indices = np.random.choice(len(src_pos), size=min(max_points, len(src_pos)), replace=False)
+ for i in indices:
+ if src_alpha[i] == 1:
+ ax.scatter(src_pos[i, 0], src_pos[i, 1], s=5, c="black", marker='x')
+ else:
+ ax.scatter(src_pos[i, 0], src_pos[i, 1], s=5, linewidths=1.5, c="black", marker='o')
+ ax.scatter(src_pos[i, 0], src_pos[i, 1], s=2.5, c="white", marker='o')
+ if tgt_alpha[i] == 1:
+ ax.scatter(W + P + tgt_pos[i, 0], tgt_pos[i, 1], s=5, c="black", marker='x')
+ else:
+ ax.scatter(W + P + tgt_pos[i, 0], tgt_pos[i, 1], s=5, linewidths=1.5, c="black", marker='o')
+ ax.scatter(W + P + tgt_pos[i, 0], tgt_pos[i, 1], s=2.5, c="white", marker='o')
+
+ plt.plot([src_pos[i, 0], W + P + tgt_pos[i, 0]], [src_pos[i, 1], tgt_pos[i, 1]], linewidth=0.5, c="black")
+
+ # Save
+ ax.axis('off')
+ plt.tight_layout()
+ plt.subplots_adjust(wspace=0)
+ create_folder(osp.dirname(save_path))
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
+ plt.close()
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/torch.py b/data/dot_single_video/dot/utils/torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dd06e5ac671aee78d40523f00b759ef5abbc288
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch.py
@@ -0,0 +1,133 @@
+import numpy as np
+import torch
+import torch.distributed as dist
+
+
+def reduce(tensor, world_size):
+ if isinstance(tensor, torch.Tensor):
+ tensor = tensor.clone()
+ dist.all_reduce(tensor, dist.ReduceOp.SUM)
+ tensor.div_(world_size)
+ return tensor
+
+
+def expand(mask, num=1):
+ # mask: ... H W
+ # -----------------
+ # mask: ... H W
+ for _ in range(num):
+ mask[..., 1:, :] = mask[..., 1:, :] | mask[..., :-1, :]
+ mask[..., :-1, :] = mask[..., :-1, :] | mask[..., 1:, :]
+ mask[..., :, 1:] = mask[..., :, 1:] | mask[..., :, :-1]
+ mask[..., :, :-1] = mask[..., :, :-1] | mask[..., :, 1:]
+ return mask
+
+
+def differentiate(mask):
+ # mask: ... H W
+ # -----------------
+ # diff: ... H W
+ diff = torch.zeros_like(mask).bool()
+ diff_y = mask[..., 1:, :] != mask[..., :-1, :]
+ diff_x = mask[..., :, 1:] != mask[..., :, :-1]
+ diff[..., 1:, :] = diff[..., 1:, :] | diff_y
+ diff[..., :-1, :] = diff[..., :-1, :] | diff_y
+ diff[..., :, 1:] = diff[..., :, 1:] | diff_x
+ diff[..., :, :-1] = diff[..., :, :-1] | diff_x
+ return diff
+
+
+def sample_points(step, boundaries, num_samples):
+ if boundaries.ndim == 3:
+ points = []
+ for boundaries_k in boundaries:
+ points_k = sample_points(step, boundaries_k, num_samples)
+ points.append(points_k)
+ points = torch.stack(points)
+ else:
+ H, W = boundaries.shape
+ boundary_points, _ = sample_mask_points(step, boundaries, num_samples // 2)
+ num_boundary_points = boundary_points.shape[0]
+ num_random_points = num_samples - num_boundary_points
+ random_points = sample_random_points(step, H, W, num_random_points)
+ random_points = random_points.to(boundary_points.device)
+ points = torch.cat((boundary_points, random_points), dim=0)
+ return points
+
+
+def sample_mask_points(step, mask, num_points):
+ num_nonzero = int(mask.sum())
+ i, j = torch.nonzero(mask, as_tuple=True)
+ if num_points < num_nonzero:
+ sample = np.random.choice(num_nonzero, size=num_points, replace=False)
+ i, j = i[sample], j[sample]
+ t = torch.ones_like(i) * step
+ x, y = j, i
+ points = torch.stack((t, x, y), dim=-1) # [num_points, 3]
+ return points.float(), (i, j)
+
+
+def sample_random_points(step, height, width, num_points):
+ x = torch.randint(width, size=[num_points])
+ y = torch.randint(height, size=[num_points])
+ t = torch.ones(num_points) * step
+ points = torch.stack((t, x, y), dim=-1) # [num_points, 3]
+ return points.float()
+
+
+def get_grid(height, width, shape=None, dtype="torch", device="cpu", align_corners=True, normalize=True):
+ H, W = height, width
+ S = shape if shape else []
+ if align_corners:
+ x = torch.linspace(0, 1, W, device=device)
+ y = torch.linspace(0, 1, H, device=device)
+ if not normalize:
+ x = x * (W - 1)
+ y = y * (H - 1)
+ else:
+ x = torch.linspace(0.5 / W, 1.0 - 0.5 / W, W, device=device)
+ y = torch.linspace(0.5 / H, 1.0 - 0.5 / H, H, device=device)
+ if not normalize:
+ x = x * W
+ y = y * H
+ x_view, y_view, exp = [1 for _ in S] + [1, -1], [1 for _ in S] + [-1, 1], S + [H, W]
+ x = x.view(*x_view).expand(*exp)
+ y = y.view(*y_view).expand(*exp)
+ grid = torch.stack([x, y], dim=-1)
+ if dtype == "numpy":
+ grid = grid.numpy()
+ return grid
+
+
+def get_sobel_kernel(kernel_size):
+ K = kernel_size
+ sobel = torch.tensor(list(range(K))) - K // 2
+ sobel_x, sobel_y = sobel.view(-1, 1), sobel.view(1, -1)
+ sum_xy = sobel_x ** 2 + sobel_y ** 2
+ sum_xy[sum_xy == 0] = 1
+ sobel_x, sobel_y = sobel_x / sum_xy, sobel_y / sum_xy
+ sobel_kernel = torch.stack([sobel_x.unsqueeze(0), sobel_y.unsqueeze(0)], dim=0)
+ return sobel_kernel
+
+
+def to_device(data, device):
+ data = {k: v.to(device) for k, v in data.items()}
+ return data
+
+
+def get_alpha_consistency(bflow, fflow, thresh_1=0.01, thresh_2=0.5, thresh_mul=1):
+ norm = lambda x: x.pow(2).sum(dim=-1).sqrt()
+ B, H, W, C = bflow.shape
+
+ mag = norm(fflow) + norm(bflow)
+ grid = get_grid(H, W, shape=[B], device=fflow.device)
+ grid[..., 0] = grid[..., 0] + bflow[..., 0] / (W - 1)
+ grid[..., 1] = grid[..., 1] + bflow[..., 1] / (H - 1)
+ grid = grid * 2 - 1
+ fflow_warped = torch.nn.functional.grid_sample(fflow.permute(0, 3, 1, 2), grid, mode="bilinear", align_corners=True)
+ flow_diff = bflow + fflow_warped.permute(0, 2, 3, 1)
+ occ_thresh = thresh_1 * mag + thresh_2
+ occ_thresh = occ_thresh * thresh_mul
+ alpha = norm(flow_diff) < occ_thresh
+ alpha = alpha.float()
+ return alpha
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/torch3d/LICENSE b/data/dot_single_video/dot/utils/torch3d/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..caab102f8b1bb5578bea0395d1a3c8dd62da6308
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/LICENSE
@@ -0,0 +1,30 @@
+BSD License
+
+For PyTorch3D software
+
+Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
+
+Redistribution and use in source and binary forms, with or without modification,
+are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+ * Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+ * Neither the name Meta nor the names of its contributors may be used to
+ endorse or promote products derived from this software without specific
+ prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/data/dot_single_video/dot/utils/torch3d/__init__.py b/data/dot_single_video/dot/utils/torch3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28195c1fbca7697dc428da32823467d9c4fc427d
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/__init__.py
@@ -0,0 +1,2 @@
+from .knn import knn_points
+from .packed_to_padded import packed_to_padded
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/build.ninja b/data/dot_single_video/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/build.ninja
new file mode 100644
index 0000000000000000000000000000000000000000..29a603cfd9c35f4ab0399c9e501551e2cae251e8
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/build.ninja
@@ -0,0 +1,36 @@
+ninja_required_version = 1.3
+cxx = c++
+nvcc = /usr/local/cuda/bin/nvcc
+
+cflags = -pthread -B /mnt/zhongwei/subapp/miniconda3/envs/torch2/compiler_compat -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -fPIC -O2 -isystem /mnt/zhongwei/subapp/miniconda3/envs/torch2/include -fPIC -O2 -isystem /mnt/zhongwei/subapp/miniconda3/envs/torch2/include -fPIC -DWITH_CUDA -DTHRUST_IGNORE_CUB_VERSION_CHECK -I/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include/TH -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include/THC -I/usr/local/cuda/include -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/include/python3.10 -c
+post_cflags = -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0
+cuda_cflags = -DWITH_CUDA -DTHRUST_IGNORE_CUB_VERSION_CHECK -I/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include/TH -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/lib/python3.10/site-packages/torch/include/THC -I/usr/local/cuda/include -I/mnt/zhongwei/subapp/miniconda3/envs/torch2/include/python3.10 -c
+cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_80,code=compute_80 -gencode=arch=compute_80,code=sm_80
+cuda_dlink_post_cflags =
+ldflags =
+
+rule compile
+ command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags
+ depfile = $out.d
+ deps = gcc
+
+rule cuda_compile
+ depfile = $out.d
+ deps = gcc
+ command = $nvcc --generate-dependencies-with-compile --dependency-output $out.d $cuda_cflags -c $in -o $out $cuda_post_cflags
+
+
+
+
+
+build /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/ext.o: compile /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/ext.cpp
+build /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/knn/knn.o: cuda_compile /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/knn/knn.cu
+build /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/knn/knn_cpu.o: compile /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/knn/knn_cpu.cpp
+build /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.o: cuda_compile /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu
+build /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/build/temp.linux-x86_64-cpython-310/mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.o: compile /mnt/zhongwei/zhongwei/all_good_tools/dot_all/24_06_06/dot_single_video/dot_ori/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp
+
+
+
+
+
+
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/ext.cpp b/data/dot_single_video/dot/utils/torch3d/csrc/ext.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..ab0337e5db86f2aaea9eaf3d7049fd40d98d4884
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/ext.cpp
@@ -0,0 +1,23 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+// clang-format off
+#include
+// clang-format on
+#include "knn/knn.h"
+#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("packed_to_padded", &PackedToPadded);
+ m.def("padded_to_packed", &PaddedToPacked);
+#ifdef WITH_CUDA
+ m.def("knn_check_version", &KnnCheckVersion);
+#endif
+ m.def("knn_points_idx", &KNearestNeighborIdx);
+ m.def("knn_points_backward", &KNearestNeighborBackward);
+}
\ No newline at end of file
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn.cu b/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn.cu
new file mode 100644
index 0000000000000000000000000000000000000000..633065c991eaa036f1b2041000c81a2638430a1a
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn.cu
@@ -0,0 +1,584 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "utils/dispatch.cuh"
+#include "utils/mink.cuh"
+
+// A chunk of work is blocksize-many points of P1.
+// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
+// call (1+(P1-1)/blocksize) chunks_per_cloud
+// These chunks are divided among the gridSize-many blocks.
+// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
+// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
+// blocksize*(i%chunks_per_cloud).
+
+template
+__global__ void KNearestNeighborKernelV0(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t D,
+ const size_t K,
+ const size_t norm) {
+ // Store both dists and indices for knn in global memory.
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ int offset = n * P1 * K + p1 * K;
+ int64_t length2 = lengths2[n];
+ MinK mink(dists + offset, idxs + offset, K);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ // Find the distance between points1[n, p1] and points[n, p2]
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
+ scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
+ scalar_t diff = coord1 - coord2;
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ }
+}
+
+template
+__global__ void KNearestNeighborKernelV1(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t K,
+ const size_t norm) {
+ // Same idea as the previous version, but hoist D into a template argument
+ // so we can cache the current point in a thread-local array. We still store
+ // the current best K dists and indices in global memory, so this should work
+ // for very large K and fairly large D.
+ scalar_t cur_point[D];
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ for (int d = 0; d < D; ++d) {
+ cur_point[d] = points1[n * P1 * D + p1 * D + d];
+ }
+ int offset = n * P1 * K + p1 * K;
+ int64_t length2 = lengths2[n];
+ MinK mink(dists + offset, idxs + offset, K);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ // Find the distance between cur_point and points[n, p2]
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ }
+}
+
+// This is a shim functor to allow us to dispatch using DispatchKernel1D
+template
+struct KNearestNeighborV1Functor {
+ static void run(
+ size_t blocks,
+ size_t threads,
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t K,
+ const size_t norm) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ KNearestNeighborKernelV1<<>>(
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
+ }
+};
+
+template
+__global__ void KNearestNeighborKernelV2(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const int64_t N,
+ const int64_t P1,
+ const int64_t P2,
+ const size_t norm) {
+ // Same general implementation as V2, but also hoist K into a template arg.
+ scalar_t cur_point[D];
+ scalar_t min_dists[K];
+ int min_idxs[K];
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ for (int d = 0; d < D; ++d) {
+ cur_point[d] = points1[n * P1 * D + p1 * D + d];
+ }
+ int64_t length2 = lengths2[n];
+ MinK mink(min_dists, min_idxs, K);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ int offset = n * P2 * D + p2 * D + d;
+ scalar_t diff = cur_point[d] - points2[offset];
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ for (int k = 0; k < mink.size(); ++k) {
+ idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
+ dists[n * P1 * K + p1 * K + k] = min_dists[k];
+ }
+ }
+}
+
+// This is a shim so we can dispatch using DispatchKernel2D
+template
+struct KNearestNeighborKernelV2Functor {
+ static void run(
+ size_t blocks,
+ size_t threads,
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const int64_t N,
+ const int64_t P1,
+ const int64_t P2,
+ const size_t norm) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ KNearestNeighborKernelV2<<>>(
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
+ }
+};
+
+template
+__global__ void KNearestNeighborKernelV3(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t norm) {
+ // Same idea as V2, but use register indexing for thread-local arrays.
+ // Enabling sorting for this version leads to huge slowdowns; I suspect
+ // that it forces min_dists into local memory rather than registers.
+ // As a result this version is always unsorted.
+ scalar_t cur_point[D];
+ scalar_t min_dists[K];
+ int min_idxs[K];
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ for (int d = 0; d < D; ++d) {
+ cur_point[d] = points1[n * P1 * D + p1 * D + d];
+ }
+ int64_t length2 = lengths2[n];
+ RegisterMinK mink(min_dists, min_idxs);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ int offset = n * P2 * D + p2 * D + d;
+ scalar_t diff = cur_point[d] - points2[offset];
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ for (int k = 0; k < mink.size(); ++k) {
+ idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
+ dists[n * P1 * K + p1 * K + k] = min_dists[k];
+ }
+ }
+}
+
+// This is a shim so we can dispatch using DispatchKernel2D
+template
+struct KNearestNeighborKernelV3Functor {
+ static void run(
+ size_t blocks,
+ size_t threads,
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t norm) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ KNearestNeighborKernelV3<<>>(
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
+ }
+};
+
+constexpr int V1_MIN_D = 1;
+constexpr int V1_MAX_D = 32;
+
+constexpr int V2_MIN_D = 1;
+constexpr int V2_MAX_D = 8;
+constexpr int V2_MIN_K = 1;
+constexpr int V2_MAX_K = 32;
+
+constexpr int V3_MIN_D = 1;
+constexpr int V3_MAX_D = 8;
+constexpr int V3_MIN_K = 1;
+constexpr int V3_MAX_K = 4;
+
+bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
+ return min <= x && x <= max;
+}
+
+bool KnnCheckVersion(int version, const int64_t D, const int64_t K) {
+ if (version == 0) {
+ return true;
+ } else if (version == 1) {
+ return InBounds(V1_MIN_D, D, V1_MAX_D);
+ } else if (version == 2) {
+ return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K);
+ } else if (version == 3) {
+ return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K);
+ }
+ return false;
+}
+
+int ChooseVersion(const int64_t D, const int64_t K) {
+ for (int version = 3; version >= 1; version--) {
+ if (KnnCheckVersion(version, D, K)) {
+ return version;
+ }
+ }
+ return 0;
+}
+
+std::tuple KNearestNeighborIdxCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K,
+ int version) {
+ // Check inputs are on the same device
+ at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
+ lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
+ at::CheckedFrom c = "KNearestNeighborIdxCuda";
+ at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
+ at::checkAllSameType(c, {p1_t, p2_t});
+
+ // Set the device for the kernel launch based on the device of the input
+ at::cuda::CUDAGuard device_guard(p1.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ const auto N = p1.size(0);
+ const auto P1 = p1.size(1);
+ const auto P2 = p2.size(1);
+ const auto D = p2.size(2);
+ const int64_t K_64 = K;
+
+ TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
+
+ TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
+ auto long_dtype = lengths1.options().dtype(at::kLong);
+ auto idxs = at::zeros({N, P1, K}, long_dtype);
+ auto dists = at::zeros({N, P1, K}, p1.options());
+
+ if (idxs.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(idxs, dists);
+ }
+
+ if (version < 0) {
+ version = ChooseVersion(D, K);
+ } else if (!KnnCheckVersion(version, D, K)) {
+ int new_version = ChooseVersion(D, K);
+ std::cout << "WARNING: Requested KNN version " << version
+ << " is not compatible with D = " << D << "; K = " << K
+ << ". Falling back to version = " << new_version << std::endl;
+ version = new_version;
+ }
+
+ // At this point we should have a valid version no matter what data the user
+ // gave us. But we can check once more to be sure; however this time
+ // assert fail since failing at this point means we have a bug in our version
+ // selection or checking code.
+ AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version");
+
+ const size_t threads = 256;
+ const size_t blocks = 256;
+ if (version == 0) {
+ AT_DISPATCH_FLOATING_TYPES(
+ p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ KNearestNeighborKernelV0<<>>(
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ D,
+ K,
+ norm);
+ }));
+ } else if (version == 1) {
+ AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ DispatchKernel1D<
+ KNearestNeighborV1Functor,
+ scalar_t,
+ V1_MIN_D,
+ V1_MAX_D>(
+ D,
+ blocks,
+ threads,
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ K,
+ norm);
+ }));
+ } else if (version == 2) {
+ AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ DispatchKernel2D<
+ KNearestNeighborKernelV2Functor,
+ scalar_t,
+ V2_MIN_D,
+ V2_MAX_D,
+ V2_MIN_K,
+ V2_MAX_K>(
+ D,
+ K_64,
+ blocks,
+ threads,
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ norm);
+ }));
+ } else if (version == 3) {
+ AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ DispatchKernel2D<
+ KNearestNeighborKernelV3Functor,
+ scalar_t,
+ V3_MIN_D,
+ V3_MAX_D,
+ V3_MIN_K,
+ V3_MAX_K>(
+ D,
+ K_64,
+ blocks,
+ threads,
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ norm);
+ }));
+ }
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(idxs, dists);
+}
+
+// ------------------------------------------------------------- //
+// Backward Operators //
+// ------------------------------------------------------------- //
+
+// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
+// Currently, support is for floats only.
+__global__ void KNearestNeighborBackwardKernel(
+ const float* __restrict__ p1, // (N, P1, D)
+ const float* __restrict__ p2, // (N, P2, D)
+ const int64_t* __restrict__ lengths1, // (N,)
+ const int64_t* __restrict__ lengths2, // (N,)
+ const int64_t* __restrict__ idxs, // (N, P1, K)
+ const float* __restrict__ grad_dists, // (N, P1, K)
+ float* __restrict__ grad_p1, // (N, P1, D)
+ float* __restrict__ grad_p2, // (N, P2, D)
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t K,
+ const size_t D,
+ const size_t norm) {
+ const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+ const size_t stride = gridDim.x * blockDim.x;
+
+ for (size_t i = tid; i < N * P1 * K * D; i += stride) {
+ const size_t n = i / (P1 * K * D); // batch index
+ size_t rem = i % (P1 * K * D);
+ const size_t p1_idx = rem / (K * D); // index of point in p1
+ rem = rem % (K * D);
+ const size_t k = rem / D; // k-th nearest neighbor
+ const size_t d = rem % D; // d-th dimension in the feature vector
+
+ const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
+ const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
+ if ((p1_idx < num1) && (k < num2)) {
+ const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
+ // index of point in p2 corresponding to the k-th nearest neighbor
+ const size_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
+ // If the index is the pad value of -1 then ignore it
+ if (p2_idx == -1) {
+ continue;
+ }
+ float diff = 0.0;
+ if (norm == 1) {
+ float sign =
+ (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
+ ? 1.0
+ : -1.0;
+ diff = grad_dist * sign;
+ } else { // norm is 2
+ diff = 2.0 * grad_dist *
+ (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
+ }
+ atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
+ atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
+ }
+ }
+}
+
+std::tuple KNearestNeighborBackwardCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ int norm,
+ const at::Tensor& grad_dists) {
+ // Check inputs are on the same device
+ at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
+ lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
+ idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
+ at::CheckedFrom c = "KNearestNeighborBackwardCuda";
+ at::checkAllSameGPU(
+ c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
+ at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
+
+ // Set the device for the kernel launch based on the device of the input
+ at::cuda::CUDAGuard device_guard(p1.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ const auto N = p1.size(0);
+ const auto P1 = p1.size(1);
+ const auto P2 = p2.size(1);
+ const auto D = p2.size(2);
+ const auto K = idxs.size(2);
+
+ TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
+ TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
+ TORCH_CHECK(
+ idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
+ TORCH_CHECK(grad_dists.size(0) == N);
+ TORCH_CHECK(grad_dists.size(1) == P1);
+ TORCH_CHECK(grad_dists.size(2) == K);
+
+ auto grad_p1 = at::zeros({N, P1, D}, p1.options());
+ auto grad_p2 = at::zeros({N, P2, D}, p2.options());
+
+ if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(grad_p1, grad_p2);
+ }
+
+ const int blocks = 64;
+ const int threads = 512;
+
+ KNearestNeighborBackwardKernel<<>>(
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ idxs.contiguous().data_ptr(),
+ grad_dists.contiguous().data_ptr(),
+ grad_p1.data_ptr(),
+ grad_p2.data_ptr(),
+ N,
+ P1,
+ P2,
+ K,
+ D,
+ norm);
+
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(grad_p1, grad_p2);
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn.h b/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn.h
new file mode 100644
index 0000000000000000000000000000000000000000..c27126cf52ac273f8a46313571648cb7b1fdf1f5
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn.h
@@ -0,0 +1,157 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+#include "utils/pytorch3d_cutils.h"
+
+// Compute indices of K nearest neighbors in pointcloud p2 to points
+// in pointcloud p1.
+//
+// Args:
+// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
+// containing P1 points of dimension D.
+// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
+// containing P2 points of dimension D.
+// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
+// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
+// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
+// K: int giving the number of nearest points to return.
+// version: Integer telling which implementation to use.
+//
+// Returns:
+// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
+// p1_neighbor_idx[n, i, k] = j means that the kth nearest
+// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
+// It is padded with zeros so that it can be used easily in a later
+// gather() operation.
+//
+// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
+// distance from each point p1[n, p, :] to its K neighbors
+// p2[n, p1_neighbor_idx[n, p, k], :].
+
+// CPU implementation.
+std::tuple KNearestNeighborIdxCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K);
+
+// CUDA implementation
+std::tuple KNearestNeighborIdxCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K,
+ const int version);
+
+// Implementation which is exposed.
+std::tuple KNearestNeighborIdx(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K,
+ const int version) {
+ if (p1.is_cuda() || p2.is_cuda()) {
+#ifdef WITH_CUDA
+ CHECK_CUDA(p1);
+ CHECK_CUDA(p2);
+ return KNearestNeighborIdxCuda(
+ p1, p2, lengths1, lengths2, norm, K, version);
+#else
+ AT_ERROR("Not compiled with GPU support.");
+#endif
+ }
+ return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
+}
+
+// Compute gradients with respect to p1 and p2
+//
+// Args:
+// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
+// containing P1 points of dimension D.
+// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
+// containing P2 points of dimension D.
+// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
+// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
+// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
+// p1_neighbor_idx[n, i, k] = j means that the kth nearest
+// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
+// It is padded with zeros so that it can be used easily in a later
+// gather() operation. This is computed from the forward pass.
+// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
+// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
+// gradients.
+//
+// Returns:
+// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
+// wrt p1.
+// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
+// wrt p2.
+
+// CPU implementation.
+std::tuple KNearestNeighborBackwardCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists);
+
+// CUDA implementation
+std::tuple KNearestNeighborBackwardCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists);
+
+// Implementation which is exposed.
+std::tuple KNearestNeighborBackward(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists) {
+ if (p1.is_cuda() || p2.is_cuda()) {
+#ifdef WITH_CUDA
+ CHECK_CUDA(p1);
+ CHECK_CUDA(p2);
+ return KNearestNeighborBackwardCuda(
+ p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
+#else
+ AT_ERROR("Not compiled with GPU support.");
+#endif
+ }
+ return KNearestNeighborBackwardCpu(
+ p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
+}
+
+// Utility to check whether a KNN version can be used.
+//
+// Args:
+// version: Integer in the range 0 <= version <= 3 indicating one of our
+// KNN implementations.
+// D: Number of dimensions for the input and query point clouds
+// K: Number of neighbors to be found
+//
+// Returns:
+// Whether the indicated KNN version can be used.
+bool KnnCheckVersion(int version, const int64_t D, const int64_t K);
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn_cpu.cpp b/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..896e6f6ab2c952f214fb537cadd10423a8fb1663
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/knn/knn_cpu.cpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+std::tuple KNearestNeighborIdxCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K) {
+ const int N = p1.size(0);
+ const int P1 = p1.size(1);
+ const int D = p1.size(2);
+
+ auto long_opts = lengths1.options().dtype(torch::kInt64);
+ torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
+ torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
+
+ auto p1_a = p1.accessor();
+ auto p2_a = p2.accessor();
+ auto lengths1_a = lengths1.accessor();
+ auto lengths2_a = lengths2.accessor();
+ auto idxs_a = idxs.accessor();
+ auto dists_a = dists.accessor();
+
+ for (int n = 0; n < N; ++n) {
+ const int64_t length1 = lengths1_a[n];
+ const int64_t length2 = lengths2_a[n];
+ for (int64_t i1 = 0; i1 < length1; ++i1) {
+ // Use a priority queue to store (distance, index) tuples.
+ std::priority_queue> q;
+ for (int64_t i2 = 0; i2 < length2; ++i2) {
+ float dist = 0;
+ for (int d = 0; d < D; ++d) {
+ float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
+ if (norm == 1) {
+ dist += abs(diff);
+ } else { // norm is 2 (default)
+ dist += diff * diff;
+ }
+ }
+ int size = static_cast(q.size());
+ if (size < K || dist < std::get<0>(q.top())) {
+ q.emplace(dist, i2);
+ if (size >= K) {
+ q.pop();
+ }
+ }
+ }
+ while (!q.empty()) {
+ auto t = q.top();
+ q.pop();
+ const int k = q.size();
+ dists_a[n][i1][k] = std::get<0>(t);
+ idxs_a[n][i1][k] = std::get<1>(t);
+ }
+ }
+ }
+ return std::make_tuple(idxs, dists);
+}
+
+// ------------------------------------------------------------- //
+// Backward Operators //
+// ------------------------------------------------------------- //
+
+std::tuple KNearestNeighborBackwardCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists) {
+ const int N = p1.size(0);
+ const int P1 = p1.size(1);
+ const int D = p1.size(2);
+ const int P2 = p2.size(1);
+ const int K = idxs.size(2);
+
+ torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
+ torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
+
+ auto p1_a = p1.accessor();
+ auto p2_a = p2.accessor();
+ auto lengths1_a = lengths1.accessor();
+ auto lengths2_a = lengths2.accessor();
+ auto idxs_a = idxs.accessor();
+ auto grad_dists_a = grad_dists.accessor();
+ auto grad_p1_a = grad_p1.accessor();
+ auto grad_p2_a = grad_p2.accessor();
+
+ for (int n = 0; n < N; ++n) {
+ const int64_t length1 = lengths1_a[n];
+ int64_t length2 = lengths2_a[n];
+ length2 = (length2 < K) ? length2 : K;
+ for (int64_t i1 = 0; i1 < length1; ++i1) {
+ for (int64_t k = 0; k < length2; ++k) {
+ const int64_t i2 = idxs_a[n][i1][k];
+ // If the index is the pad value of -1 then ignore it
+ if (i2 == -1) {
+ continue;
+ }
+ for (int64_t d = 0; d < D; ++d) {
+ float diff = 0.0;
+ if (norm == 1) {
+ float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
+ diff = grad_dists_a[n][i1][k] * sign;
+ } else { // norm is 2 (default)
+ diff = 2.0f * grad_dists_a[n][i1][k] *
+ (p1_a[n][i1][d] - p2_a[n][i2][d]);
+ }
+ grad_p1_a[n][i1][d] += diff;
+ grad_p2_a[n][i2][d] += -1.0f * diff;
+ }
+ }
+ }
+ }
+ return std::make_tuple(grad_p1, grad_p2);
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu b/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu
new file mode 100644
index 0000000000000000000000000000000000000000..24c05ad54eef1ae1fa53f8caf2c8022832214a55
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu
@@ -0,0 +1,241 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+// Kernel for inputs_packed of shape (F, D), where D > 1
+template
+__global__ void PackedToPaddedKernel(
+ const scalar_t* __restrict__ inputs_packed,
+ const int64_t* __restrict__ first_idxs,
+ scalar_t* __restrict__ inputs_padded,
+ const size_t batch_size,
+ const size_t max_size,
+ const size_t num_inputs,
+ const size_t D) {
+ // Batch elements split evenly across blocks (num blocks = batch_size) and
+ // values for each element split across threads in the block. Each thread adds
+ // the values of its respective input elements to the global inputs_padded
+ // tensor.
+ const size_t tid = threadIdx.x;
+ const size_t batch_idx = blockIdx.x;
+
+ const int64_t start = first_idxs[batch_idx];
+ const int64_t end =
+ batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
+ const int num = end - start;
+ for (size_t f = tid; f < num; f += blockDim.x) {
+ for (size_t j = 0; j < D; ++j) {
+ inputs_padded[batch_idx * max_size * D + f * D + j] =
+ inputs_packed[(start + f) * D + j];
+ }
+ }
+}
+
+// Kernel for inputs of shape (F, 1)
+template
+__global__ void PackedToPaddedKernelD1(
+ const scalar_t* __restrict__ inputs_packed,
+ const int64_t* __restrict__ first_idxs,
+ scalar_t* __restrict__ inputs_padded,
+ const size_t batch_size,
+ const size_t max_size,
+ const size_t num_inputs) {
+ // Batch elements split evenly across blocks (num blocks = batch_size) and
+ // values for each element split across threads in the block. Each thread adds
+ // the values of its respective input elements to the global inputs_padded
+ // tensor.
+ const size_t tid = threadIdx.x;
+ const size_t batch_idx = blockIdx.x;
+
+ const int64_t start = first_idxs[batch_idx];
+ const int64_t end =
+ batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
+ const int num = end - start;
+ for (size_t f = tid; f < num; f += blockDim.x) {
+ inputs_padded[batch_idx * max_size + f] = inputs_packed[start + f];
+ }
+}
+
+// Kernel for inputs_padded of shape (B, F, D), where D > 1
+template
+__global__ void PaddedToPackedKernel(
+ const scalar_t* __restrict__ inputs_padded,
+ const int64_t* __restrict__ first_idxs,
+ scalar_t* __restrict__ inputs_packed,
+ const size_t batch_size,
+ const size_t max_size,
+ const size_t num_inputs,
+ const size_t D) {
+ // Batch elements split evenly across blocks (num blocks = batch_size) and
+ // values for each element split across threads in the block. Each thread adds
+ // the values of its respective input elements to the global inputs_packed
+ // tensor.
+ const size_t tid = threadIdx.x;
+ const size_t batch_idx = blockIdx.x;
+
+ const int64_t start = first_idxs[batch_idx];
+ const int64_t end =
+ batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
+ const int num = end - start;
+ for (size_t f = tid; f < num; f += blockDim.x) {
+ for (size_t j = 0; j < D; ++j) {
+ inputs_packed[(start + f) * D + j] =
+ inputs_padded[batch_idx * max_size * D + f * D + j];
+ }
+ }
+}
+
+// Kernel for inputs_padded of shape (B, F, 1)
+template
+__global__ void PaddedToPackedKernelD1(
+ const scalar_t* __restrict__ inputs_padded,
+ const int64_t* __restrict__ first_idxs,
+ scalar_t* __restrict__ inputs_packed,
+ const size_t batch_size,
+ const size_t max_size,
+ const size_t num_inputs) {
+ // Batch elements split evenly across blocks (num blocks = batch_size) and
+ // values for each element split across threads in the block. Each thread adds
+ // the values of its respective input elements to the global inputs_packed
+ // tensor.
+ const size_t tid = threadIdx.x;
+ const size_t batch_idx = blockIdx.x;
+
+ const int64_t start = first_idxs[batch_idx];
+ const int64_t end =
+ batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
+ const int num = end - start;
+ for (size_t f = tid; f < num; f += blockDim.x) {
+ inputs_packed[start + f] = inputs_padded[batch_idx * max_size + f];
+ }
+}
+
+at::Tensor PackedToPaddedCuda(
+ const at::Tensor inputs_packed,
+ const at::Tensor first_idxs,
+ const int64_t max_size) {
+ // Check inputs are on the same device
+ at::TensorArg inputs_packed_t{inputs_packed, "inputs_packed", 1},
+ first_idxs_t{first_idxs, "first_idxs", 2};
+ at::CheckedFrom c = "PackedToPaddedCuda";
+ at::checkAllSameGPU(c, {inputs_packed_t, first_idxs_t});
+
+ // Set the device for the kernel launch based on the device of the input
+ at::cuda::CUDAGuard device_guard(inputs_packed.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ const int64_t num_inputs = inputs_packed.size(0);
+ const int64_t batch_size = first_idxs.size(0);
+
+ TORCH_CHECK(
+ inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
+ const int64_t D = inputs_packed.size(1);
+ at::Tensor inputs_padded =
+ at::zeros({batch_size, max_size, D}, inputs_packed.options());
+
+ if (inputs_padded.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return inputs_padded;
+ }
+
+ const int threads = 512;
+ const int blocks = batch_size;
+ if (D == 1) {
+ AT_DISPATCH_FLOATING_TYPES(
+ inputs_packed.scalar_type(), "packed_to_padded_d1_kernel", ([&] {
+ PackedToPaddedKernelD1<<>>(
+ inputs_packed.contiguous().data_ptr(),
+ first_idxs.contiguous().data_ptr(),
+ inputs_padded.data_ptr(),
+ batch_size,
+ max_size,
+ num_inputs);
+ }));
+ } else {
+ AT_DISPATCH_FLOATING_TYPES(
+ inputs_packed.scalar_type(), "packed_to_padded_kernel", ([&] {
+ PackedToPaddedKernel<<>>(
+ inputs_packed.contiguous().data_ptr(),
+ first_idxs.contiguous().data_ptr(),
+ inputs_padded.data_ptr(),
+ batch_size,
+ max_size,
+ num_inputs,
+ D);
+ }));
+ }
+
+ AT_CUDA_CHECK(cudaGetLastError());
+ return inputs_padded;
+}
+
+at::Tensor PaddedToPackedCuda(
+ const at::Tensor inputs_padded,
+ const at::Tensor first_idxs,
+ const int64_t num_inputs) {
+ // Check inputs are on the same device
+ at::TensorArg inputs_padded_t{inputs_padded, "inputs_padded", 1},
+ first_idxs_t{first_idxs, "first_idxs", 2};
+ at::CheckedFrom c = "PaddedToPackedCuda";
+ at::checkAllSameGPU(c, {inputs_padded_t, first_idxs_t});
+
+ // Set the device for the kernel launch based on the device of the input
+ at::cuda::CUDAGuard device_guard(inputs_padded.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ const int64_t batch_size = inputs_padded.size(0);
+ const int64_t max_size = inputs_padded.size(1);
+
+ TORCH_CHECK(batch_size == first_idxs.size(0), "sizes mismatch");
+ TORCH_CHECK(
+ inputs_padded.dim() == 3,
+ "inputs_padded must be a 3-dimensional tensor");
+ const int64_t D = inputs_padded.size(2);
+
+ at::Tensor inputs_packed =
+ at::zeros({num_inputs, D}, inputs_padded.options());
+
+ if (inputs_packed.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return inputs_packed;
+ }
+
+ const int threads = 512;
+ const int blocks = batch_size;
+
+ if (D == 1) {
+ AT_DISPATCH_FLOATING_TYPES(
+ inputs_padded.scalar_type(), "padded_to_packed_d1_kernel", ([&] {
+ PaddedToPackedKernelD1<<>>(
+ inputs_padded.contiguous().data_ptr(),
+ first_idxs.contiguous().data_ptr(),
+ inputs_packed.data_ptr(),
+ batch_size,
+ max_size,
+ num_inputs);
+ }));
+ } else {
+ AT_DISPATCH_FLOATING_TYPES(
+ inputs_padded.scalar_type(), "padded_to_packed_kernel", ([&] {
+ PaddedToPackedKernel<<>>(
+ inputs_padded.contiguous().data_ptr(),
+ first_idxs.contiguous().data_ptr(),
+ inputs_packed.data_ptr(),
+ batch_size,
+ max_size,
+ num_inputs,
+ D);
+ }));
+ }
+
+ AT_CUDA_CHECK(cudaGetLastError());
+ return inputs_packed;
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h b/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h
new file mode 100644
index 0000000000000000000000000000000000000000..97ad2e3ee95acff21106a9f5312c6ccf1f095e45
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h
@@ -0,0 +1,109 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include "utils/pytorch3d_cutils.h"
+
+// PackedToPadded
+// Converts a packed tensor into a padded tensor, restoring the batch dimension.
+// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
+//
+// Inputs:
+// inputs_packed: FloatTensor of shape (F, D), representing the packed batch
+// tensor, e.g. areas for faces in a batch of meshes.
+// first_idxs: LongTensor of shape (N,) where N is the number of
+// elements in the batch and `first_idxs[i] = f`
+// means that the inputs for batch element i begin at
+// `inputs[f]`.
+// max_size: Max length of an element in the batch.
+// Returns:
+// inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max
+// of `sizes`. The values for batch element i which start at
+// `inputs_packed[first_idxs[i]]` will be copied to
+// `inputs_padded[i, :]`, with zeros padding out the extra
+// inputs.
+//
+
+// PaddedToPacked
+// Converts a padded tensor into a packed tensor.
+// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
+//
+// Inputs:
+// inputs_padded: FloatTensor of shape (N, max_size, D), representing the
+// padded tensor, e.g. areas for faces in a batch of meshes.
+// first_idxs: LongTensor of shape (N,) where N is the number of
+// elements in the batch and `first_idxs[i] = f`
+// means that the inputs for batch element i begin at
+// `inputs_packed[f]`.
+// num_inputs: Number of packed entries (= F)
+// Returns:
+// inputs_packed: FloatTensor of shape (F, D), where
+// `inputs_packed[first_idx[i]:] = inputs_padded[i, :]`.
+//
+//
+
+// Cpu implementation.
+at::Tensor PackedToPaddedCpu(
+ const at::Tensor inputs_packed,
+ const at::Tensor first_idxs,
+ const int64_t max_size);
+
+// Cpu implementation.
+at::Tensor PaddedToPackedCpu(
+ const at::Tensor inputs_padded,
+ const at::Tensor first_idxs,
+ const int64_t num_inputs);
+
+#ifdef WITH_CUDA
+// Cuda implementation.
+at::Tensor PackedToPaddedCuda(
+ const at::Tensor inputs_packed,
+ const at::Tensor first_idxs,
+ const int64_t max_size);
+
+// Cuda implementation.
+at::Tensor PaddedToPackedCuda(
+ const at::Tensor inputs_padded,
+ const at::Tensor first_idxs,
+ const int64_t num_inputs);
+#endif
+
+// Implementation which is exposed.
+at::Tensor PackedToPadded(
+ const at::Tensor inputs_packed,
+ const at::Tensor first_idxs,
+ const int64_t max_size) {
+ if (inputs_packed.is_cuda()) {
+#ifdef WITH_CUDA
+ CHECK_CUDA(inputs_packed);
+ CHECK_CUDA(first_idxs);
+ return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
+#else
+ AT_ERROR("Not compiled with GPU support.");
+#endif
+ }
+ return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
+}
+
+// Implementation which is exposed.
+at::Tensor PaddedToPacked(
+ const at::Tensor inputs_padded,
+ const at::Tensor first_idxs,
+ const int64_t num_inputs) {
+ if (inputs_padded.is_cuda()) {
+#ifdef WITH_CUDA
+ CHECK_CUDA(inputs_padded);
+ CHECK_CUDA(first_idxs);
+ return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
+#else
+ AT_ERROR("Not compiled with GPU support.");
+#endif
+ }
+ return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp b/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..34a002a8ed582a73c759b1ced2ca05d1507cb3f3
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+
+at::Tensor PackedToPaddedCpu(
+ const at::Tensor inputs_packed,
+ const at::Tensor first_idxs,
+ const int64_t max_size) {
+ const int64_t num_inputs = inputs_packed.size(0);
+ const int64_t batch_size = first_idxs.size(0);
+
+ AT_ASSERTM(
+ inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
+ const int64_t D = inputs_packed.size(1);
+
+ torch::Tensor inputs_padded =
+ torch::zeros({batch_size, max_size, D}, inputs_packed.options());
+
+ auto inputs_packed_a = inputs_packed.accessor();
+ auto first_idxs_a = first_idxs.accessor();
+ auto inputs_padded_a = inputs_padded.accessor();
+
+ for (int b = 0; b < batch_size; ++b) {
+ const int64_t start = first_idxs_a[b];
+ const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs;
+ const int64_t num = end - start;
+ for (int i = 0; i < num; ++i) {
+ for (int j = 0; j < D; ++j) {
+ inputs_padded_a[b][i][j] = inputs_packed_a[start + i][j];
+ }
+ }
+ }
+ return inputs_padded;
+}
+
+at::Tensor PaddedToPackedCpu(
+ const at::Tensor inputs_padded,
+ const at::Tensor first_idxs,
+ const int64_t num_inputs) {
+ const int64_t batch_size = inputs_padded.size(0);
+
+ AT_ASSERTM(
+ inputs_padded.dim() == 3, "inputs_padded must be a 3-dimensional tensor");
+ const int64_t D = inputs_padded.size(2);
+
+ torch::Tensor inputs_packed =
+ torch::zeros({num_inputs, D}, inputs_padded.options());
+
+ auto inputs_padded_a = inputs_padded.accessor();
+ auto first_idxs_a = first_idxs.accessor();
+ auto inputs_packed_a = inputs_packed.accessor();
+
+ for (int b = 0; b < batch_size; ++b) {
+ const int64_t start = first_idxs_a[b];
+ const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs;
+ const int64_t num = end - start;
+ for (int i = 0; i < num; ++i) {
+ for (int j = 0; j < D; ++j) {
+ inputs_packed_a[start + i][j] = inputs_padded_a[b][i][j];
+ }
+ }
+ }
+ return inputs_packed;
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/utils/dispatch.cuh b/data/dot_single_video/dot/utils/torch3d/csrc/utils/dispatch.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..eff9521630a2298c3f3d29b3b07adcdc2d44db8a
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/utils/dispatch.cuh
@@ -0,0 +1,357 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+// This file provides utilities for dispatching to specialized versions of
+// functions. This is especially useful for CUDA kernels, since specializing
+// them to particular input sizes can often allow the compiler to unroll loops
+// and place arrays into registers, which can give huge performance speedups.
+//
+// As an example, suppose we have the following function which is specialized
+// based on a compile-time int64_t value:
+//
+// template
+// struct SquareOffset {
+// static void run(T y) {
+// T val = x * x + y;
+// std::cout << val << std::endl;
+// }
+// }
+//
+// This function takes one compile-time argument x, and one run-time argument y.
+// We might want to compile specialized versions of this for x=0, x=1, etc and
+// then dispatch to the correct one based on the runtime value of x.
+// One simple way to achieve this is with a lookup table:
+//
+// template
+// void DispatchSquareOffset(const int64_t x, T y) {
+// if (x == 0) {
+// SquareOffset::run(y);
+// } else if (x == 1) {
+// SquareOffset::run(y);
+// } else if (x == 2) {
+// SquareOffset::run(y);
+// }
+// }
+//
+// This function takes both x and y as run-time arguments, and dispatches to
+// different specialized versions of SquareOffset based on the run-time value
+// of x. This works, but it's tedious and error-prone. If we want to change the
+// set of x values for which we provide compile-time specializations, then we
+// will need to do a lot of tedius editing of the dispatch function. Also, if we
+// want to provide compile-time specializations for another function other than
+// SquareOffset, we will need to duplicate the entire lookup table.
+//
+// To solve these problems, we can use the DispatchKernel1D function provided by
+// this file instead:
+//
+// template
+// void DispatchSquareOffset(const int64_t x, T y) {
+// constexpr int64_t xmin = 0;
+// constexpr int64_t xmax = 2;
+// DispatchKernel1D(x, y);
+// }
+//
+// DispatchKernel1D uses template metaprogramming to compile specialized
+// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
+// then dispatches to the correct one based on the run-time value of x. If we
+// want to change the range of x values for which SquareOffset is specialized
+// at compile-time, then all we have to do is change the values of the
+// compile-time constants xmin and xmax.
+//
+// This file also allows us to similarly dispatch functions that depend on two
+// compile-time int64_t values, using the DispatchKernel2D function like this:
+//
+// template
+// struct Sum {
+// static void run(T z, T w) {
+// T val = x + y + z + w;
+// std::cout << val << std::endl;
+// }
+// }
+//
+// template
+// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
+// constexpr int64_t xmin = 1;
+// constexpr int64_t xmax = 3;
+// constexpr int64_t ymin = 2;
+// constexpr int64_t ymax = 5;
+// DispatchKernel2D(x, y, z, w);
+// }
+//
+// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
+// compile specialized versions of sum for all values of (x, y) with
+// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
+// specialized version based on the runtime values of x and y.
+
+// Define some helper structs in an anonymous namespace.
+namespace {
+
+// 1D dispatch: general case.
+// Kernel is the function we want to dispatch to; it should take a typename and
+// an int64_t as template args, and it should define a static void function
+// run which takes any number of arguments of any type.
+// In order to dispatch, we will take an additional template argument curN,
+// and increment it via template recursion until it is equal to the run-time
+// argument N.
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t curN,
+ typename... Args>
+struct DispatchKernelHelper1D {
+ static void run(const int64_t N, Args... args) {
+ if (curN == N) {
+ // The compile-time value curN is equal to the run-time value N, so we
+ // can dispatch to the run method of the Kernel.
+ Kernel::run(args...);
+ } else if (curN < N) {
+ // Increment curN via template recursion
+ DispatchKernelHelper1D::run(
+ N, args...);
+ }
+ // We shouldn't get here -- throw an error?
+ }
+};
+
+// 1D dispatch: Specialization when curN == maxN
+// We need this base case to avoid infinite template recursion.
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ typename... Args>
+struct DispatchKernelHelper1D {
+ static void run(const int64_t N, Args... args) {
+ if (N == maxN) {
+ Kernel::run(args...);
+ }
+ // We shouldn't get here -- throw an error?
+ }
+};
+
+// 2D dispatch, general case.
+// This is similar to the 1D case: we take additional template args curN and
+// curM, and increment them via template recursion until they are equal to
+// the run-time values of N and M, at which point we dispatch to the run
+// method of the kernel.
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t curN,
+ int64_t minM,
+ int64_t maxM,
+ int64_t curM,
+ typename... Args>
+struct DispatchKernelHelper2D {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (curN == N && curM == M) {
+ Kernel::run(args...);
+ } else if (curN < N && curM < M) {
+ // Increment both curN and curM. This isn't strictly necessary; we could
+ // just increment one or the other at each step. But this helps to cut
+ // on the number of recursive calls we make.
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN + 1,
+ minM,
+ maxM,
+ curM + 1,
+ Args...>::run(N, M, args...);
+ } else if (curN < N) {
+ // Increment curN only
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN + 1,
+ minM,
+ maxM,
+ curM,
+ Args...>::run(N, M, args...);
+ } else if (curM < M) {
+ // Increment curM only
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN,
+ minM,
+ maxM,
+ curM + 1,
+ Args...>::run(N, M, args...);
+ }
+ }
+};
+
+// 2D dispatch, specialization for curN == maxN
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t minM,
+ int64_t maxM,
+ int64_t curM,
+ typename... Args>
+struct DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ maxN,
+ minM,
+ maxM,
+ curM,
+ Args...> {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (maxN == N && curM == M) {
+ Kernel::run(args...);
+ } else if (curM < maxM) {
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ maxN,
+ minM,
+ maxM,
+ curM + 1,
+ Args...>::run(N, M, args...);
+ }
+ // We should not get here -- throw an error?
+ }
+};
+
+// 2D dispatch, specialization for curM == maxM
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t curN,
+ int64_t minM,
+ int64_t maxM,
+ typename... Args>
+struct DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN,
+ minM,
+ maxM,
+ maxM,
+ Args...> {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (curN == N && maxM == M) {
+ Kernel::run(args...);
+ } else if (curN < maxN) {
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN + 1,
+ minM,
+ maxM,
+ maxM,
+ Args...>::run(N, M, args...);
+ }
+ // We should not get here -- throw an error?
+ }
+};
+
+// 2D dispatch, specialization for curN == maxN, curM == maxM
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t minM,
+ int64_t maxM,
+ typename... Args>
+struct DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ maxN,
+ minM,
+ maxM,
+ maxM,
+ Args...> {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (maxN == N && maxM == M) {
+ Kernel::run(args...);
+ }
+ // We should not get here -- throw an error?
+ }
+};
+
+} // namespace
+
+// This is the function we expect users to call to dispatch to 1D functions
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ typename... Args>
+void DispatchKernel1D(const int64_t N, Args... args) {
+ if (minN <= N && N <= maxN) {
+ // Kick off the template recursion by calling the Helper with curN = minN
+ DispatchKernelHelper1D::run(
+ N, args...);
+ }
+ // Maybe throw an error if we tried to dispatch outside the allowed range?
+}
+
+// This is the function we expect users to call to dispatch to 2D functions
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t minM,
+ int64_t maxM,
+ typename... Args>
+void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
+ if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
+ // Kick off the template recursion by calling the Helper with curN = minN
+ // and curM = minM
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ minN,
+ minM,
+ maxM,
+ minM,
+ Args...>::run(N, M, args...);
+ }
+ // Maybe throw an error if we tried to dispatch outside the specified range?
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/utils/float_math.cuh b/data/dot_single_video/dot/utils/torch3d/csrc/utils/float_math.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..9678eee8a2d544c4078933803416c90c895e071b
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/utils/float_math.cuh
@@ -0,0 +1,153 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+
+// Set epsilon
+#ifdef _MSC_VER
+#define vEpsilon 1e-8f
+#else
+const auto vEpsilon = 1e-8;
+#endif
+
+// Common functions and operators for float2.
+
+__device__ inline float2 operator-(const float2& a, const float2& b) {
+ return make_float2(a.x - b.x, a.y - b.y);
+}
+
+__device__ inline float2 operator+(const float2& a, const float2& b) {
+ return make_float2(a.x + b.x, a.y + b.y);
+}
+
+__device__ inline float2 operator/(const float2& a, const float2& b) {
+ return make_float2(a.x / b.x, a.y / b.y);
+}
+
+__device__ inline float2 operator/(const float2& a, const float b) {
+ return make_float2(a.x / b, a.y / b);
+}
+
+__device__ inline float2 operator*(const float2& a, const float2& b) {
+ return make_float2(a.x * b.x, a.y * b.y);
+}
+
+__device__ inline float2 operator*(const float a, const float2& b) {
+ return make_float2(a * b.x, a * b.y);
+}
+
+__device__ inline float FloatMin3(const float a, const float b, const float c) {
+ return fminf(a, fminf(b, c));
+}
+
+__device__ inline float FloatMax3(const float a, const float b, const float c) {
+ return fmaxf(a, fmaxf(b, c));
+}
+
+__device__ inline float dot(const float2& a, const float2& b) {
+ return a.x * b.x + a.y * b.y;
+}
+
+// Backward pass for the dot product.
+// Args:
+// a, b: Coordinates of two points.
+// grad_dot: Upstream gradient for the output.
+//
+// Returns:
+// tuple of gradients for each of the input points:
+// (float2 grad_a, float2 grad_b)
+//
+__device__ inline thrust::tuple
+DotBackward(const float2& a, const float2& b, const float& grad_dot) {
+ return thrust::make_tuple(grad_dot * b, grad_dot * a);
+}
+
+__device__ inline float sum(const float2& a) {
+ return a.x + a.y;
+}
+
+// Common functions and operators for float3.
+
+__device__ inline float3 operator-(const float3& a, const float3& b) {
+ return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
+}
+
+__device__ inline float3 operator+(const float3& a, const float3& b) {
+ return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
+}
+
+__device__ inline float3 operator/(const float3& a, const float3& b) {
+ return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
+}
+
+__device__ inline float3 operator/(const float3& a, const float b) {
+ return make_float3(a.x / b, a.y / b, a.z / b);
+}
+
+__device__ inline float3 operator*(const float3& a, const float3& b) {
+ return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
+}
+
+__device__ inline float3 operator*(const float a, const float3& b) {
+ return make_float3(a * b.x, a * b.y, a * b.z);
+}
+
+__device__ inline float dot(const float3& a, const float3& b) {
+ return a.x * b.x + a.y * b.y + a.z * b.z;
+}
+
+__device__ inline float sum(const float3& a) {
+ return a.x + a.y + a.z;
+}
+
+__device__ inline float3 cross(const float3& a, const float3& b) {
+ return make_float3(
+ a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
+}
+
+__device__ inline thrust::tuple
+cross_backward(const float3& a, const float3& b, const float3& grad_cross) {
+ const float grad_ax = -grad_cross.y * b.z + grad_cross.z * b.y;
+ const float grad_ay = grad_cross.x * b.z - grad_cross.z * b.x;
+ const float grad_az = -grad_cross.x * b.y + grad_cross.y * b.x;
+ const float3 grad_a = make_float3(grad_ax, grad_ay, grad_az);
+
+ const float grad_bx = grad_cross.y * a.z - grad_cross.z * a.y;
+ const float grad_by = -grad_cross.x * a.z + grad_cross.z * a.x;
+ const float grad_bz = grad_cross.x * a.y - grad_cross.y * a.x;
+ const float3 grad_b = make_float3(grad_bx, grad_by, grad_bz);
+
+ return thrust::make_tuple(grad_a, grad_b);
+}
+
+__device__ inline float norm(const float3& a) {
+ return sqrt(dot(a, a));
+}
+
+__device__ inline float3 normalize(const float3& a) {
+ return a / (norm(a) + vEpsilon);
+}
+
+__device__ inline float3 normalize_backward(
+ const float3& a,
+ const float3& grad_normz) {
+ const float a_norm = norm(a) + vEpsilon;
+ const float3 out = a / a_norm;
+
+ const float grad_ax = grad_normz.x * (1.0f - out.x * out.x) / a_norm +
+ grad_normz.y * (-out.x * out.y) / a_norm +
+ grad_normz.z * (-out.x * out.z) / a_norm;
+ const float grad_ay = grad_normz.x * (-out.x * out.y) / a_norm +
+ grad_normz.y * (1.0f - out.y * out.y) / a_norm +
+ grad_normz.z * (-out.y * out.z) / a_norm;
+ const float grad_az = grad_normz.x * (-out.x * out.z) / a_norm +
+ grad_normz.y * (-out.y * out.z) / a_norm +
+ grad_normz.z * (1.0f - out.z * out.z) / a_norm;
+ return make_float3(grad_ax, grad_ay, grad_az);
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/utils/geometry_utils.cuh b/data/dot_single_video/dot/utils/torch3d/csrc/utils/geometry_utils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..940dbb2c60a3a1c36b8620d86b540c32bc137537
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/utils/geometry_utils.cuh
@@ -0,0 +1,792 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+#include "float_math.cuh"
+
+// Set epsilon for preventing floating point errors and division by 0.
+#ifdef _MSC_VER
+#define kEpsilon 1e-8f
+#else
+const auto kEpsilon = 1e-8;
+#endif
+
+// ************************************************************* //
+// vec2 utils //
+// ************************************************************* //
+
+// Determines whether a point p is on the right side of a 2D line segment
+// given by the end points v0, v1.
+//
+// Args:
+// p: vec2 Coordinates of a point.
+// v0, v1: vec2 Coordinates of the end points of the edge.
+//
+// Returns:
+// area: The signed area of the parallelogram given by the vectors
+// A = p - v0
+// B = v1 - v0
+//
+__device__ inline float
+EdgeFunctionForward(const float2& p, const float2& v0, const float2& v1) {
+ return (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
+}
+
+// Backward pass for the edge function returning partial dervivatives for each
+// of the input points.
+//
+// Args:
+// p: vec2 Coordinates of a point.
+// v0, v1: vec2 Coordinates of the end points of the edge.
+// grad_edge: Upstream gradient for output from edge function.
+//
+// Returns:
+// tuple of gradients for each of the input points:
+// (float2 d_edge_dp, float2 d_edge_dv0, float2 d_edge_dv1)
+//
+__device__ inline thrust::tuple EdgeFunctionBackward(
+ const float2& p,
+ const float2& v0,
+ const float2& v1,
+ const float& grad_edge) {
+ const float2 dedge_dp = make_float2(v1.y - v0.y, v0.x - v1.x);
+ const float2 dedge_dv0 = make_float2(p.y - v1.y, v1.x - p.x);
+ const float2 dedge_dv1 = make_float2(v0.y - p.y, p.x - v0.x);
+ return thrust::make_tuple(
+ grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
+}
+
+// The forward pass for computing the barycentric coordinates of a point
+// relative to a triangle.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1, v2: Coordinates of the triangle vertices.
+//
+// Returns
+// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
+//
+__device__ inline float3 BarycentricCoordsForward(
+ const float2& p,
+ const float2& v0,
+ const float2& v1,
+ const float2& v2) {
+ const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
+ const float w0 = EdgeFunctionForward(p, v1, v2) / area;
+ const float w1 = EdgeFunctionForward(p, v2, v0) / area;
+ const float w2 = EdgeFunctionForward(p, v0, v1) / area;
+ return make_float3(w0, w1, w2);
+}
+
+// The backward pass for computing the barycentric coordinates of a point
+// relative to a triangle.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1, v2: (x, y) coordinates of the triangle vertices.
+// grad_bary_upstream: vec3 Upstream gradient for each of the
+// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
+//
+// Returns
+// tuple of gradients for each of the triangle vertices:
+// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
+//
+__device__ inline thrust::tuple
+BarycentricCoordsBackward(
+ const float2& p,
+ const float2& v0,
+ const float2& v1,
+ const float2& v2,
+ const float3& grad_bary_upstream) {
+ const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
+ const float area2 = pow(area, 2.0f);
+ const float e0 = EdgeFunctionForward(p, v1, v2);
+ const float e1 = EdgeFunctionForward(p, v2, v0);
+ const float e2 = EdgeFunctionForward(p, v0, v1);
+
+ const float grad_w0 = grad_bary_upstream.x;
+ const float grad_w1 = grad_bary_upstream.y;
+ const float grad_w2 = grad_bary_upstream.z;
+
+ // Calculate component of the gradient from each of w0, w1 and w2.
+ // e.g. for w0:
+ // dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
+ // + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
+ const float dw0_darea = -e0 / (area2);
+ const float dw0_e0 = 1 / area;
+ const float dloss_d_w0area = grad_w0 * dw0_darea;
+ const float dloss_e0 = grad_w0 * dw0_e0;
+ auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
+ auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
+ const float2 dw0_p = thrust::get<0>(de0_dv);
+ const float2 dw0_dv0 = thrust::get<1>(dw0area_dv);
+ const float2 dw0_dv1 = thrust::get<1>(de0_dv) + thrust::get<2>(dw0area_dv);
+ const float2 dw0_dv2 = thrust::get<2>(de0_dv) + thrust::get<0>(dw0area_dv);
+
+ const float dw1_darea = -e1 / (area2);
+ const float dw1_e1 = 1 / area;
+ const float dloss_d_w1area = grad_w1 * dw1_darea;
+ const float dloss_e1 = grad_w1 * dw1_e1;
+ auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
+ auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
+ const float2 dw1_p = thrust::get<0>(de1_dv);
+ const float2 dw1_dv0 = thrust::get<2>(de1_dv) + thrust::get<1>(dw1area_dv);
+ const float2 dw1_dv1 = thrust::get<2>(dw1area_dv);
+ const float2 dw1_dv2 = thrust::get<1>(de1_dv) + thrust::get<0>(dw1area_dv);
+
+ const float dw2_darea = -e2 / (area2);
+ const float dw2_e2 = 1 / area;
+ const float dloss_d_w2area = grad_w2 * dw2_darea;
+ const float dloss_e2 = grad_w2 * dw2_e2;
+ auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
+ auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
+ const float2 dw2_p = thrust::get<0>(de2_dv);
+ const float2 dw2_dv0 = thrust::get<1>(de2_dv) + thrust::get<1>(dw2area_dv);
+ const float2 dw2_dv1 = thrust::get<2>(de2_dv) + thrust::get<2>(dw2area_dv);
+ const float2 dw2_dv2 = thrust::get<0>(dw2area_dv);
+
+ const float2 dbary_p = dw0_p + dw1_p + dw2_p;
+ const float2 dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
+ const float2 dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
+ const float2 dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
+
+ return thrust::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
+}
+
+// Forward pass for applying perspective correction to barycentric coordinates.
+//
+// Args:
+// bary: Screen-space barycentric coordinates for a point
+// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
+//
+// Returns
+// World-space barycentric coordinates
+//
+__device__ inline float3 BarycentricPerspectiveCorrectionForward(
+ const float3& bary,
+ const float z0,
+ const float z1,
+ const float z2) {
+ const float w0_top = bary.x * z1 * z2;
+ const float w1_top = z0 * bary.y * z2;
+ const float w2_top = z0 * z1 * bary.z;
+ const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon);
+ const float w0 = w0_top / denom;
+ const float w1 = w1_top / denom;
+ const float w2 = w2_top / denom;
+ return make_float3(w0, w1, w2);
+}
+
+// Backward pass for applying perspective correction to barycentric coordinates.
+//
+// Args:
+// bary: Screen-space barycentric coordinates for a point
+// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
+// grad_out: Upstream gradient of the loss with respect to the corrected
+// barycentric coordinates.
+//
+// Returns a tuple of:
+// grad_bary: Downstream gradient of the loss with respect to the the
+// uncorrected barycentric coordinates.
+// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
+// to the z-coordinates of the triangle verts
+__device__ inline thrust::tuple
+BarycentricPerspectiveCorrectionBackward(
+ const float3& bary,
+ const float z0,
+ const float z1,
+ const float z2,
+ const float3& grad_out) {
+ // Recompute forward pass
+ const float w0_top = bary.x * z1 * z2;
+ const float w1_top = z0 * bary.y * z2;
+ const float w2_top = z0 * z1 * bary.z;
+ const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon);
+
+ // Now do backward pass
+ const float grad_denom_top =
+ -w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
+ const float grad_denom = grad_denom_top / (denom * denom);
+ const float grad_w0_top = grad_denom + grad_out.x / denom;
+ const float grad_w1_top = grad_denom + grad_out.y / denom;
+ const float grad_w2_top = grad_denom + grad_out.z / denom;
+ const float grad_bary_x = grad_w0_top * z1 * z2;
+ const float grad_bary_y = grad_w1_top * z0 * z2;
+ const float grad_bary_z = grad_w2_top * z0 * z1;
+ const float3 grad_bary = make_float3(grad_bary_x, grad_bary_y, grad_bary_z);
+ const float grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
+ const float grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
+ const float grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
+ return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
+}
+
+// Clip negative barycentric coordinates to 0.0 and renormalize so
+// the barycentric coordinates for a point sum to 1. When the blur_radius
+// is greater than 0, a face will still be recorded as overlapping a pixel
+// if the pixel is outside the face. In this case at least one of the
+// barycentric coordinates for the pixel relative to the face will be negative.
+// Clipping will ensure that the texture and z buffer are interpolated
+// correctly.
+//
+// Args
+// bary: (w0, w1, w2) barycentric coordinates which can be outside the
+// range [0, 1].
+//
+// Returns
+// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
+// satisfy the condition: sum(w0, w1, w2) = 1.0.
+//
+__device__ inline float3 BarycentricClipForward(const float3 bary) {
+ float3 w = make_float3(0.0f, 0.0f, 0.0f);
+ // Clamp lower bound only
+ w.x = max(bary.x, 0.0);
+ w.y = max(bary.y, 0.0);
+ w.z = max(bary.z, 0.0);
+ float w_sum = w.x + w.y + w.z;
+ w_sum = fmaxf(w_sum, 1e-5);
+ w.x /= w_sum;
+ w.y /= w_sum;
+ w.z /= w_sum;
+
+ return w;
+}
+
+// Backward pass for barycentric coordinate clipping.
+//
+// Args
+// bary: (w0, w1, w2) barycentric coordinates which can be outside the
+// range [0, 1].
+// grad_baryclip_upstream: vec3 Upstream gradient for each of the clipped
+// barycentric coordinates [grad_w0, grad_w1, grad_w2].
+//
+// Returns
+// vec3 of gradients for the unclipped barycentric coordinates:
+// (grad_w0, grad_w1, grad_w2)
+//
+__device__ inline float3 BarycentricClipBackward(
+ const float3 bary,
+ const float3 grad_baryclip_upstream) {
+ // Redo some of the forward pass calculations
+ float3 w = make_float3(0.0f, 0.0f, 0.0f);
+ // Clamp lower bound only
+ w.x = max(bary.x, 0.0);
+ w.y = max(bary.y, 0.0);
+ w.z = max(bary.z, 0.0);
+ float w_sum = w.x + w.y + w.z;
+
+ float3 grad_bary = make_float3(1.0f, 1.0f, 1.0f);
+ float3 grad_clip = make_float3(1.0f, 1.0f, 1.0f);
+ float3 grad_sum = make_float3(1.0f, 1.0f, 1.0f);
+
+ // Check if sum was clipped.
+ float grad_sum_clip = 1.0f;
+ if (w_sum < 1e-5) {
+ grad_sum_clip = 0.0f;
+ w_sum = 1e-5;
+ }
+
+ // Check if any of bary values have been clipped.
+ if (bary.x < 0.0f) {
+ grad_clip.x = 0.0f;
+ }
+ if (bary.y < 0.0f) {
+ grad_clip.y = 0.0f;
+ }
+ if (bary.z < 0.0f) {
+ grad_clip.z = 0.0f;
+ }
+
+ // Gradients of the sum.
+ grad_sum.x = -w.x / (pow(w_sum, 2.0f)) * grad_sum_clip;
+ grad_sum.y = -w.y / (pow(w_sum, 2.0f)) * grad_sum_clip;
+ grad_sum.z = -w.z / (pow(w_sum, 2.0f)) * grad_sum_clip;
+
+ // Gradients for each of the bary coordinates including the cross terms
+ // from the sum.
+ grad_bary.x = grad_clip.x *
+ (grad_baryclip_upstream.x * (1.0f / w_sum + grad_sum.x) +
+ grad_baryclip_upstream.y * (grad_sum.y) +
+ grad_baryclip_upstream.z * (grad_sum.z));
+
+ grad_bary.y = grad_clip.y *
+ (grad_baryclip_upstream.y * (1.0f / w_sum + grad_sum.y) +
+ grad_baryclip_upstream.x * (grad_sum.x) +
+ grad_baryclip_upstream.z * (grad_sum.z));
+
+ grad_bary.z = grad_clip.z *
+ (grad_baryclip_upstream.z * (1.0f / w_sum + grad_sum.z) +
+ grad_baryclip_upstream.x * (grad_sum.x) +
+ grad_baryclip_upstream.y * (grad_sum.y));
+
+ return grad_bary;
+}
+
+// Return minimum distance between line segment (v1 - v0) and point p.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1: Coordinates of the end points of the line segment.
+//
+// Returns:
+// squared distance to the boundary of the triangle.
+//
+__device__ inline float
+PointLineDistanceForward(const float2& p, const float2& a, const float2& b) {
+ const float2 ba = b - a;
+ float l2 = dot(ba, ba);
+ float t = dot(ba, p - a) / l2;
+ if (l2 <= kEpsilon) {
+ return dot(p - b, p - b);
+ }
+ t = __saturatef(t); // clamp to the interval [+0.0, 1.0]
+ const float2 p_proj = a + t * ba;
+ const float2 d = (p_proj - p);
+ return dot(d, d); // squared distance
+}
+
+// Backward pass for point to line distance in 2D.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1: Coordinates of the end points of the line segment.
+// grad_dist: Upstream gradient for the distance.
+//
+// Returns:
+// tuple of gradients for each of the input points:
+// (float2 grad_p, float2 grad_v0, float2 grad_v1)
+//
+__device__ inline thrust::tuple
+PointLineDistanceBackward(
+ const float2& p,
+ const float2& v0,
+ const float2& v1,
+ const float& grad_dist) {
+ // Redo some of the forward pass calculations.
+ const float2 v1v0 = v1 - v0;
+ const float2 pv0 = p - v0;
+ const float t_bot = dot(v1v0, v1v0);
+ const float t_top = dot(v1v0, pv0);
+ float tt = t_top / t_bot;
+ tt = __saturatef(tt);
+ const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
+ const float2 d = p - p_proj;
+ const float dist = sqrt(dot(d, d));
+
+ const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
+ const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
+ const float2 grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
+
+ return thrust::make_tuple(grad_p, grad_v0, grad_v1);
+}
+
+// The forward pass for calculating the shortest distance between a point
+// and a triangle.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1, v2: Coordinates of the three triangle vertices.
+//
+// Returns:
+// shortest squared distance from a point to a triangle.
+//
+__device__ inline float PointTriangleDistanceForward(
+ const float2& p,
+ const float2& v0,
+ const float2& v1,
+ const float2& v2) {
+ // Compute distance to all 3 edges of the triangle and return the min.
+ const float e01_dist = PointLineDistanceForward(p, v0, v1);
+ const float e02_dist = PointLineDistanceForward(p, v0, v2);
+ const float e12_dist = PointLineDistanceForward(p, v1, v2);
+ const float edge_dist = fminf(fminf(e01_dist, e02_dist), e12_dist);
+ return edge_dist;
+}
+
+// Backward pass for point triangle distance.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1, v2: Coordinates of the three triangle vertices.
+// grad_dist: Upstream gradient for the distance.
+//
+// Returns:
+// tuple of gradients for each of the triangle vertices:
+// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
+//
+__device__ inline thrust::tuple
+PointTriangleDistanceBackward(
+ const float2& p,
+ const float2& v0,
+ const float2& v1,
+ const float2& v2,
+ const float& grad_dist) {
+ // Compute distance to all 3 edges of the triangle.
+ const float e01_dist = PointLineDistanceForward(p, v0, v1);
+ const float e02_dist = PointLineDistanceForward(p, v0, v2);
+ const float e12_dist = PointLineDistanceForward(p, v1, v2);
+
+ // Initialize output tensors.
+ float2 grad_v0 = make_float2(0.0f, 0.0f);
+ float2 grad_v1 = make_float2(0.0f, 0.0f);
+ float2 grad_v2 = make_float2(0.0f, 0.0f);
+ float2 grad_p = make_float2(0.0f, 0.0f);
+
+ // Find which edge is the closest and return PointLineDistanceBackward for
+ // that edge.
+ if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
+ // Closest edge is v1 - v0.
+ auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
+ grad_p = thrust::get<0>(grad_e01);
+ grad_v0 = thrust::get<1>(grad_e01);
+ grad_v1 = thrust::get<2>(grad_e01);
+ } else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
+ // Closest edge is v2 - v0.
+ auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
+ grad_p = thrust::get<0>(grad_e02);
+ grad_v0 = thrust::get<1>(grad_e02);
+ grad_v2 = thrust::get<2>(grad_e02);
+ } else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
+ // Closest edge is v2 - v1.
+ auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
+ grad_p = thrust::get<0>(grad_e12);
+ grad_v1 = thrust::get<1>(grad_e12);
+ grad_v2 = thrust::get<2>(grad_e12);
+ }
+
+ return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
+}
+
+// ************************************************************* //
+// vec3 utils //
+// ************************************************************* //
+
+// Computes the area of a triangle (v0, v1, v2).
+//
+// Args:
+// v0, v1, v2: vec3 coordinates of the triangle vertices
+//
+// Returns
+// area: float: The area of the triangle
+//
+__device__ inline float
+AreaOfTriangle(const float3& v0, const float3& v1, const float3& v2) {
+ float3 p0 = v1 - v0;
+ float3 p1 = v2 - v0;
+
+ // compute the hypotenus of the scross product (p0 x p1)
+ float dd = hypot(
+ p0.y * p1.z - p0.z * p1.y,
+ hypot(p0.z * p1.x - p0.x * p1.z, p0.x * p1.y - p0.y * p1.x));
+
+ return dd / 2.0;
+}
+
+// Computes the barycentric coordinates of a point p relative
+// to a triangle (v0, v1, v2), i.e. p = w0 * v0 + w1 * v1 + w2 * v2
+// s.t. w0 + w1 + w2 = 1.0
+//
+// NOTE that this function assumes that p lives on the space spanned
+// by (v0, v1, v2).
+// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
+// and throw an error if check fails
+//
+// Args:
+// p: vec3 coordinates of a point
+// v0, v1, v2: vec3 coordinates of the triangle vertices
+//
+// Returns
+// bary: (w0, w1, w2) barycentric coordinates
+//
+__device__ inline float3 BarycentricCoords3Forward(
+ const float3& p,
+ const float3& v0,
+ const float3& v1,
+ const float3& v2) {
+ float3 p0 = v1 - v0;
+ float3 p1 = v2 - v0;
+ float3 p2 = p - v0;
+
+ const float d00 = dot(p0, p0);
+ const float d01 = dot(p0, p1);
+ const float d11 = dot(p1, p1);
+ const float d20 = dot(p2, p0);
+ const float d21 = dot(p2, p1);
+
+ const float denom = d00 * d11 - d01 * d01 + kEpsilon;
+ const float w1 = (d11 * d20 - d01 * d21) / denom;
+ const float w2 = (d00 * d21 - d01 * d20) / denom;
+ const float w0 = 1.0f - w1 - w2;
+
+ return make_float3(w0, w1, w2);
+}
+
+// Checks whether the point p is inside the triangle (v0, v1, v2).
+// A point is inside the triangle, if all barycentric coordinates
+// wrt the triangle are >= 0 & <= 1.
+// If the triangle is degenerate, aka line or point, then return False.
+//
+// NOTE that this function assumes that p lives on the space spanned
+// by (v0, v1, v2).
+// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
+// and throw an error if check fails
+//
+// Args:
+// p: vec3 coordinates of a point
+// v0, v1, v2: vec3 coordinates of the triangle vertices
+// min_triangle_area: triangles less than this size are considered
+// points/lines, IsInsideTriangle returns False
+//
+// Returns:
+// inside: bool indicating wether p is inside triangle
+//
+__device__ inline bool IsInsideTriangle(
+ const float3& p,
+ const float3& v0,
+ const float3& v1,
+ const float3& v2,
+ const double min_triangle_area) {
+ bool inside;
+ if (AreaOfTriangle(v0, v1, v2) < min_triangle_area) {
+ inside = 0;
+ } else {
+ float3 bary = BarycentricCoords3Forward(p, v0, v1, v2);
+ bool x_in = 0.0f <= bary.x && bary.x <= 1.0f;
+ bool y_in = 0.0f <= bary.y && bary.y <= 1.0f;
+ bool z_in = 0.0f <= bary.z && bary.z <= 1.0f;
+ inside = x_in && y_in && z_in;
+ }
+ return inside;
+}
+
+// Computes the minimum squared Euclidean distance between the point p
+// and the segment spanned by (v0, v1).
+// To find this we parametrize p as: x(t) = v0 + t * (v1 - v0)
+// and find t which minimizes (x(t) - p) ^ 2.
+// Note that p does not need to live in the space spanned by (v0, v1)
+//
+// Args:
+// p: vec3 coordinates of a point
+// v0, v1: vec3 coordinates of start and end of segment
+//
+// Returns:
+// dist: the minimum squared distance of p from segment (v0, v1)
+//
+
+__device__ inline float
+PointLine3DistanceForward(const float3& p, const float3& v0, const float3& v1) {
+ const float3 v1v0 = v1 - v0;
+ const float3 pv0 = p - v0;
+ const float t_bot = dot(v1v0, v1v0);
+ const float t_top = dot(pv0, v1v0);
+ // if t_bot small, then v0 == v1, set tt to 0.
+ float tt = (t_bot < kEpsilon) ? 0.0f : (t_top / t_bot);
+
+ tt = __saturatef(tt); // clamps to [0, 1]
+
+ const float3 p_proj = v0 + tt * v1v0;
+ const float3 diff = p - p_proj;
+ const float dist = dot(diff, diff);
+ return dist;
+}
+
+// Backward function of the minimum squared Euclidean distance between the point
+// p and the line segment (v0, v1).
+//
+// Args:
+// p: vec3 coordinates of a point
+// v0, v1: vec3 coordinates of start and end of segment
+// grad_dist: Float of the gradient wrt dist
+//
+// Returns:
+// tuple of gradients for the point and line segment (v0, v1):
+// (float3 grad_p, float3 grad_v0, float3 grad_v1)
+
+__device__ inline thrust::tuple
+PointLine3DistanceBackward(
+ const float3& p,
+ const float3& v0,
+ const float3& v1,
+ const float& grad_dist) {
+ const float3 v1v0 = v1 - v0;
+ const float3 pv0 = p - v0;
+ const float t_bot = dot(v1v0, v1v0);
+ const float t_top = dot(v1v0, pv0);
+
+ float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
+ float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
+ float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
+
+ const float tt = t_top / t_bot;
+
+ if (t_bot < kEpsilon) {
+ // if t_bot small, then v0 == v1,
+ // and dist = 0.5 * dot(pv0, pv0) + 0.5 * dot(pv1, pv1)
+ grad_p = grad_dist * 2.0f * pv0;
+ grad_v0 = -0.5f * grad_p;
+ grad_v1 = grad_v0;
+ } else if (tt < 0.0f) {
+ grad_p = grad_dist * 2.0f * pv0;
+ grad_v0 = -1.0f * grad_p;
+ // no gradients wrt v1
+ } else if (tt > 1.0f) {
+ grad_p = grad_dist * 2.0f * (p - v1);
+ grad_v1 = -1.0f * grad_p;
+ // no gradients wrt v0
+ } else {
+ const float3 p_proj = v0 + tt * v1v0;
+ const float3 diff = p - p_proj;
+ const float3 grad_base = grad_dist * 2.0f * diff;
+ grad_p = grad_base - dot(grad_base, v1v0) * v1v0 / t_bot;
+ const float3 dtt_v0 = (-1.0f * v1v0 - pv0 + 2.0f * tt * v1v0) / t_bot;
+ grad_v0 = (-1.0f + tt) * grad_base - dot(grad_base, v1v0) * dtt_v0;
+ const float3 dtt_v1 = (pv0 - 2.0f * tt * v1v0) / t_bot;
+ grad_v1 = -dot(grad_base, v1v0) * dtt_v1 - tt * grad_base;
+ }
+
+ return thrust::make_tuple(grad_p, grad_v0, grad_v1);
+}
+
+// Computes the squared distance of a point p relative to a triangle (v0, v1,
+// v2). If the point's projection p0 on the plane spanned by (v0, v1, v2) is
+// inside the triangle with vertices (v0, v1, v2), then the returned value is
+// the squared distance of p to its projection p0. Otherwise, the returned value
+// is the smallest squared distance of p from the line segments (v0, v1), (v0,
+// v2) and (v1, v2).
+//
+// Args:
+// p: vec3 coordinates of a point
+// v0, v1, v2: vec3 coordinates of the triangle vertices
+// min_triangle_area: triangles less than this size are considered
+// points/lines, IsInsideTriangle returns False
+//
+// Returns:
+// dist: Float of the squared distance
+//
+
+__device__ inline float PointTriangle3DistanceForward(
+ const float3& p,
+ const float3& v0,
+ const float3& v1,
+ const float3& v2,
+ const double min_triangle_area) {
+ float3 normal = cross(v2 - v0, v1 - v0);
+ const float norm_normal = norm(normal);
+ normal = normalize(normal);
+
+ // p0 is the projection of p on the plane spanned by (v0, v1, v2)
+ // i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
+ const float t = dot(v0 - p, normal);
+ const float3 p0 = p + t * normal;
+
+ bool is_inside = IsInsideTriangle(p0, v0, v1, v2, min_triangle_area);
+ float dist = 0.0f;
+
+ if ((is_inside) && (norm_normal > kEpsilon)) {
+ // if projection p0 is inside triangle spanned by (v0, v1, v2)
+ // then distance is equal to norm(p0 - p)^2
+ dist = t * t;
+ } else {
+ const float e01 = PointLine3DistanceForward(p, v0, v1);
+ const float e02 = PointLine3DistanceForward(p, v0, v2);
+ const float e12 = PointLine3DistanceForward(p, v1, v2);
+
+ dist = (e01 > e02) ? e02 : e01;
+ dist = (dist > e12) ? e12 : dist;
+ }
+
+ return dist;
+}
+
+// The backward pass for computing the squared distance of a point
+// to the triangle (v0, v1, v2).
+//
+// Args:
+// p: xyz coordinates of a point
+// v0, v1, v2: xyz coordinates of the triangle vertices
+// grad_dist: Float of the gradient wrt dist
+// min_triangle_area: triangles less than this size are considered
+// points/lines, IsInsideTriangle returns False
+//
+// Returns:
+// tuple of gradients for the point and triangle:
+// (float3 grad_p, float3 grad_v0, float3 grad_v1, float3 grad_v2)
+//
+
+__device__ inline thrust::tuple
+PointTriangle3DistanceBackward(
+ const float3& p,
+ const float3& v0,
+ const float3& v1,
+ const float3& v2,
+ const float& grad_dist,
+ const double min_triangle_area) {
+ const float3 v2v0 = v2 - v0;
+ const float3 v1v0 = v1 - v0;
+ const float3 v0p = v0 - p;
+ float3 raw_normal = cross(v2v0, v1v0);
+ const float norm_normal = norm(raw_normal);
+ float3 normal = normalize(raw_normal);
+
+ // p0 is the projection of p on the plane spanned by (v0, v1, v2)
+ // i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
+ const float t = dot(v0 - p, normal);
+ const float3 p0 = p + t * normal;
+ const float3 diff = t * normal;
+
+ bool is_inside = IsInsideTriangle(p0, v0, v1, v2, min_triangle_area);
+
+ float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
+ float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
+ float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
+ float3 grad_v2 = make_float3(0.0f, 0.0f, 0.0f);
+
+ if ((is_inside) && (norm_normal > kEpsilon)) {
+ // derivative of dist wrt p
+ grad_p = -2.0f * grad_dist * t * normal;
+ // derivative of dist wrt normal
+ const float3 grad_normal = 2.0f * grad_dist * t * (v0p + diff);
+ // derivative of dist wrt raw_normal
+ const float3 grad_raw_normal = normalize_backward(raw_normal, grad_normal);
+ // derivative of dist wrt v2v0 and v1v0
+ const auto grad_cross = cross_backward(v2v0, v1v0, grad_raw_normal);
+ const float3 grad_cross_v2v0 = thrust::get<0>(grad_cross);
+ const float3 grad_cross_v1v0 = thrust::get<1>(grad_cross);
+ grad_v0 =
+ grad_dist * 2.0f * t * normal - (grad_cross_v2v0 + grad_cross_v1v0);
+ grad_v1 = grad_cross_v1v0;
+ grad_v2 = grad_cross_v2v0;
+ } else {
+ const float e01 = PointLine3DistanceForward(p, v0, v1);
+ const float e02 = PointLine3DistanceForward(p, v0, v2);
+ const float e12 = PointLine3DistanceForward(p, v1, v2);
+
+ if ((e01 <= e02) && (e01 <= e12)) {
+ // e01 is smallest
+ const auto grads = PointLine3DistanceBackward(p, v0, v1, grad_dist);
+ grad_p = thrust::get<0>(grads);
+ grad_v0 = thrust::get<1>(grads);
+ grad_v1 = thrust::get<2>(grads);
+ } else if ((e02 <= e01) && (e02 <= e12)) {
+ // e02 is smallest
+ const auto grads = PointLine3DistanceBackward(p, v0, v2, grad_dist);
+ grad_p = thrust::get<0>(grads);
+ grad_v0 = thrust::get<1>(grads);
+ grad_v2 = thrust::get<2>(grads);
+ } else if ((e12 <= e01) && (e12 <= e02)) {
+ // e12 is smallest
+ const auto grads = PointLine3DistanceBackward(p, v1, v2, grad_dist);
+ grad_p = thrust::get<0>(grads);
+ grad_v1 = thrust::get<1>(grads);
+ grad_v2 = thrust::get<2>(grads);
+ }
+ }
+
+ return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
+}
diff --git a/data/dot_single_video/dot/utils/torch3d/csrc/utils/geometry_utils.h b/data/dot_single_video/dot/utils/torch3d/csrc/utils/geometry_utils.h
new file mode 100644
index 0000000000000000000000000000000000000000..ce6e37cc4c1dbf444e04bbc302d1900b85d4834f
--- /dev/null
+++ b/data/dot_single_video/dot/utils/torch3d/csrc/utils/geometry_utils.h
@@ -0,0 +1,823 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+#include
+#include "vec2.h"
+#include "vec3.h"
+
+// Set epsilon for preventing floating point errors and division by 0.
+const auto kEpsilon = 1e-8;
+
+// Determines whether a point p is on the right side of a 2D line segment
+// given by the end points v0, v1.
+//
+// Args:
+// p: vec2 Coordinates of a point.
+// v0, v1: vec2 Coordinates of the end points of the edge.
+//
+// Returns:
+// area: The signed area of the parallelogram given by the vectors
+// A = p - v0
+// B = v1 - v0
+//
+// v1 ________
+// /\ /
+// A / \ /
+// / \ /
+// v0 /______\/
+// B p
+//
+// The area can also be interpreted as the cross product A x B.
+// If the sign of the area is positive, the point p is on the
+// right side of the edge. Negative area indicates the point is on
+// the left side of the edge. i.e. for an edge v1 - v0:
+//
+// v1
+// /
+// /
+// - / +
+// /
+// /
+// v0
+//
+template
+T EdgeFunctionForward(const vec2& p, const vec2& v0, const vec2& v1) {
+ const T edge = (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
+ return edge;
+}
+
+// Backward pass for the edge function returning partial dervivatives for each
+// of the input points.
+//
+// Args:
+// p: vec2 Coordinates of a point.
+// v0, v1: vec2 Coordinates of the end points of the edge.
+// grad_edge: Upstream gradient for output from edge function.
+//
+// Returns:
+// tuple of gradients for each of the input points:
+// (vec2 d_edge_dp, vec2 d_edge_dv0, vec2 d_edge_dv1)
+//
+template
+inline std::tuple, vec2, vec2> EdgeFunctionBackward(
+ const vec2& p,
+ const vec2& v0,
+ const vec2& v1,
+ const T grad_edge) {
+ const vec2 dedge_dp(v1.y - v0.y, v0.x - v1.x);
+ const vec2 dedge_dv0(p.y - v1.y, v1.x - p.x);
+ const vec2 dedge_dv1(v0.y - p.y, p.x - v0.x);
+ return std::make_tuple(
+ grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
+}
+
+// The forward pass for computing the barycentric coordinates of a point
+// relative to a triangle.
+// Ref:
+// https://www.scratchapixel.com/lessons/3d-basic-rendering/ray-tracing-rendering-a-triangle/barycentric-coordinates
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1, v2: Coordinates of the triangle vertices.
+//
+// Returns
+// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
+//
+template
+vec3 BarycentricCoordinatesForward(
+ const vec2& p,
+ const vec2& v0,
+ const vec2& v1,
+ const vec2& v2) {
+ const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
+ const T w0 = EdgeFunctionForward(p, v1, v2) / area;
+ const T w1 = EdgeFunctionForward(p, v2, v0) / area;
+ const T w2 = EdgeFunctionForward(p, v0, v1) / area;
+ return vec3(w0, w1, w2);
+}
+
+// The backward pass for computing the barycentric coordinates of a point
+// relative to a triangle.
+//
+// Args:
+// p: Coordinates of a point.
+// v0, v1, v2: (x, y) coordinates of the triangle vertices.
+// grad_bary_upstream: vec3 Upstream gradient for each of the
+// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
+//
+// Returns
+// tuple of gradients for each of the triangle vertices:
+// (vec2 grad_v0, vec2 grad_v1, vec2 grad_v2)
+//
+template
+inline std::tuple, vec2, vec2, vec2> BarycentricCoordsBackward(
+ const vec2& p,
+ const vec2& v0,
+ const vec2& v1,
+ const vec2& v2,
+ const vec3& grad_bary_upstream) {
+ const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
+ const T area2 = pow(area, 2.0f);
+ const T area_inv = 1.0f / area;
+ const T e0 = EdgeFunctionForward(p, v1, v2);
+ const T e1 = EdgeFunctionForward(p, v2, v0);
+ const T e2 = EdgeFunctionForward(p, v0, v1);
+
+ const T grad_w0 = grad_bary_upstream.x;
+ const T grad_w1 = grad_bary_upstream.y;
+ const T grad_w2 = grad_bary_upstream.z;
+
+ // Calculate component of the gradient from each of w0, w1 and w2.
+ // e.g. for w0:
+ // dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
+ // + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
+ const T dw0_darea = -e0 / (area2);
+ const T dw0_e0 = area_inv;
+ const T dloss_d_w0area = grad_w0 * dw0_darea;
+ const T dloss_e0 = grad_w0 * dw0_e0;
+ auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
+ auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
+ const vec2 dw0_p = std::get<0>(de0_dv);
+ const vec2 dw0_dv0 = std::get<1>(dw0area_dv);
+ const vec2 dw0_dv1 = std::get<1>(de0_dv) + std::get<2>(dw0area_dv);
+ const vec2 dw0_dv2 = std::get<2>(de0_dv) + std::get<0>(dw0area_dv);
+
+ const T dw1_darea = -e1 / (area2);
+ const T dw1_e1 = area_inv;
+ const T dloss_d_w1area = grad_w1 * dw1_darea;
+ const T dloss_e1 = grad_w1 * dw1_e1;
+ auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
+ auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
+ const vec2 dw1_p = std::get<0>(de1_dv);
+ const vec2 dw1_dv0 = std::get<2>(de1_dv) + std::get<1>(dw1area_dv);
+ const vec2 dw1_dv1 = std::get<2>(dw1area_dv);
+ const vec2 dw1_dv2 = std::get<1>(de1_dv) + std::get<0>(dw1area_dv);
+
+ const T dw2_darea = -e2 / (area2);
+ const T dw2_e2 = area_inv;
+ const T dloss_d_w2area = grad_w2 * dw2_darea;
+ const T dloss_e2 = grad_w2 * dw2_e2;
+ auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
+ auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
+ const vec2 dw2_p = std::get<0>(de2_dv);
+ const vec2 dw2_dv0 = std::get<1>(de2_dv) + std::get<1>(dw2area_dv);
+ const vec2 dw2_dv1 = std::get<2>(de2_dv) + std::get<2>(dw2area_dv);
+ const vec2 dw2_dv2 = std::get<0>(dw2area_dv);
+
+ const vec2 dbary_p = dw0_p + dw1_p + dw2_p;
+ const vec2 dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
+ const vec2 dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
+ const vec2 dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
+
+ return std::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
+}
+
+// Forward pass for applying perspective correction to barycentric coordinates.
+//
+// Args:
+// bary: Screen-space barycentric coordinates for a point
+// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
+//
+// Returns
+// World-space barycentric coordinates
+//
+template
+inline vec3 BarycentricPerspectiveCorrectionForward(
+ const vec3& bary,
+ const T z0,
+ const T z1,
+ const T z2) {
+ const T w0_top = bary.x * z1 * z2;
+ const T w1_top = bary.y * z0 * z2;
+ const T w2_top = bary.z * z0 * z1;
+ const T denom = std::max(w0_top + w1_top + w2_top, kEpsilon);
+ const T w0 = w0_top / denom;
+ const T w1 = w1_top / denom;
+ const T w2 = w2_top / denom;
+ return vec3(w0, w1, w2);
+}
+
+// Backward pass for applying perspective correction to barycentric coordinates.
+//
+// Args:
+// bary: Screen-space barycentric coordinates for a point
+// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
+// grad_out: Upstream gradient of the loss with respect to the corrected
+// barycentric coordinates.
+//
+// Returns a tuple of:
+// grad_bary: Downstream gradient of the loss with respect to the the
+// uncorrected barycentric coordinates.
+// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
+// to the z-coordinates of the triangle verts
+template
+inline std::tuple, T, T, T> BarycentricPerspectiveCorrectionBackward(
+ const vec3& bary,
+ const T z0,
+ const T z1,
+ const T z2,
+ const vec3& grad_out) {
+ // Recompute forward pass
+ const T w0_top = bary.x * z1 * z2;
+ const T w1_top = bary.y * z0 * z2;
+ const T w2_top = bary.z * z0 * z1;
+ const T denom = std::max(w0_top + w1_top + w2_top, kEpsilon);
+
+ // Now do backward pass
+ const T grad_denom_top =
+ -w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
+ const T grad_denom = grad_denom_top / (denom * denom);
+ const T grad_w0_top = grad_denom + grad_out.x / denom;
+ const T grad_w1_top = grad_denom + grad_out.y / denom;
+ const T grad_w2_top = grad_denom + grad_out.z / denom;
+ const T grad_bary_x = grad_w0_top * z1 * z2;
+ const T grad_bary_y = grad_w1_top * z0 * z2;
+ const T grad_bary_z = grad_w2_top * z0 * z1;
+ const vec3 grad_bary(grad_bary_x, grad_bary_y, grad_bary_z);
+ const T grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
+ const T grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
+ const T grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
+ return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
+}
+
+// Clip negative barycentric coordinates to 0.0 and renormalize so
+// the barycentric coordinates for a point sum to 1. When the blur_radius
+// is greater than 0, a face will still be recorded as overlapping a pixel
+// if the pixel is outside the face. In this case at least one of the
+// barycentric coordinates for the pixel relative to the face will be negative.
+// Clipping will ensure that the texture and z buffer are interpolated
+// correctly.
+//
+// Args
+// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
+//
+// Returns
+// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1] which
+// satisfy the condition: sum(w0, w1, w2) = 1.0.
+//
+template
+vec3 BarycentricClipForward(const vec3 bary) {
+ vec3 w(0.0f, 0.0f, 0.0f);
+ // Only clamp negative values to 0.0.
+ // No need to clamp values > 1.0 as they will be renormalized.
+ w.x = std::max(bary.x, 0.0f);
+ w.y = std::max(bary.y, 0.0f);
+ w.z = std::max(bary.z, 0.0f);
+ float w_sum = w.x + w.y + w.z;
+ w_sum = std::fmaxf(w_sum, 1e-5);
+ w.x /= w_sum;
+ w.y /= w_sum;
+ w.z /= w_sum;
+ return w;
+}
+
+// Backward pass for barycentric coordinate clipping.
+//
+// Args
+// bary: (w0, w1, w2) barycentric coordinates which can contain values < 0.
+// grad_baryclip_upstream: vec3 Upstream gradient for each of the clipped
+// barycentric coordinates [grad_w0, grad_w1, grad_w2].
+//
+// Returns
+// vec3 of gradients for the unclipped barycentric coordinates:
+// (grad_w0, grad_w1, grad_w2)
+//
+template
+vec3 BarycentricClipBackward(
+ const vec3 bary,
+ const vec3 grad_baryclip_upstream) {
+ // Redo some of the forward pass calculations
+ vec3 w(0.0f, 0.0f, 0.0f);
+ w.x = std::max(bary.x, 0.0f);
+ w.y = std::max(bary.y, 0.0f);
+ w.z = std::max(bary.z, 0.0f);
+ float w_sum = w.x + w.y + w.z;
+
+ vec3 grad_bary(1.0f, 1.0f, 1.0f);
+ vec3 grad_clip(1.0f, 1.0f, 1.0f);
+ vec3